diff --git a/routers/static_filter.go b/routers/static_filter.go index 240a25b4..9fed9aa2 100644 --- a/routers/static_filter.go +++ b/routers/static_filter.go @@ -26,6 +26,7 @@ import ( "github.com/beego/beego/context" "github.com/casdoor/casdoor/conf" + "github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/util" ) @@ -46,6 +47,46 @@ func getWebBuildFolder() string { return path } +func fastAutoSignin(ctx *context.Context) (string, error) { + userId := getSessionUser(ctx) + if userId == "" { + return "", nil + } + + clientId := ctx.Input.Query("client_id") + responseType := ctx.Input.Query("response_type") + redirectUri := ctx.Input.Query("redirect_uri") + scope := ctx.Input.Query("scope") + state := ctx.Input.Query("state") + nonce := "" + codeChallenge := "" + if clientId == "" || responseType != "code" || redirectUri == "" { + return "", nil + } + + application, err := object.GetApplicationByClientId(clientId) + if err != nil { + return "", err + } + if application == nil { + return "", nil + } + + if !application.EnableAutoSignin { + return "", nil + } + + code, err := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge, ctx.Request.Host, getAcceptLanguage(ctx)) + if err != nil { + return "", err + } else if code.Message != "" { + return "", fmt.Errorf(code.Message) + } + + res := fmt.Sprintf("%s?code=%s&state=%s", redirectUri, code.Code, state) + return res, nil +} + func StaticFilter(ctx *context.Context) { urlPath := ctx.Request.URL.Path @@ -63,6 +104,19 @@ func StaticFilter(ctx *context.Context) { return } + if urlPath == "/login/oauth/authorize" { + redirectUrl, err := fastAutoSignin(ctx) + if err != nil { + responseError(ctx, err.Error()) + return + } + + if redirectUrl != "" { + http.Redirect(ctx.ResponseWriter, ctx.Request, redirectUrl, http.StatusFound) + return + } + } + webBuildFolder := getWebBuildFolder() path := webBuildFolder if urlPath == "/" {