From 811999b6cc9df558a9d49d66c8436739dc98a762 Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Sun, 19 Nov 2023 19:58:07 +0800 Subject: [PATCH] feat: fix error handling in CheckPassword() related functions --- controllers/auth.go | 80 ++++++++++++++++++++++------------- controllers/user.go | 20 ++++----- ldap/server.go | 12 +++--- ldap/util.go | 14 +++--- object/check.go | 70 +++++++++++++++--------------- object/check_util.go | 24 +++++++---- object/token.go | 12 +++--- object/verification.go | 18 ++++---- radius/server.go | 7 ++- routers/auto_signin_filter.go | 7 ++- 10 files changed, 150 insertions(+), 114 deletions(-) diff --git a/controllers/auth.go b/controllers/auth.go index d1a028b5..e6d3fa4b 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -34,6 +34,7 @@ import ( "github.com/casdoor/casdoor/proxy" "github.com/casdoor/casdoor/util" "github.com/google/uuid" + "golang.org/x/oauth2" ) var ( @@ -331,8 +332,6 @@ func (c *ApiController) Login() { } var user *object.User - var msg string - if authForm.Password == "" { if user, err = object.GetUserByFields(authForm.Organization, authForm.Username); err != nil { c.ResponseError(err.Error(), nil) @@ -354,20 +353,21 @@ func (c *ApiController) Login() { } // check result through Email or Phone - checkResult := object.CheckSigninCode(user, checkDest, authForm.Code, c.GetAcceptLanguage()) - if len(checkResult) != 0 { - c.ResponseError(fmt.Sprintf("%s - %s", verificationCodeType, checkResult)) + err = object.CheckSigninCode(user, checkDest, authForm.Code, c.GetAcceptLanguage()) + if err != nil { + c.ResponseError(fmt.Sprintf("%s - %s", verificationCodeType, err.Error())) return } // disable the verification code - err := object.DisableVerificationCode(checkDest) + err = object.DisableVerificationCode(checkDest) if err != nil { c.ResponseError(err.Error(), nil) return } } else { - application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + var application *object.Application + application, err = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) if err != nil { c.ResponseError(err.Error(), nil) return @@ -386,7 +386,8 @@ func (c *ApiController) Login() { c.ResponseError(err.Error()) return } else if enableCaptcha { - isHuman, err := captcha.VerifyCaptchaByCaptchaType(authForm.CaptchaType, authForm.CaptchaToken, authForm.ClientSecret) + var isHuman bool + isHuman, err = captcha.VerifyCaptchaByCaptchaType(authForm.CaptchaType, authForm.CaptchaToken, authForm.ClientSecret) if err != nil { c.ResponseError(err.Error()) return @@ -399,13 +400,15 @@ func (c *ApiController) Login() { } password := authForm.Password - user, msg = object.CheckUserPassword(authForm.Organization, authForm.Username, password, c.GetAcceptLanguage(), enableCaptcha) + user, err = object.CheckUserPassword(authForm.Organization, authForm.Username, password, c.GetAcceptLanguage(), enableCaptcha) } - if msg != "" { - resp = &Response{Status: "error", Msg: msg} + if err != nil { + c.ResponseError(err.Error()) + return } else { - application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + var application *object.Application + application, err = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) if err != nil { c.ResponseError(err.Error()) return @@ -416,7 +419,8 @@ func (c *ApiController) Login() { return } - organization, err := object.GetOrganizationByUser(user) + var organization *object.Organization + organization, err = object.GetOrganizationByUser(user) if err != nil { c.ResponseError(err.Error()) } @@ -461,12 +465,15 @@ func (c *ApiController) Login() { c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) return } - organization, err := object.GetOrganization(util.GetId("admin", application.Organization)) + + var organization *object.Organization + organization, err = object.GetOrganization(util.GetId("admin", application.Organization)) if err != nil { c.ResponseError(c.T(err.Error())) } - provider, err := object.GetProvider(util.GetId("admin", authForm.Provider)) + var provider *object.Provider + provider, err = object.GetProvider(util.GetId("admin", authForm.Provider)) if err != nil { c.ResponseError(err.Error()) return @@ -488,7 +495,8 @@ func (c *ApiController) Login() { } else if provider.Category == "OAuth" || provider.Category == "Web3" { // OAuth idpInfo := object.FromProviderToIdpInfo(c.Ctx, provider) - idProvider, err := idp.GetIdProvider(idpInfo, authForm.RedirectUri) + var idProvider idp.IdProvider + idProvider, err = idp.GetIdProvider(idpInfo, authForm.RedirectUri) if err != nil { c.ResponseError(err.Error()) return @@ -506,7 +514,8 @@ func (c *ApiController) Login() { } // https://github.com/golang/oauth2/issues/123#issuecomment-103715338 - token, err := idProvider.GetToken(authForm.Code) + var token *oauth2.Token + token, err = idProvider.GetToken(authForm.Code) if err != nil { c.ResponseError(err.Error()) return @@ -548,7 +557,7 @@ func (c *ApiController) Login() { c.ResponseError(c.T("check:The user is forbidden to sign in, please contact the administrator")) } // sync info from 3rd-party if possible - _, err := object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) + _, err = object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) if err != nil { c.ResponseError(err.Error()) return @@ -593,14 +602,16 @@ func (c *ApiController) Login() { } // Handle username conflicts - tmpUser, err := object.GetUser(util.GetId(application.Organization, userInfo.Username)) + var tmpUser *object.User + tmpUser, err = object.GetUser(util.GetId(application.Organization, userInfo.Username)) if err != nil { c.ResponseError(err.Error()) return } if tmpUser != nil { - uid, err := uuid.NewRandom() + var uid uuid.UUID + uid, err = uuid.NewRandom() if err != nil { c.ResponseError(err.Error()) return @@ -611,14 +622,16 @@ func (c *ApiController) Login() { } properties := map[string]string{} - count, err := object.GetUserCount(application.Organization, "", "", "") + var count int64 + count, err = object.GetUserCount(application.Organization, "", "", "") if err != nil { c.ResponseError(err.Error()) return } properties["no"] = strconv.Itoa(int(count + 2)) - initScore, err := organization.GetInitScore() + var initScore int + initScore, err = organization.GetInitScore() if err != nil { c.ResponseError(fmt.Errorf(c.T("account:Get init score failed, error: %w"), err).Error()) return @@ -650,7 +663,8 @@ func (c *ApiController) Login() { Properties: properties, } - affected, err := object.AddUser(user) + var affected bool + affected, err = object.AddUser(user) if err != nil { c.ResponseError(err.Error()) return @@ -672,7 +686,7 @@ func (c *ApiController) Login() { } // sync info from 3rd-party if possible - _, err := object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) + _, err = object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) if err != nil { c.ResponseError(err.Error()) return @@ -708,7 +722,8 @@ func (c *ApiController) Login() { return } - oldUser, err := object.GetUserByField(application.Organization, provider.Type, userInfo.Id) + var oldUser *object.User + oldUser, err = object.GetUserByField(application.Organization, provider.Type, userInfo.Id) if err != nil { c.ResponseError(err.Error()) return @@ -719,7 +734,8 @@ func (c *ApiController) Login() { return } - user, err := object.GetUser(userId) + var user *object.User + user, err = object.GetUser(userId) if err != nil { c.ResponseError(err.Error()) return @@ -732,7 +748,8 @@ func (c *ApiController) Login() { return } - isLinked, err := object.LinkUserAccount(user, provider.Type, userInfo.Id) + var isLinked bool + isLinked, err = object.LinkUserAccount(user, provider.Type, userInfo.Id) if err != nil { c.ResponseError(err.Error()) return @@ -745,7 +762,8 @@ func (c *ApiController) Login() { } } } else if c.getMfaUserSession() != "" { - user, err := object.GetUser(c.getMfaUserSession()) + var user *object.User + user, err = object.GetUser(c.getMfaUserSession()) if err != nil { c.ResponseError(err.Error()) return @@ -778,7 +796,8 @@ func (c *ApiController) Login() { return } - application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + var application *object.Application + application, err = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) if err != nil { c.ResponseError(err.Error()) return @@ -799,7 +818,8 @@ func (c *ApiController) Login() { } else { if c.GetSessionUsername() != "" { // user already signed in to Casdoor, so let the user click the avatar button to do the quick sign-in - application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + var application *object.Application + application, err = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) if err != nil { c.ResponseError(err.Error()) return diff --git a/controllers/user.go b/controllers/user.go index d91fdedb..1df6d8a3 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -476,16 +476,16 @@ func (c *ApiController) SetPassword() { isAdmin := c.IsAdmin() if isAdmin { if oldPassword != "" { - msg := object.CheckPassword(targetUser, oldPassword, c.GetAcceptLanguage()) - if msg != "" { - c.ResponseError(msg) + err = object.CheckPassword(targetUser, oldPassword, c.GetAcceptLanguage()) + if err != nil { + c.ResponseError(err.Error()) return } } } else if code == "" { - msg := object.CheckPassword(targetUser, oldPassword, c.GetAcceptLanguage()) - if msg != "" { - c.ResponseError(msg) + err = object.CheckPassword(targetUser, oldPassword, c.GetAcceptLanguage()) + if err != nil { + c.ResponseError(err.Error()) return } } @@ -518,11 +518,11 @@ func (c *ApiController) CheckUserPassword() { return } - _, msg := object.CheckUserPassword(user.Owner, user.Name, user.Password, c.GetAcceptLanguage()) - if msg == "" { - c.ResponseOk() + _, err = object.CheckUserPassword(user.Owner, user.Name, user.Password, c.GetAcceptLanguage()) + if err != nil { + c.ResponseError(err.Error()) } else { - c.ResponseError(msg) + c.ResponseOk() } } diff --git a/ldap/server.go b/ldap/server.go index 3b2835f8..4ec78375 100644 --- a/ldap/server.go +++ b/ldap/server.go @@ -49,20 +49,20 @@ func handleBind(w ldap.ResponseWriter, m *ldap.Message) { if r.AuthenticationChoice() == "simple" { bindUsername, bindOrg, err := getNameAndOrgFromDN(string(r.Name())) - if err != "" { - log.Printf("Bind failed ,ErrMsg=%s", err) + if err != nil { + log.Printf("getNameAndOrgFromDN() error: %s", err.Error()) res.SetResultCode(ldap.LDAPResultInvalidDNSyntax) - res.SetDiagnosticMessage("bind failed ErrMsg: " + err) + res.SetDiagnosticMessage(fmt.Sprintf("getNameAndOrgFromDN() error: %s", err.Error())) w.Write(res) return } bindPassword := string(r.AuthenticationSimple()) bindUser, err := object.CheckUserPassword(bindOrg, bindUsername, bindPassword, "en") - if err != "" { + if err != nil { log.Printf("Bind failed User=%s, Pass=%#v, ErrMsg=%s", string(r.Name()), r.Authentication(), err) res.SetResultCode(ldap.LDAPResultInvalidCredentials) - res.SetDiagnosticMessage("invalid credentials ErrMsg: " + err) + res.SetDiagnosticMessage("invalid credentials ErrMsg: " + err.Error()) w.Write(res) return } @@ -78,7 +78,7 @@ func handleBind(w ldap.ResponseWriter, m *ldap.Message) { m.Client.OrgName = bindOrg } else { res.SetResultCode(ldap.LDAPResultAuthMethodNotSupported) - res.SetDiagnosticMessage("Authentication method not supported,Please use Simple Authentication") + res.SetDiagnosticMessage("Authentication method not supported, please use Simple Authentication") } w.Write(res) } diff --git a/ldap/util.go b/ldap/util.go index f2ae7b80..d76b59aa 100644 --- a/ldap/util.go +++ b/ldap/util.go @@ -26,7 +26,7 @@ import ( ldap "github.com/forestmgy/ldapserver" ) -func getNameAndOrgFromDN(DN string) (string, string, string) { +func getNameAndOrgFromDN(DN string) (string, string, error) { DNFields := strings.Split(DN, ",") params := make(map[string]string, len(DNFields)) for _, field := range DNFields { @@ -37,12 +37,12 @@ func getNameAndOrgFromDN(DN string) (string, string, string) { } if params["cn"] == "" { - return "", "", "please use Admin Name format like cn=xxx,ou=xxx,dc=example,dc=com" + return "", "", fmt.Errorf("please use Admin Name format like cn=xxx,ou=xxx,dc=example,dc=com") } if params["ou"] == "" { - return params["cn"], object.CasdoorOrganization, "" + return params["cn"], object.CasdoorOrganization, nil } - return params["cn"], params["ou"], "" + return params["cn"], params["ou"], nil } func getNameAndOrgFromFilter(baseDN, filter string) (string, string, int) { @@ -50,7 +50,11 @@ func getNameAndOrgFromFilter(baseDN, filter string) (string, string, int) { return "", "", ldap.LDAPResultInvalidDNSyntax } - name, org, _ := getNameAndOrgFromDN(fmt.Sprintf("cn=%s,", getUsername(filter)) + baseDN) + name, org, err := getNameAndOrgFromDN(fmt.Sprintf("cn=%s,", getUsername(filter)) + baseDN) + if err != nil { + panic(err) + } + return name, org, ldap.LDAPResultSuccess } diff --git a/object/check.go b/object/check.go index 04c8c6c2..18804428 100644 --- a/object/check.go +++ b/object/check.go @@ -142,7 +142,7 @@ func CheckUserSignup(application *Application, organization *Organization, form return "" } -func checkSigninErrorTimes(user *User, lang string) string { +func checkSigninErrorTimes(user *User, lang string) error { if user.SigninWrongTimes >= SigninWrongTimesLimit { lastSignWrongTime, _ := time.Parse(time.RFC3339, user.LastSigninWrongTime) passedTime := time.Now().UTC().Sub(lastSignWrongTime) @@ -150,37 +150,39 @@ func checkSigninErrorTimes(user *User, lang string) string { // deny the login if the error times is greater than the limit and the last login time is less than the duration if minutes > 0 { - return fmt.Sprintf(i18n.Translate(lang, "check:You have entered the wrong password or code too many times, please wait for %d minutes and try again"), minutes) + return fmt.Errorf(i18n.Translate(lang, "check:You have entered the wrong password or code too many times, please wait for %d minutes and try again"), minutes) } // reset the error times user.SigninWrongTimes = 0 - UpdateUser(user.GetId(), user, []string{"signin_wrong_times"}, false) + _, err := UpdateUser(user.GetId(), user, []string{"signin_wrong_times"}, false) + return err } - return "" + return nil } -func CheckPassword(user *User, password string, lang string, options ...bool) string { +func CheckPassword(user *User, password string, lang string, options ...bool) error { enableCaptcha := false if len(options) > 0 { enableCaptcha = options[0] } // check the login error times if !enableCaptcha { - if msg := checkSigninErrorTimes(user, lang); msg != "" { - return msg + err := checkSigninErrorTimes(user, lang) + if err != nil { + return err } } organization, err := GetOrganizationByUser(user) if err != nil { - panic(err) + return err } if organization == nil { - return i18n.Translate(lang, "check:Organization does not exist") + return fmt.Errorf(i18n.Translate(lang, "check:Organization does not exist")) } passwordType := user.PasswordType @@ -191,19 +193,17 @@ func CheckPassword(user *User, password string, lang string, options ...bool) st if credManager != nil { if organization.MasterPassword != "" { if credManager.IsPasswordCorrect(password, organization.MasterPassword, "", organization.PasswordSalt) { - resetUserSigninErrorTimes(user) - return "" + return resetUserSigninErrorTimes(user) } } if credManager.IsPasswordCorrect(password, user.Password, user.PasswordSalt, organization.PasswordSalt) { - resetUserSigninErrorTimes(user) - return "" + return resetUserSigninErrorTimes(user) } return recordSigninErrorInfo(user, lang, enableCaptcha) } else { - return fmt.Sprintf(i18n.Translate(lang, "check:unsupported password type: %s"), organization.PasswordType) + return fmt.Errorf(i18n.Translate(lang, "check:unsupported password type: %s"), organization.PasswordType) } } @@ -217,10 +217,10 @@ func CheckPasswordComplexity(user *User, password string) string { return CheckPasswordComplexityByOrg(organization, password) } -func checkLdapUserPassword(user *User, password string, lang string) string { +func checkLdapUserPassword(user *User, password string, lang string) error { ldaps, err := GetLdaps(user.Owner) if err != nil { - return err.Error() + return err } ldapLoginSuccess := false @@ -237,14 +237,14 @@ func checkLdapUserPassword(user *User, password string, lang string) string { searchResult, err := conn.Conn.Search(searchReq) if err != nil { - return err.Error() + return err } if len(searchResult.Entries) == 0 { continue } if len(searchResult.Entries) > 1 { - return i18n.Translate(lang, "check:Multiple accounts with same uid, please check your ldap server") + return fmt.Errorf(i18n.Translate(lang, "check:Multiple accounts with same uid, please check your ldap server")) } hit = true @@ -257,45 +257,47 @@ func checkLdapUserPassword(user *User, password string, lang string) string { if !ldapLoginSuccess { if !hit { - return "user not exist" + return fmt.Errorf("user not exist") } - return i18n.Translate(lang, "check:LDAP user name or password incorrect") + return fmt.Errorf(i18n.Translate(lang, "check:LDAP user name or password incorrect")) } - return "" + return nil } -func CheckUserPassword(organization string, username string, password string, lang string, options ...bool) (*User, string) { +func CheckUserPassword(organization string, username string, password string, lang string, options ...bool) (*User, error) { enableCaptcha := false if len(options) > 0 { enableCaptcha = options[0] } user, err := GetUserByFields(organization, username) if err != nil { - panic(err) + return nil, err } if user == nil || user.IsDeleted { - return nil, fmt.Sprintf(i18n.Translate(lang, "general:The user: %s doesn't exist"), util.GetId(organization, username)) + return nil, fmt.Errorf(i18n.Translate(lang, "general:The user: %s doesn't exist"), util.GetId(organization, username)) } if user.IsForbidden { - return nil, i18n.Translate(lang, "check:The user is forbidden to sign in, please contact the administrator") + return nil, fmt.Errorf(i18n.Translate(lang, "check:The user is forbidden to sign in, please contact the administrator")) } if user.Ldap != "" { - // ONLY for ldap users - if msg := checkLdapUserPassword(user, password, lang); msg != "" { - if msg == "user not exist" { - return nil, fmt.Sprintf(i18n.Translate(lang, "check:The user: %s doesn't exist in LDAP server"), username) + // only for LDAP users + err = checkLdapUserPassword(user, password, lang) + if err != nil { + if err.Error() == "user not exist" { + return nil, fmt.Errorf(i18n.Translate(lang, "check:The user: %s doesn't exist in LDAP server"), username) } - return nil, msg + return nil, err } } else { - if msg := CheckPassword(user, password, lang, enableCaptcha); msg != "" { - return nil, msg + err = CheckPassword(user, password, lang, enableCaptcha) + if err != nil { + return nil, err } } - return user, "" + return user, nil } func CheckUserPermission(requestUserId, userId string, strict bool, lang string) (bool, error) { @@ -308,7 +310,7 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string) if userId != "" { targetUser, err := GetUser(userId) if err != nil { - panic(err) + return false, err } if targetUser == nil { diff --git a/object/check_util.go b/object/check_util.go index 822cfd0d..51216e6a 100644 --- a/object/check_util.go +++ b/object/check_util.go @@ -36,20 +36,23 @@ func isValidRealName(s string) bool { return reRealName.MatchString(s) } -func resetUserSigninErrorTimes(user *User) { +func resetUserSigninErrorTimes(user *User) error { // if the password is correct and wrong times is not zero, reset the error times if user.SigninWrongTimes == 0 { - return + return nil } + user.SigninWrongTimes = 0 - UpdateUser(user.GetId(), user, []string{"signin_wrong_times", "last_signin_wrong_time"}, false) + _, err := UpdateUser(user.GetId(), user, []string{"signin_wrong_times", "last_signin_wrong_time"}, false) + return err } -func recordSigninErrorInfo(user *User, lang string, options ...bool) string { +func recordSigninErrorInfo(user *User, lang string, options ...bool) error { enableCaptcha := false if len(options) > 0 { enableCaptcha = options[0] } + // increase failed login count if user.SigninWrongTimes < SigninWrongTimesLimit { user.SigninWrongTimes++ @@ -61,13 +64,18 @@ func recordSigninErrorInfo(user *User, lang string, options ...bool) string { } // update user - UpdateUser(user.GetId(), user, []string{"signin_wrong_times", "last_signin_wrong_time"}, false) + _, err := UpdateUser(user.GetId(), user, []string{"signin_wrong_times", "last_signin_wrong_time"}, false) + if err != nil { + return err + } + leftChances := SigninWrongTimesLimit - user.SigninWrongTimes if leftChances == 0 && enableCaptcha { - return fmt.Sprint(i18n.Translate(lang, "check:password or code is incorrect")) + return fmt.Errorf(i18n.Translate(lang, "check:password or code is incorrect")) } else if leftChances >= 0 { - return fmt.Sprintf(i18n.Translate(lang, "check:password or code is incorrect, you have %d remaining chances"), leftChances) + return fmt.Errorf(i18n.Translate(lang, "check:password or code is incorrect, you have %d remaining chances"), leftChances) } + // don't show the chance error message if the user has no chance left - return fmt.Sprintf(i18n.Translate(lang, "check:You have entered the wrong password or code too many times, please wait for %d minutes and try again"), int(LastSignWrongTimeDuration.Minutes())) + return fmt.Errorf(i18n.Translate(lang, "check:You have entered the wrong password or code too many times, please wait for %d minutes and try again"), int(LastSignWrongTimeDuration.Minutes())) } diff --git a/object/token.go b/object/token.go index 3a81c765..b68f62ec 100644 --- a/object/token.go +++ b/object/token.go @@ -621,25 +621,25 @@ func GetPasswordToken(application *Application, username string, password string if err != nil { return nil, nil, err } - if user == nil { return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "the user does not exist", }, nil } - var msg string + if user.Ldap != "" { - msg = checkLdapUserPassword(user, password, "en") + err = checkLdapUserPassword(user, password, "en") } else { - msg = CheckPassword(user, password, "en") + err = CheckPassword(user, password, "en") } - if msg != "" { + if err != nil { return nil, &TokenError{ Error: InvalidGrant, - ErrorDescription: "invalid username or password", + ErrorDescription: fmt.Sprintf("invalid username or password: %s", err.Error()), }, nil } + if user.IsForbidden { return nil, &TokenError{ Error: InvalidGrant, diff --git a/object/verification.go b/object/verification.go index b4c030d0..6374add2 100644 --- a/object/verification.go +++ b/object/verification.go @@ -192,32 +192,32 @@ func CheckVerificationCode(dest string, code string, lang string) *VerifyResult return &VerifyResult{VerificationSuccess, ""} } -func DisableVerificationCode(dest string) (err error) { +func DisableVerificationCode(dest string) error { record, err := getVerificationRecord(dest) if record == nil || err != nil { - return + return nil } record.IsUsed = true _, err = ormer.Engine.ID(core.PK{record.Owner, record.Name}).AllCols().Update(record) - return + return err } -func CheckSigninCode(user *User, dest, code, lang string) string { +func CheckSigninCode(user *User, dest, code, lang string) error { // check the login error times - if msg := checkSigninErrorTimes(user, lang); msg != "" { - return msg + err := checkSigninErrorTimes(user, lang) + if err != nil { + return err } result := CheckVerificationCode(dest, code, lang) switch result.Code { case VerificationSuccess: - resetUserSigninErrorTimes(user) - return "" + return resetUserSigninErrorTimes(user) case wrongCodeError: return recordSigninErrorInfo(user, lang) default: - return result.Msg + return fmt.Errorf(result.Msg) } } diff --git a/radius/server.go b/radius/server.go index 5a25dc2c..1cd01b1c 100644 --- a/radius/server.go +++ b/radius/server.go @@ -55,15 +55,18 @@ func handleAccessRequest(w radius.ResponseWriter, r *radius.Request) { password := rfc2865.UserPassword_GetString(r.Packet) organization := rfc2865.Class_GetString(r.Packet) log.Printf("handleAccessRequest() username=%v, org=%v, password=%v", username, organization, password) + if organization == "" { w.Write(r.Response(radius.CodeAccessReject)) return } - _, msg := object.CheckUserPassword(organization, username, password, "en") - if msg != "" { + + _, err := object.CheckUserPassword(organization, username, password, "en") + if err != nil { w.Write(r.Response(radius.CodeAccessReject)) return } + w.Write(r.Response(radius.CodeAccessAccept)) } diff --git a/routers/auto_signin_filter.go b/routers/auto_signin_filter.go index a5a9e120..7c09f2e1 100644 --- a/routers/auto_signin_filter.go +++ b/routers/auto_signin_filter.go @@ -83,13 +83,12 @@ func AutoSigninFilter(ctx *context.Context) { password := ctx.Input.Query("password") if userId != "" && password != "" && ctx.Input.Query("grant_type") == "" { owner, name := util.GetOwnerAndNameFromId(userId) - _, msg := object.CheckUserPassword(owner, name, password, "en") - if msg != "" { - responseError(ctx, msg) + _, err = object.CheckUserPassword(owner, name, password, "en") + if err != nil { + responseError(ctx, err.Error()) return } setSessionUser(ctx, userId) - return } }