From f672045b45ca9bd20d09b4804ec15526fe333775 Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Wed, 9 Jun 2021 21:27:20 +0800 Subject: [PATCH] Allow to sign up with OAuth. --- controllers/account.go | 1 + controllers/auth.go | 53 +++++++++++++++++----------------------- controllers/link.go | 2 +- object/user_cred.go | 38 +++++++++++++++++++++++++++++ object/user_util.go | 55 +++++++++++++++++++++++++++--------------- 5 files changed, 97 insertions(+), 52 deletions(-) create mode 100644 object/user_cred.go diff --git a/controllers/account.go b/controllers/account.go index b0244cff..14f52e49 100644 --- a/controllers/account.go +++ b/controllers/account.go @@ -132,6 +132,7 @@ func (c *ApiController) Signup() { IsAdmin: false, IsGlobalAdmin: false, IsForbidden: false, + Properties: map[string]string{}, } object.AddUser(user) diff --git a/controllers/auth.go b/controllers/auth.go index 5cad0f2c..111846e5 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -245,7 +245,26 @@ func (c *ApiController) Login() { // object.LinkUserAccount(userId, provider.Type, userInfo.Id) //} - if !application.EnableSignUp { + // sign up via OAuth + if provider.EnableSignUp { + user := &object.User{ + Owner: application.Organization, + Name: userInfo.Username, + CreatedTime: util.GetCurrentTime(), + Id: util.GenerateId(), + Type: "normal-user", + DisplayName: userInfo.DisplayName, + Avatar: userInfo.AvatarUrl, + Email: userInfo.Email, + IsAdmin: false, + IsGlobalAdmin: false, + IsForbidden: false, + Properties: map[string]string{}, + } + object.AddUser(user) + + resp = c.HandleLoggedIn(user, &form) + } else if !application.EnableSignUp { resp = &Response{Status: "error", Msg: fmt.Sprintf("The account for provider: %s and username: %s does not exist and is not allowed to sign up as new account, please contact your IT support", provider.Type, userInfo.Username)} c.Data["json"] = resp c.ServeJSON() @@ -258,7 +277,7 @@ func (c *ApiController) Login() { } } //resp = &Response{Status: "ok", Msg: "", Data: res} - } else { + } else { // form.Method != "signup" userId := c.GetSessionUser() if userId == "" { resp = &Response{Status: "error", Msg: "The account does not exist", Data: userInfo} @@ -270,35 +289,7 @@ func (c *ApiController) Login() { user := object.GetUser(userId) // sync info from 3rd-party if possible - if userInfo.Id != "" { - propertyName := fmt.Sprintf("oauth_%s_id", provider.Type) - object.SetUserProperty(user, propertyName, userInfo.Id) - } - if userInfo.Username != "" { - propertyName := fmt.Sprintf("oauth_%s_username", provider.Type) - object.SetUserProperty(user, propertyName, userInfo.Username) - } - if userInfo.DisplayName != "" { - propertyName := fmt.Sprintf("oauth_%s_displayName", provider.Type) - object.SetUserProperty(user, propertyName, userInfo.DisplayName) - if user.DisplayName == "" { - object.SetUserField(user, "display_name", userInfo.DisplayName) - } - } - if userInfo.Email != "" { - propertyName := fmt.Sprintf("oauth_%s_email", provider.Type) - object.SetUserProperty(user, propertyName, userInfo.Email) - if user.Email == "" { - object.SetUserField(user, "email", userInfo.Email) - } - } - if userInfo.AvatarUrl != "" { - propertyName := fmt.Sprintf("oauth_%s_avatarUrl", provider.Type) - object.SetUserProperty(user, propertyName, userInfo.AvatarUrl) - if user.Avatar == "" { - object.SetUserField(user, "avatar", userInfo.AvatarUrl) - } - } + object.SetUserOAuthProperties(user, provider.Type, userInfo) isLinked := object.LinkUserAccount(user, provider.Type, userInfo.Id) if isLinked { diff --git a/controllers/link.go b/controllers/link.go index 01c12108..6f2654c2 100644 --- a/controllers/link.go +++ b/controllers/link.go @@ -49,7 +49,7 @@ func (c *ApiController) Unlink() { return } - object.ClearUserProperties(user, providerType) + object.ClearUserOAuthProperties(user, providerType) object.LinkUserAccount(user, providerType, "") resp = Response{Status: "ok", Msg: ""} diff --git a/object/user_cred.go b/object/user_cred.go new file mode 100644 index 00000000..46a25486 --- /dev/null +++ b/object/user_cred.go @@ -0,0 +1,38 @@ +// Copyright 2021 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package object + +import ( + "strconv" + "strings" + + "github.com/casdoor/casdoor/util" +) + +func calculateHash(user *User) string { + s := strings.Join([]string{user.Id, user.Password, user.DisplayName, user.Avatar, user.Phone, strconv.Itoa(user.Score)}, "|") + return util.GetMd5Hash(s) +} + +func (user *User) UpdateUserHash() { + hash := calculateHash(user) + user.Hash = hash +} + +func (user *User) UpdateUserPassword(organization *Organization) { + if organization.PasswordType == "salt" { + user.Password = getSaltedPassword(user.Password, organization.PasswordSalt) + } +} diff --git a/object/user_util.go b/object/user_util.go index 6afc90cc..489483b9 100644 --- a/object/user_util.go +++ b/object/user_util.go @@ -17,10 +17,9 @@ package object import ( "fmt" "reflect" - "strconv" "strings" - "github.com/casdoor/casdoor/util" + "github.com/casdoor/casdoor/idp" "xorm.io/core" ) @@ -93,12 +92,44 @@ func GetUserField(user *User, field string) string { return f.String() } -func SetUserProperty(user *User, field string, value string) bool { +func setUserProperty(user *User, field string, value string) { if value == "" { delete(user.Properties, field) } else { user.Properties[field] = value } +} + +func SetUserOAuthProperties(user *User, providerType string, userInfo *idp.UserInfo) bool { + if userInfo.Id != "" { + propertyName := fmt.Sprintf("oauth_%s_id", providerType) + setUserProperty(user, propertyName, userInfo.Id) + } + if userInfo.Username != "" { + propertyName := fmt.Sprintf("oauth_%s_username", providerType) + setUserProperty(user, propertyName, userInfo.Username) + } + if userInfo.DisplayName != "" { + propertyName := fmt.Sprintf("oauth_%s_displayName", providerType) + setUserProperty(user, propertyName, userInfo.DisplayName) + if user.DisplayName == "" { + SetUserField(user, "display_name", userInfo.DisplayName) + } + } + if userInfo.Email != "" { + propertyName := fmt.Sprintf("oauth_%s_email", providerType) + setUserProperty(user, propertyName, userInfo.Email) + if user.Email == "" { + SetUserField(user, "email", userInfo.Email) + } + } + if userInfo.AvatarUrl != "" { + propertyName := fmt.Sprintf("oauth_%s_avatarUrl", providerType) + setUserProperty(user, propertyName, userInfo.AvatarUrl) + if user.Avatar == "" { + SetUserField(user, "avatar", userInfo.AvatarUrl) + } + } affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("properties").Update(user) if err != nil { @@ -108,7 +139,7 @@ func SetUserProperty(user *User, field string, value string) bool { return affected != 0 } -func ClearUserProperties(user *User, providerType string) bool { +func ClearUserOAuthProperties(user *User, providerType string) bool { for k := range user.Properties { prefix := fmt.Sprintf("oauth_%s_", providerType) if strings.HasPrefix(k, prefix) { @@ -123,19 +154,3 @@ func ClearUserProperties(user *User, providerType string) bool { return affected != 0 } - -func calculateHash(user *User) string { - s := strings.Join([]string{user.Id, user.Password, user.DisplayName, user.Avatar, user.Phone, strconv.Itoa(user.Score)}, "|") - return util.GetMd5Hash(s) -} - -func (user *User) UpdateUserHash() { - hash := calculateHash(user) - user.Hash = hash -} - -func (user *User) UpdateUserPassword(organization *Organization) { - if organization.PasswordType == "salt" { - user.Password = getSaltedPassword(user.Password, organization.PasswordSalt) - } -}