feat: improve saml idp err message (#1584)

This commit is contained in:
Yaodong Yu
2023-02-24 21:20:57 +08:00
committed by GitHub
parent 910816c7a3
commit 3b6ec3e7c4

View File

@ -230,16 +230,20 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
// base64 decode
defated, err := base64.StdEncoding.DecodeString(samlRequest)
if err != nil {
return "", "", method, fmt.Errorf("err: %s", err.Error())
return "", "", method, fmt.Errorf("err: Failed to decode SAML request , %s", err.Error())
}
// decompress
var buffer bytes.Buffer
rdr := flate.NewReader(bytes.NewReader(defated))
io.Copy(&buffer, rdr)
_, err = io.Copy(&buffer, rdr)
if err != nil {
return "", "", "", err
}
var authnRequest saml.AuthnRequest
err = xml.Unmarshal(buffer.Bytes(), &authnRequest)
if err != nil {
return "", "", method, fmt.Errorf("err: %s", err.Error())
return "", "", method, fmt.Errorf("err: Failed to unmarshal AuthnRequest, please check the SAML request. %s", err.Error())
}
// verify samlRequest
@ -252,14 +256,15 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
block, _ := pem.Decode([]byte(cert.Certificate))
certificate := base64.StdEncoding.EncodeToString(block.Bytes)
_, originBackend := getOriginFromHost(host)
// redirect Url (Assertion Consumer Url)
if application.SamlReplyUrl != "" {
method = "POST"
authnRequest.AssertionConsumerServiceURL = application.SamlReplyUrl
} else if authnRequest.AssertionConsumerServiceURL == "" {
return "", "", "", fmt.Errorf("err: SAML request don't has attribute 'AssertionConsumerServiceURL' in <samlp:AuthnRequest>")
}
_, originBackend := getOriginFromHost(host)
// build signedResponse
samlResponse, _ := NewSamlResponse(user, originBackend, certificate, authnRequest.AssertionConsumerServiceURL, authnRequest.Issuer.Url, authnRequest.ID, application.RedirectUris)
randomKeyStore := &X509Key{
@ -279,7 +284,7 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
doc.SetRoot(samlResponse)
xmlBytes, err := doc.WriteToBytes()
if err != nil {
return "", "", method, fmt.Errorf("err: %s", err.Error())
return "", "", method, fmt.Errorf("err: Failed to serializes the SAML request into bytes, %s", err.Error())
}
// compress
@ -287,15 +292,21 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
flated := bytes.NewBuffer(nil)
writer, err := flate.NewWriter(flated, flate.DefaultCompression)
if err != nil {
return "", "", method, fmt.Errorf("err: %s", err.Error())
return "", "", method, err
}
_, err = writer.Write(xmlBytes)
if err != nil {
return "", "", "", err
}
err = writer.Close()
if err != nil {
return "", "", "", err
}
writer.Write(xmlBytes)
writer.Close()
xmlBytes = flated.Bytes()
}
// base64 encode
res := base64.StdEncoding.EncodeToString(xmlBytes)
return res, authnRequest.AssertionConsumerServiceURL, method, nil
return res, authnRequest.AssertionConsumerServiceURL, method, err
}
// NewSamlResponse11 return a saml1.1 response(not 2.0)
@ -359,7 +370,10 @@ func NewSamlResponse11(user *User, requestID string, host string) *etree.Element
data, _ := json.Marshal(user)
tmp := map[string]string{}
json.Unmarshal(data, &tmp)
err := json.Unmarshal(data, &tmp)
if err != nil {
panic(err)
}
for k, v := range tmp {
if v != "" {