diff --git a/object/saml_idp.go b/object/saml_idp.go index bd35215d..6253b7f9 100644 --- a/object/saml_idp.go +++ b/object/saml_idp.go @@ -70,7 +70,9 @@ func NewSamlResponse(application *Application, user *User, host string, certific if application.UseEmailAsSamlNameId { nameIDValue = user.Email } - subject.CreateElement("saml:NameID").SetText(nameIDValue) + nameId := subject.CreateElement("saml:NameID") + nameId.CreateAttr("Format", "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent") + nameId.SetText(nameIDValue) subjectConfirmation := subject.CreateElement("saml:SubjectConfirmation") subjectConfirmation.CreateAttr("Method", "urn:oasis:names:tc:SAML:2.0:cm:bearer") subjectConfirmationData := subjectConfirmation.CreateElement("saml:SubjectConfirmationData") @@ -108,20 +110,46 @@ func NewSamlResponse(application *Application, user *User, host string, certific displayName.CreateAttr("NameFormat", "urn:oasis:names:tc:SAML:2.0:attrname-format:basic") displayName.CreateElement("saml:AttributeValue").CreateAttr("xsi:type", "xs:string").Element().SetText(user.DisplayName) + err := ExtendUserWithRolesAndPermissions(user) + if err != nil { + return nil, err + } + for _, item := range application.SamlAttributes { role := attributes.CreateElement("saml:Attribute") role.CreateAttr("Name", item.Name) role.CreateAttr("NameFormat", item.NameFormat) - role.CreateElement("saml:AttributeValue").CreateAttr("xsi:type", "xs:string").Element().SetText(item.Value) + + valueList := []string{item.Value} + if strings.Contains(item.Value, "$user.roles") { + valueList = replaceSamlAttributeValuesWithList("$user.roles", getUserRoleNames(user), valueList) + } + + if strings.Contains(item.Value, "$user.permissions") { + valueList = replaceSamlAttributeValuesWithList("$user.permissions", getUserPermissionNames(user), valueList) + } + + if strings.Contains(item.Value, "$user.groups") { + valueList = replaceSamlAttributeValuesWithList("$user.groups", user.Groups, valueList) + } + + valueList = replaceSamlAttributeValues("$user.owner", user.Owner, valueList) + valueList = replaceSamlAttributeValues("$user.name", user.Name, valueList) + valueList = replaceSamlAttributeValues("$user.email", user.Email, valueList) + valueList = replaceSamlAttributeValues("$user.id", user.Id, valueList) + valueList = replaceSamlAttributeValues("$user.phone", user.Phone, valueList) + + for _, value := range valueList { + av := role.CreateElement("saml:AttributeValue") + av.CreateAttr("xmlns:xs", "http://www.w3.org/2001/XMLSchema") + av.CreateAttr("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") + av.CreateAttr("xsi:type", "xs:string").Element().SetText(value) + } } roles := attributes.CreateElement("saml:Attribute") roles.CreateAttr("Name", "Roles") roles.CreateAttr("NameFormat", "urn:oasis:names:tc:SAML:2.0:attrname-format:basic") - err := ExtendUserWithRolesAndPermissions(user) - if err != nil { - return nil, err - } for _, role := range user.Roles { roles.CreateElement("saml:AttributeValue").CreateAttr("xsi:type", "xs:string").Element().SetText(role.Name) @@ -130,6 +158,26 @@ func NewSamlResponse(application *Application, user *User, host string, certific return samlResponse, nil } +func replaceSamlAttributeValues(val string, replaceVal string, values []string) []string { + newValues := []string{} + for _, value := range values { + newValues = append(newValues, strings.ReplaceAll(value, val, replaceVal)) + } + + return newValues +} + +func replaceSamlAttributeValuesWithList(val string, replaceVals []string, values []string) []string { + newValues := []string{} + for _, value := range values { + for _, rVal := range replaceVals { + newValues = append(newValues, strings.ReplaceAll(value, val, rVal)) + } + } + + return newValues +} + type X509Key struct { X509Certificate string PrivateKey string diff --git a/object/user_util.go b/object/user_util.go index a33df0b4..1ed926f7 100644 --- a/object/user_util.go +++ b/object/user_util.go @@ -248,6 +248,20 @@ func SetUserOAuthProperties(organization *Organization, user *User, providerType return UpdateUserForAllFields(user.GetId(), user) } +func getUserRoleNames(user *User) (res []string) { + for _, role := range user.Roles { + res = append(res, role.Name) + } + return res +} + +func getUserPermissionNames(user *User) (res []string) { + for _, permission := range user.Permissions { + res = append(res, permission.Name) + } + return res +} + func ClearUserOAuthProperties(user *User, providerType string) (bool, error) { for k := range user.Properties { prefix := fmt.Sprintf("oauth_%s_", providerType)