diff --git a/controllers/auth.go b/controllers/auth.go index d7ef5b7f..083216e3 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -199,7 +199,7 @@ func (c *ApiController) Login() { userInfo := &idp.UserInfo{} if provider.Category == "SAML" { // SAML - userInfo.Id, err = object.ParseSamlResponse(form.SamlResponse) + userInfo.Id, err = object.ParseSamlResponse(form.SamlResponse, provider.Type) if err != nil { c.ResponseError(err.Error()) return @@ -241,7 +241,7 @@ func (c *ApiController) Login() { if form.Method == "signup" { user := &object.User{} if provider.Category == "SAML" { - user = object.GetUserByField(application.Organization, "id", userInfo.Id) + user = object.GetUser(fmt.Sprintf("%s/%s", application.Organization, userInfo.Id)) } else if provider.Category == "OAuth" { user = object.GetUserByField(application.Organization, provider.Type, userInfo.Id) if user == nil { diff --git a/object/saml.go b/object/saml.go index c932ef44..8505e9e3 100644 --- a/object/saml.go +++ b/object/saml.go @@ -27,9 +27,9 @@ import ( dsig "github.com/russellhaering/goxmldsig" ) -func ParseSamlResponse(samlResponse string) (string, error) { +func ParseSamlResponse(samlResponse string, providerType string) (string, error) { samlResponse, _ = url.QueryUnescape(samlResponse) - sp, err := buildSp(nil, samlResponse) + sp, err := buildSp(&Provider{Type: providerType}, samlResponse) if err != nil { return "", err } @@ -63,15 +63,8 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide origin := beego.AppConfig.String("origin") certEncodedData := "" if samlResponse != "" { - de, err := base64.StdEncoding.DecodeString(samlResponse) - if err != nil { - panic(err) - } - deStr := strings.Replace(string(de), "\n", "", -1) - res := regexp.MustCompile(`(.*?)`).FindAllStringSubmatch(deStr, -1) - str := res[0][0] - certEncodedData = str[20 : len(str)-21] - } else if provider != nil { + certEncodedData = parseSamlResponse(samlResponse, provider.Type) + } else if provider.IdP != "" { certEncodedData = provider.IdP } certData, err := base64.StdEncoding.DecodeString(certEncodedData) @@ -88,7 +81,7 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide AssertionConsumerServiceURL: fmt.Sprintf("%s/api/acs", origin), IDPCertificateStore: &certStore, } - if provider != nil { + if provider.Endpoint != "" { randomKeyStore := dsig.RandomKeyStoreForTest() sp.IdentityProviderSSOURL = provider.Endpoint sp.IdentityProviderIssuer = provider.IssuerUrl @@ -97,3 +90,19 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide } return sp, nil } + +func parseSamlResponse(samlResponse string, providerType string) string { + de, err := base64.StdEncoding.DecodeString(samlResponse) + if err != nil { + panic(err) + } + deStr := strings.Replace(string(de), "\n", "", -1) + tagMap := map[string]string{ + "Aliyun IDaaS": "ds", + "Keycloak": "dsig", + } + tag := tagMap[providerType] + expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)", tag, tag) + res := regexp.MustCompile(expression).FindStringSubmatch(deStr) + return res[1] +} \ No newline at end of file diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js index 10b82aa9..5c185de0 100644 --- a/web/src/ProviderEditPage.js +++ b/web/src/ProviderEditPage.js @@ -111,6 +111,7 @@ class ProviderEditPage extends React.Component { } else if (provider.category === "SAML") { return ([ {id: 'Aliyun IDaaS', name: 'Aliyun IDaaS'}, + {id: 'Keycloak', name: 'Keycloak'}, ]); } else { return []; diff --git a/web/src/Setting.js b/web/src/Setting.js index ea89a4e6..7ee66ac7 100644 --- a/web/src/Setting.js +++ b/web/src/Setting.js @@ -375,7 +375,7 @@ export function getClickable(text) { } export function getProviderLogo(provider) { - const idp = provider.type.toLowerCase(); + const idp = provider.type.toLowerCase().trim().split(' ')[0]; const url = `${StaticBaseUrl}/img/social_${idp}.png`; return ( {idp} diff --git a/web/src/UserEditPage.js b/web/src/UserEditPage.js index 885f224c..a2e919a7 100644 --- a/web/src/UserEditPage.js +++ b/web/src/UserEditPage.js @@ -29,6 +29,7 @@ import SelectRegionBox from "./SelectRegionBox"; import {Controlled as CodeMirror} from 'react-codemirror2'; import "codemirror/lib/codemirror.css"; +import SamlWidget from "./common/SamlWidget"; require('codemirror/theme/material-darker.css'); require("codemirror/mode/javascript/javascript"); @@ -302,7 +303,13 @@ class UserEditPage extends React.Component {
{ (this.state.application === null || this.state.user === null) ? null : ( - this.state.application?.providers.filter(providerItem => Setting.isProviderVisible(providerItem)).map((providerItem, index) => { return this.unlinked()}} />) + this.state.application?.providers.filter(providerItem => Setting.isProviderVisible(providerItem)).map((providerItem, index) => + (providerItem.category === "OAuth") ? ( + { return this.unlinked()}} /> + ) : ( + { return this.unlinked()}} /> + ) + ) ) }
diff --git a/web/src/auth/LoginPage.js b/web/src/auth/LoginPage.js index 4ffaa0ab..731b0be4 100644 --- a/web/src/auth/LoginPage.js +++ b/web/src/auth/LoginPage.js @@ -194,14 +194,14 @@ class LoginPage extends React.Component { return text; } - getSamlUrl(providerId) { + getSamlUrl(provider) { const params = new URLSearchParams(this.props.location.search); - let clientId = params.get("client_id") + let clientId = params.get("client_id"); let application = params.get("state"); let realRedirectUri = params.get("redirect_uri"); - let redirectUri = `${window.location.origin}/callback/saml` - let providerName = providerId.split('/')[1]; - AuthBackend.getSamlLogin(providerId).then((res) => { + let redirectUri = `${window.location.origin}/callback/saml`; + let providerName = provider.name; + AuthBackend.getSamlLogin(`${provider.owner}/${providerName}`).then((res) => { const replyState = `${clientId}&${application}&${providerName}&${realRedirectUri}&${redirectUri}`; window.location.href = `${res.data}&RelayState=${btoa(replyState)}`; }); @@ -217,7 +217,7 @@ class LoginPage extends React.Component { ) } else if (provider.category === "SAML") { return ( - + {provider.displayName} ) diff --git a/web/src/auth/Provider.js b/web/src/auth/Provider.js index dc24223f..cb281364 100644 --- a/web/src/auth/Provider.js +++ b/web/src/auth/Provider.js @@ -125,6 +125,10 @@ const otherProviderInfo = { logo: `${StaticBaseUrl}/img/social_aliyun.png`, url: "https://aliyun.com/product/idaas" }, + "Keycloak": { + logo: `${StaticBaseUrl}/img/social_keycloak.png`, + url: "https://www.keycloak.org/" + }, }, }; diff --git a/web/src/common/SamlWidget.js b/web/src/common/SamlWidget.js new file mode 100644 index 00000000..f96083d4 --- /dev/null +++ b/web/src/common/SamlWidget.js @@ -0,0 +1,45 @@ +import React from "react"; +import {Col, Row} from "antd"; +import * as Setting from "../Setting"; + +class SamlWidget extends React.Component { + constructor(props) { + super(props); + this.state = { + classes: props, + addressOptions: [], + affiliationOptions: [], + }; + } + + renderIdp(user, application, providerItem) { + const provider = providerItem.provider; + const name = user.name; + + return ( + + + { + Setting.getProviderLogo(provider) + } + + { + `${provider.type}:` + } + + + + {name} + + + ) + } + + render() { + return this.renderIdp(this.props.user, this.props.application, this.props.providerItem) + } +} + +export default SamlWidget; \ No newline at end of file