From 370e835499ebc6e6dfeb533fd276be6d626032cd Mon Sep 17 00:00:00 2001 From: Yixiang Zhao Date: Wed, 15 Dec 2021 21:38:00 +0800 Subject: [PATCH] feat: support AuthnRequest in SAML (#372) Signed-off-by: Yixiang Zhao --- controllers/auth.go | 5 ++-- object/provider.go | 7 ++--- object/saml.go | 51 ++++++++++++++++++++++++++++--------- web/src/ProviderEditPage.js | 12 ++++++++- web/src/auth/AuthBackend.js | 4 +-- web/src/auth/LoginPage.js | 10 +++++--- 6 files changed, 66 insertions(+), 23 deletions(-) diff --git a/controllers/auth.go b/controllers/auth.go index 083216e3..186de111 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -364,11 +364,12 @@ func (c *ApiController) Login() { func (c *ApiController) GetSamlLogin() { providerId := c.Input().Get("id") - authURL, err := object.GenerateSamlLoginUrl(providerId) + relayState := c.Input().Get("relayState") + authURL, method, err := object.GenerateSamlLoginUrl(providerId, relayState) if err != nil { c.ResponseError(err.Error()) } - c.ResponseOk(authURL) + c.ResponseOk(authURL, method) } func (c *ApiController) HandleSamlLogin() { diff --git a/object/provider.go b/object/provider.go index edf949af..160c86d3 100644 --- a/object/provider.go +++ b/object/provider.go @@ -48,9 +48,10 @@ type Provider struct { Domain string `xorm:"varchar(100)" json:"domain"` Bucket string `xorm:"varchar(100)" json:"bucket"` - Metadata string `xorm:"mediumtext" json:"metadata"` - IdP string `xorm:"mediumtext" json:"idP"` - IssuerUrl string `xorm:"varchar(100)" json:"issuerUrl"` + Metadata string `xorm:"mediumtext" json:"metadata"` + IdP string `xorm:"mediumtext" json:"idP"` + IssuerUrl string `xorm:"varchar(100)" json:"issuerUrl"` + EnableSignAuthnRequest bool `json:"enableSignAuthnRequest"` ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"` } diff --git a/object/saml.go b/object/saml.go index 8505e9e3..535bbcf2 100644 --- a/object/saml.go +++ b/object/saml.go @@ -15,6 +15,7 @@ package object import ( + "crypto/tls" "crypto/x509" "encoding/base64" "fmt" @@ -40,20 +41,32 @@ func ParseSamlResponse(samlResponse string, providerType string) (string, error) return assertionInfo.NameID, nil } -func GenerateSamlLoginUrl(id string) (string, error) { +func GenerateSamlLoginUrl(id, relayState string) (string, string, error) { provider := GetProvider(id) if provider.Category != "SAML" { - return "", fmt.Errorf("Provider %s's category is not SAML", provider.Name) + return "", "", fmt.Errorf("Provider %s's category is not SAML", provider.Name) } sp, err := buildSp(provider, "") if err != nil { - return "", err + return "", "", err } - authURL, err := sp.BuildAuthURL("") - if err != nil { - return "", err + auth := "" + method := "" + if provider.EnableSignAuthnRequest { + post, err := sp.BuildAuthBodyPost(relayState) + if err != nil { + return "", "", err + } + auth = string(post[:]) + method = "POST" + } else { + auth, err = sp.BuildAuthURL(relayState) + if err != nil { + return "", "", err + } + method = "GET" } - return authURL, nil + return auth, method, nil } func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvider, error) { @@ -80,13 +93,16 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide ServiceProviderIssuer: fmt.Sprintf("%s/api/acs", origin), AssertionConsumerServiceURL: fmt.Sprintf("%s/api/acs", origin), IDPCertificateStore: &certStore, + SignAuthnRequests: false, + SPKeyStore: dsig.RandomKeyStoreForTest(), } if provider.Endpoint != "" { - randomKeyStore := dsig.RandomKeyStoreForTest() sp.IdentityProviderSSOURL = provider.Endpoint sp.IdentityProviderIssuer = provider.IssuerUrl - sp.SignAuthnRequests = false - sp.SPKeyStore = randomKeyStore + } + if provider.EnableSignAuthnRequest { + sp.SignAuthnRequests = true + sp.SPKeyStore = buildSpKeyStore() } return sp, nil } @@ -99,10 +115,21 @@ func parseSamlResponse(samlResponse string, providerType string) string { deStr := strings.Replace(string(de), "\n", "", -1) tagMap := map[string]string{ "Aliyun IDaaS": "ds", - "Keycloak": "dsig", + "Keycloak": "dsig", } tag := tagMap[providerType] expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)", tag, tag) res := regexp.MustCompile(expression).FindStringSubmatch(deStr) return res[1] -} \ No newline at end of file +} + +func buildSpKeyStore() dsig.X509KeyStore { + keyPair, err := tls.LoadX509KeyPair("object/token_jwt_key.pem", "object/token_jwt_key.key") + if err != nil { + panic(err) + } + return &dsig.TLSCertKeyStore { + PrivateKey: keyPair.PrivateKey, + Certificate: keyPair.Certificate, + } +} diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js index 5c185de0..05b04d60 100644 --- a/web/src/ProviderEditPage.js +++ b/web/src/ProviderEditPage.js @@ -13,7 +13,7 @@ // limitations under the License. import React from "react"; -import {Button, Card, Col, Input, InputNumber, Row, Select} from 'antd'; +import {Button, Card, Col, Input, InputNumber, Row, Select, Switch} from 'antd'; import {LinkOutlined} from "@ant-design/icons"; import * as ProviderBackend from "./backend/ProviderBackend"; import * as Setting from "./Setting"; @@ -418,6 +418,16 @@ class ProviderEditPage extends React.Component { ) : this.state.provider.category === "SAML" ? ( + + + {Setting.getLabel(i18next.t("provider:Sign request"), i18next.t("provider:Sign request - Tooltip"))} : + + + { + this.updateProviderField('enableSignAuthnRequest', checked); + }} /> + + {Setting.getLabel(i18next.t("provider:Metadata"), i18next.t("provider:Metadata - Tooltip"))} : diff --git a/web/src/auth/AuthBackend.js b/web/src/auth/AuthBackend.js index 137dd68e..16c2a54e 100644 --- a/web/src/auth/AuthBackend.js +++ b/web/src/auth/AuthBackend.js @@ -77,8 +77,8 @@ export function unlink(values) { }).then(res => res.json()); } -export function getSamlLogin(providerId) { - return fetch(`${authConfig.serverUrl}/api/get-saml-login?id=${providerId}`, { +export function getSamlLogin(providerId, relayState) { + return fetch(`${authConfig.serverUrl}/api/get-saml-login?id=${providerId}&relayState=${relayState}`, { method: 'GET', credentials: 'include', }).then(res => res.json()); diff --git a/web/src/auth/LoginPage.js b/web/src/auth/LoginPage.js index 731b0be4..77fbb9dc 100644 --- a/web/src/auth/LoginPage.js +++ b/web/src/auth/LoginPage.js @@ -201,9 +201,13 @@ class LoginPage extends React.Component { let realRedirectUri = params.get("redirect_uri"); let redirectUri = `${window.location.origin}/callback/saml`; let providerName = provider.name; - AuthBackend.getSamlLogin(`${provider.owner}/${providerName}`).then((res) => { - const replyState = `${clientId}&${application}&${providerName}&${realRedirectUri}&${redirectUri}`; - window.location.href = `${res.data}&RelayState=${btoa(replyState)}`; + let relayState = `${clientId}&${application}&${providerName}&${realRedirectUri}&${redirectUri}`; + AuthBackend.getSamlLogin(`${provider.owner}/${providerName}`, btoa(relayState)).then((res) => { + if (res.data2 === "POST") { + document.write(res.data) + } else { + window.location.href = res.data + } }); }