Improve error handling in AutoSigninFilter

This commit is contained in:
Yang Luo 2023-10-15 12:43:36 +08:00
parent 1c296e9b6f
commit 1055d7781b
5 changed files with 27 additions and 16 deletions

View File

@ -35,14 +35,14 @@ type Object struct {
func getUsername(ctx *context.Context) (username string) { func getUsername(ctx *context.Context) (username string) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
username = getUsernameByClientIdSecret(ctx) username, _ = getUsernameByClientIdSecret(ctx)
} }
}() }()
username = ctx.Input.Session("username").(string) username = ctx.Input.Session("username").(string)
if username == "" { if username == "" {
username = getUsernameByClientIdSecret(ctx) username, _ = getUsernameByClientIdSecret(ctx)
} }
if username == "" { if username == "" {

View File

@ -45,19 +45,21 @@ func AutoSigninFilter(ctx *context.Context) {
} }
if token == nil { if token == nil {
responseError(ctx, "Access token doesn't exist") responseError(ctx, "Access token doesn't exist in database")
return return
} }
if util.IsTokenExpired(token.CreatedTime, token.ExpiresIn) { isExpired, expireTime := util.IsTokenExpired(token.CreatedTime, token.ExpiresIn)
responseError(ctx, "Access token has expired") if isExpired {
responseError(ctx, fmt.Sprintf("Access token has expired, expireTime = %s", expireTime))
return return
} }
userId := util.GetId(token.Organization, token.User) userId := util.GetId(token.Organization, token.User)
application, err := object.GetApplicationByUserId(fmt.Sprintf("app/%s", token.Application)) application, err := object.GetApplicationByUserId(fmt.Sprintf("app/%s", token.Application))
if err != nil { if err != nil {
panic(err) responseError(ctx, err.Error())
return
} }
setSessionUser(ctx, userId) setSessionUser(ctx, userId)
@ -66,7 +68,11 @@ func AutoSigninFilter(ctx *context.Context) {
} }
// "/page?clientId=123&clientSecret=456" // "/page?clientId=123&clientSecret=456"
userId := getUsernameByClientIdSecret(ctx) userId, err := getUsernameByClientIdSecret(ctx)
if err != nil {
responseError(ctx, err.Error())
return
}
if userId != "" { if userId != "" {
setSessionUser(ctx, userId) setSessionUser(ctx, userId)
return return

View File

@ -66,7 +66,7 @@ func denyRequest(ctx *context.Context) {
responseError(ctx, T(ctx, "auth:Unauthorized operation")) responseError(ctx, T(ctx, "auth:Unauthorized operation"))
} }
func getUsernameByClientIdSecret(ctx *context.Context) string { func getUsernameByClientIdSecret(ctx *context.Context) (string, error) {
clientId, clientSecret, ok := ctx.Request.BasicAuth() clientId, clientSecret, ok := ctx.Request.BasicAuth()
if !ok { if !ok {
clientId = ctx.Input.Query("clientId") clientId = ctx.Input.Query("clientId")
@ -74,19 +74,22 @@ func getUsernameByClientIdSecret(ctx *context.Context) string {
} }
if clientId == "" || clientSecret == "" { if clientId == "" || clientSecret == "" {
return "" return "", nil
} }
application, err := object.GetApplicationByClientId(clientId) application, err := object.GetApplicationByClientId(clientId)
if err != nil { if err != nil {
panic(err) return "", err
}
if application == nil {
return "", fmt.Errorf("Application not found for client ID: %s", clientId)
} }
if application == nil || application.ClientSecret != clientSecret { if application.ClientSecret != clientSecret {
return "" return "", fmt.Errorf("Incorrect client secret for application: %s", application.Name)
} }
return fmt.Sprintf("app/%s", application.Name) return fmt.Sprintf("app/%s", application.Name), nil
} }
func getUsernameByKeys(ctx *context.Context) string { func getUsernameByKeys(ctx *context.Context) string {

View File

@ -58,8 +58,10 @@ func Time2String(timestamp time.Time) string {
return timestamp.Format(time.RFC3339) return timestamp.Format(time.RFC3339)
} }
func IsTokenExpired(createdTime string, expiresIn int) bool { func IsTokenExpired(createdTime string, expiresIn int) (bool, string) {
createdTimeObj, _ := time.Parse(time.RFC3339, createdTime) createdTimeObj, _ := time.Parse(time.RFC3339, createdTime)
expiresAtObj := createdTimeObj.Add(time.Duration(expiresIn) * time.Second) expiresAtObj := createdTimeObj.Add(time.Duration(expiresIn) * time.Second)
return time.Now().After(expiresAtObj) isExpired := time.Now().After(expiresAtObj)
expireTime := expiresAtObj.Local().Format(time.RFC3339)
return isExpired, expireTime
} }

View File

@ -102,7 +102,7 @@ func Test_IsTokenExpired(t *testing.T) {
}, },
} { } {
t.Run(scenario.description, func(t *testing.T) { t.Run(scenario.description, func(t *testing.T) {
result := IsTokenExpired(scenario.input.createdTime, scenario.input.expiresIn) result, _ := IsTokenExpired(scenario.input.createdTime, scenario.input.expiresIn)
assert.Equal(t, scenario.expected, result, fmt.Sprintf("Expected %t, but was founded %t", scenario.expected, result)) assert.Equal(t, scenario.expected, result, fmt.Sprintf("Expected %t, but was founded %t", scenario.expected, result))
}) })
} }