diff --git a/controllers/token.go b/controllers/token.go index ffb17164..b47747aa 100644 --- a/controllers/token.go +++ b/controllers/token.go @@ -261,7 +261,7 @@ func (c *ApiController) TokenLogout() { flag, application := object.DeleteTokenByAccessToken(token) redirectUri := c.Input().Get("post_logout_redirect_uri") state := c.Input().Get("state") - if application != nil && object.CheckRedirectUriValid(application, redirectUri) { + if application != nil && application.IsRedirectUriValid(redirectUri) { c.Ctx.Redirect(http.StatusFound, redirectUri+"?state="+state) return } diff --git a/object/application.go b/object/application.go index 6ab37392..cfc6286b 100644 --- a/object/application.go +++ b/object/application.go @@ -16,7 +16,6 @@ package object import ( "fmt" - "net/url" "regexp" "strings" @@ -354,52 +353,26 @@ func (application *Application) GetId() string { return fmt.Sprintf("%s/%s", application.Owner, application.Name) } -func CheckRedirectUriValid(application *Application, redirectUri string) bool { - validUri := false - for _, tmpUri := range application.RedirectUris { - tmpUriRegex := regexp.MustCompile(tmpUri) - if tmpUriRegex.MatchString(redirectUri) || strings.Contains(redirectUri, tmpUri) { - validUri = true +func (application *Application) IsRedirectUriValid(redirectUri string) bool { + isValid := false + for _, targetUri := range application.RedirectUris { + targetUriRegex := regexp.MustCompile(targetUri) + if targetUriRegex.MatchString(redirectUri) || strings.Contains(redirectUri, targetUri) { + isValid = true break } } - return validUri + return isValid } -func IsAllowOrigin(origin string) bool { - allowOrigin := false - originUrl, err := url.Parse(origin) - if err != nil { - return false - } - - rows, err := adapter.Engine.Cols("redirect_uris").Rows(&Application{}) - if err != nil { - panic(err) - } - - application := Application{} - for rows.Next() { - err := rows.Scan(&application) - if err != nil { - panic(err) - } - for _, tmpRedirectUri := range application.RedirectUris { - u1, err := url.Parse(tmpRedirectUri) - if err != nil { - continue - } - if u1.Scheme == originUrl.Scheme && u1.Host == originUrl.Host { - allowOrigin = true - break - } - } - if allowOrigin { - break +func IsOriginAllowed(origin string) bool { + applications := GetApplications("") + for _, application := range applications { + if application.IsRedirectUriValid(origin) { + return true } } - - return allowOrigin + return false } func getApplicationMap(organization string) map[string]*Application { diff --git a/object/saml_idp.go b/object/saml_idp.go index 6a6f4243..92598f3c 100644 --- a/object/saml_idp.go +++ b/object/saml_idp.go @@ -240,8 +240,8 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h } // verify samlRequest - if valid := CheckRedirectUriValid(application, authnRequest.Issuer.Url); !valid { - return "", "", fmt.Errorf("err: invalid issuer url") + if isValid := application.IsRedirectUriValid(authnRequest.Issuer.Url); !isValid { + return "", "", fmt.Errorf("err: Issuer URI: %s doesn't exist in the allowed Redirect URI list", authnRequest.Issuer.Url) } // get certificate string diff --git a/object/token.go b/object/token.go index 5955cb2e..53b90fc6 100644 --- a/object/token.go +++ b/object/token.go @@ -18,7 +18,6 @@ import ( "crypto/sha256" "encoding/base64" "fmt" - "strings" "time" "github.com/casdoor/casdoor/i18n" @@ -253,14 +252,7 @@ func CheckOAuthLogin(clientId string, responseType string, redirectUri string, s return i18n.Translate(lang, "token:Invalid client_id"), nil } - validUri := false - for _, tmpUri := range application.RedirectUris { - if strings.Contains(redirectUri, tmpUri) { - validUri = true - break - } - } - if !validUri { + if !application.IsRedirectUriValid(redirectUri) { return fmt.Sprintf(i18n.Translate(lang, "token:Redirect URI: %s doesn't exist in the allowed Redirect URI list"), redirectUri), application } diff --git a/routers/cors_filter.go b/routers/cors_filter.go index cbda29b1..6058f07c 100644 --- a/routers/cors_filter.go +++ b/routers/cors_filter.go @@ -34,7 +34,7 @@ func CorsFilter(ctx *context.Context) { originConf := conf.GetConfigString("origin") if origin != "" && originConf != "" && origin != originConf { - if object.IsAllowOrigin(origin) { + if object.IsOriginAllowed(origin) { ctx.Output.Header(headerAllowOrigin, origin) ctx.Output.Header(headerAllowMethods, "POST, GET, OPTIONS") ctx.Output.Header(headerAllowHeaders, "Content-Type, Authorization")