From 1a9d02be46fd3ad2d3262c6e5d1619605f41d7bb Mon Sep 17 00:00:00 2001 From: Yaodong Yu <2814461814@qq.com> Date: Fri, 11 Aug 2023 10:59:18 +0800 Subject: [PATCH] feat: use the casbin model to store relationships between users and groups (#2178) * fix:reslove conflict * fix: remove interface --- controllers/user.go | 22 ++++++++-- main.go | 1 + object/adapter.go | 13 ++++-- object/group.go | 46 +++++++++---------- object/user.go | 28 ++++++++++++ object/user_enforcer.go | 95 ++++++++++++++++++++++++++++++++++++++++ web/src/GroupTreePage.js | 1 + 7 files changed, 173 insertions(+), 33 deletions(-) create mode 100644 object/user_enforcer.go diff --git a/controllers/user.go b/controllers/user.go index 4dc2a5ab..356d9b05 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -90,7 +90,7 @@ func (c *ApiController) GetUsers() { if limit == "" || page == "" { if groupName != "" { - maskedUsers, err := object.GetMaskedUsers(object.GetGroupUsers(groupName)) + maskedUsers, err := object.GetMaskedUsers(object.GetGroupUsers(util.GetId(owner, groupName))) if err != nil { c.ResponseError(err.Error()) return @@ -567,6 +567,22 @@ func (c *ApiController) RemoveUserFromGroup() { name := c.Ctx.Request.Form.Get("name") groupName := c.Ctx.Request.Form.Get("groupName") - c.Data["json"] = wrapActionResponse(object.RemoveUserFromGroup(owner, name, util.GetId(owner, groupName))) - c.ServeJSON() + organization, err := object.GetOrganization(util.GetId("admin", owner)) + if err != nil { + return + } + item := object.GetAccountItemByName("Groups", organization) + res, msg := object.CheckAccountItemModifyRule(item, c.IsAdmin(), c.GetAcceptLanguage()) + if !res { + c.ResponseError(msg) + return + } + + affected, err := object.DeleteGroupForUser(util.GetId(owner, name), groupName) + if err != nil { + c.ResponseError(err.Error()) + return + } + + c.ResponseOk(affected) } diff --git a/main.go b/main.go index 07bbc5a9..96598bdd 100644 --- a/main.go +++ b/main.go @@ -49,6 +49,7 @@ func main() { object.InitLdapAutoSynchronizer() proxy.InitHttpClient() authz.InitApi() + object.InitUserManager() util.SafeGoroutine(func() { object.RunSyncUsersJob() }) diff --git a/object/adapter.go b/object/adapter.go index 79dec2f1..7c53e825 100644 --- a/object/adapter.go +++ b/object/adapter.go @@ -23,6 +23,7 @@ import ( "github.com/casdoor/casdoor/util" xormadapter "github.com/casdoor/xorm-adapter/v3" "github.com/xorm-io/core" + "github.com/xorm-io/xorm" ) type Adapter struct { @@ -155,14 +156,17 @@ func (adapter *Adapter) initAdapter() error { if adapter.builtInAdapter() { dataSourceName = conf.GetConfigString("dataSourceName") + if adapter.DatabaseType == "mysql" { + dataSourceName = dataSourceName + adapter.Database + } } else { switch adapter.DatabaseType { case "mssql": dataSourceName = fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", adapter.User, adapter.Password, adapter.Host, adapter.Port, adapter.Database) case "mysql": - dataSourceName = fmt.Sprintf("%s:%s@tcp(%s:%d)/", adapter.User, - adapter.Password, adapter.Host, adapter.Port) + dataSourceName = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", adapter.User, + adapter.Password, adapter.Host, adapter.Port, adapter.Database) case "postgres": dataSourceName = fmt.Sprintf("user=%s password=%s host=%s port=%d sslmode=disable dbname=%s", adapter.User, adapter.Password, adapter.Host, adapter.Port, adapter.Database) @@ -181,7 +185,8 @@ func (adapter *Adapter) initAdapter() error { } var err error - adapter.Adapter, err = xormadapter.NewAdapterByEngineWithTableName(NewAdapter(adapter.DatabaseType, dataSourceName, adapter.Database).Engine, adapter.getTable(), adapter.TableNamePrefix) + engine, err := xorm.NewEngine(adapter.DatabaseType, dataSourceName) + adapter.Adapter, err = xormadapter.NewAdapterByEngineWithTableName(engine, adapter.getTable(), adapter.TableNamePrefix) if err != nil { return err } @@ -327,7 +332,7 @@ func (adapter *Adapter) builtInAdapter() bool { return false } - return adapter.Name == "permission-adapter-built-in" || adapter.Name == "api-adapter-built-in" + return adapter.Name == "user-adapter-built-in" || adapter.Name == "api-adapter-built-in" } func getModelDef() model.Model { diff --git a/object/group.go b/object/group.go index c726b007..7aacba84 100644 --- a/object/group.go +++ b/object/group.go @@ -214,30 +214,18 @@ func ConvertToTreeData(groups []*Group, parentId string) []*Group { return treeData } -func RemoveUserFromGroup(owner, name, groupId string) (bool, error) { - user, err := getUser(owner, name) - if err != nil { - return false, err - } - if user == nil { - return false, errors.New("user not exist") - } - - user.Groups = util.DeleteVal(user.Groups, groupId) - affected, err := updateUser(user.GetId(), user, []string{"groups"}) - if err != nil { - return false, err - } - return affected != 0, err -} - func GetGroupUserCount(groupId string, field, value string) (int64, error) { + owner, _ := util.GetOwnerAndNameFromId(groupId) + names, err := userEnforcer.GetUserNamesByGroupName(groupId) + if err != nil { + return 0, err + } + if field == "" && value == "" { - return ormer.Engine.Where(builder.Like{"`groups`", groupId}). - Count(&User{}) + return int64(len(names)), nil } else { return ormer.Engine.Table("user"). - Where(builder.Like{"`groups`", groupId}). + Where("owner = ?", owner).In("name", names). And(fmt.Sprintf("user.%s LIKE ?", util.CamelToSnakeCase(field)), "%"+value+"%"). Count() } @@ -245,8 +233,14 @@ func GetGroupUserCount(groupId string, field, value string) (int64, error) { func GetPaginationGroupUsers(groupId string, offset, limit int, field, value, sortField, sortOrder string) ([]*User, error) { users := []*User{} + owner, _ := util.GetOwnerAndNameFromId(groupId) + names, err := userEnforcer.GetUserNamesByGroupName(groupId) + if err != nil { + return nil, err + } + session := ormer.Engine.Table("user"). - Where(builder.Like{"`groups`", groupId + "\""}) + Where("owner = ?", owner).In("name", names) if offset != -1 && limit != -1 { session.Limit(limit, offset) @@ -265,7 +259,7 @@ func GetPaginationGroupUsers(groupId string, offset, limit int, field, value, so session = session.Desc(fmt.Sprintf("user.%s", util.SnakeString(sortField))) } - err := session.Find(&users) + err = session.Find(&users) if err != nil { return nil, err } @@ -275,13 +269,13 @@ func GetPaginationGroupUsers(groupId string, offset, limit int, field, value, so func GetGroupUsers(groupId string) ([]*User, error) { users := []*User{} - err := ormer.Engine.Table("user"). - Where(builder.Like{"`groups`", groupId + "\""}). - Find(&users) + owner, _ := util.GetOwnerAndNameFromId(groupId) + names, err := userEnforcer.GetUserNamesByGroupName(groupId) + + err = ormer.Engine.Where("owner = ?", owner).In("name", names).Find(&users) if err != nil { return nil, err } - return users, nil } diff --git a/object/user.go b/object/user.go index 5faf47ae..023e2905 100644 --- a/object/user.go +++ b/object/user.go @@ -29,6 +29,23 @@ const ( UserPropertiesWechatOpenId = "wechatOpenId" ) +const UserEnforcerId = "built-in/user-enforcer-built-in" + +var userEnforcer *UserGroupEnforcer + +func InitUserManager() { + enforcer, err := GetEnforcer(UserEnforcerId) + if err != nil { + panic(err) + } + err = enforcer.InitEnforcer() + if err != nil { + panic(err) + } + + userEnforcer = NewUserGroupEnforcer(enforcer.Enforcer) +} + type User struct { Owner string `xorm:"varchar(100) notnull pk" json:"owner"` Name string `xorm:"varchar(100) notnull pk" json:"name"` @@ -531,6 +548,13 @@ func UpdateUser(id string, user *User, columns []string, isAdmin bool) (bool, er columns = append(columns, "name", "email", "phone", "country_code") } + if util.ContainsString(columns, "groups") { + _, err := userEnforcer.UpdateGroupsForUser(user.GetId(), user.Groups) + if err != nil { + return false, err + } + } + affected, err := updateUser(id, user, columns) if err != nil { return false, err @@ -778,6 +802,10 @@ func ExtendUserWithRolesAndPermissions(user *User) (err error) { return } +func DeleteGroupForUser(user string, group string) (bool, error) { + return userEnforcer.DeleteGroupForUser(user, group) +} + func userChangeTrigger(oldName string, newName string) error { session := ormer.Engine.NewSession() defer session.Close() diff --git a/object/user_enforcer.go b/object/user_enforcer.go new file mode 100644 index 00000000..0e70c2ba --- /dev/null +++ b/object/user_enforcer.go @@ -0,0 +1,95 @@ +package object + +import ( + "github.com/casbin/casbin/v2" + "github.com/casbin/casbin/v2/errors" + "github.com/casdoor/casdoor/util" +) + +type UserGroupEnforcer struct { + // use rbac model implement use group, the enforcer can also implement user role + enforcer *casbin.Enforcer +} + +func NewUserGroupEnforcer(enforcer *casbin.Enforcer) *UserGroupEnforcer { + return &UserGroupEnforcer{ + enforcer: enforcer, + } +} + +func (e *UserGroupEnforcer) AddGroupForUser(user string, group string) (bool, error) { + return e.enforcer.AddRoleForUser(user, GetGroupWithPrefix(group)) +} + +func (e *UserGroupEnforcer) AddGroupsForUser(user string, groups []string) (bool, error) { + g := make([]string, len(groups)) + for i, group := range groups { + g[i] = GetGroupWithPrefix(group) + } + return e.enforcer.AddRolesForUser(user, g) +} + +func (e *UserGroupEnforcer) DeleteGroupForUser(user string, group string) (bool, error) { + return e.enforcer.DeleteRoleForUser(user, GetGroupWithPrefix(group)) +} + +func (e *UserGroupEnforcer) DeleteGroupsForUser(user string) (bool, error) { + return e.enforcer.DeleteRolesForUser(user) +} + +func (e *UserGroupEnforcer) GetGroupsForUser(user string) ([]string, error) { + groups, err := e.enforcer.GetRolesForUser(user) + for i, group := range groups { + groups[i] = GetGroupWithoutPrefix(group) + } + return groups, err +} + +func (e *UserGroupEnforcer) GetAllUsersByGroup(group string) ([]string, error) { + users, err := e.enforcer.GetUsersForRole(GetGroupWithPrefix(group)) + if err != nil { + if err == errors.ERR_NAME_NOT_FOUND { + return []string{}, nil + } + return nil, err + } + return users, nil +} + +func GetGroupWithPrefix(group string) string { + return "group:" + group +} + +func GetGroupWithoutPrefix(group string) string { + return group[len("group:"):] +} + +func (e *UserGroupEnforcer) GetUserNamesByGroupName(groupName string) ([]string, error) { + var names []string + + userIds, err := e.GetAllUsersByGroup(groupName) + if err != nil { + return nil, err + } + + for _, userId := range userIds { + _, name := util.GetOwnerAndNameFromIdNoCheck(userId) + names = append(names, name) + } + + return names, nil +} + +func (e *UserGroupEnforcer) UpdateGroupsForUser(user string, groups []string) (bool, error) { + _, err := e.DeleteGroupsForUser(user) + if err != nil { + return false, err + } + + affected, err := e.AddGroupsForUser(user, groups) + if err != nil { + return false, err + } + + return affected, nil +} diff --git a/web/src/GroupTreePage.js b/web/src/GroupTreePage.js index c4cf76e5..35f04d58 100644 --- a/web/src/GroupTreePage.js +++ b/web/src/GroupTreePage.js @@ -221,6 +221,7 @@ class GroupTreePage extends React.Component { onChange={(value) => { this.setState({ organizationName: value, + groupName: "", }); this.props.history.push(`/trees/${value}`); }}