feat: support SAML Custom provider (#2430)

* 111

* feat: support custom saml provider

* feat: gofumpt code

* feat: gofumpt code

* feat: remove comment

---------

Co-authored-by: hsluoyz <hsluoyz@qq.com>
This commit is contained in:
haiwu 2023-10-20 21:11:36 +08:00 committed by GitHub
parent 9960b4933b
commit b68e291f37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 113 additions and 64 deletions

View File

@ -477,11 +477,10 @@ func (c *ApiController) Login() {
c.ResponseError(fmt.Sprintf(c.T("auth:The provider: %s is not enabled for the application"), provider.Name)) c.ResponseError(fmt.Sprintf(c.T("auth:The provider: %s is not enabled for the application"), provider.Name))
return return
} }
userInfo := &idp.UserInfo{} userInfo := &idp.UserInfo{}
if provider.Category == "SAML" { if provider.Category == "SAML" {
// SAML // SAML
userInfo.Id, err = object.ParseSamlResponse(authForm.SamlResponse, provider, c.Ctx.Request.Host) userInfo, err = object.ParseSamlResponse(authForm.SamlResponse, provider, c.Ctx.Request.Host)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
@ -524,7 +523,8 @@ func (c *ApiController) Login() {
if authForm.Method == "signup" { if authForm.Method == "signup" {
user := &object.User{} user := &object.User{}
if provider.Category == "SAML" { if provider.Category == "SAML" {
user, err = object.GetUser(util.GetId(application.Organization, userInfo.Id)) // The userInfo.Id is the NameID in SAML response, it could be name / email / phone
user, err = object.GetUserByFields(application.Organization, userInfo.Id)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
@ -679,6 +679,7 @@ func (c *ApiController) Login() {
record2.User = user.Name record2.User = user.Name
util.SafeGoroutine(func() { object.AddRecord(record2) }) util.SafeGoroutine(func() { object.AddRecord(record2) })
} else if provider.Category == "SAML" { } else if provider.Category == "SAML" {
// TODO: since we get the user info from SAML response, we can try to create the user
resp = &Response{Status: "error", Msg: fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(application.Organization, userInfo.Id))} resp = &Response{Status: "error", Msg: fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(application.Organization, userInfo.Id))}
} }
// resp = &Response{Status: "ok", Msg: "", Data: res} // resp = &Response{Status: "ok", Msg: "", Data: res}

View File

@ -23,23 +23,49 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/casdoor/casdoor/idp"
"github.com/mitchellh/mapstructure"
"github.com/casdoor/casdoor/i18n" "github.com/casdoor/casdoor/i18n"
saml2 "github.com/russellhaering/gosaml2" saml2 "github.com/russellhaering/gosaml2"
dsig "github.com/russellhaering/goxmldsig" dsig "github.com/russellhaering/goxmldsig"
) )
func ParseSamlResponse(samlResponse string, provider *Provider, host string) (string, error) { func ParseSamlResponse(samlResponse string, provider *Provider, host string) (*idp.UserInfo, error) {
samlResponse, _ = url.QueryUnescape(samlResponse) samlResponse, _ = url.QueryUnescape(samlResponse)
sp, err := buildSp(provider, samlResponse, host) sp, err := buildSp(provider, samlResponse, host)
if err != nil { if err != nil {
return "", err return nil, err
} }
assertionInfo, err := sp.RetrieveAssertionInfo(samlResponse) assertionInfo, err := sp.RetrieveAssertionInfo(samlResponse)
if err != nil { if err != nil {
return "", err return nil, err
} }
return assertionInfo.NameID, err
userInfoMap := make(map[string]string)
for spAttr, idpAttr := range provider.UserMapping {
for _, attr := range assertionInfo.Values {
if attr.Name == idpAttr {
userInfoMap[spAttr] = attr.Values[0].Value
}
}
}
userInfoMap["id"] = assertionInfo.NameID
customUserInfo := &idp.CustomUserInfo{}
err = mapstructure.Decode(userInfoMap, customUserInfo)
if err != nil {
return nil, err
}
userInfo := &idp.UserInfo{
Id: customUserInfo.Id,
Username: customUserInfo.Username,
DisplayName: customUserInfo.DisplayName,
Email: customUserInfo.Email,
AvatarUrl: customUserInfo.AvatarUrl,
}
return userInfo, err
} }
func GenerateSamlRequest(id, relayState, host, lang string) (auth string, method string, err error) { func GenerateSamlRequest(id, relayState, host, lang string) (auth string, method string, err error) {
@ -146,14 +172,24 @@ func getCertificateFromSamlResponse(samlResponse string, providerType string) (s
if err != nil { if err != nil {
return "", err return "", err
} }
var (
deStr := strings.Replace(string(de), "\n", "", -1) expression string
tagMap := map[string]string{ deStr = strings.Replace(string(de), "\n", "", -1)
tagMap = map[string]string{
"Aliyun IDaaS": "ds", "Aliyun IDaaS": "ds",
"Keycloak": "dsig", "Keycloak": "dsig",
} }
)
tag := tagMap[providerType] tag := tagMap[providerType]
expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)</%s:X509Certificate>", tag, tag) if tag == "" {
// <ds:X509Certificate>...</ds:X509Certificate>
// <dsig:X509Certificate>...</dsig:X509Certificate>
// <X509Certificate>...</X509Certificate>
// ...
expression = "<[^>]*:?X509Certificate>([\\s\\S]*?)<[^>]*:?X509Certificate>"
} else {
expression = fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)</%s:X509Certificate>", tag, tag)
}
res := regexp.MustCompile(expression).FindStringSubmatch(deStr) res := regexp.MustCompile(expression).FindStringSubmatch(deStr)
return res[1], nil return res[1], nil
} }

