diff --git a/object/user.go b/object/user.go index 7c2bfd22..b901697a 100644 --- a/object/user.go +++ b/object/user.go @@ -16,8 +16,6 @@ package object import ( "fmt" - "reflect" - "strings" "github.com/casdoor/casdoor/util" "xorm.io/core" @@ -88,6 +86,20 @@ func GetUser(id string) *User { return getUser(owner, name) } +func GetMaskedUser(user *User) *User { + if user.Password != "" { + user.Password = "***" + } + return user +} + +func GetMaskedUsers(users []*User) []*User { + for _, user := range users { + user = GetMaskedUser(user) + } + return users +} + func UpdateUser(id string, user *User) bool { owner, name := util.GetOwnerAndNameFromId(id) if getUser(owner, name) == nil { @@ -185,109 +197,10 @@ func DeleteUser(user *User) bool { return affected != 0 } -func GetUserByField(organizationName string, field string, value string) *User { - user := User{Owner: organizationName} - existed, err := adapter.Engine.Where(fmt.Sprintf("%s=?", field), value).Get(&user) - if err != nil { - panic(err) - } - - if existed { - return &user - } else { - return nil - } -} - -func HasUserByField(organizationName string, field string, value string) bool { - return GetUserByField(organizationName, field, value) != nil -} - -func GetUserByFields(organization string, field string) *User { - // check username - user := GetUserByField(organization, "name", field) - if user != nil { - return user - } - - // check email - user = GetUserByField(organization, "email", field) - if user != nil { - return user - } - - // check phone - user = GetUserByField(organization, "phone", field) - if user != nil { - return user - } - - return nil -} - -func SetUserField(user *User, field string, value string) bool { - if field == "password" { - organization := GetOrganizationByUser(user) - user.UpdateUserPassword(organization) - value = user.Password - } - - affected, err := adapter.Engine.Table(user).ID(core.PK{user.Owner, user.Name}).Update(map[string]interface{}{field: value}) - if err != nil { - panic(err) - } - - user = getUser(user.Owner, user.Name) - user.UpdateUserHash() - _, err = adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("hash").Update(user) - if err != nil { - panic(err) - } - - return affected != 0 -} - func LinkUserAccount(user *User, field string, value string) bool { return SetUserField(user, field, value) } -func GetUserField(user *User, field string) string { - // https://socketloop.com/tutorials/golang-how-to-get-struct-field-and-value-by-name - u := reflect.ValueOf(user) - f := reflect.Indirect(u).FieldByName(field) - return f.String() -} - -func GetMaskedUser(user *User) *User { - if user.Password != "" { - user.Password = "***" - } - return user -} - -func GetMaskedUsers(users []*User) []*User { - for _, user := range users { - user = GetMaskedUser(user) - } - return users -} - -func calculateHash(user *User) string { - s := strings.Join([]string{user.Id, user.Password, user.DisplayName, user.Avatar, user.Phone}, "|") - return util.GetMd5Hash(s) -} - -func (user *User) UpdateUserHash() { - hash := calculateHash(user) - user.Hash = hash -} - -func (user *User) UpdateUserPassword(organization *Organization) { - if organization.PasswordType == "salt" { - user.Password = getSaltedPassword(user.Password, organization.PasswordSalt) - } -} - func (user *User) GetId() string { return fmt.Sprintf("%s/%s", user.Owner, user.Name) } diff --git a/object/user_util.go b/object/user_util.go new file mode 100644 index 00000000..c86a090b --- /dev/null +++ b/object/user_util.go @@ -0,0 +1,109 @@ +// Copyright 2021 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package object + +import ( + "fmt" + "reflect" + "strings" + + "github.com/casdoor/casdoor/util" + "xorm.io/core" +) + +func GetUserByField(organizationName string, field string, value string) *User { + user := User{Owner: organizationName} + existed, err := adapter.Engine.Where(fmt.Sprintf("%s=?", field), value).Get(&user) + if err != nil { + panic(err) + } + + if existed { + return &user + } else { + return nil + } +} + +func HasUserByField(organizationName string, field string, value string) bool { + return GetUserByField(organizationName, field, value) != nil +} + +func GetUserByFields(organization string, field string) *User { + // check username + user := GetUserByField(organization, "name", field) + if user != nil { + return user + } + + // check email + user = GetUserByField(organization, "email", field) + if user != nil { + return user + } + + // check phone + user = GetUserByField(organization, "phone", field) + if user != nil { + return user + } + + return nil +} + +func SetUserField(user *User, field string, value string) bool { + if field == "password" { + organization := GetOrganizationByUser(user) + user.UpdateUserPassword(organization) + value = user.Password + } + + affected, err := adapter.Engine.Table(user).ID(core.PK{user.Owner, user.Name}).Update(map[string]interface{}{field: value}) + if err != nil { + panic(err) + } + + user = getUser(user.Owner, user.Name) + user.UpdateUserHash() + _, err = adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("hash").Update(user) + if err != nil { + panic(err) + } + + return affected != 0 +} + +func GetUserField(user *User, field string) string { + // https://socketloop.com/tutorials/golang-how-to-get-struct-field-and-value-by-name + u := reflect.ValueOf(user) + f := reflect.Indirect(u).FieldByName(field) + return f.String() +} + +func calculateHash(user *User) string { + s := strings.Join([]string{user.Id, user.Password, user.DisplayName, user.Avatar, user.Phone}, "|") + return util.GetMd5Hash(s) +} + +func (user *User) UpdateUserHash() { + hash := calculateHash(user) + user.Hash = hash +} + +func (user *User) UpdateUserPassword(organization *Organization) { + if organization.PasswordType == "salt" { + user.Password = getSaltedPassword(user.Password, organization.PasswordSalt) + } +}