diff --git a/controllers/auth.go b/controllers/auth.go
index d7ef5b7f..083216e3 100644
--- a/controllers/auth.go
+++ b/controllers/auth.go
@@ -199,7 +199,7 @@ func (c *ApiController) Login() {
userInfo := &idp.UserInfo{}
if provider.Category == "SAML" {
// SAML
- userInfo.Id, err = object.ParseSamlResponse(form.SamlResponse)
+ userInfo.Id, err = object.ParseSamlResponse(form.SamlResponse, provider.Type)
if err != nil {
c.ResponseError(err.Error())
return
@@ -241,7 +241,7 @@ func (c *ApiController) Login() {
if form.Method == "signup" {
user := &object.User{}
if provider.Category == "SAML" {
- user = object.GetUserByField(application.Organization, "id", userInfo.Id)
+ user = object.GetUser(fmt.Sprintf("%s/%s", application.Organization, userInfo.Id))
} else if provider.Category == "OAuth" {
user = object.GetUserByField(application.Organization, provider.Type, userInfo.Id)
if user == nil {
diff --git a/object/saml.go b/object/saml.go
index c932ef44..8505e9e3 100644
--- a/object/saml.go
+++ b/object/saml.go
@@ -27,9 +27,9 @@ import (
dsig "github.com/russellhaering/goxmldsig"
)
-func ParseSamlResponse(samlResponse string) (string, error) {
+func ParseSamlResponse(samlResponse string, providerType string) (string, error) {
samlResponse, _ = url.QueryUnescape(samlResponse)
- sp, err := buildSp(nil, samlResponse)
+ sp, err := buildSp(&Provider{Type: providerType}, samlResponse)
if err != nil {
return "", err
}
@@ -63,15 +63,8 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide
origin := beego.AppConfig.String("origin")
certEncodedData := ""
if samlResponse != "" {
- de, err := base64.StdEncoding.DecodeString(samlResponse)
- if err != nil {
- panic(err)
- }
- deStr := strings.Replace(string(de), "\n", "", -1)
- res := regexp.MustCompile(`(.*?)`).FindAllStringSubmatch(deStr, -1)
- str := res[0][0]
- certEncodedData = str[20 : len(str)-21]
- } else if provider != nil {
+ certEncodedData = parseSamlResponse(samlResponse, provider.Type)
+ } else if provider.IdP != "" {
certEncodedData = provider.IdP
}
certData, err := base64.StdEncoding.DecodeString(certEncodedData)
@@ -88,7 +81,7 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide
AssertionConsumerServiceURL: fmt.Sprintf("%s/api/acs", origin),
IDPCertificateStore: &certStore,
}
- if provider != nil {
+ if provider.Endpoint != "" {
randomKeyStore := dsig.RandomKeyStoreForTest()
sp.IdentityProviderSSOURL = provider.Endpoint
sp.IdentityProviderIssuer = provider.IssuerUrl
@@ -97,3 +90,19 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide
}
return sp, nil
}
+
+func parseSamlResponse(samlResponse string, providerType string) string {
+ de, err := base64.StdEncoding.DecodeString(samlResponse)
+ if err != nil {
+ panic(err)
+ }
+ deStr := strings.Replace(string(de), "\n", "", -1)
+ tagMap := map[string]string{
+ "Aliyun IDaaS": "ds",
+ "Keycloak": "dsig",
+ }
+ tag := tagMap[providerType]
+ expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)%s:X509Certificate>", tag, tag)
+ res := regexp.MustCompile(expression).FindStringSubmatch(deStr)
+ return res[1]
+}
\ No newline at end of file
diff --git a/web/src/ProviderEditPage.js b/web/src/ProviderEditPage.js
index 10b82aa9..5c185de0 100644
--- a/web/src/ProviderEditPage.js
+++ b/web/src/ProviderEditPage.js
@@ -111,6 +111,7 @@ class ProviderEditPage extends React.Component {
} else if (provider.category === "SAML") {
return ([
{id: 'Aliyun IDaaS', name: 'Aliyun IDaaS'},
+ {id: 'Keycloak', name: 'Keycloak'},
]);
} else {
return [];
diff --git a/web/src/Setting.js b/web/src/Setting.js
index ea89a4e6..7ee66ac7 100644
--- a/web/src/Setting.js
+++ b/web/src/Setting.js
@@ -375,7 +375,7 @@ export function getClickable(text) {
}
export function getProviderLogo(provider) {
- const idp = provider.type.toLowerCase();
+ const idp = provider.type.toLowerCase().trim().split(' ')[0];
const url = `${StaticBaseUrl}/img/social_${idp}.png`;
return (
diff --git a/web/src/UserEditPage.js b/web/src/UserEditPage.js
index 885f224c..a2e919a7 100644
--- a/web/src/UserEditPage.js
+++ b/web/src/UserEditPage.js
@@ -29,6 +29,7 @@ import SelectRegionBox from "./SelectRegionBox";
import {Controlled as CodeMirror} from 'react-codemirror2';
import "codemirror/lib/codemirror.css";
+import SamlWidget from "./common/SamlWidget";
require('codemirror/theme/material-darker.css');
require("codemirror/mode/javascript/javascript");
@@ -302,7 +303,13 @@ class UserEditPage extends React.Component {
{
(this.state.application === null || this.state.user === null) ? null : (
- this.state.application?.providers.filter(providerItem => Setting.isProviderVisible(providerItem)).map((providerItem, index) => { return this.unlinked()}} />)
+ this.state.application?.providers.filter(providerItem => Setting.isProviderVisible(providerItem)).map((providerItem, index) =>
+ (providerItem.category === "OAuth") ? (
+ { return this.unlinked()}} />
+ ) : (
+ { return this.unlinked()}} />
+ )
+ )
)
}
diff --git a/web/src/auth/LoginPage.js b/web/src/auth/LoginPage.js
index 4ffaa0ab..731b0be4 100644
--- a/web/src/auth/LoginPage.js
+++ b/web/src/auth/LoginPage.js
@@ -194,14 +194,14 @@ class LoginPage extends React.Component {
return text;
}
- getSamlUrl(providerId) {
+ getSamlUrl(provider) {
const params = new URLSearchParams(this.props.location.search);
- let clientId = params.get("client_id")
+ let clientId = params.get("client_id");
let application = params.get("state");
let realRedirectUri = params.get("redirect_uri");
- let redirectUri = `${window.location.origin}/callback/saml`
- let providerName = providerId.split('/')[1];
- AuthBackend.getSamlLogin(providerId).then((res) => {
+ let redirectUri = `${window.location.origin}/callback/saml`;
+ let providerName = provider.name;
+ AuthBackend.getSamlLogin(`${provider.owner}/${providerName}`).then((res) => {
const replyState = `${clientId}&${application}&${providerName}&${realRedirectUri}&${redirectUri}`;
window.location.href = `${res.data}&RelayState=${btoa(replyState)}`;
});
@@ -217,7 +217,7 @@ class LoginPage extends React.Component {
)
} else if (provider.category === "SAML") {
return (
-
+
)
diff --git a/web/src/auth/Provider.js b/web/src/auth/Provider.js
index dc24223f..cb281364 100644
--- a/web/src/auth/Provider.js
+++ b/web/src/auth/Provider.js
@@ -125,6 +125,10 @@ const otherProviderInfo = {
logo: `${StaticBaseUrl}/img/social_aliyun.png`,
url: "https://aliyun.com/product/idaas"
},
+ "Keycloak": {
+ logo: `${StaticBaseUrl}/img/social_keycloak.png`,
+ url: "https://www.keycloak.org/"
+ },
},
};
diff --git a/web/src/common/SamlWidget.js b/web/src/common/SamlWidget.js
new file mode 100644
index 00000000..f96083d4
--- /dev/null
+++ b/web/src/common/SamlWidget.js
@@ -0,0 +1,45 @@
+import React from "react";
+import {Col, Row} from "antd";
+import * as Setting from "../Setting";
+
+class SamlWidget extends React.Component {
+ constructor(props) {
+ super(props);
+ this.state = {
+ classes: props,
+ addressOptions: [],
+ affiliationOptions: [],
+ };
+ }
+
+ renderIdp(user, application, providerItem) {
+ const provider = providerItem.provider;
+ const name = user.name;
+
+ return (
+
+
+ {
+ Setting.getProviderLogo(provider)
+ }
+
+ {
+ `${provider.type}:`
+ }
+
+
+
+ {name}
+
+
+ )
+ }
+
+ render() {
+ return this.renderIdp(this.props.user, this.props.application, this.props.providerItem)
+ }
+}
+
+export default SamlWidget;
\ No newline at end of file