From 1eeeb64a0c8fd1a5c60d735f93f23c337ae892e7 Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Thu, 24 Aug 2023 18:16:23 +0800 Subject: [PATCH] Add checkModel() for UserGroupEnforcer --- object/user_enforcer.go | 52 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/object/user_enforcer.go b/object/user_enforcer.go index 0e70c2ba..f883246d 100644 --- a/object/user_enforcer.go +++ b/object/user_enforcer.go @@ -1,6 +1,8 @@ package object import ( + "fmt" + "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/errors" "github.com/casdoor/casdoor/util" @@ -17,11 +19,28 @@ func NewUserGroupEnforcer(enforcer *casbin.Enforcer) *UserGroupEnforcer { } } +func (e *UserGroupEnforcer) checkModel() error { + if _, ok := e.enforcer.GetModel()["g"]; !ok { + return fmt.Errorf("The Casbin model used by enforcer doesn't support RBAC (\"[role_definition]\" section not found), please use a RBAC enabled Casbin model for the enforcer") + } + return nil +} + func (e *UserGroupEnforcer) AddGroupForUser(user string, group string) (bool, error) { + err := e.checkModel() + if err != nil { + return false, err + } + return e.enforcer.AddRoleForUser(user, GetGroupWithPrefix(group)) } func (e *UserGroupEnforcer) AddGroupsForUser(user string, groups []string) (bool, error) { + err := e.checkModel() + if err != nil { + return false, err + } + g := make([]string, len(groups)) for i, group := range groups { g[i] = GetGroupWithPrefix(group) @@ -30,14 +49,29 @@ func (e *UserGroupEnforcer) AddGroupsForUser(user string, groups []string) (bool } func (e *UserGroupEnforcer) DeleteGroupForUser(user string, group string) (bool, error) { + err := e.checkModel() + if err != nil { + return false, err + } + return e.enforcer.DeleteRoleForUser(user, GetGroupWithPrefix(group)) } func (e *UserGroupEnforcer) DeleteGroupsForUser(user string) (bool, error) { + err := e.checkModel() + if err != nil { + return false, err + } + return e.enforcer.DeleteRolesForUser(user) } func (e *UserGroupEnforcer) GetGroupsForUser(user string) ([]string, error) { + err := e.checkModel() + if err != nil { + return nil, err + } + groups, err := e.enforcer.GetRolesForUser(user) for i, group := range groups { groups[i] = GetGroupWithoutPrefix(group) @@ -46,6 +80,11 @@ func (e *UserGroupEnforcer) GetGroupsForUser(user string) ([]string, error) { } func (e *UserGroupEnforcer) GetAllUsersByGroup(group string) ([]string, error) { + err := e.checkModel() + if err != nil { + return nil, err + } + users, err := e.enforcer.GetUsersForRole(GetGroupWithPrefix(group)) if err != nil { if err == errors.ERR_NAME_NOT_FOUND { @@ -65,13 +104,17 @@ func GetGroupWithoutPrefix(group string) string { } func (e *UserGroupEnforcer) GetUserNamesByGroupName(groupName string) ([]string, error) { - var names []string + err := e.checkModel() + if err != nil { + return nil, err + } userIds, err := e.GetAllUsersByGroup(groupName) if err != nil { return nil, err } + names := []string{} for _, userId := range userIds { _, name := util.GetOwnerAndNameFromIdNoCheck(userId) names = append(names, name) @@ -81,7 +124,12 @@ func (e *UserGroupEnforcer) GetUserNamesByGroupName(groupName string) ([]string, } func (e *UserGroupEnforcer) UpdateGroupsForUser(user string, groups []string) (bool, error) { - _, err := e.DeleteGroupsForUser(user) + err := e.checkModel() + if err != nil { + return false, err + } + + _, err = e.DeleteGroupsForUser(user) if err != nil { return false, err }