diff --git a/controllers/enforcer.go b/controllers/enforcer.go index 40e985a7..0d6ce21e 100644 --- a/controllers/enforcer.go +++ b/controllers/enforcer.go @@ -76,8 +76,10 @@ func (c *ApiController) Enforce() { } res := []bool{} - for _, permission := range permissions { - enforceResult, err := object.Enforce(permission.GetId(), &request) + + listPermissionIdMap := object.GroupPermissionsByModelAdapter(permissions) + for _, permissionIds := range listPermissionIdMap { + enforceResult, err := object.Enforce(permissionIds[0], &request, permissionIds...) if err != nil { c.ResponseError(err.Error()) return @@ -85,6 +87,7 @@ func (c *ApiController) Enforce() { res = append(res, enforceResult) } + c.ResponseOk(res) } @@ -135,18 +138,8 @@ func (c *ApiController) BatchEnforce() { } res := [][]bool{} - listPermissionIdMap := map[string][]string{} - - for _, permission := range permissions { - key := permission.Model + permission.Adapter - permissionIds, ok := listPermissionIdMap[key] - if !ok { - listPermissionIdMap[key] = []string{permission.GetId()} - } else { - listPermissionIdMap[key] = append(permissionIds, permission.GetId()) - } - } + listPermissionIdMap := object.GroupPermissionsByModelAdapter(permissions) for _, permissionIds := range listPermissionIdMap { enforceResult, err := object.BatchEnforce(permissionIds[0], &requests, permissionIds...) if err != nil { diff --git a/object/permission.go b/object/permission.go index 5ce010ce..08d4100a 100644 --- a/object/permission.go +++ b/object/permission.go @@ -370,3 +370,24 @@ func GetMaskedPermissions(permissions []*Permission) []*Permission { return permissions } + +// GroupPermissionsByModelAdapter group permissions by model and adapter. +// Every model and adapter will be a key, and the value is a list of permission ids. +// With each list of permission ids have the same key, we just need to init the +// enforcer and do the enforce/batch-enforce once (with list of permission ids +// as the policyFilter when the enforcer load policy). +func GroupPermissionsByModelAdapter(permissions []*Permission) map[string][]string { + m := make(map[string][]string) + + for _, permission := range permissions { + key := permission.Model + permission.Adapter + permissionIds, ok := m[key] + if !ok { + m[key] = []string{permission.GetId()} + } else { + m[key] = append(permissionIds, permission.GetId()) + } + } + + return m +} diff --git a/object/permission_enforcer.go b/object/permission_enforcer.go index ca93ecc2..abbd3f82 100644 --- a/object/permission_enforcer.go +++ b/object/permission_enforcer.go @@ -246,13 +246,13 @@ func removePolicies(permission *Permission) { type CasbinRequest = []interface{} -func Enforce(permissionId string, request *CasbinRequest) (bool, error) { +func Enforce(permissionId string, request *CasbinRequest, permissionIds ...string) (bool, error) { permission, err := GetPermission(permissionId) if err != nil { return false, err } - enforcer := getEnforcer(permission) + enforcer := getEnforcer(permission, permissionIds...) return enforcer.Enforce(*request...) }