Allow to sign up with OAuth.

This commit is contained in:
Yang Luo
2021-06-09 21:27:20 +08:00
parent 440aad2369
commit f672045b45
5 changed files with 97 additions and 52 deletions

View File

@ -132,6 +132,7 @@ func (c *ApiController) Signup() {
IsAdmin: false, IsAdmin: false,
IsGlobalAdmin: false, IsGlobalAdmin: false,
IsForbidden: false, IsForbidden: false,
Properties: map[string]string{},
} }
object.AddUser(user) object.AddUser(user)

View File

@ -245,7 +245,26 @@ func (c *ApiController) Login() {
// object.LinkUserAccount(userId, provider.Type, userInfo.Id) // object.LinkUserAccount(userId, provider.Type, userInfo.Id)
//} //}
if !application.EnableSignUp { // sign up via OAuth
if provider.EnableSignUp {
user := &object.User{
Owner: application.Organization,
Name: userInfo.Username,
CreatedTime: util.GetCurrentTime(),
Id: util.GenerateId(),
Type: "normal-user",
DisplayName: userInfo.DisplayName,
Avatar: userInfo.AvatarUrl,
Email: userInfo.Email,
IsAdmin: false,
IsGlobalAdmin: false,
IsForbidden: false,
Properties: map[string]string{},
}
object.AddUser(user)
resp = c.HandleLoggedIn(user, &form)
} else if !application.EnableSignUp {
resp = &Response{Status: "error", Msg: fmt.Sprintf("The account for provider: %s and username: %s does not exist and is not allowed to sign up as new account, please contact your IT support", provider.Type, userInfo.Username)} resp = &Response{Status: "error", Msg: fmt.Sprintf("The account for provider: %s and username: %s does not exist and is not allowed to sign up as new account, please contact your IT support", provider.Type, userInfo.Username)}
c.Data["json"] = resp c.Data["json"] = resp
c.ServeJSON() c.ServeJSON()
@ -258,7 +277,7 @@ func (c *ApiController) Login() {
} }
} }
//resp = &Response{Status: "ok", Msg: "", Data: res} //resp = &Response{Status: "ok", Msg: "", Data: res}
} else { } else { // form.Method != "signup"
userId := c.GetSessionUser() userId := c.GetSessionUser()
if userId == "" { if userId == "" {
resp = &Response{Status: "error", Msg: "The account does not exist", Data: userInfo} resp = &Response{Status: "error", Msg: "The account does not exist", Data: userInfo}
@ -270,35 +289,7 @@ func (c *ApiController) Login() {
user := object.GetUser(userId) user := object.GetUser(userId)
// sync info from 3rd-party if possible // sync info from 3rd-party if possible
if userInfo.Id != "" { object.SetUserOAuthProperties(user, provider.Type, userInfo)
propertyName := fmt.Sprintf("oauth_%s_id", provider.Type)
object.SetUserProperty(user, propertyName, userInfo.Id)
}
if userInfo.Username != "" {
propertyName := fmt.Sprintf("oauth_%s_username", provider.Type)
object.SetUserProperty(user, propertyName, userInfo.Username)
}
if userInfo.DisplayName != "" {
propertyName := fmt.Sprintf("oauth_%s_displayName", provider.Type)
object.SetUserProperty(user, propertyName, userInfo.DisplayName)
if user.DisplayName == "" {
object.SetUserField(user, "display_name", userInfo.DisplayName)
}
}
if userInfo.Email != "" {
propertyName := fmt.Sprintf("oauth_%s_email", provider.Type)
object.SetUserProperty(user, propertyName, userInfo.Email)
if user.Email == "" {
object.SetUserField(user, "email", userInfo.Email)
}
}
if userInfo.AvatarUrl != "" {
propertyName := fmt.Sprintf("oauth_%s_avatarUrl", provider.Type)
object.SetUserProperty(user, propertyName, userInfo.AvatarUrl)
if user.Avatar == "" {
object.SetUserField(user, "avatar", userInfo.AvatarUrl)
}
}
isLinked := object.LinkUserAccount(user, provider.Type, userInfo.Id) isLinked := object.LinkUserAccount(user, provider.Type, userInfo.Id)
if isLinked { if isLinked {

View File

@ -49,7 +49,7 @@ func (c *ApiController) Unlink() {
return return
} }
object.ClearUserProperties(user, providerType) object.ClearUserOAuthProperties(user, providerType)
object.LinkUserAccount(user, providerType, "") object.LinkUserAccount(user, providerType, "")
resp = Response{Status: "ok", Msg: ""} resp = Response{Status: "ok", Msg: ""}

38
object/user_cred.go Normal file
View File

@ -0,0 +1,38 @@
// Copyright 2021 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package object
import (
"strconv"
"strings"
"github.com/casdoor/casdoor/util"
)
func calculateHash(user *User) string {
s := strings.Join([]string{user.Id, user.Password, user.DisplayName, user.Avatar, user.Phone, strconv.Itoa(user.Score)}, "|")
return util.GetMd5Hash(s)
}
func (user *User) UpdateUserHash() {
hash := calculateHash(user)
user.Hash = hash
}
func (user *User) UpdateUserPassword(organization *Organization) {
if organization.PasswordType == "salt" {
user.Password = getSaltedPassword(user.Password, organization.PasswordSalt)
}
}

View File

@ -17,10 +17,9 @@ package object
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"strings" "strings"
"github.com/casdoor/casdoor/util" "github.com/casdoor/casdoor/idp"
"xorm.io/core" "xorm.io/core"
) )
@ -93,12 +92,44 @@ func GetUserField(user *User, field string) string {
return f.String() return f.String()
} }
func SetUserProperty(user *User, field string, value string) bool { func setUserProperty(user *User, field string, value string) {
if value == "" { if value == "" {
delete(user.Properties, field) delete(user.Properties, field)
} else { } else {
user.Properties[field] = value user.Properties[field] = value
} }
}
func SetUserOAuthProperties(user *User, providerType string, userInfo *idp.UserInfo) bool {
if userInfo.Id != "" {
propertyName := fmt.Sprintf("oauth_%s_id", providerType)
setUserProperty(user, propertyName, userInfo.Id)
}
if userInfo.Username != "" {
propertyName := fmt.Sprintf("oauth_%s_username", providerType)
setUserProperty(user, propertyName, userInfo.Username)
}
if userInfo.DisplayName != "" {
propertyName := fmt.Sprintf("oauth_%s_displayName", providerType)
setUserProperty(user, propertyName, userInfo.DisplayName)
if user.DisplayName == "" {
SetUserField(user, "display_name", userInfo.DisplayName)
}
}
if userInfo.Email != "" {
propertyName := fmt.Sprintf("oauth_%s_email", providerType)
setUserProperty(user, propertyName, userInfo.Email)
if user.Email == "" {
SetUserField(user, "email", userInfo.Email)
}
}
if userInfo.AvatarUrl != "" {
propertyName := fmt.Sprintf("oauth_%s_avatarUrl", providerType)
setUserProperty(user, propertyName, userInfo.AvatarUrl)
if user.Avatar == "" {
SetUserField(user, "avatar", userInfo.AvatarUrl)
}
}
affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("properties").Update(user) affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("properties").Update(user)
if err != nil { if err != nil {
@ -108,7 +139,7 @@ func SetUserProperty(user *User, field string, value string) bool {
return affected != 0 return affected != 0
} }
func ClearUserProperties(user *User, providerType string) bool { func ClearUserOAuthProperties(user *User, providerType string) bool {
for k := range user.Properties { for k := range user.Properties {
prefix := fmt.Sprintf("oauth_%s_", providerType) prefix := fmt.Sprintf("oauth_%s_", providerType)
if strings.HasPrefix(k, prefix) { if strings.HasPrefix(k, prefix) {
@ -123,19 +154,3 @@ func ClearUserProperties(user *User, providerType string) bool {
return affected != 0 return affected != 0
} }
func calculateHash(user *User) string {
s := strings.Join([]string{user.Id, user.Password, user.DisplayName, user.Avatar, user.Phone, strconv.Itoa(user.Score)}, "|")
return util.GetMd5Hash(s)
}
func (user *User) UpdateUserHash() {
hash := calculateHash(user)
user.Hash = hash
}
func (user *User) UpdateUserPassword(organization *Organization) {
if organization.PasswordType == "salt" {
user.Password = getSaltedPassword(user.Password, organization.PasswordSalt)
}
}