diff --git a/controllers/auth.go b/controllers/auth.go index 8e60fad5..6770873e 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -15,7 +15,6 @@ package controllers import ( - "context" "encoding/json" "fmt" @@ -23,7 +22,6 @@ import ( "github.com/casdoor/casdoor/idp" "github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/util" - "golang.org/x/oauth2" ) func codeToResponse(code *object.Code) *Response { @@ -105,36 +103,30 @@ func (c *ApiController) Login() { application := object.GetApplication(fmt.Sprintf("admin/%s", form.Application)) provider := object.GetProvider(fmt.Sprintf("admin/%s", form.Provider)) - idProvider := idp.GetIdProvider(provider.Type, provider.ClientId) - oauthConfig := idProvider.GetConfig() - oauthConfig.ClientID = provider.ClientId - oauthConfig.ClientSecret = provider.ClientSecret - oauthConfig.RedirectURL = form.RedirectUri - - var res authResponse + idProvider := idp.GetIdProvider(provider.Type, provider.ClientId, provider.ClientSecret, form.RedirectUri) + idProvider.SetHttpClient(httpClient) if form.State != beego.AppConfig.String("AuthState") && form.State != application.Name { - resp = &Response{Status: "error", Msg: fmt.Sprintf("state expected: \"%s\", but got: \"%s\"", beego.AppConfig.String("AuthState"), form.State), Data: res} + resp = &Response{Status: "error", Msg: fmt.Sprintf("state expected: \"%s\", but got: \"%s\"", beego.AppConfig.String("AuthState"), form.State)} c.Data["json"] = resp c.ServeJSON() return } // https://github.com/golang/oauth2/issues/123#issuecomment-103715338 - ctx := context.WithValue(oauth2.NoContext, oauth2.HTTPClient, httpClient) - token, err := oauthConfig.Exchange(ctx, form.Code) + token, err := idProvider.GetToken(form.Code) if err != nil { panic(err) } if !token.Valid() { - resp = &Response{Status: "error", Msg: "invalid token", Data: res} + resp = &Response{Status: "error", Msg: "invalid token"} c.Data["json"] = resp c.ServeJSON() return } - res.Email, res.Method, res.Avatar, err = idProvider.GetUserInfo(httpClient, token) + userInfo, err := idProvider.GetUserInfo(token) if err != nil { resp = &Response{Status: "error", Msg: "login failed, please try again."} c.Data["json"] = resp @@ -145,9 +137,9 @@ func (c *ApiController) Login() { if form.Method == "signup" { userId := "" if provider.Type == "github" { - userId = object.GetUserIdByField(application, "github", res.Method) + userId = object.GetUserIdByField(application, "github", userInfo.Username) } else if provider.Type == "google" { - userId = object.GetUserIdByField(application, "google", res.Email) + userId = object.GetUserIdByField(application, "google", userInfo.Email) } if userId != "" { @@ -168,13 +160,13 @@ func (c *ApiController) Login() { // return //} - if userId := object.GetUserIdByField(application, "email", res.Email); userId != "" { + if userId := object.GetUserIdByField(application, "email", userInfo.Email); userId != "" { resp = c.HandleLoggedIn(userId, &form) if provider.Type == "github" { - _ = object.LinkUserAccount(userId, "github", res.Method) + _ = object.LinkUserAccount(userId, "github", userInfo.Username) } else if provider.Type == "google" { - _ = object.LinkUserAccount(userId, "google", res.Email) + _ = object.LinkUserAccount(userId, "google", userInfo.Email) } } } @@ -182,7 +174,7 @@ func (c *ApiController) Login() { } else { userId := c.GetSessionUser() if userId == "" { - resp = &Response{Status: "error", Msg: "user doesn't exist", Data: res} + resp = &Response{Status: "error", Msg: "user doesn't exist", Data: userInfo} c.Data["json"] = resp c.ServeJSON() return @@ -190,9 +182,9 @@ func (c *ApiController) Login() { linkRes := false if provider.Type == "github" { - linkRes = object.LinkUserAccount(userId, "github", res.Method) + linkRes = object.LinkUserAccount(userId, "github", userInfo.Username) } else if provider.Type == "google" { - linkRes = object.LinkUserAccount(userId, "google", res.Email) + linkRes = object.LinkUserAccount(userId, "google", userInfo.Email) } if linkRes { resp = &Response{Status: "ok", Msg: "", Data: linkRes} diff --git a/idp/github.go b/idp/github.go index 16c04fd4..87463138 100644 --- a/idp/github.go +++ b/idp/github.go @@ -15,6 +15,7 @@ package idp import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -23,9 +24,35 @@ import ( "golang.org/x/oauth2" ) -type GithubIdProvider struct{} +type GithubIdProvider struct { + Client *http.Client + Config *oauth2.Config + ClientId string + ClientSecret string + RedirectUrl string +} -func (idp *GithubIdProvider) GetConfig() *oauth2.Config { +func NewGithubIdProvider(clientId string, clientSecret string, redirectUrl string) *GithubIdProvider { + idp := &GithubIdProvider{ + ClientId: clientId, + ClientSecret: clientSecret, + RedirectUrl: redirectUrl, + } + + config := idp.getConfig() + config.ClientID = clientId + config.ClientSecret = clientSecret + config.RedirectURL = redirectUrl + idp.Config = config + + return idp +} + +func (idp *GithubIdProvider) SetHttpClient(client *http.Client) { + idp.Client = client +} + +func (idp *GithubIdProvider) getConfig() *oauth2.Config { var endpoint = oauth2.Endpoint{ AuthURL: "https://github.com/login/oauth/authorize", TokenURL: "https://github.com/login/oauth/access_token", @@ -39,7 +66,12 @@ func (idp *GithubIdProvider) GetConfig() *oauth2.Config { return config } -func (idp *GithubIdProvider) getEmail(httpClient *http.Client, token *oauth2.Token) string { +func (idp *GithubIdProvider) GetToken(code string) (*oauth2.Token, error) { + ctx := context.WithValue(oauth2.NoContext, oauth2.HTTPClient, idp.Client) + return idp.Config.Exchange(ctx, code) +} + +func (idp *GithubIdProvider) getEmail(token *oauth2.Token) string { res := "" type GithubEmail struct { @@ -55,7 +87,7 @@ func (idp *GithubIdProvider) getEmail(httpClient *http.Client, token *oauth2.Tok panic(err) } req.Header.Add("Authorization", "token "+token.AccessToken) - response, err := httpClient.Do(req) + response, err := idp.Client.Do(req) if err != nil { panic(err) } @@ -75,7 +107,7 @@ func (idp *GithubIdProvider) getEmail(httpClient *http.Client, token *oauth2.Tok return res } -func (idp *GithubIdProvider) getLoginAndAvatar(httpClient *http.Client, token *oauth2.Token) (string, string) { +func (idp *GithubIdProvider) getLoginAndAvatar(token *oauth2.Token) (string, string) { type GithubUser struct { Login string `json:"login"` AvatarUrl string `json:"avatar_url"` @@ -87,7 +119,7 @@ func (idp *GithubIdProvider) getLoginAndAvatar(httpClient *http.Client, token *o panic(err) } req.Header.Add("Authorization", "token "+token.AccessToken) - resp, err := httpClient.Do(req) + resp, err := idp.Client.Do(req) if err != nil { panic(err) } @@ -101,20 +133,20 @@ func (idp *GithubIdProvider) getLoginAndAvatar(httpClient *http.Client, token *o return githubUser.Login, githubUser.AvatarUrl } -func (idp *GithubIdProvider) GetUserInfo(httpClient *http.Client, token *oauth2.Token) (string, string, string, error) { - var email, username, avatarUrl string +func (idp *GithubIdProvider) GetUserInfo(token *oauth2.Token) (*UserInfo, error) { + userInfo := &UserInfo{} var wg sync.WaitGroup wg.Add(2) go func() { - email = idp.getEmail(httpClient, token) + userInfo.Email = idp.getEmail(token) wg.Done() }() go func() { - username, avatarUrl = idp.getLoginAndAvatar(httpClient, token) + userInfo.Username, userInfo.AvatarUrl = idp.getLoginAndAvatar(token) wg.Done() }() wg.Wait() - return email, username, avatarUrl, nil + return userInfo, nil } diff --git a/idp/google.go b/idp/google.go index dc777de1..13a995cd 100644 --- a/idp/google.go +++ b/idp/google.go @@ -15,6 +15,7 @@ package idp import ( + "context" "encoding/json" "errors" "io/ioutil" @@ -23,9 +24,35 @@ import ( "golang.org/x/oauth2" ) -type GoogleIdProvider struct{} +type GoogleIdProvider struct { + Client *http.Client + Config *oauth2.Config + ClientId string + ClientSecret string + RedirectUrl string +} -func (idp *GoogleIdProvider) GetConfig() *oauth2.Config { +func NewGoogleIdProvider(clientId string, clientSecret string, redirectUrl string) *GithubIdProvider { + idp := &GithubIdProvider{ + ClientId: clientId, + ClientSecret: clientSecret, + RedirectUrl: redirectUrl, + } + + config := idp.getConfig() + config.ClientID = clientId + config.ClientSecret = clientSecret + config.RedirectURL = redirectUrl + idp.Config = config + + return idp +} + +func (idp *GoogleIdProvider) SetHttpClient(client *http.Client) { + idp.Client = client +} + +func (idp *GoogleIdProvider) getConfig() *oauth2.Config { var endpoint = oauth2.Endpoint{ AuthURL: "https://accounts.google.com/o/oauth2/auth", TokenURL: "https://accounts.google.com/o/oauth2/token", @@ -39,15 +66,20 @@ func (idp *GoogleIdProvider) GetConfig() *oauth2.Config { return config } -func (idp *GoogleIdProvider) GetUserInfo(httpClient *http.Client, token *oauth2.Token) (string, string, string, error) { - var email, username, avatarUrl string +func (idp *GoogleIdProvider) GetToken(code string) (*oauth2.Token, error) { + ctx := context.WithValue(oauth2.NoContext, oauth2.HTTPClient, idp.Client) + return idp.Config.Exchange(ctx, code) +} + +func (idp *GoogleIdProvider) GetUserInfo(token *oauth2.Token) (*UserInfo, error) { + userInfo := &UserInfo{} type userInfoFromGoogle struct { Picture string `json:"picture"` Email string `json:"email"` } - resp, err := httpClient.Get("https://www.googleapis.com/oauth2/v2/userinfo?alt=json&access_token=" + token.AccessToken) + resp, err := idp.Client.Get("https://www.googleapis.com/oauth2/v2/userinfo?alt=json&access_token=" + token.AccessToken) defer resp.Body.Close() contents, err := ioutil.ReadAll(resp.Body) var tempUser userInfoFromGoogle @@ -55,12 +87,12 @@ func (idp *GoogleIdProvider) GetUserInfo(httpClient *http.Client, token *oauth2. if err != nil { panic(err) } - email = tempUser.Email - avatarUrl = tempUser.Picture + userInfo.Email = tempUser.Email + userInfo.AvatarUrl = tempUser.Picture - if email == "" { - return email, username, avatarUrl, errors.New("google email is empty, please try again") + if userInfo.Email == "" { + return userInfo, errors.New("google email is empty, please try again") } - return email, username, avatarUrl, nil + return userInfo, nil } diff --git a/idp/provider.go b/idp/provider.go index 14987291..1dd413c1 100644 --- a/idp/provider.go +++ b/idp/provider.go @@ -20,18 +20,25 @@ import ( "golang.org/x/oauth2" ) -type IdProvider interface { - GetConfig() *oauth2.Config - GetUserInfo(httpClient *http.Client, token *oauth2.Token) (string, string, string, error) +type UserInfo struct { + Username string + Email string + AvatarUrl string } -func GetIdProvider(providerType string, clientId string) IdProvider { +type IdProvider interface { + SetHttpClient(client *http.Client) + GetToken(code string) (*oauth2.Token, error) + GetUserInfo(token *oauth2.Token) (*UserInfo, error) +} + +func GetIdProvider(providerType string, clientId string, clientSecret string, redirectUrl string) IdProvider { if providerType == "github" { - return &GithubIdProvider{} + return NewGithubIdProvider(clientId, clientSecret, redirectUrl) } else if providerType == "google" { - return &GoogleIdProvider{} + return NewGoogleIdProvider(clientId, clientSecret, redirectUrl) } else if providerType == "qq" { - return &QqIdProvider{ClientId: clientId} + return NewQqIdProvider(clientId, clientSecret, redirectUrl) } return nil diff --git a/idp/qq.go b/idp/qq.go index 24ff7266..e0f734a7 100644 --- a/idp/qq.go +++ b/idp/qq.go @@ -15,6 +15,7 @@ package idp import ( + "context" "encoding/json" "errors" "fmt" @@ -26,10 +27,34 @@ import ( ) type QqIdProvider struct { + Client *http.Client + Config *oauth2.Config ClientId string + ClientSecret string + RedirectUrl string } -func (idp *QqIdProvider) GetConfig() *oauth2.Config { +func NewQqIdProvider(clientId string, clientSecret string, redirectUrl string) *QqIdProvider { + idp := &QqIdProvider{ + ClientId: clientId, + ClientSecret: clientSecret, + RedirectUrl: redirectUrl, + } + + config := idp.getConfig() + config.ClientID = clientId + config.ClientSecret = clientSecret + config.RedirectURL = redirectUrl + idp.Config = config + + return idp +} + +func (idp *QqIdProvider) SetHttpClient(client *http.Client) { + idp.Client = client +} + +func (idp *QqIdProvider) getConfig() *oauth2.Config { var endpoint = oauth2.Endpoint{ TokenURL: "https://graph.qq.com/oauth2.0/token", } @@ -42,8 +67,13 @@ func (idp *QqIdProvider) GetConfig() *oauth2.Config { return config } -func (idp *QqIdProvider) GetUserInfo(httpClient *http.Client, token *oauth2.Token) (string, string, string, error) { - var email, username, avatarUrl string +func (idp *QqIdProvider) GetToken(code string) (*oauth2.Token, error) { + ctx := context.WithValue(oauth2.NoContext, oauth2.HTTPClient, idp.Client) + return idp.Config.Exchange(ctx, code) +} + +func (idp *QqIdProvider) GetUserInfo(token *oauth2.Token) (*UserInfo, error) { + userInfo := &UserInfo{} type userInfoFromQq struct { Ret int `json:"ret"` @@ -53,7 +83,7 @@ func (idp *QqIdProvider) GetUserInfo(httpClient *http.Client, token *oauth2.Toke getOpenIdUrl := fmt.Sprintf("https://graph.qq.com/oauth2.0/me?access_token=%s", token) - openIdResponse, err := httpClient.Get(getOpenIdUrl) + openIdResponse, err := idp.Client.Get(getOpenIdUrl) if err != nil { panic(err) } @@ -65,25 +95,24 @@ func (idp *QqIdProvider) GetUserInfo(httpClient *http.Client, token *oauth2.Toke openId := openIdRegRes[0][1] if openId == "" { - return "", "", "", errors.New("openId is empty") + return userInfo, errors.New("openId is empty") } getUserInfoUrl := fmt.Sprintf("https://graph.qq.com/user/get_user_info?access_token=%s&oauth_consumer_key=%s&openid=%s", token, idp.ClientId, openId) - getUserInfoResponse, err := httpClient.Get(getUserInfoUrl) + getUserInfoResponse, err := idp.Client.Get(getUserInfoUrl) if err != nil { panic(err) } defer getUserInfoResponse.Body.Close() userInfoContent, err := ioutil.ReadAll(getUserInfoResponse.Body) - var userInfo userInfoFromQq - err = json.Unmarshal(userInfoContent, &userInfo) - if err != nil || userInfo.Ret != 0 { - return "", "", "", err + var info userInfoFromQq + err = json.Unmarshal(userInfoContent, &info) + if err != nil || info.Ret != 0 { + return userInfo, err } - email = "" - username = userInfo.Nickname - avatarUrl = userInfo.AvatarUrl + userInfo.Username = info.Nickname + userInfo.AvatarUrl = userInfo.AvatarUrl - return email, username, avatarUrl, nil + return userInfo, nil }