Improve permission error handling

This commit is contained in:
Yang Luo 2023-10-22 02:30:29 +08:00
parent 98defe617b
commit 37744d6cd7
5 changed files with 170 additions and 66 deletions

View File

@ -379,7 +379,10 @@ func CheckLoginPermission(userId string, application *Application) (bool, error)
continue continue
} }
enforcer := getPermissionEnforcer(permission) enforcer, err := getPermissionEnforcer(permission)
if err != nil {
return false, err
}
var isAllowed bool var isAllowed bool
isAllowed, err = enforcer.Enforce(userId, application.Name, "Read") isAllowed, err = enforcer.Enforce(userId, application.Name, "Read")

View File

@ -113,11 +113,15 @@ func GetPermission(id string) (*Permission, error) {
// checkPermissionValid verifies if the permission is valid // checkPermissionValid verifies if the permission is valid
func checkPermissionValid(permission *Permission) error { func checkPermissionValid(permission *Permission) error {
enforcer := getPermissionEnforcer(permission) enforcer, err := getPermissionEnforcer(permission)
if err != nil {
return err
}
enforcer.EnableAutoSave(false) enforcer.EnableAutoSave(false)
policies := getPolicies(permission) policies := getPolicies(permission)
_, err := enforcer.AddPolicies(policies) _, err = enforcer.AddPolicies(policies)
if err != nil { if err != nil {
return err return err
} }
@ -129,7 +133,7 @@ func checkPermissionValid(permission *Permission) error {
groupingPolicies := getGroupingPolicies(permission) groupingPolicies := getGroupingPolicies(permission)
if len(groupingPolicies) > 0 { if len(groupingPolicies) > 0 {
_, err := enforcer.AddGroupingPolicies(groupingPolicies) _, err = enforcer.AddGroupingPolicies(groupingPolicies)
if err != nil { if err != nil {
return err return err
} }
@ -174,8 +178,16 @@ func UpdatePermission(id string, permission *Permission) (bool, error) {
} }
if affected != 0 { if affected != 0 {
removeGroupingPolicies(oldPermission) err = removeGroupingPolicies(oldPermission)
removePolicies(oldPermission) if err != nil {
return false, err
}
err = removePolicies(oldPermission)
if err != nil {
return false, err
}
if oldPermission.Adapter != "" && oldPermission.Adapter != permission.Adapter { if oldPermission.Adapter != "" && oldPermission.Adapter != permission.Adapter {
isEmpty, _ := ormer.Engine.IsTableEmpty(oldPermission.Adapter) isEmpty, _ := ormer.Engine.IsTableEmpty(oldPermission.Adapter)
if isEmpty { if isEmpty {
@ -185,8 +197,16 @@ func UpdatePermission(id string, permission *Permission) (bool, error) {
} }
} }
} }
addGroupingPolicies(permission)
addPolicies(permission) err = addGroupingPolicies(permission)
if err != nil {
return false, err
}
err = addPolicies(permission)
if err != nil {
return false, err
}
} }
return affected != 0, nil return affected != 0, nil
@ -199,40 +219,54 @@ func AddPermission(permission *Permission) (bool, error) {
} }
if affected != 0 { if affected != 0 {
addGroupingPolicies(permission) err = addGroupingPolicies(permission)
addPolicies(permission) if err != nil {
return false, err
}
err = addPolicies(permission)
if err != nil {
return false, err
}
} }
return affected != 0, nil return affected != 0, nil
} }
func AddPermissions(permissions []*Permission) bool { func AddPermissions(permissions []*Permission) (bool, error) {
if len(permissions) == 0 { if len(permissions) == 0 {
return false return false, nil
} }
affected, err := ormer.Engine.Insert(permissions) affected, err := ormer.Engine.Insert(permissions)
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "Duplicate entry") { if !strings.Contains(err.Error(), "Duplicate entry") {
panic(err) return false, err
} }
} }
for _, permission := range permissions { for _, permission := range permissions {
// add using for loop // add using for loop
if affected != 0 { if affected != 0 {
addGroupingPolicies(permission) err = addGroupingPolicies(permission)
addPolicies(permission) if err != nil {
return false, err
}
err = addPolicies(permission)
if err != nil {
return false, err
}
} }
} }
return affected != 0 return affected != 0, nil
} }
func AddPermissionsInBatch(permissions []*Permission) bool { func AddPermissionsInBatch(permissions []*Permission) (bool, error) {
batchSize := conf.GetConfigBatchSize() batchSize := conf.GetConfigBatchSize()
if len(permissions) == 0 { if len(permissions) == 0 {
return false return false, nil
} }
affected := false affected := false
@ -245,12 +279,18 @@ func AddPermissionsInBatch(permissions []*Permission) bool {
tmp := permissions[start:end] tmp := permissions[start:end]
fmt.Printf("The syncer adds permissions: [%d - %d]\n", start, end) fmt.Printf("The syncer adds permissions: [%d - %d]\n", start, end)
if AddPermissions(tmp) {
b, err := AddPermissions(tmp)
if err != nil {
return false, err
}
if b {
affected = true affected = true
} }
} }
return affected return affected, nil
} }
func DeletePermission(permission *Permission) (bool, error) { func DeletePermission(permission *Permission) (bool, error) {
@ -260,8 +300,16 @@ func DeletePermission(permission *Permission) (bool, error) {
} }
if affected != 0 { if affected != 0 {
removeGroupingPolicies(permission) err = removeGroupingPolicies(permission)
removePolicies(permission) if err != nil {
return false, err
}
err = removePolicies(permission)
if err != nil {
return false, err
}
if permission.Adapter != "" && permission.Adapter != "permission_rule" { if permission.Adapter != "" && permission.Adapter != "permission_rule" {
isEmpty, _ := ormer.Engine.IsTableEmpty(permission.Adapter) isEmpty, _ := ormer.Engine.IsTableEmpty(permission.Adapter)
if isEmpty { if isEmpty {

View File

@ -26,23 +26,23 @@ import (
xormadapter "github.com/casdoor/xorm-adapter/v3" xormadapter "github.com/casdoor/xorm-adapter/v3"
) )
func getPermissionEnforcer(p *Permission, permissionIDs ...string) *casbin.Enforcer { func getPermissionEnforcer(p *Permission, permissionIDs ...string) (*casbin.Enforcer, error) {
// Init an enforcer instance without specifying a model or adapter. // Init an enforcer instance without specifying a model or adapter.
// If you specify an adapter, it will load all policies, which is a // If you specify an adapter, it will load all policies, which is a
// heavy process that can slow down the application. // heavy process that can slow down the application.
enforcer, err := casbin.NewEnforcer(&log.DefaultLogger{}, false) enforcer, err := casbin.NewEnforcer(&log.DefaultLogger{}, false)
if err != nil { if err != nil {
panic(err) return nil, err
} }
err = p.setEnforcerModel(enforcer) err = p.setEnforcerModel(enforcer)
if err != nil { if err != nil {
panic(err) return nil, err
} }
err = p.setEnforcerAdapter(enforcer) err = p.setEnforcerAdapter(enforcer)
if err != nil { if err != nil {
panic(err) return nil, err
} }
policyFilterV5 := []string{p.GetId()} policyFilterV5 := []string{p.GetId()}
@ -60,10 +60,10 @@ func getPermissionEnforcer(p *Permission, permissionIDs ...string) *casbin.Enfor
err = enforcer.LoadFilteredPolicy(policyFilter) err = enforcer.LoadFilteredPolicy(policyFilter)
if err != nil { if err != nil {
panic(err) return nil, err
} }
return enforcer return enforcer, nil
} }
func (p *Permission) setEnforcerAdapter(enforcer *casbin.Enforcer) error { func (p *Permission) setEnforcerAdapter(enforcer *casbin.Enforcer) error {
@ -201,72 +201,96 @@ func getGroupingPolicies(permission *Permission) [][]string {
return groupingPolicies return groupingPolicies
} }
func addPolicies(permission *Permission) { func addPolicies(permission *Permission) error {
enforcer := getPermissionEnforcer(permission) enforcer, err := getPermissionEnforcer(permission)
if err != nil {
return err
}
policies := getPolicies(permission) policies := getPolicies(permission)
_, err := enforcer.AddPolicies(policies) _, err = enforcer.AddPolicies(policies)
if err != nil { return err
panic(err)
}
} }
func addGroupingPolicies(permission *Permission) { func removePolicies(permission *Permission) error {
enforcer := getPermissionEnforcer(permission) enforcer, err := getPermissionEnforcer(permission)
if err != nil {
return err
}
policies := getPolicies(permission)
_, err = enforcer.RemovePolicies(policies)
return err
}
func addGroupingPolicies(permission *Permission) error {
enforcer, err := getPermissionEnforcer(permission)
if err != nil {
return err
}
groupingPolicies := getGroupingPolicies(permission) groupingPolicies := getGroupingPolicies(permission)
if len(groupingPolicies) > 0 { if len(groupingPolicies) > 0 {
_, err := enforcer.AddGroupingPolicies(groupingPolicies) _, err = enforcer.AddGroupingPolicies(groupingPolicies)
if err != nil { if err != nil {
panic(err) return err
} }
} }
return nil
} }
func removeGroupingPolicies(permission *Permission) { func removeGroupingPolicies(permission *Permission) error {
enforcer := getPermissionEnforcer(permission) enforcer, err := getPermissionEnforcer(permission)
if err != nil {
return err
}
groupingPolicies := getGroupingPolicies(permission) groupingPolicies := getGroupingPolicies(permission)
if len(groupingPolicies) > 0 { if len(groupingPolicies) > 0 {
_, err := enforcer.RemoveGroupingPolicies(groupingPolicies) _, err = enforcer.RemoveGroupingPolicies(groupingPolicies)
if err != nil { if err != nil {
panic(err) return err
} }
} }
}
func removePolicies(permission *Permission) { return nil
enforcer := getPermissionEnforcer(permission)
policies := getPolicies(permission)
_, err := enforcer.RemovePolicies(policies)
if err != nil {
panic(err)
}
} }
type CasbinRequest = []interface{} type CasbinRequest = []interface{}
func Enforce(permission *Permission, request *CasbinRequest, permissionIds ...string) (bool, error) { func Enforce(permission *Permission, request *CasbinRequest, permissionIds ...string) (bool, error) {
enforcer := getPermissionEnforcer(permission, permissionIds...) enforcer, err := getPermissionEnforcer(permission, permissionIds...)
if err != nil {
return false, err
}
return enforcer.Enforce(*request...) return enforcer.Enforce(*request...)
} }
func BatchEnforce(permission *Permission, requests *[]CasbinRequest, permissionIds ...string) ([]bool, error) { func BatchEnforce(permission *Permission, requests *[]CasbinRequest, permissionIds ...string) ([]bool, error) {
enforcer := getPermissionEnforcer(permission, permissionIds...) enforcer, err := getPermissionEnforcer(permission, permissionIds...)
if err != nil {
return nil, err
}
return enforcer.BatchEnforce(*requests) return enforcer.BatchEnforce(*requests)
} }
func getAllValues(userId string, fn func(enforcer *casbin.Enforcer) []string) []string { func getAllValues(userId string, fn func(enforcer *casbin.Enforcer) []string) ([]string, error) {
permissions, _, err := getPermissionsAndRolesByUser(userId) permissions, _, err := getPermissionsAndRolesByUser(userId)
if err != nil { if err != nil {
panic(err) return nil, err
} }
for _, role := range GetAllRoles(userId) { for _, role := range GetAllRoles(userId) {
permissionsByRole, err := GetPermissionsByRole(role) permissionsByRole, err := GetPermissionsByRole(role)
if err != nil { if err != nil {
panic(err) return nil, err
} }
permissions = append(permissions, permissionsByRole...) permissions = append(permissions, permissionsByRole...)
@ -274,19 +298,24 @@ func getAllValues(userId string, fn func(enforcer *casbin.Enforcer) []string) []
var values []string var values []string
for _, permission := range permissions { for _, permission := range permissions {
enforcer := getPermissionEnforcer(permission) enforcer, err := getPermissionEnforcer(permission)
if err != nil {
return nil, err
}
values = append(values, fn(enforcer)...) values = append(values, fn(enforcer)...)
} }
return values
return values, nil
} }
func GetAllObjects(userId string) []string { func GetAllObjects(userId string) ([]string, error) {
return getAllValues(userId, func(enforcer *casbin.Enforcer) []string { return getAllValues(userId, func(enforcer *casbin.Enforcer) []string {
return enforcer.GetAllObjects() return enforcer.GetAllObjects()
}) })
} }
func GetAllActions(userId string) []string { func GetAllActions(userId string) ([]string, error) {
return getAllValues(userId, func(enforcer *casbin.Enforcer) []string { return getAllValues(userId, func(enforcer *casbin.Enforcer) []string {
return enforcer.GetAllActions() return enforcer.GetAllActions()
}) })
@ -330,17 +359,23 @@ m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act`
// load [policy_definition] // load [policy_definition]
policyDefinition := strings.Split(cfg.String("policy_definition::p"), ",") policyDefinition := strings.Split(cfg.String("policy_definition::p"), ",")
fieldsNum := len(policyDefinition) fieldsNum := len(policyDefinition)
if fieldsNum > builtInAvailableField { if fieldsNum > builtInAvailableField {
panic(fmt.Errorf("the maximum policy_definition field number cannot exceed %d, got %d", builtInAvailableField, fieldsNum)) return nil, fmt.Errorf("the maximum policy_definition field number cannot exceed %d, got %d", builtInAvailableField, fieldsNum)
} }
// filled empty field with "" and V5 with "permissionId" // filled empty field with "" and V5 with "permissionId"
for i := builtInAvailableField - fieldsNum; i > 0; i-- { for i := builtInAvailableField - fieldsNum; i > 0; i-- {
policyDefinition = append(policyDefinition, "") policyDefinition = append(policyDefinition, "")
} }
policyDefinition = append(policyDefinition, "permissionId") policyDefinition = append(policyDefinition, "permissionId")
m, _ := model.NewModelFromString(modelText) m, err := model.NewModelFromString(modelText)
if err != nil {
return nil, err
}
m.AddDef("p", "p", strings.Join(policyDefinition, ",")) m.AddDef("p", "p", strings.Join(policyDefinition, ","))
return m, err return m, err

View File

@ -83,5 +83,10 @@ func UploadPermissions(owner string, path string) (bool, error) {
return false, nil return false, nil
} }
return AddPermissionsInBatch(newPermissions), nil affected, err := AddPermissionsInBatch(newPermissions)
if err != nil {
return false, err
}
return affected, nil
} }

View File

@ -151,8 +151,16 @@ func UpdateRole(id string, role *Role) (bool, error) {
} }
for _, permission := range permissions { for _, permission := range permissions {
addGroupingPolicies(permission) err = addGroupingPolicies(permission)
addPolicies(permission) if err != nil {
return false, err
}
err = addPolicies(permission)
if err != nil {
return false, err
}
visited[permission.GetId()] = struct{}{} visited[permission.GetId()] = struct{}{}
} }
@ -166,10 +174,15 @@ func UpdateRole(id string, role *Role) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
for _, permission := range permissions { for _, permission := range permissions {
permissionId := permission.GetId() permissionId := permission.GetId()
if _, ok := visited[permissionId]; !ok { if _, ok := visited[permissionId]; !ok {
addGroupingPolicies(permission) err = addGroupingPolicies(permission)
if err != nil {
return false, err
}
visited[permissionId] = struct{}{} visited[permissionId] = struct{}{}
} }
} }