Improve error handling for GetSamlResponse()

This commit is contained in:
Yang Luo
2024-03-08 02:17:50 +08:00
parent c532a5d54d
commit 49fb269170
2 changed files with 39 additions and 19 deletions

View File

@ -273,7 +273,7 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
// base64 decode
defated, err := base64.StdEncoding.DecodeString(samlRequest)
if err != nil {
return "", "", method, fmt.Errorf("err: Failed to decode SAML request , %s", err.Error())
return "", "", "", fmt.Errorf("err: Failed to decode SAML request, %s", err.Error())
}
// decompress
@ -281,7 +281,7 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
rdr := flate.NewReader(bytes.NewReader(defated))
for {
_, err := io.CopyN(&buffer, rdr, 1024)
_, err = io.CopyN(&buffer, rdr, 1024)
if err != nil {
if err == io.EOF {
break
@ -293,12 +293,12 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
var authnRequest saml.AuthNRequest
err = xml.Unmarshal(buffer.Bytes(), &authnRequest)
if err != nil {
return "", "", method, fmt.Errorf("err: Failed to unmarshal AuthnRequest, please check the SAML request. %s", err.Error())
return "", "", "", fmt.Errorf("err: Failed to unmarshal AuthnRequest, please check the SAML request, %s", err.Error())
}
// verify samlRequest
if isValid := application.IsRedirectUriValid(authnRequest.Issuer); !isValid {
return "", "", method, fmt.Errorf("err: Issuer URI: %s doesn't exist in the allowed Redirect URI list", authnRequest.Issuer)
return "", "", "", fmt.Errorf("err: Issuer URI: %s doesn't exist in the allowed Redirect URI list", authnRequest.Issuer)
}
// get certificate string
@ -323,8 +323,13 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
}
_, originBackend := getOriginFromHost(host)
// build signedResponse
samlResponse, _ := NewSamlResponse(application, user, originBackend, certificate, authnRequest.AssertionConsumerServiceURL, authnRequest.Issuer, authnRequest.ID, application.RedirectUris)
samlResponse, err := NewSamlResponse(application, user, originBackend, certificate, authnRequest.AssertionConsumerServiceURL, authnRequest.Issuer, authnRequest.ID, application.RedirectUris)
if err != nil {
return "", "", "", fmt.Errorf("err: NewSamlResponse() error, %s", err.Error())
}
randomKeyStore := &X509Key{
PrivateKey: cert.PrivateKey,
X509Certificate: certificate,
@ -336,18 +341,23 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
ctx.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList("")
}
//signedXML, err := ctx.SignEnvelopedLimix(samlResponse)
//if err != nil {
// signedXML, err := ctx.SignEnvelopedLimix(samlResponse)
// if err != nil {
// return "", "", fmt.Errorf("err: %s", err.Error())
//}
// }
sig, err := ctx.ConstructSignature(samlResponse, true)
if err != nil {
return "", "", "", fmt.Errorf("err: Failed to serializes the SAML request into bytes, %s", err.Error())
}
samlResponse.InsertChildAt(1, sig)
doc := etree.NewDocument()
doc.SetRoot(samlResponse)
xmlBytes, err := doc.WriteToBytes()
if err != nil {
return "", "", method, fmt.Errorf("err: Failed to serializes the SAML request into bytes, %s", err.Error())
return "", "", "", fmt.Errorf("err: Failed to serializes the SAML request into bytes, %s", err.Error())
}
// compress
@ -355,16 +365,19 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
flated := bytes.NewBuffer(nil)
writer, err := flate.NewWriter(flated, flate.DefaultCompression)
if err != nil {
return "", "", method, err
return "", "", "", err
}
_, err = writer.Write(xmlBytes)
if err != nil {
return "", "", "", err
}
err = writer.Close()
if err != nil {
return "", "", "", err
}
xmlBytes = flated.Bytes()
}
// base64 encode
@ -373,12 +386,12 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
}
// NewSamlResponse11 return a saml1.1 response(not 2.0)
func NewSamlResponse11(user *User, requestID string, host string) *etree.Element {
func NewSamlResponse11(user *User, requestID string, host string) (*etree.Element, error) {
samlResponse := &etree.Element{
Space: "samlp",
Tag: "Response",
}
// create samlresponse
samlResponse.CreateAttr("xmlns:samlp", "urn:oasis:names:tc:SAML:1.0:protocol")
samlResponse.CreateAttr("MajorVersion", "1")
samlResponse.CreateAttr("MinorVersion", "1")
@ -431,11 +444,15 @@ func NewSamlResponse11(user *User, requestID string, host string) *etree.Element
subjectConfirmationInAttribute := subjectInAttribute.CreateElement("saml:SubjectConfirmation")
subjectConfirmationInAttribute.CreateElement("saml:ConfirmationMethod").SetText("urn:oasis:names:tc:SAML:1.0:cm:artifact")
data, _ := json.Marshal(user)
tmp := map[string]string{}
err := json.Unmarshal(data, &tmp)
data, err := json.Marshal(user)
if err != nil {
panic(err)
return nil, err
}
tmp := map[string]string{}
err = json.Unmarshal(data, &tmp)
if err != nil {
return nil, err
}
for k, v := range tmp {
@ -447,7 +464,7 @@ func NewSamlResponse11(user *User, requestID string, host string) *etree.Element
}
}
return samlResponse
return samlResponse, nil
}
func GetSamlRedirectAddress(owner string, application string, relayState string, samlRequest string, host string) string {

View File

@ -256,7 +256,7 @@ func GetValidationBySaml(samlRequest string, host string) (string, string, error
ticket := request.AssertionArtifact.InnerXML
if ticket == "" {
return "", "", fmt.Errorf("samlp:AssertionArtifact field not found")
return "", "", fmt.Errorf("request.AssertionArtifact.InnerXML error, AssertionArtifact field not found")
}
ok, _, service, userId := GetCasTokenByTicket(ticket)
@ -282,7 +282,10 @@ func GetValidationBySaml(samlRequest string, host string) (string, string, error
return "", "", fmt.Errorf("application for user %s found", userId)
}
samlResponse := NewSamlResponse11(user, request.RequestID, host)
samlResponse, err := NewSamlResponse11(user, request.RequestID, host)
if err != nil {
return "", "", err
}
cert, err := getCertByApplication(application)
if err != nil {