diff --git a/idp/github.go b/idp/github.go index 9521becf..f2b274a0 100644 --- a/idp/github.go +++ b/idp/github.go @@ -15,11 +15,13 @@ package idp import ( - "context" "encoding/json" + "fmt" + "io" "io/ioutil" "net/http" "strconv" + "strings" "time" "golang.org/x/oauth2" @@ -60,9 +62,38 @@ func (idp *GithubIdProvider) getConfig() *oauth2.Config { return config } +type GithubToken struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Error string `json:"error"` +} + func (idp *GithubIdProvider) GetToken(code string) (*oauth2.Token, error) { - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, idp.Client) - return idp.Config.Exchange(ctx, code) + params := &struct { + Code string `json:"code"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` + }{code, idp.Config.ClientID, idp.Config.ClientSecret} + data, err := idp.postWithBody(params, idp.Config.Endpoint.TokenURL) + if err != nil { + return nil, err + } + pToken := &GithubToken{} + if err = json.Unmarshal(data, pToken); err != nil { + return nil, err + } + if pToken.Error != "" { + return nil, fmt.Errorf("err: %s", pToken.Error) + } + + token := &oauth2.Token{ + AccessToken: pToken.AccessToken, + TokenType: "Bearer", + } + + return token, nil + } //{ @@ -192,3 +223,30 @@ func (idp *GithubIdProvider) GetUserInfo(token *oauth2.Token) (*UserInfo, error) } return &userInfo, nil } + +func (idp *GithubIdProvider) postWithBody(body interface{}, url string) ([]byte, error) { + bs, err := json.Marshal(body) + if err != nil { + return nil, err + } + r := strings.NewReader(string(bs)) + req, _ := http.NewRequest("POST", url, r) + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + resp, err := idp.Client.Do(req) + if err != nil { + return nil, err + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + return + } + }(resp.Body) + + return data, nil +} diff --git a/idp/goth.go b/idp/goth.go index 2b9fe5fb..fb9a4b39 100644 --- a/idp/goth.go +++ b/idp/goth.go @@ -231,6 +231,10 @@ func (idp *GothIdProvider) GetToken(code string) (*oauth2.Token, error) { value.Add("code", code) } accessToken, err := idp.Session.Authorize(idp.Provider, value) + if err != nil { + return nil, err + } + //Get ExpiresAt's value valueOfExpire := reflect.ValueOf(idp.Session).Elem().FieldByName("ExpiresAt") if valueOfExpire.IsValid() { @@ -240,7 +244,8 @@ func (idp *GothIdProvider) GetToken(code string) (*oauth2.Token, error) { AccessToken: accessToken, Expiry: expireAt, } - return &token, err + + return &token, nil } func (idp *GothIdProvider) GetUserInfo(token *oauth2.Token) (*UserInfo, error) {