diff --git a/object/permission_enforcer.go b/object/permission_enforcer.go index 0e958673..0ae9daaf 100644 --- a/object/permission_enforcer.go +++ b/object/permission_enforcer.go @@ -118,35 +118,53 @@ func getPolicies(permission *Permission) [][]string { return policies } +func getRolesInRole(roleId string, visited map[string]struct{}) []*Role { + role := GetRole(roleId) + if role == nil { + return []*Role{} + } + visited[roleId] = struct{}{} + + roles := []*Role{role} + for _, subRole := range role.Roles { + if _, ok := visited[subRole]; !ok { + roles = append(roles, getRolesInRole(subRole, visited)...) + } + } + + return roles +} + func getGroupingPolicies(permission *Permission) [][]string { var groupingPolicies [][]string domainExist := len(permission.Domains) > 0 permissionId := permission.GetId() - for _, role := range permission.Roles { - roleObj := GetRole(role) - if roleObj == nil { - continue - } + for _, roleId := range permission.Roles { + visited := map[string]struct{}{} + rolesInRole := getRolesInRole(roleId, visited) - for _, subUser := range roleObj.Users { - if domainExist { - for _, domain := range permission.Domains { - groupingPolicies = append(groupingPolicies, []string{subUser, role, domain, "", "", permissionId}) + for _, role := range rolesInRole { + roleId := role.GetId() + for _, subUser := range role.Users { + if domainExist { + for _, domain := range permission.Domains { + groupingPolicies = append(groupingPolicies, []string{subUser, roleId, domain, "", "", permissionId}) + } + } else { + groupingPolicies = append(groupingPolicies, []string{subUser, roleId, "", "", "", permissionId}) } - } else { - groupingPolicies = append(groupingPolicies, []string{subUser, role, "", "", "", permissionId}) } - } - for _, subRole := range roleObj.Roles { - if domainExist { - for _, domain := range permission.Domains { - groupingPolicies = append(groupingPolicies, []string{subRole, role, domain, "", "", permissionId}) + for _, subRole := range role.Roles { + if domainExist { + for _, domain := range permission.Domains { + groupingPolicies = append(groupingPolicies, []string{subRole, roleId, domain, "", "", permissionId}) + } + } else { + groupingPolicies = append(groupingPolicies, []string{subRole, roleId, "", "", "", permissionId}) } - } else { - groupingPolicies = append(groupingPolicies, []string{subRole, role, "", "", "", permissionId}) } } } diff --git a/object/role.go b/object/role.go index e544de6f..5d744166 100644 --- a/object/role.go +++ b/object/role.go @@ -94,10 +94,25 @@ func UpdateRole(id string, role *Role) bool { return false } + visited := map[string]struct{}{} + permissions := GetPermissionsByRole(id) for _, permission := range permissions { removeGroupingPolicies(permission) removePolicies(permission) + visited[permission.GetId()] = struct{}{} + } + + ancestorRoles := GetAncestorRoles(id) + for _, r := range ancestorRoles { + permissions := GetPermissionsByRole(r.GetId()) + for _, permission := range permissions { + permissionId := permission.GetId() + if _, ok := visited[permissionId]; !ok { + removeGroupingPolicies(permission) + visited[permissionId] = struct{}{} + } + } } if name != role.Name { @@ -112,11 +127,25 @@ func UpdateRole(id string, role *Role) bool { panic(err) } + visited = map[string]struct{}{} newRoleID := role.GetId() permissions = GetPermissionsByRole(newRoleID) for _, permission := range permissions { addGroupingPolicies(permission) addPolicies(permission) + visited[permission.GetId()] = struct{}{} + } + + ancestorRoles = GetAncestorRoles(newRoleID) + for _, r := range ancestorRoles { + permissions := GetPermissionsByRole(r.GetId()) + for _, permission := range permissions { + permissionId := permission.GetId() + if _, ok := visited[permissionId]; !ok { + addGroupingPolicies(permission) + visited[permissionId] = struct{}{} + } + } } return affected != 0 @@ -221,3 +250,64 @@ func GetMaskedRoles(roles []*Role) []*Role { return roles } + +func GetRolesByNamePrefix(owner string, prefix string) []*Role { + roles := []*Role{} + err := adapter.Engine.Where("owner=? and name like ?", owner, prefix+"%").Find(&roles) + if err != nil { + panic(err) + } + + return roles +} + +func GetAncestorRoles(roleId string) []*Role { + var ( + result []*Role + roleMap = make(map[string]*Role) + visited = make(map[string]bool) + ) + + owner, _ := util.GetOwnerAndNameFromIdNoCheck(roleId) + + allRoles := GetRoles(owner) + for _, r := range allRoles { + roleMap[r.GetId()] = r + } + + // Second, find all the roles that contain father roles + for _, r := range allRoles { + isContain, ok := visited[r.GetId()] + if isContain { + result = append(result, r) + } else if !ok { + rId := r.GetId() + visited[rId] = containsRole(r, roleId, roleMap, visited) + if visited[rId] { + result = append(result, r) + } + } + } + + return result +} + +// containsRole is a helper function to check if a slice of roles contains a specific roleId +func containsRole(role *Role, roleId string, roleMap map[string]*Role, visited map[string]bool) bool { + if isContain, ok := visited[role.GetId()]; ok { + return isContain + } + + for _, subRole := range role.Roles { + if subRole == roleId { + return true + } + + r, ok := roleMap[subRole] + if ok && containsRole(r, roleId, roleMap, visited) { + return true + } + } + + return false +}