diff --git a/object/saml_idp.go b/object/saml_idp.go index b318a96b..999cb2de 100644 --- a/object/saml_idp.go +++ b/object/saml_idp.go @@ -273,7 +273,7 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h // base64 decode defated, err := base64.StdEncoding.DecodeString(samlRequest) if err != nil { - return "", "", method, fmt.Errorf("err: Failed to decode SAML request , %s", err.Error()) + return "", "", "", fmt.Errorf("err: Failed to decode SAML request, %s", err.Error()) } // decompress @@ -281,7 +281,7 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h rdr := flate.NewReader(bytes.NewReader(defated)) for { - _, err := io.CopyN(&buffer, rdr, 1024) + _, err = io.CopyN(&buffer, rdr, 1024) if err != nil { if err == io.EOF { break @@ -293,12 +293,12 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h var authnRequest saml.AuthNRequest err = xml.Unmarshal(buffer.Bytes(), &authnRequest) if err != nil { - return "", "", method, fmt.Errorf("err: Failed to unmarshal AuthnRequest, please check the SAML request. %s", err.Error()) + return "", "", "", fmt.Errorf("err: Failed to unmarshal AuthnRequest, please check the SAML request, %s", err.Error()) } // verify samlRequest if isValid := application.IsRedirectUriValid(authnRequest.Issuer); !isValid { - return "", "", method, fmt.Errorf("err: Issuer URI: %s doesn't exist in the allowed Redirect URI list", authnRequest.Issuer) + return "", "", "", fmt.Errorf("err: Issuer URI: %s doesn't exist in the allowed Redirect URI list", authnRequest.Issuer) } // get certificate string @@ -323,8 +323,13 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h } _, originBackend := getOriginFromHost(host) + // build signedResponse - samlResponse, _ := NewSamlResponse(application, user, originBackend, certificate, authnRequest.AssertionConsumerServiceURL, authnRequest.Issuer, authnRequest.ID, application.RedirectUris) + samlResponse, err := NewSamlResponse(application, user, originBackend, certificate, authnRequest.AssertionConsumerServiceURL, authnRequest.Issuer, authnRequest.ID, application.RedirectUris) + if err != nil { + return "", "", "", fmt.Errorf("err: NewSamlResponse() error, %s", err.Error()) + } + randomKeyStore := &X509Key{ PrivateKey: cert.PrivateKey, X509Certificate: certificate, @@ -336,18 +341,23 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h ctx.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList("") } - //signedXML, err := ctx.SignEnvelopedLimix(samlResponse) - //if err != nil { + // signedXML, err := ctx.SignEnvelopedLimix(samlResponse) + // if err != nil { // return "", "", fmt.Errorf("err: %s", err.Error()) - //} + // } + sig, err := ctx.ConstructSignature(samlResponse, true) + if err != nil { + return "", "", "", fmt.Errorf("err: Failed to serializes the SAML request into bytes, %s", err.Error()) + } + samlResponse.InsertChildAt(1, sig) doc := etree.NewDocument() doc.SetRoot(samlResponse) xmlBytes, err := doc.WriteToBytes() if err != nil { - return "", "", method, fmt.Errorf("err: Failed to serializes the SAML request into bytes, %s", err.Error()) + return "", "", "", fmt.Errorf("err: Failed to serializes the SAML request into bytes, %s", err.Error()) } // compress @@ -355,16 +365,19 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h flated := bytes.NewBuffer(nil) writer, err := flate.NewWriter(flated, flate.DefaultCompression) if err != nil { - return "", "", method, err + return "", "", "", err } + _, err = writer.Write(xmlBytes) if err != nil { return "", "", "", err } + err = writer.Close() if err != nil { return "", "", "", err } + xmlBytes = flated.Bytes() } // base64 encode @@ -373,12 +386,12 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h } // NewSamlResponse11 return a saml1.1 response(not 2.0) -func NewSamlResponse11(user *User, requestID string, host string) *etree.Element { +func NewSamlResponse11(user *User, requestID string, host string) (*etree.Element, error) { samlResponse := &etree.Element{ Space: "samlp", Tag: "Response", } - // create samlresponse + samlResponse.CreateAttr("xmlns:samlp", "urn:oasis:names:tc:SAML:1.0:protocol") samlResponse.CreateAttr("MajorVersion", "1") samlResponse.CreateAttr("MinorVersion", "1") @@ -431,11 +444,15 @@ func NewSamlResponse11(user *User, requestID string, host string) *etree.Element subjectConfirmationInAttribute := subjectInAttribute.CreateElement("saml:SubjectConfirmation") subjectConfirmationInAttribute.CreateElement("saml:ConfirmationMethod").SetText("urn:oasis:names:tc:SAML:1.0:cm:artifact") - data, _ := json.Marshal(user) - tmp := map[string]string{} - err := json.Unmarshal(data, &tmp) + data, err := json.Marshal(user) if err != nil { - panic(err) + return nil, err + } + + tmp := map[string]string{} + err = json.Unmarshal(data, &tmp) + if err != nil { + return nil, err } for k, v := range tmp { @@ -447,7 +464,7 @@ func NewSamlResponse11(user *User, requestID string, host string) *etree.Element } } - return samlResponse + return samlResponse, nil } func GetSamlRedirectAddress(owner string, application string, relayState string, samlRequest string, host string) string { diff --git a/object/token_cas.go b/object/token_cas.go index ecbbf5d7..866ba9b4 100644 --- a/object/token_cas.go +++ b/object/token_cas.go @@ -256,7 +256,7 @@ func GetValidationBySaml(samlRequest string, host string) (string, string, error ticket := request.AssertionArtifact.InnerXML if ticket == "" { - return "", "", fmt.Errorf("samlp:AssertionArtifact field not found") + return "", "", fmt.Errorf("request.AssertionArtifact.InnerXML error, AssertionArtifact field not found") } ok, _, service, userId := GetCasTokenByTicket(ticket) @@ -282,7 +282,10 @@ func GetValidationBySaml(samlRequest string, host string) (string, string, error return "", "", fmt.Errorf("application for user %s found", userId) } - samlResponse := NewSamlResponse11(user, request.RequestID, host) + samlResponse, err := NewSamlResponse11(user, request.RequestID, host) + if err != nil { + return "", "", err + } cert, err := getCertByApplication(application) if err != nil {