View File

@ -379,10 +379,11 @@ class ProviderEditPage extends React.Component {
loadSamlConfiguration() { loadSamlConfiguration() {
const parser = new DOMParser(); const parser = new DOMParser();
const xmlDoc = parser.parseFromString(this.state.provider.metadata, "text/xml"); const rawXml = this.state.provider.metadata.replace("\n", "");
const cert = xmlDoc.getElementsByTagName("ds:X509Certificate")[0].childNodes[0].nodeValue; const xmlDoc = parser.parseFromString(rawXml, "text/xml");
const endpoint = xmlDoc.getElementsByTagName("md:SingleSignOnService")[0].getAttribute("Location"); const cert = xmlDoc.querySelector("X509Certificate").childNodes[0].nodeValue.replace(" ", "");
const issuerUrl = xmlDoc.getElementsByTagName("md:EntityDescriptor")[0].getAttribute("entityID"); const endpoint = xmlDoc.querySelector("SingleSignOnService").getAttribute("Location");
const issuerUrl = xmlDoc.querySelector("EntityDescriptor").getAttribute("entityID");
this.updateProviderField("idP", cert); this.updateProviderField("idP", cert);
this.updateProviderField("endpoint", endpoint); this.updateProviderField("endpoint", endpoint);
this.updateProviderField("issuerUrl", issuerUrl); this.updateProviderField("issuerUrl", issuerUrl);
@ -491,7 +492,7 @@ class ProviderEditPage extends React.Component {
this.updateProviderField("type", value); this.updateProviderField("type", value);
if (value === "Local File System") { if (value === "Local File System") {
this.updateProviderField("domain", Setting.getFullServerUrl()); this.updateProviderField("domain", Setting.getFullServerUrl());
} else if (value === "Custom") { } else if (value === "Custom" && this.state.provider.category === "OAuth") {
this.updateProviderField("customAuthUrl", "https://door.casdoor.com/login/oauth/authorize"); this.updateProviderField("customAuthUrl", "https://door.casdoor.com/login/oauth/authorize");
this.updateProviderField("scopes", "openid profile email"); this.updateProviderField("scopes", "openid profile email");
this.updateProviderField("customTokenUrl", "https://door.casdoor.com/api/login/oauth/access_token"); this.updateProviderField("customTokenUrl", "https://door.casdoor.com/api/login/oauth/access_token");
@ -553,8 +554,11 @@ class ProviderEditPage extends React.Component {
) )
} }
{ {
this.state.provider.type !== "Custom" ? null : ( this.state.provider.type === "Custom" ? (
<React.Fragment> <React.Fragment>
{
this.state.provider.category === "OAuth" ? (
<Col>
<Row style={{marginTop: "20px"}} > <Row style={{marginTop: "20px"}} >
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{Setting.getLabel(i18next.t("provider:Auth URL"), i18next.t("provider:Auth URL - Tooltip"))} {Setting.getLabel(i18next.t("provider:Auth URL"), i18next.t("provider:Auth URL - Tooltip"))}
@ -595,6 +599,9 @@ class ProviderEditPage extends React.Component {
}} /> }} />
</Col> </Col>
</Row> </Row>
</Col>
) : null
}
<Row style={{marginTop: "20px"}} > <Row style={{marginTop: "20px"}} >
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}> <Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
{Setting.getLabel(i18next.t("provider:User mapping"), i18next.t("provider:User mapping - Tooltip"))} : {Setting.getLabel(i18next.t("provider:User mapping"), i18next.t("provider:User mapping - Tooltip"))} :
@ -631,7 +638,7 @@ class ProviderEditPage extends React.Component {
</Col> </Col>
</Row> </Row>
</React.Fragment> </React.Fragment>
) ) : null
} }
{ {
(this.state.provider.category === "Captcha" && this.state.provider.type === "Default") || (this.state.provider.category === "Captcha" && this.state.provider.type === "Default") ||

View File

@ -209,6 +209,10 @@ export const OtherProviderInfo = {
logo: `${StaticBaseUrl}/img/social_keycloak.png`, logo: `${StaticBaseUrl}/img/social_keycloak.png`,
url: "https://www.keycloak.org/", url: "https://www.keycloak.org/",
}, },
"Custom": {
logo: `${StaticBaseUrl}/img/social_custom.png`,
url: "https://door.casdoor.com/",
},
}, },
Payment: { Payment: {
"Dummy": { "Dummy": {
@ -866,10 +870,10 @@ export function getClickable(text) {
} }
export function getProviderLogoURL(provider) { export function getProviderLogoURL(provider) {
if (provider.category === "OAuth") {
if (provider.type === "Custom" && provider.customLogo) { if (provider.type === "Custom" && provider.customLogo) {
return provider.customLogo; return provider.customLogo;
} }
if (provider.category === "OAuth") {
return `${StaticBaseUrl}/img/social_${provider.type.toLowerCase()}.png`; return `${StaticBaseUrl}/img/social_${provider.type.toLowerCase()}.png`;
} else { } else {
const info = OtherProviderInfo[provider.category][provider.type]; const info = OtherProviderInfo[provider.category][provider.type];
@ -1014,6 +1018,7 @@ export function getProviderTypeOptions(category) {
return ([ return ([
{id: "Aliyun IDaaS", name: "Aliyun IDaaS"}, {id: "Aliyun IDaaS", name: "Aliyun IDaaS"},
{id: "Keycloak", name: "Keycloak"}, {id: "Keycloak", name: "Keycloak"},
{id: "Custom", name: "Custom"},
]); ]);
} else if (category === "Payment") { } else if (category === "Payment") {
return ([ return ([