diff --git a/authz/authz.go b/authz/authz.go index 151ed05b..2c38c347 100644 --- a/authz/authz.go +++ b/authz/authz.go @@ -154,7 +154,11 @@ func IsAllowed(subOwner string, subName string, method string, urlPath string, o } } - user := object.GetUser(util.GetId(subOwner, subName)) + user, err := object.GetUser(util.GetId(subOwner, subName)) + if err != nil { + panic(err) + } + if user != nil && user.IsAdmin && (subOwner == objOwner || (objOwner == "admin")) { return true } diff --git a/conf/conf.go b/conf/conf.go index c06a68cb..adbd9e8a 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -16,7 +16,6 @@ package conf import ( "encoding/json" - "fmt" "os" "runtime" "strconv" @@ -73,14 +72,13 @@ func GetConfigString(key string) string { return res } -func GetConfigBool(key string) (bool, error) { +func GetConfigBool(key string) bool { value := GetConfigString(key) if value == "true" { - return true, nil - } else if value == "false" { - return false, nil + return true + } else { + return false } - return false, fmt.Errorf("value %s cannot be converted into bool", value) } func GetConfigInt64(key string) (int64, error) { diff --git a/conf/conf_test.go b/conf/conf_test.go index 0ee5653c..4062bc12 100644 --- a/conf/conf_test.go +++ b/conf/conf_test.go @@ -87,7 +87,7 @@ func TestGetConfBool(t *testing.T) { assert.Nil(t, err) for _, scenery := range scenarios { t.Run(scenery.description, func(t *testing.T) { - actual, err := GetConfigBool(scenery.input) + actual := GetConfigBool(scenery.input) assert.Nil(t, err) assert.Equal(t, scenery.expected, actual) }) diff --git a/controllers/account.go b/controllers/account.go index 567f2072..5e045507 100644 --- a/controllers/account.go +++ b/controllers/account.go @@ -78,13 +78,23 @@ func (c *ApiController) Signup() { return } - application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + if err != nil { + c.ResponseError(err.Error()) + return + } + if !application.EnableSignUp { c.ResponseError(c.T("account:The application does not allow to sign up new account")) return } - organization := object.GetOrganization(util.GetId("admin", authForm.Organization)) + organization, err := object.GetOrganization(util.GetId("admin", authForm.Organization)) + if err != nil { + c.ResponseError(c.T(err.Error())) + return + } + msg := object.CheckUserSignup(application, organization, &authForm, c.GetAcceptLanguage()) if msg != "" { c.ResponseError(msg) @@ -111,7 +121,11 @@ func (c *ApiController) Signup() { id := util.GenerateId() if application.GetSignupItemRule("ID") == "Incremental" { - lastUser := object.GetLastUser(authForm.Organization) + lastUser, err := object.GetLastUser(authForm.Organization) + if err != nil { + c.ResponseError(err.Error()) + return + } lastIdInt := -1 if lastUser != nil { @@ -173,25 +187,47 @@ func (c *ApiController) Signup() { } } - affected := object.AddUser(user) + affected, err := object.AddUser(user) + if err != nil { + c.ResponseError(err.Error()) + return + } + if !affected { c.ResponseError(c.T("account:Failed to add user"), util.StructToJson(user)) return } - object.AddUserToOriginalDatabase(user) + err = object.AddUserToOriginalDatabase(user) + if err != nil { + c.ResponseError(err.Error()) + return + } if application.HasPromptPage() { // The prompt page needs the user to be signed in c.SetSessionUsername(user.GetId()) } - object.DisableVerificationCode(authForm.Email) - object.DisableVerificationCode(checkPhone) + err = object.DisableVerificationCode(authForm.Email) + if err != nil { + c.ResponseError(err.Error()) + return + } + + err = object.DisableVerificationCode(checkPhone) + if err != nil { + c.ResponseError(err.Error()) + return + } isSignupFromPricing := authForm.Plan != "" && authForm.Pricing != "" if isSignupFromPricing { - object.Subscribe(organization.Name, user.Name, authForm.Plan, authForm.Pricing) + _, err = object.Subscribe(organization.Name, user.Name, authForm.Plan, authForm.Pricing) + if err != nil { + c.ResponseError(err.Error()) + return + } } record := object.NewRecord(c.Ctx) @@ -231,7 +267,11 @@ func (c *ApiController) Logout() { c.ClearUserSession() owner, username := util.GetOwnerAndNameFromId(user) - object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID()) + _, err := object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID()) + if err != nil { + c.ResponseError(err.Error()) + return + } util.LogInfo(c.Ctx, "API: [%s] logged out", user) @@ -252,7 +292,12 @@ func (c *ApiController) Logout() { return } - affected, application, token := object.ExpireTokenByAccessToken(accessToken) + affected, application, token, err := object.ExpireTokenByAccessToken(accessToken) + if err != nil { + c.ResponseError(err.Error()) + return + } + if !affected { c.ResponseError(c.T("token:Token not found, invalid accessToken")) return @@ -272,7 +317,12 @@ func (c *ApiController) Logout() { // TODO https://github.com/casdoor/casdoor/pull/1494#discussion_r1095675265 owner, username := util.GetOwnerAndNameFromId(user) - object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID()) + _, err := object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID()) + if err != nil { + c.ResponseError(err.Error()) + return + } + util.LogInfo(c.Ctx, "API: [%s] logged out", user) c.Ctx.Redirect(http.StatusFound, fmt.Sprintf("%s?state=%s", strings.TrimRight(redirectUri, "/"), state)) @@ -290,6 +340,7 @@ func (c *ApiController) Logout() { // @Success 200 {object} controllers.Response The Response object // @router /get-account [get] func (c *ApiController) GetAccount() { + var err error user, ok := c.RequireSignedInUser() if !ok { return @@ -297,20 +348,39 @@ func (c *ApiController) GetAccount() { managedAccounts := c.Input().Get("managedAccounts") if managedAccounts == "1" { - user = object.ExtendManagedAccountsWithUser(user) + user, err = object.ExtendManagedAccountsWithUser(user) + if err != nil { + c.ResponseError(err.Error()) + return + } } - object.ExtendUserWithRolesAndPermissions(user) + err = object.ExtendUserWithRolesAndPermissions(user) + if err != nil { + c.ResponseError(err.Error()) + return + } user.Permissions = object.GetMaskedPermissions(user.Permissions) user.Roles = object.GetMaskedRoles(user.Roles) - organization := object.GetMaskedOrganization(object.GetOrganizationByUser(user)) + organization, err := object.GetMaskedOrganization(object.GetOrganizationByUser(user)) + if err != nil { + c.ResponseError(err.Error()) + return + } + + u, err := object.GetMaskedUser(user) + if err != nil { + c.ResponseError(err.Error()) + return + } + resp := Response{ Status: "ok", Sub: user.Id, Name: user.Name, - Data: object.GetMaskedUser(user), + Data: u, Data2: organization, } c.Data["json"] = resp @@ -391,7 +461,12 @@ func (c *ApiController) GetCaptcha() { if captchaProvider != nil { if captchaProvider.Type == "Default" { - id, img := object.GetCaptcha() + id, img, err := object.GetCaptcha() + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(Captcha{Type: captchaProvider.Type, CaptchaId: id, CaptchaImage: img}) return } else if captchaProvider.Type != "" { diff --git a/controllers/application.go b/controllers/application.go index 8cd31ca5..3d167db3 100644 --- a/controllers/application.go +++ b/controllers/application.go @@ -40,21 +40,35 @@ func (c *ApiController) GetApplications() { sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") organization := c.Input().Get("organization") - + var err error if limit == "" || page == "" { var applications []*object.Application if organization == "" { - applications = object.GetApplications(owner) + applications, err = object.GetApplications(owner) } else { - applications = object.GetOrganizationApplications(owner, organization) + applications, err = object.GetOrganizationApplications(owner, organization) + } + + if err != nil { + panic(err) } c.Data["json"] = object.GetMaskedApplications(applications, userId) c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetApplicationCount(owner, field, value))) - applications := object.GetMaskedApplications(object.GetPaginationApplications(owner, paginator.Offset(), limit, field, value, sortField, sortOrder), userId) + count, err := object.GetApplicationCount(owner, field, value) + if err != nil { + panic(err) + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + app, err := object.GetPaginationApplications(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + panic(err) + } + + applications := object.GetMaskedApplications(app, userId) c.ResponseOk(applications, paginator.Nums()) } } @@ -69,8 +83,12 @@ func (c *ApiController) GetApplications() { func (c *ApiController) GetApplication() { userId := c.GetSessionUsername() id := c.Input().Get("id") + app, err := object.GetApplication(id) + if err != nil { + panic(err) + } - c.Data["json"] = object.GetMaskedApplication(object.GetApplication(id), userId) + c.Data["json"] = object.GetMaskedApplication(app, userId) c.ServeJSON() } @@ -84,13 +102,22 @@ func (c *ApiController) GetApplication() { func (c *ApiController) GetUserApplication() { userId := c.GetSessionUsername() id := c.Input().Get("id") - user := object.GetUser(id) + user, err := object.GetUser(id) + if err != nil { + panic(err) + } + if user == nil { c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), id)) return } - c.Data["json"] = object.GetMaskedApplication(object.GetApplicationByUser(user), userId) + app, err := object.GetApplicationByUser(user) + if err != nil { + panic(err) + } + + c.Data["json"] = object.GetMaskedApplication(app, userId) c.ServeJSON() } @@ -118,13 +145,30 @@ func (c *ApiController) GetOrganizationApplications() { } if limit == "" || page == "" { - applications := object.GetOrganizationApplications(owner, organization) + applications, err := object.GetOrganizationApplications(owner, organization) + if err != nil { + panic(err) + } + c.Data["json"] = object.GetMaskedApplications(applications, userId) c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetOrganizationApplicationCount(owner, organization, field, value))) - applications := object.GetMaskedApplications(object.GetPaginationOrganizationApplications(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder), userId) + + count, err := object.GetOrganizationApplicationCount(owner, organization, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + app, err := object.GetPaginationOrganizationApplications(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + + applications := object.GetMaskedApplications(app, userId) c.ResponseOk(applications, paginator.Nums()) } } @@ -166,8 +210,13 @@ func (c *ApiController) AddApplication() { return } - count := object.GetApplicationCount("", "", "") - if err := checkQuotaForApplication(count); err != nil { + count, err := object.GetApplicationCount("", "", "") + if err != nil { + c.ResponseError(err.Error()) + return + } + + if err := checkQuotaForApplication(int(count)); err != nil { c.ResponseError(err.Error()) return } diff --git a/controllers/auth.go b/controllers/auth.go index a8f202bc..771139e6 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -93,7 +93,12 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob c.ResponseError(c.T("auth:Challenge method should be S256")) return } - code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge, c.Ctx.Request.Host, c.GetAcceptLanguage()) + code, err := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge, c.Ctx.Request.Host, c.GetAcceptLanguage()) + if err != nil { + c.ResponseError(err.Error(), nil) + return + } + resp = codeToResponse(code) if application.EnableSigninSession || application.HasPromptPage() { @@ -142,12 +147,16 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob } if resp.Status == "ok" { - object.AddSession(&object.Session{ + _, err = object.AddSession(&object.Session{ Owner: user.Owner, Name: user.Name, Application: application.Name, SessionId: []string{c.Ctx.Input.CruSession.SessionID()}, }) + if err != nil { + c.ResponseError(err.Error(), nil) + return + } } return resp @@ -171,7 +180,12 @@ func (c *ApiController) GetApplicationLogin() { scope := c.Input().Get("scope") state := c.Input().Get("state") - msg, application := object.CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, c.GetAcceptLanguage()) + msg, application, err := object.CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, c.GetAcceptLanguage()) + if err != nil { + c.ResponseError(err.Error()) + return + } + application = object.GetMaskedApplication(application, "") if msg != "" { c.ResponseError(msg, application) @@ -248,7 +262,10 @@ func (c *ApiController) Login() { var msg string if authForm.Password == "" { - if user = object.GetUserByFields(authForm.Organization, authForm.Username); user == nil { + if user, err = object.GetUserByFields(authForm.Organization, authForm.Username); err != nil { + c.ResponseError(err.Error(), nil) + return + } else if user == nil { c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(authForm.Organization, authForm.Username))) return } @@ -272,9 +289,18 @@ func (c *ApiController) Login() { } // disable the verification code - object.DisableVerificationCode(checkDest) + err := object.DisableVerificationCode(checkDest) + if err != nil { + c.ResponseError(err.Error(), nil) + return + } } else { - application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + if err != nil { + c.ResponseError(err.Error(), nil) + return + } + if application == nil { c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) return @@ -284,7 +310,10 @@ func (c *ApiController) Login() { return } var enableCaptcha bool - if enableCaptcha = object.CheckToEnableCaptcha(application, authForm.Organization, authForm.Username); enableCaptcha { + if enableCaptcha, err = object.CheckToEnableCaptcha(application, authForm.Organization, authForm.Username); err != nil { + c.ResponseError(err.Error()) + return + } else if enableCaptcha { isHuman, err := captcha.VerifyCaptchaByCaptchaType(authForm.CaptchaType, authForm.CaptchaToken, authForm.ClientSecret) if err != nil { c.ResponseError(err.Error()) @@ -304,7 +333,12 @@ func (c *ApiController) Login() { if msg != "" { resp = &Response{Status: "error", Msg: msg} } else { - application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + if err != nil { + c.ResponseError(err.Error()) + return + } + if application == nil { c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) return @@ -312,7 +346,11 @@ func (c *ApiController) Login() { resp = c.HandleLoggedIn(application, user, &authForm) - organization := object.GetOrganizationByUser(user) + organization, err := object.GetOrganizationByUser(user) + if err != nil { + c.ResponseError(err.Error()) + } + if user != nil && organization.HasRequiredMfa() && !user.IsMfaEnabled() { resp.Msg = object.RequiredMfa } @@ -325,18 +363,34 @@ func (c *ApiController) Login() { } else if authForm.Provider != "" { var application *object.Application if authForm.ClientId != "" { - application = object.GetApplicationByClientId(authForm.ClientId) + application, err = object.GetApplicationByClientId(authForm.ClientId) + if err != nil { + c.ResponseError(err.Error()) + return + } } else { - application = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + application, err = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + if err != nil { + c.ResponseError(err.Error()) + return + } } if application == nil { c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) return } + organization, err := object.GetOrganization(util.GetId("admin", application.Organization)) + if err != nil { + c.ResponseError(c.T(err.Error())) + } + + provider, err := object.GetProvider(util.GetId("admin", authForm.Provider)) + if err != nil { + c.ResponseError(err.Error()) + return + } - organization := object.GetOrganization(util.GetId("admin", application.Organization)) - provider := object.GetProvider(util.GetId("admin", authForm.Provider)) providerItem := application.GetProviderItem(provider.Name) if !providerItem.IsProviderVisible() { c.ResponseError(fmt.Sprintf(c.T("auth:The provider: %s is not enabled for the application"), provider.Name)) @@ -396,9 +450,17 @@ func (c *ApiController) Login() { if authForm.Method == "signup" { user := &object.User{} if provider.Category == "SAML" { - user = object.GetUser(util.GetId(application.Organization, userInfo.Id)) + user, err = object.GetUser(util.GetId(application.Organization, userInfo.Id)) + if err != nil { + c.ResponseError(err.Error()) + return + } } else if provider.Category == "OAuth" { - user = object.GetUserByField(application.Organization, provider.Type, userInfo.Id) + user, err = object.GetUserByField(application.Organization, provider.Type, userInfo.Id) + if err != nil { + c.ResponseError(err.Error()) + return + } } if user != nil && !user.IsDeleted { @@ -419,12 +481,20 @@ func (c *ApiController) Login() { if application.EnableLinkWithEmail { if userInfo.Email != "" { // Find existing user with Email - user = object.GetUserByField(application.Organization, "email", userInfo.Email) + user, err = object.GetUserByField(application.Organization, "email", userInfo.Email) + if err != nil { + c.ResponseError(err.Error()) + return + } } if user == nil && userInfo.Phone != "" { // Find existing user with phone number - user = object.GetUserByField(application.Organization, "phone", userInfo.Phone) + user, err = object.GetUserByField(application.Organization, "phone", userInfo.Phone) + if err != nil { + c.ResponseError(err.Error()) + return + } } } @@ -440,7 +510,12 @@ func (c *ApiController) Login() { } // Handle username conflicts - tmpUser := object.GetUser(util.GetId(application.Organization, userInfo.Username)) + tmpUser, err := object.GetUser(util.GetId(application.Organization, userInfo.Username)) + if err != nil { + c.ResponseError(err.Error()) + return + } + if tmpUser != nil { uid, err := uuid.NewRandom() if err != nil { @@ -453,7 +528,13 @@ func (c *ApiController) Login() { } properties := map[string]string{} - properties["no"] = strconv.Itoa(object.GetUserCount(application.Organization, "", "") + 2) + count, err := object.GetUserCount(application.Organization, "", "") + if err != nil { + c.ResponseError(err.Error()) + return + } + + properties["no"] = strconv.Itoa(int(count + 2)) initScore, err := organization.GetInitScore() if err != nil { c.ResponseError(fmt.Errorf(c.T("account:Get init score failed, error: %w"), err).Error()) @@ -482,7 +563,12 @@ func (c *ApiController) Login() { Properties: properties, } - affected := object.AddUser(user) + affected, err := object.AddUser(user) + if err != nil { + c.ResponseError(err.Error()) + return + } + if !affected { c.ResponseError(fmt.Sprintf(c.T("auth:Failed to create user, user information is invalid: %s"), util.StructToJson(user))) return @@ -490,8 +576,17 @@ func (c *ApiController) Login() { } // sync info from 3rd-party if possible - object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) - object.LinkUserAccount(user, provider.Type, userInfo.Id) + _, err := object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) + if err != nil { + c.ResponseError(err.Error()) + return + } + + _, err = object.LinkUserAccount(user, provider.Type, userInfo.Id) + if err != nil { + c.ResponseError(err.Error()) + return + } resp = c.HandleLoggedIn(application, user, &authForm) @@ -516,18 +611,36 @@ func (c *ApiController) Login() { return } - oldUser := object.GetUserByField(application.Organization, provider.Type, userInfo.Id) + oldUser, err := object.GetUserByField(application.Organization, provider.Type, userInfo.Id) + if err != nil { + c.ResponseError(err.Error()) + return + } + if oldUser != nil { c.ResponseError(fmt.Sprintf(c.T("auth:The account for provider: %s and username: %s (%s) is already linked to another account: %s (%s)"), provider.Type, userInfo.Username, userInfo.DisplayName, oldUser.Name, oldUser.DisplayName)) return } - user := object.GetUser(userId) + user, err := object.GetUser(userId) + if err != nil { + c.ResponseError(err.Error()) + return + } // sync info from 3rd-party if possible - object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) + _, err = object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) + if err != nil { + c.ResponseError(err.Error()) + return + } + + isLinked, err := object.LinkUserAccount(user, provider.Type, userInfo.Id) + if err != nil { + c.ResponseError(err.Error()) + return + } - isLinked := object.LinkUserAccount(user, provider.Type, userInfo.Id) if isLinked { resp = &Response{Status: "ok", Msg: "", Data: isLinked} } else { @@ -536,7 +649,11 @@ func (c *ApiController) Login() { } } else if c.getMfaSessionData() != nil { mfaSession := c.getMfaSessionData() - user := object.GetUser(mfaSession.UserId) + user, err := object.GetUser(mfaSession.UserId) + if err != nil { + c.ResponseError(err.Error()) + return + } if authForm.Passcode != "" { MfaUtil := object.GetMfaUtil(authForm.MfaType, user.GetPreferMfa(false)) @@ -554,7 +671,12 @@ func (c *ApiController) Login() { } } - application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + if err != nil { + c.ResponseError(err.Error()) + return + } + if application == nil { c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) return @@ -569,7 +691,12 @@ func (c *ApiController) Login() { } else { if c.GetSessionUsername() != "" { // user already signed in to Casdoor, so let the user click the avatar button to do the quick sign-in - application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) + if err != nil { + c.ResponseError(err.Error()) + return + } + if application == nil { c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) return @@ -624,8 +751,9 @@ func (c *ApiController) HandleSamlLogin() { func (c *ApiController) HandleOfficialAccountEvent() { respBytes, err := ioutil.ReadAll(c.Ctx.Request.Body) if err != nil { - c.ResponseError(err.Error()) + panic(err) } + var data struct { MsgType string `xml:"MsgType"` Event string `xml:"Event"` @@ -633,8 +761,9 @@ func (c *ApiController) HandleOfficialAccountEvent() { } err = xml.Unmarshal(respBytes, &data) if err != nil { - c.ResponseError(err.Error()) + panic(err) } + lock.Lock() defer lock.Unlock() if data.EventKey != "" { @@ -670,7 +799,12 @@ func (c *ApiController) GetWebhookEventType() { func (c *ApiController) GetCaptchaStatus() { organization := c.Input().Get("organization") userId := c.Input().Get("user_id") - user := object.GetUserByFields(organization, userId) + user, err := object.GetUserByFields(organization, userId) + if err != nil { + c.ResponseError(err.Error()) + return + } + var captchaEnabled bool if user != nil && user.SigninWrongTimes >= object.SigninWrongTimesLimit { captchaEnabled = true diff --git a/controllers/base.go b/controllers/base.go index d70c1d3e..503f3b43 100644 --- a/controllers/base.go +++ b/controllers/base.go @@ -72,11 +72,15 @@ func (c *ApiController) isGlobalAdmin() (bool, *object.User) { func (c *ApiController) getCurrentUser() *object.User { var user *object.User + var err error userId := c.GetSessionUsername() if userId == "" { user = nil } else { - user = object.GetUser(userId) + user, err = object.GetUser(userId) + if err != nil { + panic(err) + } } return user } @@ -106,7 +110,11 @@ func (c *ApiController) GetSessionApplication() *object.Application { if clientId == nil { return nil } - application := object.GetApplicationByClientId(clientId.(string)) + application, err := object.GetApplicationByClientId(clientId.(string)) + if err != nil { + panic(err) + } + return application } @@ -192,8 +200,10 @@ func (c *ApiController) setExpireForSession() { }) } -func wrapActionResponse(affected bool) *Response { - if affected { +func wrapActionResponse(affected bool, e ...error) *Response { + if len(e) != 0 && e[0] != nil { + return &Response{Status: "error", Msg: e[0].Error()} + } else if affected { return &Response{Status: "ok", Msg: "", Data: "Affected"} } else { return &Response{Status: "ok", Msg: "", Data: "Unaffected"} diff --git a/controllers/casbin_adapter.go b/controllers/casbin_adapter.go index 75d9c673..34b2319e 100644 --- a/controllers/casbin_adapter.go +++ b/controllers/casbin_adapter.go @@ -33,19 +33,40 @@ func (c *ApiController) GetCasbinAdapters() { sortOrder := c.Input().Get("sortOrder") organization := c.Input().Get("organization") if limit == "" || page == "" { - adapters := object.GetCasbinAdapters(owner, organization) + adapters, err := object.GetCasbinAdapters(owner, organization) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(adapters) } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetCasbinAdapterCount(owner, organization, field, value))) - adapters := object.GetPaginationCasbinAdapters(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetCasbinAdapterCount(owner, organization, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + adapters, err := object.GetPaginationCasbinAdapters(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(adapters, paginator.Nums()) } } func (c *ApiController) GetCasbinAdapter() { id := c.Input().Get("id") - adapter := object.GetCasbinAdapter(id) + adapter, err := object.GetCasbinAdapter(id) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(adapter) } @@ -89,7 +110,11 @@ func (c *ApiController) DeleteCasbinAdapter() { func (c *ApiController) SyncPolicies() { id := c.Input().Get("id") - adapter := object.GetCasbinAdapter(id) + adapter, err := object.GetCasbinAdapter(id) + if err != nil { + c.ResponseError(err.Error()) + return + } policies, err := object.SyncPolicies(adapter) if err != nil { @@ -102,9 +127,14 @@ func (c *ApiController) SyncPolicies() { func (c *ApiController) UpdatePolicy() { id := c.Input().Get("id") - adapter := object.GetCasbinAdapter(id) + adapter, err := object.GetCasbinAdapter(id) + if err != nil { + c.ResponseError(err.Error()) + return + } + var policies []xormadapter.CasbinRule - err := json.Unmarshal(c.Ctx.Input.RequestBody, &policies) + err = json.Unmarshal(c.Ctx.Input.RequestBody, &policies) if err != nil { c.ResponseError(err.Error()) return @@ -121,9 +151,14 @@ func (c *ApiController) UpdatePolicy() { func (c *ApiController) AddPolicy() { id := c.Input().Get("id") - adapter := object.GetCasbinAdapter(id) + adapter, err := object.GetCasbinAdapter(id) + if err != nil { + c.ResponseError(err.Error()) + return + } + var policy xormadapter.CasbinRule - err := json.Unmarshal(c.Ctx.Input.RequestBody, &policy) + err = json.Unmarshal(c.Ctx.Input.RequestBody, &policy) if err != nil { c.ResponseError(err.Error()) return @@ -140,9 +175,14 @@ func (c *ApiController) AddPolicy() { func (c *ApiController) RemovePolicy() { id := c.Input().Get("id") - adapter := object.GetCasbinAdapter(id) + adapter, err := object.GetCasbinAdapter(id) + if err != nil { + c.ResponseError(err.Error()) + return + } + var policy xormadapter.CasbinRule - err := json.Unmarshal(c.Ctx.Input.RequestBody, &policy) + err = json.Unmarshal(c.Ctx.Input.RequestBody, &policy) if err != nil { c.ResponseError(err.Error()) return diff --git a/controllers/cert.go b/controllers/cert.go index 64dac054..2fa93bc4 100644 --- a/controllers/cert.go +++ b/controllers/cert.go @@ -37,13 +37,28 @@ func (c *ApiController) GetCerts() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetMaskedCerts(object.GetCerts(owner)) + maskedCerts, err := object.GetMaskedCerts(object.GetCerts(owner)) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedCerts c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetCertCount(owner, field, value))) - certs := object.GetMaskedCerts(object.GetPaginationCerts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) + count, err := object.GetCertCount(owner, field, value) + if err != nil { + panic(err) + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + certs, err := object.GetMaskedCerts(object.GetPaginationCerts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) + if err != nil { + panic(err) + } + c.ResponseOk(certs, paginator.Nums()) } } @@ -61,13 +76,28 @@ func (c *ApiController) GetGlobleCerts() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetMaskedCerts(object.GetGlobleCerts()) + maskedCerts, err := object.GetMaskedCerts(object.GetGlobleCerts()) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedCerts c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalCertsCount(field, value))) - certs := object.GetMaskedCerts(object.GetPaginationGlobalCerts(paginator.Offset(), limit, field, value, sortField, sortOrder)) + count, err := object.GetGlobalCertsCount(field, value) + if err != nil { + panic(err) + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + certs, err := object.GetMaskedCerts(object.GetPaginationGlobalCerts(paginator.Offset(), limit, field, value, sortField, sortOrder)) + if err != nil { + panic(err) + } + c.ResponseOk(certs, paginator.Nums()) } } @@ -81,8 +111,12 @@ func (c *ApiController) GetGlobleCerts() { // @router /get-cert [get] func (c *ApiController) GetCert() { id := c.Input().Get("id") + cert, err := object.GetCert(id) + if err != nil { + panic(err) + } - c.Data["json"] = object.GetMaskedCert(object.GetCert(id)) + c.Data["json"] = object.GetMaskedCert(cert) c.ServeJSON() } diff --git a/controllers/chat.go b/controllers/chat.go index f501966d..3b041304 100644 --- a/controllers/chat.go +++ b/controllers/chat.go @@ -37,13 +37,30 @@ func (c *ApiController) GetChats() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetMaskedChats(object.GetChats(owner)) + maskedChats, err := object.GetMaskedChats(object.GetChats(owner)) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedChats c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetChatCount(owner, field, value))) - chats := object.GetMaskedChats(object.GetPaginationChats(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) + count, err := object.GetChatCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + chats, err := object.GetMaskedChats(object.GetPaginationChats(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(chats, paginator.Nums()) } } @@ -58,7 +75,12 @@ func (c *ApiController) GetChats() { func (c *ApiController) GetChat() { id := c.Input().Get("id") - c.Data["json"] = object.GetMaskedChat(object.GetChat(id)) + maskedChat, err := object.GetMaskedChat(object.GetChat(id)) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedChat c.ServeJSON() } diff --git a/controllers/enforcer.go b/controllers/enforcer.go index 9b9e98a2..42ce8e77 100644 --- a/controllers/enforcer.go +++ b/controllers/enforcer.go @@ -34,8 +34,7 @@ func (c *ApiController) Enforce() { } if permissionId != "" { - c.Data["json"] = object.Enforce(permissionId, &request) - c.ServeJSON() + c.ResponseOk(object.Enforce(permissionId, &request)) return } @@ -44,9 +43,15 @@ func (c *ApiController) Enforce() { if modelId != "" { owner, modelName := util.GetOwnerAndNameFromId(modelId) - permissions = object.GetPermissionsByModel(owner, modelName) + permissions, err = object.GetPermissionsByModel(owner, modelName) + if err != nil { + panic(err) + } } else { - permissions = object.GetPermissionsByResource(resourceId) + permissions, err = object.GetPermissionsByResource(resourceId) + if err != nil { + panic(err) + } } for _, permission := range permissions { @@ -63,8 +68,7 @@ func (c *ApiController) BatchEnforce() { var requests []object.CasbinRequest err := json.Unmarshal(c.Ctx.Input.RequestBody, &requests) if err != nil { - c.ResponseError(err.Error()) - return + panic(err) } if permissionId != "" { @@ -72,14 +76,17 @@ func (c *ApiController) BatchEnforce() { c.ServeJSON() } else { owner, modelName := util.GetOwnerAndNameFromId(modelId) - permissions := object.GetPermissionsByModel(owner, modelName) + permissions, err := object.GetPermissionsByModel(owner, modelName) + if err != nil { + panic(err) + } res := [][]bool{} for _, permission := range permissions { res = append(res, object.BatchEnforce(permission.GetId(), &requests)) } - c.Data["json"] = res - c.ServeJSON() + + c.ResponseOk(res) } } @@ -90,8 +97,7 @@ func (c *ApiController) GetAllObjects() { return } - c.Data["json"] = object.GetAllObjects(userId) - c.ServeJSON() + c.ResponseOk(object.GetAllObjects(userId)) } func (c *ApiController) GetAllActions() { @@ -101,8 +107,7 @@ func (c *ApiController) GetAllActions() { return } - c.Data["json"] = object.GetAllActions(userId) - c.ServeJSON() + c.ResponseOk(object.GetAllActions(userId)) } func (c *ApiController) GetAllRoles() { @@ -112,6 +117,5 @@ func (c *ApiController) GetAllRoles() { return } - c.Data["json"] = object.GetAllRoles(userId) - c.ServeJSON() + c.ResponseOk(object.GetAllRoles(userId)) } diff --git a/controllers/ldap.go b/controllers/ldap.go index 6a45f96e..02d5c639 100644 --- a/controllers/ldap.go +++ b/controllers/ldap.go @@ -45,7 +45,11 @@ func (c *ApiController) GetLdapUsers() { id := c.Input().Get("id") _, ldapId := util.GetOwnerAndNameFromId(id) - ldapServer := object.GetLdap(ldapId) + ldapServer, err := object.GetLdap(ldapId) + if err != nil { + c.ResponseError(err.Error()) + return + } conn, err := ldapServer.GetLdapConn() if err != nil { @@ -76,7 +80,11 @@ func (c *ApiController) GetLdapUsers() { for i, user := range users { uuids[i] = user.GetLdapUuid() } - existUuids := object.GetExistUuids(ldapServer.Owner, uuids) + existUuids, err := object.GetExistUuids(ldapServer.Owner, uuids) + if err != nil { + c.ResponseError(err.Error()) + return + } resp := LdapResp{ Users: object.AutoAdjustLdapUser(users), @@ -128,17 +136,23 @@ func (c *ApiController) AddLdap() { return } - if object.CheckLdapExist(&ldap) { + if ok, err := object.CheckLdapExist(&ldap); err != nil { + c.ResponseError(err.Error()) + return + } else if ok { c.ResponseError(c.T("ldap:Ldap server exist")) return } - affected := object.AddLdap(&ldap) - resp := wrapActionResponse(affected) + resp := wrapActionResponse(object.AddLdap(&ldap)) resp.Data2 = ldap if ldap.AutoSync != 0 { - object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id) + err = object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id) + if err != nil { + c.ResponseError(err.Error()) + return + } } c.Data["json"] = resp @@ -157,11 +171,24 @@ func (c *ApiController) UpdateLdap() { return } - prevLdap := object.GetLdap(ldap.Id) - affected := object.UpdateLdap(&ldap) + prevLdap, err := object.GetLdap(ldap.Id) + if err != nil { + c.ResponseError(err.Error()) + return + } + + affected, err := object.UpdateLdap(&ldap) + if err != nil { + c.ResponseError(err.Error()) + return + } if ldap.AutoSync != 0 { - object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id) + err := object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id) + if err != nil { + c.ResponseError(err.Error()) + return + } } else if ldap.AutoSync == 0 && prevLdap.AutoSync != 0 { object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id) } @@ -182,7 +209,11 @@ func (c *ApiController) DeleteLdap() { return } - affected := object.DeleteLdap(&ldap) + affected, err := object.DeleteLdap(&ldap) + if err != nil { + c.ResponseError(err.Error()) + return + } object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id) @@ -204,7 +235,11 @@ func (c *ApiController) SyncLdapUsers() { return } - object.UpdateLdapSyncTime(ldapId) + err = object.UpdateLdapSyncTime(ldapId) + if err != nil { + c.ResponseError(err.Error()) + return + } exist, failed, _ := object.SyncLdapUsers(owner, users, ldapId) diff --git a/controllers/link.go b/controllers/link.go index bc26babc..1fe96b4c 100644 --- a/controllers/link.go +++ b/controllers/link.go @@ -53,7 +53,12 @@ func (c *ApiController) Unlink() { if user.Id == unlinkedUser.Id && !user.IsGlobalAdmin { // if the user is unlinking themselves, should check the provider can be unlinked, if not, we should return an error. - application := object.GetApplicationByUser(user) + application, err := object.GetApplicationByUser(user) + if err != nil { + c.ResponseError(err.Error()) + return + } + if application == nil { c.ResponseError(c.T("link:You can't unlink yourself, you are not a member of any application")) return @@ -88,8 +93,17 @@ func (c *ApiController) Unlink() { return } - object.ClearUserOAuthProperties(&unlinkedUser, providerType) + _, err = object.ClearUserOAuthProperties(&unlinkedUser, providerType) + if err != nil { + c.ResponseError(err.Error()) + return + } + + _, err = object.LinkUserAccount(&unlinkedUser, providerType, "") + if err != nil { + c.ResponseError(err.Error()) + return + } - object.LinkUserAccount(&unlinkedUser, providerType, "") c.ResponseOk() } diff --git a/controllers/message.go b/controllers/message.go index a87f1564..550367b5 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -44,18 +44,35 @@ func (c *ApiController) GetMessages() { organization := c.Input().Get("organization") if limit == "" || page == "" { var messages []*object.Message + var err error if chat == "" { - messages = object.GetMessages(owner) + messages, err = object.GetMessages(owner) } else { - messages = object.GetChatMessages(chat) + messages, err = object.GetChatMessages(chat) + } + + if err != nil { + panic(err) } c.Data["json"] = object.GetMaskedMessages(messages) c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetMessageCount(owner, organization, field, value))) - messages := object.GetMaskedMessages(object.GetPaginationMessages(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)) + count, err := object.GetMessageCount(owner, organization, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + paginationMessages, err := object.GetPaginationMessages(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + + messages := object.GetMaskedMessages(paginationMessages) c.ResponseOk(messages, paginator.Nums()) } } @@ -69,8 +86,12 @@ func (c *ApiController) GetMessages() { // @router /get-message [get] func (c *ApiController) GetMessage() { id := c.Input().Get("id") + message, err := object.GetMessage(id) + if err != nil { + panic(err) + } - c.Data["json"] = object.GetMaskedMessage(object.GetMessage(id)) + c.Data["json"] = object.GetMaskedMessage(message) c.ServeJSON() } @@ -96,7 +117,12 @@ func (c *ApiController) GetMessageAnswer() { c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") - message := object.GetMessage(id) + message, err := object.GetMessage(id) + if err != nil { + c.ResponseError(err.Error()) + return + } + if message == nil { c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) return @@ -108,7 +134,12 @@ func (c *ApiController) GetMessageAnswer() { } chatId := util.GetId("admin", message.Chat) - chat := object.GetChat(chatId) + chat, err := object.GetChat(chatId) + if err != nil { + c.ResponseError(err.Error()) + return + } + if chat == nil || chat.Organization != message.Organization { c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) return @@ -119,14 +150,19 @@ func (c *ApiController) GetMessageAnswer() { return } - questionMessage := object.GetMessage(message.ReplyTo) + questionMessage, err := object.GetMessage(message.ReplyTo) if questionMessage == nil { c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) return } providerId := util.GetId(chat.Owner, chat.User2) - provider := object.GetProvider(providerId) + provider, err := object.GetProvider(providerId) + if err != nil { + c.ResponseError(err.Error()) + return + } + if provider == nil { c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId)) return @@ -148,7 +184,7 @@ func (c *ApiController) GetMessageAnswer() { fmt.Printf("Question: [%s]\n", questionMessage.Text) fmt.Printf("Answer: [") - err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder) + err = ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder) if err != nil { c.ResponseErrorStream(err.Error()) return @@ -165,7 +201,10 @@ func (c *ApiController) GetMessageAnswer() { answer := stringBuilder.String() message.Text = answer - object.UpdateMessage(message.GetId(), message) + _, err = object.UpdateMessage(message.GetId(), message) + if err != nil { + panic(err) + } } // UpdateMessage @@ -208,14 +247,24 @@ func (c *ApiController) AddMessage() { var chat *object.Chat if message.Chat != "" { chatId := util.GetId("admin", message.Chat) - chat = object.GetChat(chatId) + chat, err = object.GetChat(chatId) + if err != nil { + c.ResponseError(err.Error()) + return + } + if chat == nil || chat.Organization != message.Organization { c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) return } } - affected := object.AddMessage(&message) + affected, err := object.AddMessage(&message) + if err != nil { + c.ResponseError(err.Error()) + return + } + if affected { if chat != nil && chat.Type == "AI" { answerMessage := &object.Message{ @@ -228,7 +277,11 @@ func (c *ApiController) AddMessage() { Author: "AI", Text: "", } - object.AddMessage(answerMessage) + _, err = object.AddMessage(answerMessage) + if err != nil { + c.ResponseError(err.Error()) + return + } } } diff --git a/controllers/mfa.go b/controllers/mfa.go index 35befdd8..23c73204 100644 --- a/controllers/mfa.go +++ b/controllers/mfa.go @@ -46,7 +46,12 @@ func (c *ApiController) MfaSetupInitiate() { if MfaUtil == nil { c.ResponseError("Invalid auth type") } - user := object.GetUser(userId) + user, err := object.GetUser(userId) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError("User doesn't exist") return @@ -105,14 +110,19 @@ func (c *ApiController) MfaSetupEnable() { name := c.Ctx.Request.Form.Get("name") authType := c.Ctx.Request.Form.Get("type") - user := object.GetUser(util.GetId(owner, name)) + user, err := object.GetUser(util.GetId(owner, name)) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError("User doesn't exist") return } twoFactor := object.GetMfaUtil(authType, nil) - err := twoFactor.Enable(c.Ctx, user) + err = twoFactor.Enable(c.Ctx, user) if err != nil { c.ResponseError(err.Error()) return @@ -136,7 +146,12 @@ func (c *ApiController) DeleteMfa() { name := c.Ctx.Request.Form.Get("name") userId := util.GetId(owner, name) - user := object.GetUser(userId) + user, err := object.GetUser(userId) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError("User doesn't exist") return @@ -151,7 +166,12 @@ func (c *ApiController) DeleteMfa() { } } user.MultiFactorAuths = mfaProps - object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser()) + _, err = object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser()) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(user.MultiFactorAuths) } @@ -170,7 +190,12 @@ func (c *ApiController) SetPreferredMfa() { name := c.Ctx.Request.Form.Get("name") userId := util.GetId(owner, name) - user := object.GetUser(userId) + user, err := object.GetUser(userId) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError("User doesn't exist") return @@ -185,7 +210,11 @@ func (c *ApiController) SetPreferredMfa() { } } - object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser()) + _, err = object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser()) + if err != nil { + c.ResponseError(err.Error()) + return + } for i, mfaProp := range mfaProps { mfaProps[i] = object.GetMaskedProps(mfaProp) diff --git a/controllers/model.go b/controllers/model.go index 6f8c8a03..2ab13f31 100644 --- a/controllers/model.go +++ b/controllers/model.go @@ -37,13 +37,30 @@ func (c *ApiController) GetModels() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetModels(owner) + models, err := object.GetModels(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = models c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetModelCount(owner, field, value))) - models := object.GetPaginationModels(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetModelCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + models, err := object.GetPaginationModels(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(models, paginator.Nums()) } } @@ -58,7 +75,12 @@ func (c *ApiController) GetModels() { func (c *ApiController) GetModel() { id := c.Input().Get("id") - c.Data["json"] = object.GetModel(id) + model, err := object.GetModel(id) + if err != nil { + panic(err) + } + + c.Data["json"] = model c.ServeJSON() } diff --git a/controllers/organization.go b/controllers/organization.go index 730308e1..ee2af9b8 100644 --- a/controllers/organization.go +++ b/controllers/organization.go @@ -37,13 +37,30 @@ func (c *ApiController) GetOrganizations() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetMaskedOrganizations(object.GetOrganizations(owner)) + maskedOrganizations, err := object.GetMaskedOrganizations(object.GetOrganizations(owner)) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedOrganizations c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetOrganizationCount(owner, field, value))) - organizations := object.GetMaskedOrganizations(object.GetPaginationOrganizations(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) + count, err := object.GetOrganizationCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + organizations, err := object.GetMaskedOrganizations(object.GetPaginationOrganizations(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(organizations, paginator.Nums()) } } @@ -58,7 +75,12 @@ func (c *ApiController) GetOrganizations() { func (c *ApiController) GetOrganization() { id := c.Input().Get("id") - c.Data["json"] = object.GetMaskedOrganization(object.GetOrganization(id)) + maskedOrganization, err := object.GetMaskedOrganization(object.GetOrganization(id)) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedOrganization c.ServeJSON() } @@ -99,8 +121,13 @@ func (c *ApiController) AddOrganization() { return } - count := object.GetOrganizationCount("", "", "") - if err := checkQuotaForOrganization(count); err != nil { + count, err := object.GetOrganizationCount("", "", "") + if err != nil { + c.ResponseError(err.Error()) + return + } + + if err = checkQuotaForOrganization(int(count)); err != nil { c.ResponseError(err.Error()) return } @@ -158,6 +185,11 @@ func (c *ApiController) GetDefaultApplication() { // @router /get-organization-names [get] func (c *ApiController) GetOrganizationNames() { owner := c.Input().Get("owner") - organizationNames := object.GetOrganizationsByFields(owner, "name") + organizationNames, err := object.GetOrganizationsByFields(owner, "name") + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(organizationNames) } diff --git a/controllers/payment.go b/controllers/payment.go index be04e220..65e72799 100644 --- a/controllers/payment.go +++ b/controllers/payment.go @@ -37,13 +37,28 @@ func (c *ApiController) GetPayments() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetPayments(owner) + payments, err := object.GetPayments(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = payments c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPaymentCount(owner, field, value))) - payments := object.GetPaginationPayments(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetPaymentCount(owner, field, value) + if err != nil { + panic(err) + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + payments, err := object.GetPaginationPayments(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + panic(err) + } + c.ResponseOk(payments, paginator.Nums()) } } @@ -62,7 +77,12 @@ func (c *ApiController) GetUserPayments() { organization := c.Input().Get("organization") user := c.Input().Get("user") - payments := object.GetUserPayments(owner, organization, user) + payments, err := object.GetUserPayments(owner, organization, user) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(payments) } @@ -76,7 +96,12 @@ func (c *ApiController) GetUserPayments() { func (c *ApiController) GetPayment() { id := c.Input().Get("id") - c.Data["json"] = object.GetPayment(id) + payment, err := object.GetPayment(id) + if err != nil { + panic(err) + } + + c.Data["json"] = payment c.ServeJSON() } @@ -177,7 +202,12 @@ func (c *ApiController) NotifyPayment() { func (c *ApiController) InvoicePayment() { id := c.Input().Get("id") - payment := object.GetPayment(id) + payment, err := object.GetPayment(id) + if err != nil { + c.ResponseError(err.Error()) + return + } + invoiceUrl, err := object.InvoicePayment(payment) if err != nil { c.ResponseError(err.Error()) diff --git a/controllers/permission.go b/controllers/permission.go index 38acb87a..b1e7ed6a 100644 --- a/controllers/permission.go +++ b/controllers/permission.go @@ -37,13 +37,28 @@ func (c *ApiController) GetPermissions() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetPermissions(owner) + permissions, err := object.GetPermissions(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = permissions c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPermissionCount(owner, field, value))) - permissions := object.GetPaginationPermissions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetPermissionCount(owner, field, value) + if err != nil { + panic(err) + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + permissions, err := object.GetPaginationPermissions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + panic(err) + } + c.ResponseOk(permissions, paginator.Nums()) } } @@ -60,7 +75,12 @@ func (c *ApiController) GetPermissionsBySubmitter() { return } - permissions := object.GetPermissionsBySubmitter(user.Owner, user.Name) + permissions, err := object.GetPermissionsBySubmitter(user.Owner, user.Name) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(permissions, len(permissions)) return } @@ -74,7 +94,12 @@ func (c *ApiController) GetPermissionsBySubmitter() { // @router /get-permissions-by-role [get] func (c *ApiController) GetPermissionsByRole() { id := c.Input().Get("id") - permissions := object.GetPermissionsByRole(id) + permissions, err := object.GetPermissionsByRole(id) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(permissions, len(permissions)) return } @@ -89,7 +114,12 @@ func (c *ApiController) GetPermissionsByRole() { func (c *ApiController) GetPermission() { id := c.Input().Get("id") - c.Data["json"] = object.GetPermission(id) + permission, err := object.GetPermission(id) + if err != nil { + panic(err) + } + + c.Data["json"] = permission c.ServeJSON() } diff --git a/controllers/permission_upload.go b/controllers/permission_upload.go index 5770f383..caee1631 100644 --- a/controllers/permission_upload.go +++ b/controllers/permission_upload.go @@ -41,7 +41,11 @@ func (c *ApiController) UploadPermissions() { return } - affected := object.UploadPermissions(owner, fileId) + affected, err := object.UploadPermissions(owner, fileId) + if err != nil { + c.ResponseError(err.Error()) + } + if affected { c.ResponseOk() } else { diff --git a/controllers/plan.go b/controllers/plan.go index 0536966b..2139305c 100644 --- a/controllers/plan.go +++ b/controllers/plan.go @@ -37,13 +37,30 @@ func (c *ApiController) GetPlans() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetPlans(owner) + plans, err := object.GetPlans(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = plans c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPlanCount(owner, field, value))) - plan := object.GetPaginatedPlans(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetPlanCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + plan, err := object.GetPaginatedPlans(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(plan, paginator.Nums()) } } @@ -60,10 +77,16 @@ func (c *ApiController) GetPlan() { id := c.Input().Get("id") includeOption := c.Input().Get("includeOption") == "true" - plan := object.GetPlan(id) + plan, err := object.GetPlan(id) + if err != nil { + panic(err) + } if includeOption { - options := object.GetPermissionsByRole(plan.Role) + options, err := object.GetPermissionsByRole(plan.Role) + if err != nil { + panic(err) + } for _, option := range options { plan.Options = append(plan.Options, option.DisplayName) diff --git a/controllers/pricing.go b/controllers/pricing.go index 01ed0c4c..a726ff59 100644 --- a/controllers/pricing.go +++ b/controllers/pricing.go @@ -37,13 +37,30 @@ func (c *ApiController) GetPricings() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetPricings(owner) + pricings, err := object.GetPricings(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = pricings c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPricingCount(owner, field, value))) - pricing := object.GetPaginatedPricings(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetPricingCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + pricing, err := object.GetPaginatedPricings(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(pricing, paginator.Nums()) } } @@ -58,7 +75,10 @@ func (c *ApiController) GetPricings() { func (c *ApiController) GetPricing() { id := c.Input().Get("id") - pricing := object.GetPricing(id) + pricing, err := object.GetPricing(id) + if err != nil { + panic(err) + } c.Data["json"] = pricing c.ServeJSON() diff --git a/controllers/product.go b/controllers/product.go index 7bd06e47..b33cbe59 100644 --- a/controllers/product.go +++ b/controllers/product.go @@ -38,13 +38,30 @@ func (c *ApiController) GetProducts() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetProducts(owner) + products, err := object.GetProducts(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = products c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetProductCount(owner, field, value))) - products := object.GetPaginationProducts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetProductCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + products, err := object.GetPaginationProducts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(products, paginator.Nums()) } } @@ -59,8 +76,15 @@ func (c *ApiController) GetProducts() { func (c *ApiController) GetProduct() { id := c.Input().Get("id") - product := object.GetProduct(id) - object.ExtendProductWithProviders(product) + product, err := object.GetProduct(id) + if err != nil { + panic(err) + } + + err = object.ExtendProductWithProviders(product) + if err != nil { + panic(err) + } c.Data["json"] = product c.ServeJSON() @@ -145,7 +169,12 @@ func (c *ApiController) BuyProduct() { return } - user := object.GetUser(userId) + user, err := object.GetUser(userId) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), userId)) return diff --git a/controllers/provider.go b/controllers/provider.go index da62a817..f4e2455d 100644 --- a/controllers/provider.go +++ b/controllers/provider.go @@ -44,12 +44,29 @@ func (c *ApiController) GetProviders() { } if limit == "" || page == "" { - c.Data["json"] = object.GetMaskedProviders(object.GetProviders(owner), isMaskEnabled) + providers, err := object.GetProviders(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = object.GetMaskedProviders(providers, isMaskEnabled) c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetProviderCount(owner, field, value))) - providers := object.GetMaskedProviders(object.GetPaginationProviders(owner, paginator.Offset(), limit, field, value, sortField, sortOrder), isMaskEnabled) + count, err := object.GetProviderCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + paginationProviders, err := object.GetPaginationProviders(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + + providers := object.GetMaskedProviders(paginationProviders, isMaskEnabled) c.ResponseOk(providers, paginator.Nums()) } } @@ -74,12 +91,29 @@ func (c *ApiController) GetGlobalProviders() { } if limit == "" || page == "" { - c.Data["json"] = object.GetMaskedProviders(object.GetGlobalProviders(), isMaskEnabled) + globalProviders, err := object.GetGlobalProviders() + if err != nil { + panic(err) + } + + c.Data["json"] = object.GetMaskedProviders(globalProviders, isMaskEnabled) c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalProviderCount(field, value))) - providers := object.GetMaskedProviders(object.GetPaginationGlobalProviders(paginator.Offset(), limit, field, value, sortField, sortOrder), isMaskEnabled) + count, err := object.GetGlobalProviderCount(field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + paginationGlobalProviders, err := object.GetPaginationGlobalProviders(paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + + providers := object.GetMaskedProviders(paginationGlobalProviders, isMaskEnabled) c.ResponseOk(providers, paginator.Nums()) } } @@ -98,8 +132,13 @@ func (c *ApiController) GetProvider() { if !ok { return } + provider, err := object.GetProvider(id) + if err != nil { + c.ResponseError(err.Error()) + return + } - c.Data["json"] = object.GetMaskedProvider(object.GetProvider(id), isMaskEnabled) + c.Data["json"] = object.GetMaskedProvider(provider, isMaskEnabled) c.ServeJSON() } @@ -140,8 +179,13 @@ func (c *ApiController) AddProvider() { return } - count := object.GetProviderCount("", "", "") - if err := checkQuotaForProvider(count); err != nil { + count, err := object.GetProviderCount("", "", "") + if err != nil { + c.ResponseError(err.Error()) + return + } + + if err := checkQuotaForProvider(int(count)); err != nil { c.ResponseError(err.Error()) return } diff --git a/controllers/record.go b/controllers/record.go index fa49005b..59a7e707 100644 --- a/controllers/record.go +++ b/controllers/record.go @@ -42,14 +42,31 @@ func (c *ApiController) GetRecords() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetRecords() + records, err := object.GetRecords() + if err != nil { + panic(err) + } + + c.Data["json"] = records c.ServeJSON() } else { limit := util.ParseInt(limit) filterRecord := &object.Record{Organization: organization} - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetRecordCount(field, value, filterRecord))) - records := object.GetPaginationRecords(paginator.Offset(), limit, field, value, sortField, sortOrder, filterRecord) + count, err := object.GetRecordCount(field, value, filterRecord) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + records, err := object.GetPaginationRecords(paginator.Offset(), limit, field, value, sortField, sortOrder, filterRecord) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(records, paginator.Nums()) } } @@ -67,11 +84,15 @@ func (c *ApiController) GetRecordsByFilter() { record := &object.Record{} err := util.JsonToStruct(body, record) if err != nil { - c.ResponseError(err.Error()) - return + panic(err) } - c.Data["json"] = object.GetRecordsByField(record) + records, err := object.GetRecordsByField(record) + if err != nil { + panic(err) + } + + c.Data["json"] = records c.ServeJSON() } diff --git a/controllers/resource.go b/controllers/resource.go index 1c7389a6..4fd001d1 100644 --- a/controllers/resource.go +++ b/controllers/resource.go @@ -51,12 +51,28 @@ func (c *ApiController) GetResources() { } if limit == "" || page == "" { - c.Data["json"] = object.GetResources(owner, user) + resources, err := object.GetResources(owner, user) + if err != nil { + panic(err) + } + + c.Data["json"] = resources c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetResourceCount(owner, user, field, value))) - resources := object.GetPaginationResources(owner, user, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetResourceCount(owner, user, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + resources, err := object.GetPaginationResources(owner, user, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(resources, paginator.Nums()) } } @@ -68,7 +84,12 @@ func (c *ApiController) GetResources() { func (c *ApiController) GetResource() { id := c.Input().Get("id") - c.Data["json"] = object.GetResource(id) + resource, err := object.GetResource(id) + if err != nil { + panic(err) + } + + c.Data["json"] = resource c.ServeJSON() } @@ -187,7 +208,10 @@ func (c *ApiController) UploadResource() { index := len(fullFilePath) - len(ext) for i := 1; ; i++ { _, objectKey := object.GetUploadFileUrl(provider, fullFilePath, true) - if object.GetResourceCount(owner, username, "name", objectKey) == 0 { + if count, err := object.GetResourceCount(owner, username, "name", objectKey); err != nil { + c.ResponseError(err.Error()) + return + } else if count == 0 { break } @@ -223,20 +247,39 @@ func (c *ApiController) UploadResource() { Url: fileUrl, Description: description, } - object.AddOrUpdateResource(resource) + _, err = object.AddOrUpdateResource(resource) + if err != nil { + c.ResponseError(err.Error()) + return + } switch tag { case "avatar": - user := object.GetUserNoCheck(util.GetId(owner, username)) + user, err := object.GetUserNoCheck(util.GetId(owner, username)) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError(c.T("resource:User is nil for tag: avatar")) return } user.Avatar = fileUrl - object.UpdateUser(user.GetId(), user, []string{"avatar"}, false) + _, err = object.UpdateUser(user.GetId(), user, []string{"avatar"}, false) + if err != nil { + c.ResponseError(err.Error()) + return + } + case "termsOfUse": - user := object.GetUserNoCheck(util.GetId(owner, username)) + user, err := object.GetUserNoCheck(util.GetId(owner, username)) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(owner, username))) return @@ -248,9 +291,18 @@ func (c *ApiController) UploadResource() { } _, applicationId := util.GetOwnerAndNameFromIdNoCheck(strings.TrimRight(fullFilePath, ".html")) - applicationObj := object.GetApplication(applicationId) + applicationObj, err := object.GetApplication(applicationId) + if err != nil { + c.ResponseError(err.Error()) + return + } + applicationObj.TermsOfUse = fileUrl - object.UpdateApplication(applicationId, applicationObj) + _, err = object.UpdateApplication(applicationId, applicationObj) + if err != nil { + c.ResponseError(err.Error()) + return + } } c.ResponseOk(fileUrl, objectKey) diff --git a/controllers/role.go b/controllers/role.go index 0f3a63b7..b99e9cf3 100644 --- a/controllers/role.go +++ b/controllers/role.go @@ -37,13 +37,30 @@ func (c *ApiController) GetRoles() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetRoles(owner) + roles, err := object.GetRoles(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = roles c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetRoleCount(owner, field, value))) - roles := object.GetPaginationRoles(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetRoleCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + roles, err := object.GetPaginationRoles(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(roles, paginator.Nums()) } } @@ -58,7 +75,12 @@ func (c *ApiController) GetRoles() { func (c *ApiController) GetRole() { id := c.Input().Get("id") - c.Data["json"] = object.GetRole(id) + role, err := object.GetRole(id) + if err != nil { + panic(err) + } + + c.Data["json"] = role c.ServeJSON() } diff --git a/controllers/role_upload.go b/controllers/role_upload.go index 8b53c28c..3365dd3c 100644 --- a/controllers/role_upload.go +++ b/controllers/role_upload.go @@ -41,7 +41,11 @@ func (c *ApiController) UploadRoles() { return } - affected := object.UploadRoles(owner, fileId) + affected, err := object.UploadRoles(owner, fileId) + if err != nil { + c.ResponseError(err.Error()) + } + if affected { c.ResponseOk() } else { diff --git a/controllers/saml.go b/controllers/saml.go index 31f31b9b..01f96276 100644 --- a/controllers/saml.go +++ b/controllers/saml.go @@ -23,7 +23,12 @@ import ( func (c *ApiController) GetSamlMeta() { host := c.Ctx.Request.Host paramApp := c.Input().Get("application") - application := object.GetApplication(paramApp) + application, err := object.GetApplication(paramApp) + if err != nil { + c.ResponseError(err.Error()) + return + } + if application == nil { c.ResponseError(fmt.Sprintf(c.T("saml:Application %s not found"), paramApp)) return diff --git a/controllers/service.go b/controllers/service.go index d587f144..cae6c627 100644 --- a/controllers/service.go +++ b/controllers/service.go @@ -61,7 +61,12 @@ func (c *ApiController) SendEmail() { var provider *object.Provider if emailForm.Provider != "" { // called by frontend's TestEmailWidget, provider name is set by frontend - provider = object.GetProvider(util.GetId("admin", emailForm.Provider)) + provider, err = object.GetProvider(util.GetId("admin", emailForm.Provider)) + if err != nil { + c.ResponseError(err.Error()) + return + } + } else { // called by Casdoor SDK via Client ID & Client Secret, so the used Email provider will be the application' Email provider or the default Email provider var ok bool diff --git a/controllers/session.go b/controllers/session.go index e5e7c001..b450bdcb 100644 --- a/controllers/session.go +++ b/controllers/session.go @@ -37,13 +37,29 @@ func (c *ApiController) GetSessions() { sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") owner := c.Input().Get("owner") + if limit == "" || page == "" { - c.Data["json"] = object.GetSessions(owner) + sessions, err := object.GetSessions(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = sessions c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSessionCount(owner, field, value))) - sessions := object.GetPaginationSessions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetSessionCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + paginator := pagination.SetPaginator(c.Ctx, limit, count) + sessions, err := object.GetPaginationSessions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(sessions, paginator.Nums()) } } @@ -58,7 +74,12 @@ func (c *ApiController) GetSessions() { func (c *ApiController) GetSingleSession() { id := c.Input().Get("sessionPkId") - c.Data["json"] = object.GetSingleSession(id) + session, err := object.GetSingleSession(id) + if err != nil { + panic(err) + } + + c.Data["json"] = session c.ServeJSON() } @@ -132,7 +153,11 @@ func (c *ApiController) IsSessionDuplicated() { id := c.Input().Get("sessionPkId") sessionId := c.Input().Get("sessionId") - isUserSessionDuplicated := object.IsSessionDuplicated(id, sessionId) + isUserSessionDuplicated, err := object.IsSessionDuplicated(id, sessionId) + if err != nil { + panic(err) + } + c.Data["json"] = &Response{Status: "ok", Msg: "", Data: isUserSessionDuplicated} c.ServeJSON() diff --git a/controllers/subscription.go b/controllers/subscription.go index 1216c0af..03094482 100644 --- a/controllers/subscription.go +++ b/controllers/subscription.go @@ -37,13 +37,30 @@ func (c *ApiController) GetSubscriptions() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetSubscriptions(owner) + subscriptions, err := object.GetSubscriptions(owner) + if err != nil { + panic(err) + } + + c.Data["json"] = subscriptions c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSubscriptionCount(owner, field, value))) - subscription := object.GetPaginationSubscriptions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetSubscriptionCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + subscription, err := object.GetPaginationSubscriptions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(subscription, paginator.Nums()) } } @@ -58,7 +75,10 @@ func (c *ApiController) GetSubscriptions() { func (c *ApiController) GetSubscription() { id := c.Input().Get("id") - subscription := object.GetSubscription(id) + subscription, err := object.GetSubscription(id) + if err != nil { + panic(err) + } c.Data["json"] = subscription c.ServeJSON() diff --git a/controllers/syncer.go b/controllers/syncer.go index 009f2fbf..4f58e206 100644 --- a/controllers/syncer.go +++ b/controllers/syncer.go @@ -38,13 +38,30 @@ func (c *ApiController) GetSyncers() { sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") organization := c.Input().Get("organization") + if limit == "" || page == "" { - c.Data["json"] = object.GetOrganizationSyncers(owner, organization) + organizationSyncers, err := object.GetOrganizationSyncers(owner, organization) + if err != nil { + panic(err) + } + + c.Data["json"] = organizationSyncers c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSyncerCount(owner, organization, field, value))) - syncers := object.GetPaginationSyncers(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetSyncerCount(owner, organization, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + syncers, err := object.GetPaginationSyncers(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(syncers, paginator.Nums()) } } @@ -59,7 +76,12 @@ func (c *ApiController) GetSyncers() { func (c *ApiController) GetSyncer() { id := c.Input().Get("id") - c.Data["json"] = object.GetSyncer(id) + syncer, err := object.GetSyncer(id) + if err != nil { + panic(err) + } + + c.Data["json"] = syncer c.ServeJSON() } @@ -132,7 +154,11 @@ func (c *ApiController) DeleteSyncer() { // @router /run-syncer [get] func (c *ApiController) RunSyncer() { id := c.Input().Get("id") - syncer := object.GetSyncer(id) + syncer, err := object.GetSyncer(id) + if err != nil { + c.ResponseError(err.Error()) + return + } object.RunSyncer(syncer) diff --git a/controllers/token.go b/controllers/token.go index d84ea34f..47850a3a 100644 --- a/controllers/token.go +++ b/controllers/token.go @@ -41,12 +41,28 @@ func (c *ApiController) GetTokens() { sortOrder := c.Input().Get("sortOrder") organization := c.Input().Get("organization") if limit == "" || page == "" { - c.Data["json"] = object.GetTokens(owner, organization) + token, err := object.GetTokens(owner, organization) + if err != nil { + panic(err) + } + + c.Data["json"] = token c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetTokenCount(owner, organization, field, value))) - tokens := object.GetPaginationTokens(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetTokenCount(owner, organization, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + tokens, err := object.GetPaginationTokens(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(tokens, paginator.Nums()) } } @@ -60,8 +76,12 @@ func (c *ApiController) GetTokens() { // @router /get-token [get] func (c *ApiController) GetToken() { id := c.Input().Get("id") + token, err := object.GetToken(id) + if err != nil { + panic(err) + } - c.Data["json"] = object.GetToken(id) + c.Data["json"] = token c.ServeJSON() } @@ -171,8 +191,12 @@ func (c *ApiController) GetOAuthToken() { } } host := c.Ctx.Request.Host + oAuthtoken, err := object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier, scope, username, password, host, refreshToken, tag, avatar, c.GetAcceptLanguage()) + if err != nil { + panic(err) + } - c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier, scope, username, password, host, refreshToken, tag, avatar, c.GetAcceptLanguage()) + c.Data["json"] = oAuthtoken c.SetTokenErrorHttpStatus() c.ServeJSON() } @@ -210,7 +234,12 @@ func (c *ApiController) RefreshToken() { } } - c.Data["json"] = object.RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host) + refreshToken2, err := object.RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host) + if err != nil { + panic(err) + } + + c.Data["json"] = refreshToken2 c.SetTokenErrorHttpStatus() c.ServeJSON() } @@ -245,7 +274,11 @@ func (c *ApiController) IntrospectToken() { return } } - application := object.GetApplicationByClientId(clientId) + application, err := object.GetApplicationByClientId(clientId) + if err != nil { + panic(err) + } + if application == nil || application.ClientSecret != clientSecret { c.ResponseError(c.T("token:Invalid application or wrong clientSecret")) c.Data["json"] = &object.TokenError{ @@ -254,7 +287,11 @@ func (c *ApiController) IntrospectToken() { c.SetTokenErrorHttpStatus() return } - token := object.GetTokenByTokenAndApplication(tokenValue, application.Name) + token, err := object.GetTokenByTokenAndApplication(tokenValue, application.Name) + if err != nil { + panic(err) + } + if token == nil { c.Data["json"] = &object.IntrospectionResponse{Active: false} c.ServeJSON() diff --git a/controllers/user.go b/controllers/user.go index 81083284..3bab5e2e 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -37,14 +37,36 @@ func (c *ApiController) GetGlobalUsers() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetMaskedUsers(object.GetGlobalUsers()) + maskedUsers, err := object.GetMaskedUsers(object.GetGlobalUsers()) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedUsers c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalUserCount(field, value))) - users := object.GetPaginationGlobalUsers(paginator.Offset(), limit, field, value, sortField, sortOrder) - users = object.GetMaskedUsers(users) + count, err := object.GetGlobalUserCount(field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + users, err := object.GetPaginationGlobalUsers(paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + + users, err = object.GetMaskedUsers(users) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(users, paginator.Nums()) } } @@ -64,14 +86,36 @@ func (c *ApiController) GetUsers() { value := c.Input().Get("value") sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") + if limit == "" || page == "" { - c.Data["json"] = object.GetMaskedUsers(object.GetUsers(owner)) + maskedUsers, err := object.GetMaskedUsers(object.GetUsers(owner)) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedUsers c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetUserCount(owner, field, value))) - users := object.GetPaginationUsers(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) - users = object.GetMaskedUsers(users) + count, err := object.GetUserCount(owner, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + users, err := object.GetPaginationUsers(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + + users, err = object.GetMaskedUsers(users) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(users, paginator.Nums()) } } @@ -93,10 +137,14 @@ func (c *ApiController) GetUser() { phone := c.Input().Get("phone") userId := c.Input().Get("userId") owner := c.Input().Get("owner") - + var err error var userFromUserId *object.User if userId != "" && owner != "" { - userFromUserId = object.GetUserByUserId(owner, userId) + userFromUserId, err = object.GetUserByUserId(owner, userId) + if err != nil { + panic(err) + } + id = util.GetId(userFromUserId.Owner, userFromUserId.Name) } @@ -104,7 +152,11 @@ func (c *ApiController) GetUser() { owner = util.GetOwnerFromId(id) } - organization := object.GetOrganization(util.GetId("admin", owner)) + organization, err := object.GetOrganization(util.GetId("admin", owner)) + if err != nil { + panic(err) + } + if !organization.IsProfilePublic { requestUserId := c.GetSessionUsername() hasPermission, err := object.CheckUserPermission(requestUserId, id, false, c.GetAcceptLanguage()) @@ -117,18 +169,30 @@ func (c *ApiController) GetUser() { var user *object.User switch { case email != "": - user = object.GetUserByEmail(owner, email) + user, err = object.GetUserByEmail(owner, email) case phone != "": - user = object.GetUserByPhone(owner, phone) + user, err = object.GetUserByPhone(owner, phone) case userId != "": user = userFromUserId default: - user = object.GetUser(id) + user, err = object.GetUser(id) } - object.ExtendUserWithRolesAndPermissions(user) + if err != nil { + panic(err) + } - c.Data["json"] = object.GetMaskedUser(user) + err = object.ExtendUserWithRolesAndPermissions(user) + if err != nil { + panic(err) + } + + maskedUser, err := object.GetMaskedUser(user) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedUser c.ServeJSON() } @@ -158,7 +222,12 @@ func (c *ApiController) UpdateUser() { return } } - oldUser := object.GetUser(id) + oldUser, err := object.GetUser(id) + if err != nil { + c.ResponseError(err.Error()) + return + } + if oldUser == nil { c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), id)) return @@ -185,9 +254,18 @@ func (c *ApiController) UpdateUser() { columns = strings.Split(columnsStr, ",") } - affected := object.UpdateUser(id, &user, columns, isAdmin) + affected, err := object.UpdateUser(id, &user, columns, isAdmin) + if err != nil { + c.ResponseError(err.Error()) + return + } + if affected { - object.UpdateUserToOriginalDatabase(&user) + err = object.UpdateUserToOriginalDatabase(&user) + if err != nil { + c.ResponseError(err.Error()) + return + } } c.Data["json"] = wrapActionResponse(affected) @@ -209,8 +287,13 @@ func (c *ApiController) AddUser() { return } - count := object.GetUserCount("", "", "") - if err := checkQuotaForUser(count); err != nil { + count, err := object.GetUserCount("", "", "") + if err != nil { + c.ResponseError(err.Error()) + return + } + + if err := checkQuotaForUser(int(count)); err != nil { c.ResponseError(err.Error()) return } @@ -261,7 +344,12 @@ func (c *ApiController) GetEmailAndPhone() { organization := c.Ctx.Request.Form.Get("organization") username := c.Ctx.Request.Form.Get("username") - user := object.GetUserByFields(organization, username) + user, err := object.GetUserByFields(organization, username) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(organization, username))) return @@ -335,7 +423,11 @@ func (c *ApiController) SetPassword() { c.SetSession("verifiedCode", "") } - targetUser := object.GetUser(userId) + targetUser, err := object.GetUser(userId) + if err != nil { + c.ResponseError(err.Error()) + return + } if oldPassword != "" { msg := object.CheckPassword(targetUser, oldPassword, c.GetAcceptLanguage()) @@ -346,7 +438,12 @@ func (c *ApiController) SetPassword() { } targetUser.Password = newPassword - object.SetUserField(targetUser, "password", targetUser.Password) + _, err = object.SetUserField(targetUser, "password", targetUser.Password) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk() } @@ -384,7 +481,12 @@ func (c *ApiController) GetSortedUsers() { sorter := c.Input().Get("sorter") limit := util.ParseInt(c.Input().Get("limit")) - c.Data["json"] = object.GetMaskedUsers(object.GetSortedUsers(owner, sorter, limit)) + maskedUsers, err := object.GetMaskedUsers(object.GetSortedUsers(owner, sorter, limit)) + if err != nil { + panic(err) + } + + c.Data["json"] = maskedUsers c.ServeJSON() } @@ -400,11 +502,16 @@ func (c *ApiController) GetUserCount() { owner := c.Input().Get("owner") isOnline := c.Input().Get("isOnline") - count := 0 + var count int64 + var err error if isOnline == "" { - count = object.GetUserCount(owner, "", "") + count, err = object.GetUserCount(owner, "", "") } else { - count = object.GetOnlineUserCount(owner, util.ParseInt(isOnline)) + count, err = object.GetOnlineUserCount(owner, util.ParseInt(isOnline)) + } + if err != nil { + c.ResponseError(err.Error()) + return } c.Data["json"] = count diff --git a/controllers/user_upload.go b/controllers/user_upload.go index 121cec22..bef60559 100644 --- a/controllers/user_upload.go +++ b/controllers/user_upload.go @@ -57,7 +57,12 @@ func (c *ApiController) UploadUsers() { return } - affected := object.UploadUsers(owner, fileId) + affected, err := object.UploadUsers(owner, fileId) + if err != nil { + c.ResponseError(err.Error()) + return + } + if affected { c.ResponseOk() } else { diff --git a/controllers/util.go b/controllers/util.go index 161ecf2f..036b5303 100644 --- a/controllers/util.go +++ b/controllers/util.go @@ -92,7 +92,11 @@ func (c *ApiController) RequireSignedInUser() (*object.User, bool) { return nil, false } - user := object.GetUser(userId) + user, err := object.GetUser(userId) + if err != nil { + panic(err) + } + if user == nil { c.ClearUserSession() c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), userId)) @@ -138,7 +142,11 @@ func (c *ApiController) IsMaskedEnabled() (bool, bool) { func (c *ApiController) GetProviderFromContext(category string) (*object.Provider, *object.User, bool) { providerName := c.Input().Get("provider") if providerName != "" { - provider := object.GetProvider(util.GetId("admin", providerName)) + provider, err := object.GetProvider(util.GetId("admin", providerName)) + if err != nil { + panic(err) + } + if provider == nil { c.ResponseError(fmt.Sprintf(c.T("util:The provider: %s is not found"), providerName)) return nil, nil, false @@ -151,13 +159,21 @@ func (c *ApiController) GetProviderFromContext(category string) (*object.Provide return nil, nil, false } - application, user := object.GetApplicationByUserId(userId) + application, user, err := object.GetApplicationByUserId(userId) + if err != nil { + panic(err) + } + if application == nil { c.ResponseError(fmt.Sprintf(c.T("util:No application is found for userId: %s"), userId)) return nil, nil, false } - provider := application.GetProviderByCategory(category) + provider, err := application.GetProviderByCategory(category) + if err != nil { + panic(err) + } + if provider == nil { c.ResponseError(fmt.Sprintf(c.T("util:No provider for category: %s is found for application: %s"), category, application.Name)) return nil, nil, false diff --git a/controllers/verification.go b/controllers/verification.go index 66ae9207..5d53eee8 100644 --- a/controllers/verification.go +++ b/controllers/verification.go @@ -66,8 +66,17 @@ func (c *ApiController) SendVerificationCode() { } } - application := object.GetApplication(vform.ApplicationId) - organization := object.GetOrganization(util.GetId(application.Owner, application.Organization)) + application, err := object.GetApplication(vform.ApplicationId) + if err != nil { + c.ResponseError(err.Error()) + return + } + + organization, err := object.GetOrganization(util.GetId(application.Owner, application.Organization)) + if err != nil { + c.ResponseError(c.T(err.Error())) + } + if organization == nil { c.ResponseError(c.T("check:Organization does not exist")) return @@ -77,12 +86,20 @@ func (c *ApiController) SendVerificationCode() { // checkUser != "", means method is ForgetVerification if vform.CheckUser != "" { owner := application.Organization - user = object.GetUser(util.GetId(owner, vform.CheckUser)) + user, err = object.GetUser(util.GetId(owner, vform.CheckUser)) + if err != nil { + c.ResponseError(err.Error()) + return + } } // mfaSessionData != nil, means method is MfaSetupVerification if mfaSessionData := c.getMfaSessionData(); mfaSessionData != nil { - user = object.GetUser(mfaSessionData.UserId) + user, err = object.GetUser(mfaSessionData.UserId) + if err != nil { + c.ResponseError(err.Error()) + return + } } sendResp := errors.New("invalid dest type") @@ -99,7 +116,12 @@ func (c *ApiController) SendVerificationCode() { vform.Dest = user.Email } - user = object.GetUserByEmail(organization.Name, vform.Dest) + user, err = object.GetUserByEmail(organization.Name, vform.Dest) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError(c.T("verification:the user does not exist, please sign up first")) return @@ -113,7 +135,12 @@ func (c *ApiController) SendVerificationCode() { } } - provider := application.GetEmailProvider() + provider, err := application.GetEmailProvider() + if err != nil { + c.ResponseError(err.Error()) + return + } + sendResp = object.SendVerificationCodeToEmail(organization, user, provider, remoteAddr, vform.Dest) case object.VerifyTypePhone: if vform.Method == LoginVerification || vform.Method == ForgetVerification { @@ -121,7 +148,10 @@ func (c *ApiController) SendVerificationCode() { vform.Dest = user.Phone } - if user = object.GetUserByPhone(organization.Name, vform.Dest); user == nil { + if user, err = object.GetUserByPhone(organization.Name, vform.Dest); err != nil { + c.ResponseError(err.Error()) + return + } else if user == nil { c.ResponseError(c.T("verification:the user does not exist, please sign up first")) return } @@ -140,7 +170,12 @@ func (c *ApiController) SendVerificationCode() { vform.CountryCode = mfaProps.CountryCode } - provider := application.GetSmsProvider() + provider, err := application.GetSmsProvider() + if err != nil { + c.ResponseError(err.Error()) + return + } + if phone, ok := util.GetE164Number(vform.Dest, vform.CountryCode); !ok { c.ResponseError(fmt.Sprintf(c.T("verification:Phone number is invalid in your region %s"), vform.CountryCode)) return @@ -213,7 +248,12 @@ func (c *ApiController) ResetEmailOrPhone() { } checkDest := dest - organization := object.GetOrganizationByUser(user) + organization, err := object.GetOrganizationByUser(user) + if err != nil { + c.ResponseError(c.T(err.Error())) + return + } + if destType == object.VerifyTypePhone { if object.HasUserByField(user.Owner, "phone", dest) { c.ResponseError(c.T("check:Phone already exists")) @@ -260,16 +300,25 @@ func (c *ApiController) ResetEmailOrPhone() { switch destType { case object.VerifyTypeEmail: user.Email = dest - object.SetUserField(user, "email", user.Email) + _, err = object.SetUserField(user, "email", user.Email) case object.VerifyTypePhone: user.Phone = dest - object.SetUserField(user, "phone", user.Phone) + _, err = object.SetUserField(user, "phone", user.Phone) default: c.ResponseError(c.T("verification:Unknown type")) return } + if err != nil { + c.ResponseError(err.Error()) + return + } + + err = object.DisableVerificationCode(checkDest) + if err != nil { + c.ResponseError(err.Error()) + return + } - object.DisableVerificationCode(checkDest) c.ResponseOk() } @@ -287,7 +336,11 @@ func (c *ApiController) VerifyCode() { var user *object.User if authForm.Name != "" { - user = object.GetUserByFields(authForm.Organization, authForm.Name) + user, err = object.GetUserByFields(authForm.Organization, authForm.Name) + if err != nil { + c.ResponseError(err.Error()) + return + } } var checkDest string @@ -302,7 +355,10 @@ func (c *ApiController) VerifyCode() { } } - if user = object.GetUserByFields(authForm.Organization, authForm.Username); user == nil { + if user, err = object.GetUserByFields(authForm.Organization, authForm.Username); err != nil { + c.ResponseError(err.Error()) + return + } else if user == nil { c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(authForm.Organization, authForm.Username))) return } @@ -321,7 +377,11 @@ func (c *ApiController) VerifyCode() { c.ResponseError(result.Msg) return } - object.DisableVerificationCode(checkDest) + err = object.DisableVerificationCode(checkDest) + if err != nil { + c.ResponseError(err.Error()) + return + } c.SetSession("verifiedCode", authForm.Code) c.ResponseOk() diff --git a/controllers/webauthn.go b/controllers/webauthn.go index 1bc67beb..ae24e858 100644 --- a/controllers/webauthn.go +++ b/controllers/webauthn.go @@ -33,7 +33,12 @@ import ( // @Success 200 {object} protocol.CredentialCreation The CredentialCreationOptions object // @router /webauthn/signup/begin [get] func (c *ApiController) WebAuthnSignupBegin() { - webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host) + webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host) + if err != nil { + c.ResponseError(err.Error()) + return + } + user := c.getCurrentUser() if user == nil { c.ResponseError(c.T("general:Please login first")) @@ -64,7 +69,12 @@ func (c *ApiController) WebAuthnSignupBegin() { // @Success 200 {object} Response "The Response object" // @router /webauthn/signup/finish [post] func (c *ApiController) WebAuthnSignupFinish() { - webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host) + webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host) + if err != nil { + c.ResponseError(err.Error()) + return + } + user := c.getCurrentUser() if user == nil { c.ResponseError(c.T("general:Please login first")) @@ -84,7 +94,12 @@ func (c *ApiController) WebAuthnSignupFinish() { return } isGlobalAdmin := c.IsGlobalAdmin() - user.AddCredentials(*credential, isGlobalAdmin) + _, err = user.AddCredentials(*credential, isGlobalAdmin) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk() } @@ -97,10 +112,20 @@ func (c *ApiController) WebAuthnSignupFinish() { // @Success 200 {object} protocol.CredentialAssertion The CredentialAssertion object // @router /webauthn/signin/begin [get] func (c *ApiController) WebAuthnSigninBegin() { - webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host) + webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host) + if err != nil { + c.ResponseError(err.Error()) + return + } + userOwner := c.Input().Get("owner") userName := c.Input().Get("name") - user := object.GetUserByFields(userOwner, userName) + user, err := object.GetUserByFields(userOwner, userName) + if err != nil { + c.ResponseError(err.Error()) + return + } + if user == nil { c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(userOwner, userName))) return @@ -129,7 +154,12 @@ func (c *ApiController) WebAuthnSigninBegin() { // @router /webauthn/signin/finish [post] func (c *ApiController) WebAuthnSigninFinish() { responseType := c.Input().Get("responseType") - webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host) + webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host) + if err != nil { + c.ResponseError(err.Error()) + return + } + sessionObj := c.GetSession("authentication") sessionData, ok := sessionObj.(webauthn.SessionData) if !ok { @@ -138,8 +168,13 @@ func (c *ApiController) WebAuthnSigninFinish() { } c.Ctx.Request.Body = io.NopCloser(bytes.NewBuffer(c.Ctx.Input.RequestBody)) userId := string(sessionData.UserID) - user := object.GetUser(userId) - _, err := webauthnObj.FinishLogin(user, sessionData, c.Ctx.Request) + user, err := object.GetUser(userId) + if err != nil { + c.ResponseError(err.Error()) + return + } + + _, err = webauthnObj.FinishLogin(user, sessionData, c.Ctx.Request) if err != nil { c.ResponseError(err.Error()) return @@ -147,7 +182,12 @@ func (c *ApiController) WebAuthnSigninFinish() { c.SetSessionUsername(userId) util.LogInfo(c.Ctx, "API: [%s] signed in", userId) - application := object.GetApplicationByUser(user) + application, err := object.GetApplicationByUser(user) + if err != nil { + c.ResponseError(err.Error()) + return + } + var authForm form.AuthForm authForm.Type = responseType resp := c.HandleLoggedIn(application, user, &authForm) diff --git a/controllers/webhook.go b/controllers/webhook.go index 17e43c4a..13e24451 100644 --- a/controllers/webhook.go +++ b/controllers/webhook.go @@ -38,13 +38,31 @@ func (c *ApiController) GetWebhooks() { sortField := c.Input().Get("sortField") sortOrder := c.Input().Get("sortOrder") organization := c.Input().Get("organization") + if limit == "" || page == "" { - c.Data["json"] = object.GetWebhooks(owner, organization) + webhooks, err := object.GetWebhooks(owner, organization) + if err != nil { + panic(err) + } + + c.Data["json"] = webhooks c.ServeJSON() } else { limit := util.ParseInt(limit) - paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetWebhookCount(owner, organization, field, value))) - webhooks := object.GetPaginationWebhooks(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + count, err := object.GetWebhookCount(owner, organization, field, value) + if err != nil { + c.ResponseError(err.Error()) + return + } + + paginator := pagination.SetPaginator(c.Ctx, limit, count) + + webhooks, err := object.GetPaginationWebhooks(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) + if err != nil { + c.ResponseError(err.Error()) + return + } + c.ResponseOk(webhooks, paginator.Nums()) } } @@ -59,7 +77,12 @@ func (c *ApiController) GetWebhooks() { func (c *ApiController) GetWebhook() { id := c.Input().Get("id") - c.Data["json"] = object.GetWebhook(id) + webhook, err := object.GetWebhook(id) + if err != nil { + panic(err) + } + + c.Data["json"] = webhook c.ServeJSON() } diff --git a/ldap/util.go b/ldap/util.go index a196c47f..6b7cd234 100644 --- a/ldap/util.go +++ b/ldap/util.go @@ -84,6 +84,7 @@ func stringInSlice(value string, list []string) bool { } func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int) { + var err error r := m.GetSearchRequest() name, org, code := getNameAndOrgFromFilter(string(r.BaseObject()), r.FilterString()) @@ -93,11 +94,19 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int) if name == "*" && m.Client.IsOrgAdmin { // get all users from organization 'org' if m.Client.IsGlobalAdmin && org == "*" { - filteredUsers = object.GetGlobalUsers() + + filteredUsers, err = object.GetGlobalUsers() + if err != nil { + panic(err) + } return filteredUsers, ldap.LDAPResultSuccess } if m.Client.IsGlobalAdmin || org == m.Client.OrgName { - filteredUsers = object.GetUsers(org) + filteredUsers, err = object.GetUsers(org) + if err != nil { + panic(err) + } + return filteredUsers, ldap.LDAPResultSuccess } else { return nil, ldap.LDAPResultInsufficientAccessRights @@ -112,13 +121,21 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int) return nil, ldap.LDAPResultInsufficientAccessRights } - user := object.GetUser(userId) + user, err := object.GetUser(userId) + if err != nil { + panic(err) + } + if user != nil { filteredUsers = append(filteredUsers, user) return filteredUsers, ldap.LDAPResultSuccess } - organization := object.GetOrganization(util.GetId("admin", org)) + organization, err := object.GetOrganization(util.GetId("admin", org)) + if err != nil { + panic(err) + } + if organization == nil { return nil, ldap.LDAPResultNoSuchObject } @@ -127,7 +144,11 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int) return nil, ldap.LDAPResultNoSuchObject } - users := object.GetUsersByTag(org, name) + users, err := object.GetUsersByTag(org, name) + if err != nil { + panic(err) + } + filteredUsers = append(filteredUsers, users...) return filteredUsers, ldap.LDAPResultSuccess } @@ -137,7 +158,11 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int) // TODO not handle salt yet // @return {md5}5f4dcc3b5aa765d61d8327deb882cf99 func getUserPasswordWithType(user *object.User) string { - org := object.GetOrganizationByUser(user) + org, err := object.GetOrganizationByUser(user) + if err != nil { + panic(err) + } + if org.PasswordType == "" || org.PasswordType == "plain" { return user.Password } diff --git a/object/adapter.go b/object/adapter.go index 53ec8303..9e0bf6b8 100644 --- a/object/adapter.go +++ b/object/adapter.go @@ -119,7 +119,7 @@ func (a *Adapter) close() { } func (a *Adapter) createTable() { - showSql, _ := conf.GetConfigBool("showSql") + showSql := conf.GetConfigBool("showSql") a.Engine.ShowSQL(showSql) tableNamePrefix := conf.GetConfigString("tableNamePrefix") diff --git a/object/application.go b/object/application.go index 6b2eb451..ec45add2 100644 --- a/object/application.go +++ b/object/application.go @@ -79,134 +79,155 @@ type Application struct { FormBackgroundUrl string `xorm:"varchar(200)" json:"formBackgroundUrl"` } -func GetApplicationCount(owner, field, value string) int { +func GetApplicationCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Application{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Application{}) } -func GetOrganizationApplicationCount(owner, Organization, field, value string) int { +func GetOrganizationApplicationCount(owner, Organization, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Application{Organization: Organization}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Application{Organization: Organization}) } -func GetApplications(owner string) []*Application { +func GetApplications(owner string) ([]*Application, error) { applications := []*Application{} err := adapter.Engine.Desc("created_time").Find(&applications, &Application{Owner: owner}) if err != nil { - panic(err) + return applications, err } - return applications + return applications, nil } -func GetOrganizationApplications(owner string, organization string) []*Application { +func GetOrganizationApplications(owner string, organization string) ([]*Application, error) { applications := []*Application{} err := adapter.Engine.Desc("created_time").Find(&applications, &Application{Organization: organization}) if err != nil { - panic(err) + return applications, err } - return applications + return applications, nil } -func GetPaginationApplications(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Application { +func GetPaginationApplications(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Application, error) { var applications []*Application session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&applications) if err != nil { - panic(err) + return applications, err } - return applications + return applications, nil } -func GetPaginationOrganizationApplications(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Application { +func GetPaginationOrganizationApplications(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Application, error) { applications := []*Application{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&applications, &Application{Organization: organization}) if err != nil { - panic(err) + return applications, err } - return applications + return applications, nil } -func getProviderMap(owner string) map[string]*Provider { - providers := GetProviders(owner) - m := map[string]*Provider{} +func getProviderMap(owner string) (m map[string]*Provider, err error) { + providers, err := GetProviders(owner) + if err != nil { + return nil, err + } + + m = map[string]*Provider{} for _, provider := range providers { // Get QRCode only once if provider.Type == "WeChat" && provider.DisableSsl && provider.Content == "" { - provider.Content, _ = idp.GetWechatOfficialAccountQRCode(provider.ClientId2, provider.ClientSecret2) + provider.Content, err = idp.GetWechatOfficialAccountQRCode(provider.ClientId2, provider.ClientSecret2) + if err != nil { + return + } UpdateProvider(provider.Owner+"/"+provider.Name, provider) } m[provider.Name] = GetMaskedProvider(provider, true) } - return m + + return m, err } -func extendApplicationWithProviders(application *Application) { - m := getProviderMap(application.Organization) +func extendApplicationWithProviders(application *Application) (err error) { + m, err := getProviderMap(application.Organization) + if err != nil { + return err + } + for _, providerItem := range application.Providers { if provider, ok := m[providerItem.Name]; ok { providerItem.Provider = provider } } + + return } -func extendApplicationWithOrg(application *Application) { - organization := getOrganization(application.Owner, application.Organization) +func extendApplicationWithOrg(application *Application) (err error) { + organization, err := getOrganization(application.Owner, application.Organization) application.OrganizationObj = organization + return } -func getApplication(owner string, name string) *Application { +func getApplication(owner string, name string) (*Application, error) { if owner == "" || name == "" { - return nil + return nil, nil } application := Application{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&application) if err != nil { - panic(err) + return nil, err } if existed { - extendApplicationWithProviders(&application) - extendApplicationWithOrg(&application) - return &application + err = extendApplicationWithProviders(&application) + if err != nil { + return nil, err + } + + err = extendApplicationWithOrg(&application) + if err != nil { + return nil, err + } + + return &application, nil } else { - return nil + return nil, nil } } -func GetApplicationByOrganizationName(organization string) *Application { +func GetApplicationByOrganizationName(organization string) (*Application, error) { application := Application{} existed, err := adapter.Engine.Where("organization=?", organization).Get(&application) if err != nil { - panic(err) + return nil, nil } if existed { - extendApplicationWithProviders(&application) - extendApplicationWithOrg(&application) - return &application + err = extendApplicationWithProviders(&application) + if err != nil { + return nil, err + } + + err = extendApplicationWithOrg(&application) + if err != nil { + return nil, err + } + + return &application, nil } else { - return nil + return nil, nil } } -func GetApplicationByUser(user *User) *Application { +func GetApplicationByUser(user *User) (*Application, error) { if user.SignupApplication != "" { return getApplication("admin", user.SignupApplication) } else { @@ -214,38 +235,46 @@ func GetApplicationByUser(user *User) *Application { } } -func GetApplicationByUserId(userId string) (*Application, *User) { - var application *Application - +func GetApplicationByUserId(userId string) (application *Application, user *User, err error) { owner, name := util.GetOwnerAndNameFromId(userId) if owner == "app" { - application = getApplication("admin", name) - return application, nil + application, err = getApplication("admin", name) + return } - user := GetUser(userId) - application = GetApplicationByUser(user) - - return application, user + user, err = GetUser(userId) + if err != nil { + return nil, nil, err + } + application, err = GetApplicationByUser(user) + return } -func GetApplicationByClientId(clientId string) *Application { +func GetApplicationByClientId(clientId string) (*Application, error) { application := Application{} existed, err := adapter.Engine.Where("client_id=?", clientId).Get(&application) if err != nil { - panic(err) + return nil, err } if existed { - extendApplicationWithProviders(&application) - extendApplicationWithOrg(&application) - return &application + err = extendApplicationWithProviders(&application) + if err != nil { + return nil, err + } + + err = extendApplicationWithOrg(&application) + if err != nil { + return nil, err + } + + return &application, nil } else { - return nil + return nil, nil } } -func GetApplication(id string) *Application { +func GetApplication(id string) (*Application, error) { owner, name := util.GetOwnerAndNameFromId(id) return getApplication(owner, name) } @@ -288,11 +317,11 @@ func GetMaskedApplications(applications []*Application, userId string) []*Applic return applications } -func UpdateApplication(id string, application *Application) bool { +func UpdateApplication(id string, application *Application) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - oldApplication := getApplication(owner, name) + oldApplication, err := getApplication(owner, name) if oldApplication == nil { - return false + return false, err } if name == "app-built-in" { @@ -300,14 +329,19 @@ func UpdateApplication(id string, application *Application) bool { } if name != application.Name { - err := applicationChangeTrigger(name, application.Name) + err = applicationChangeTrigger(name, application.Name) if err != nil { - return false + return false, err } } - if oldApplication.ClientId != application.ClientId && GetApplicationByClientId(application.ClientId) != nil { - return false + applicationByClientId, err := GetApplicationByClientId(application.ClientId) + if err != nil { + return false, err + } + + if oldApplication.ClientId != application.ClientId && applicationByClientId != nil { + return false, err } for _, providerItem := range application.Providers { @@ -320,13 +354,13 @@ func UpdateApplication(id string, application *Application) bool { } affected, err := session.Update(application) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddApplication(application *Application) bool { +func AddApplication(application *Application) (bool, error) { if application.Owner == "" { application.Owner = "admin" } @@ -339,32 +373,39 @@ func AddApplication(application *Application) bool { if application.ClientSecret == "" { application.ClientSecret = util.GenerateClientSecret() } - if GetApplicationByClientId(application.ClientId) != nil { - return false + + app, err := GetApplicationByClientId(application.ClientId) + if err != nil { + return false, err } + + if app != nil { + return false, nil + } + for _, providerItem := range application.Providers { providerItem.Provider = nil } affected, err := adapter.Engine.Insert(application) if err != nil { - panic(err) + return false, nil } - return affected != 0 + return affected != 0, nil } -func DeleteApplication(application *Application) bool { +func DeleteApplication(application *Application) (bool, error) { if application.Name == "app-built-in" { - return false + return false, nil } affected, err := adapter.Engine.ID(core.PK{application.Owner, application.Name}).Delete(&Application{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (application *Application) GetId() string { @@ -383,33 +424,43 @@ func (application *Application) IsRedirectUriValid(redirectUri string) bool { return isValid } -func IsOriginAllowed(origin string) bool { - applications := GetApplications("") +func IsOriginAllowed(origin string) (bool, error) { + applications, err := GetApplications("") + if err != nil { + return false, err + } + for _, application := range applications { if application.IsRedirectUriValid(origin) { - return true + return true, nil } } - return false + return false, nil } -func getApplicationMap(organization string) map[string]*Application { - applications := GetOrganizationApplications("admin", organization) - +func getApplicationMap(organization string) (map[string]*Application, error) { applicationMap := make(map[string]*Application) + applications, err := GetOrganizationApplications("admin", organization) + if err != nil { + return applicationMap, err + } + for _, application := range applications { applicationMap[application.Name] = application } - return applicationMap + return applicationMap, nil } -func ExtendManagedAccountsWithUser(user *User) *User { +func ExtendManagedAccountsWithUser(user *User) (*User, error) { if user.ManagedAccounts == nil || len(user.ManagedAccounts) == 0 { - return user + return user, nil } - applicationMap := getApplicationMap(user.Owner) + applicationMap, err := getApplicationMap(user.Owner) + if err != nil { + return user, err + } var managedAccounts []ManagedAccount for _, managedAccount := range user.ManagedAccounts { @@ -421,7 +472,7 @@ func ExtendManagedAccountsWithUser(user *User) *User { } user.ManagedAccounts = managedAccounts - return user + return user, nil } func applicationChangeTrigger(oldName string, newName string) error { diff --git a/object/application_item.go b/object/application_item.go index 47dd93a8..7125593d 100644 --- a/object/application_item.go +++ b/object/application_item.go @@ -14,8 +14,12 @@ package object -func (application *Application) GetProviderByCategory(category string) *Provider { - providers := GetProviders(application.Organization) +func (application *Application) GetProviderByCategory(category string) (*Provider, error) { + providers, err := GetProviders(application.Organization) + if err != nil { + return nil, err + } + m := map[string]*Provider{} for _, provider := range providers { if provider.Category != category { @@ -27,22 +31,22 @@ func (application *Application) GetProviderByCategory(category string) *Provider for _, providerItem := range application.Providers { if provider, ok := m[providerItem.Name]; ok { - return provider + return provider, nil } } - return nil + return nil, nil } -func (application *Application) GetEmailProvider() *Provider { +func (application *Application) GetEmailProvider() (*Provider, error) { return application.GetProviderByCategory("Email") } -func (application *Application) GetSmsProvider() *Provider { +func (application *Application) GetSmsProvider() (*Provider, error) { return application.GetProviderByCategory("SMS") } -func (application *Application) GetStorageProvider() *Provider { +func (application *Application) GetStorageProvider() (*Provider, error) { return application.GetProviderByCategory("Storage") } diff --git a/object/avatar.go b/object/avatar.go index 387b63e8..7a2191f6 100644 --- a/object/avatar.go +++ b/object/avatar.go @@ -28,7 +28,11 @@ var defaultStorageProvider *Provider = nil func InitDefaultStorageProvider() { defaultStorageProviderStr := conf.GetConfigString("defaultStorageProvider") if defaultStorageProviderStr != "" { - defaultStorageProvider = getProvider("admin", defaultStorageProviderStr) + var err error + defaultStorageProvider, err = getProvider("admin", defaultStorageProviderStr) + if err != nil { + panic(err) + } } } @@ -50,40 +54,44 @@ func downloadFile(url string) (*bytes.Buffer, error) { return fileBuffer, nil } -func getPermanentAvatarUrl(organization string, username string, url string, upload bool) string { +func getPermanentAvatarUrl(organization string, username string, url string, upload bool) (string, error) { if url == "" { - return "" + return "", nil } if defaultStorageProvider == nil { - return "" + return "", nil } fullFilePath := fmt.Sprintf("/avatar/%s/%s.png", organization, username) uploadedFileUrl, _ := GetUploadFileUrl(defaultStorageProvider, fullFilePath, false) if upload { - DownloadAndUpload(url, fullFilePath, "en") + if err := DownloadAndUpload(url, fullFilePath, "en"); err != nil { + return "", err + } } - return uploadedFileUrl + return uploadedFileUrl, nil } -func DownloadAndUpload(url string, fullFilePath string, lang string) { +func DownloadAndUpload(url string, fullFilePath string, lang string) (err error) { fileBuffer, err := downloadFile(url) if err != nil { - panic(err) + return } _, _, err = UploadFileSafe(defaultStorageProvider, fullFilePath, fileBuffer, lang) if err != nil { - panic(err) + return } + + return } -func getPermanentAvatarUrlFromBuffer(organization string, username string, fileBuffer *bytes.Buffer, ext string, upload bool) string { +func getPermanentAvatarUrlFromBuffer(organization string, username string, fileBuffer *bytes.Buffer, ext string, upload bool) (string, error) { if defaultStorageProvider == nil { - return "" + return "", nil } fullFilePath := fmt.Sprintf("/avatar/%s/%s%s", organization, username, ext) @@ -92,9 +100,9 @@ func getPermanentAvatarUrlFromBuffer(organization string, username string, fileB if upload { _, _, err := UploadFileSafe(defaultStorageProvider, fullFilePath, fileBuffer, "en") if err != nil { - panic(err) + return "", err } } - return uploadedFileUrl + return uploadedFileUrl, nil } diff --git a/object/avatar_test.go b/object/avatar_test.go index 7827c011..d5dcbf2d 100644 --- a/object/avatar_test.go +++ b/object/avatar_test.go @@ -27,13 +27,21 @@ func TestSyncPermanentAvatars(t *testing.T) { InitDefaultStorageProvider() proxy.InitHttpClient() - users := GetGlobalUsers() + users, err := GetGlobalUsers() + if err != nil { + panic(err) + } + for i, user := range users { if user.Avatar == "" { continue } - user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) + user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) + if err != nil { + panic(err) + } + updateUserColumn("permanent_avatar", user) fmt.Printf("[%d/%d]: Update user: [%s]'s permanent avatar: %s\n", i, len(users), user.GetId(), user.PermanentAvatar) } @@ -44,16 +52,27 @@ func TestUpdateAvatars(t *testing.T) { InitDefaultStorageProvider() proxy.InitHttpClient() - users := GetUsers("casdoor") + users, err := GetUsers("casdoor") + if err != nil { + panic(err) + } + for _, user := range users { if strings.HasPrefix(user.Avatar, "http") { continue } - updated := user.refreshAvatar() + updated, err := user.refreshAvatar() + if err != nil { + panic(err) + } + if updated { user.PermanentAvatar = "*" - UpdateUser(user.GetId(), user, []string{"avatar"}, true) + _, err = UpdateUser(user.GetId(), user, []string{"avatar"}, true) + if err != nil { + panic(err) + } } } } diff --git a/object/captcha.go b/object/captcha.go index e883ef51..5f637cca 100644 --- a/object/captcha.go +++ b/object/captcha.go @@ -20,17 +20,17 @@ import ( "github.com/dchest/captcha" ) -func GetCaptcha() (string, []byte) { +func GetCaptcha() (string, []byte, error) { id := captcha.NewLen(5) var buffer bytes.Buffer err := captcha.WriteImage(&buffer, id, 200, 80) if err != nil { - panic(err) + return "", nil, err } - return id, buffer.Bytes() + return id, buffer.Bytes(), nil } func VerifyCaptcha(id string, digits string) bool { diff --git a/object/casbin_adapter.go b/object/casbin_adapter.go index a391cd44..51ae092d 100644 --- a/object/casbin_adapter.go +++ b/object/casbin_adapter.go @@ -46,64 +46,59 @@ type CasbinAdapter struct { Adapter *xormadapter.Adapter `xorm:"-" json:"-"` } -func GetCasbinAdapterCount(owner, organization, field, value string) int { +func GetCasbinAdapterCount(owner, organization, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&CasbinAdapter{Organization: organization}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&CasbinAdapter{Organization: organization}) } -func GetCasbinAdapters(owner string, organization string) []*CasbinAdapter { +func GetCasbinAdapters(owner string, organization string) ([]*CasbinAdapter, error) { adapters := []*CasbinAdapter{} err := adapter.Engine.Where("owner = ? and organization = ?", owner, organization).Find(&adapters) if err != nil { - panic(err) + return adapters, err } - return adapters + return adapters, nil } -func GetPaginationCasbinAdapters(owner, organization string, page, limit int, field, value, sort, order string) []*CasbinAdapter { +func GetPaginationCasbinAdapters(owner, organization string, page, limit int, field, value, sort, order string) ([]*CasbinAdapter, error) { session := GetSession(owner, page, limit, field, value, sort, order) adapters := []*CasbinAdapter{} err := session.Find(&adapters, &CasbinAdapter{Organization: organization}) if err != nil { - panic(err) + return adapters, err } - return adapters + return adapters, nil } -func getCasbinAdapter(owner, name string) *CasbinAdapter { +func getCasbinAdapter(owner, name string) (*CasbinAdapter, error) { if owner == "" || name == "" { - return nil + return nil, nil } casbinAdapter := CasbinAdapter{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&casbinAdapter) if err != nil { - panic(err) + return nil, err } if existed { - return &casbinAdapter + return &casbinAdapter, nil } else { - return nil + return nil, nil } } -func GetCasbinAdapter(id string) *CasbinAdapter { +func GetCasbinAdapter(id string) (*CasbinAdapter, error) { owner, name := util.GetOwnerAndNameFromId(id) return getCasbinAdapter(owner, name) } -func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) bool { +func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getCasbinAdapter(owner, name) == nil { - return false + if casbinAdapter, err := getCasbinAdapter(owner, name); casbinAdapter == nil { + return false, err } session := adapter.Engine.ID(core.PK{owner, name}).AllCols() @@ -112,28 +107,28 @@ func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) bool { } affected, err := session.Update(casbinAdapter) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddCasbinAdapter(casbinAdapter *CasbinAdapter) bool { +func AddCasbinAdapter(casbinAdapter *CasbinAdapter) (bool, error) { affected, err := adapter.Engine.Insert(casbinAdapter) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteCasbinAdapter(casbinAdapter *CasbinAdapter) bool { +func DeleteCasbinAdapter(casbinAdapter *CasbinAdapter) (bool, error) { affected, err := adapter.Engine.ID(core.PK{casbinAdapter.Owner, casbinAdapter.Name}).Delete(&CasbinAdapter{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (casbinAdapter *CasbinAdapter) GetId() string { @@ -214,7 +209,11 @@ func matrixToCasbinRules(Ptype string, policies [][]string) []*xormadapter.Casbi } func SyncPolicies(casbinAdapter *CasbinAdapter) ([]*xormadapter.CasbinRule, error) { - modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model) + modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model) + if err != nil { + return nil, err + } + enforcer, err := initEnforcer(modelObj, casbinAdapter) if err != nil { return nil, err @@ -229,7 +228,11 @@ func SyncPolicies(casbinAdapter *CasbinAdapter) ([]*xormadapter.CasbinRule, erro } func UpdatePolicy(oldPolicy, newPolicy []string, casbinAdapter *CasbinAdapter) (bool, error) { - modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model) + modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model) + if err != nil { + return false, err + } + enforcer, err := initEnforcer(modelObj, casbinAdapter) if err != nil { return false, err @@ -243,7 +246,11 @@ func UpdatePolicy(oldPolicy, newPolicy []string, casbinAdapter *CasbinAdapter) ( } func AddPolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) { - modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model) + modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model) + if err != nil { + return false, err + } + enforcer, err := initEnforcer(modelObj, casbinAdapter) if err != nil { return false, err @@ -257,7 +264,11 @@ func AddPolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) { } func RemovePolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) { - modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model) + modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model) + if err != nil { + return false, err + } + enforcer, err := initEnforcer(modelObj, casbinAdapter) if err != nil { return false, err diff --git a/object/cert.go b/object/cert.go index 2e1ccac5..697dc416 100644 --- a/object/cert.go +++ b/object/cert.go @@ -47,137 +47,133 @@ func GetMaskedCert(cert *Cert) *Cert { return cert } -func GetMaskedCerts(certs []*Cert) []*Cert { +func GetMaskedCerts(certs []*Cert, err error) ([]*Cert, error) { + if err != nil { + return nil, err + } + for _, cert := range certs { cert = GetMaskedCert(cert) } - return certs + return certs, nil } -func GetCertCount(owner, field, value string) int { +func GetCertCount(owner, field, value string) (int64, error) { session := GetSession("", -1, -1, field, value, "", "") - count, err := session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Cert{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Cert{}) } -func GetCerts(owner string) []*Cert { +func GetCerts(owner string) ([]*Cert, error) { certs := []*Cert{} err := adapter.Engine.Where("owner = ? or owner = ? ", "admin", owner).Desc("created_time").Find(&certs, &Cert{}) if err != nil { - panic(err) + return certs, err } - return certs + return certs, nil } -func GetPaginationCerts(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Cert { +func GetPaginationCerts(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Cert, error) { certs := []*Cert{} session := GetSession("", offset, limit, field, value, sortField, sortOrder) err := session.Where("owner = ? or owner = ? ", "admin", owner).Find(&certs) if err != nil { - panic(err) + return certs, err } - return certs + return certs, nil } -func GetGlobalCertsCount(field, value string) int { +func GetGlobalCertsCount(field, value string) (int64, error) { session := GetSession("", -1, -1, field, value, "", "") - count, err := session.Count(&Cert{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Cert{}) } -func GetGlobleCerts() []*Cert { +func GetGlobleCerts() ([]*Cert, error) { certs := []*Cert{} err := adapter.Engine.Desc("created_time").Find(&certs) if err != nil { - panic(err) + return certs, err } - return certs + return certs, nil } -func GetPaginationGlobalCerts(offset, limit int, field, value, sortField, sortOrder string) []*Cert { +func GetPaginationGlobalCerts(offset, limit int, field, value, sortField, sortOrder string) ([]*Cert, error) { certs := []*Cert{} session := GetSession("", offset, limit, field, value, sortField, sortOrder) err := session.Find(&certs) if err != nil { - panic(err) + return certs, err } - return certs + return certs, nil } -func getCert(owner string, name string) *Cert { +func getCert(owner string, name string) (*Cert, error) { if owner == "" || name == "" { - return nil + return nil, nil } cert := Cert{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&cert) if err != nil { - panic(err) + return &cert, err } if existed { - return &cert + return &cert, nil } else { - return nil + return nil, nil } } -func getCertByName(name string) *Cert { +func getCertByName(name string) (*Cert, error) { if name == "" { - return nil + return nil, nil } cert := Cert{Name: name} existed, err := adapter.Engine.Get(&cert) if err != nil { - panic(err) + return &cert, nil } if existed { - return &cert + return &cert, nil } else { - return nil + return nil, nil } } -func GetCert(id string) *Cert { +func GetCert(id string) (*Cert, error) { owner, name := util.GetOwnerAndNameFromId(id) return getCert(owner, name) } -func UpdateCert(id string, cert *Cert) bool { +func UpdateCert(id string, cert *Cert) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getCert(owner, name) == nil { - return false + if c, err := getCert(owner, name); err != nil { + return false, err + } else if c == nil { + return false, nil } if name != cert.Name { err := certChangeTrigger(name, cert.Name) if err != nil { - return false + return false, nil } } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(cert) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddCert(cert *Cert) bool { +func AddCert(cert *Cert) (bool, error) { if cert.Certificate == "" || cert.PrivateKey == "" { certificate, privateKey := generateRsaKeys(cert.BitSize, cert.ExpireInYears, cert.Name, cert.Owner) cert.Certificate = certificate @@ -186,26 +182,26 @@ func AddCert(cert *Cert) bool { affected, err := adapter.Engine.Insert(cert) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteCert(cert *Cert) bool { +func DeleteCert(cert *Cert) (bool, error) { affected, err := adapter.Engine.ID(core.PK{cert.Owner, cert.Name}).Delete(&Cert{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (p *Cert) GetId() string { return fmt.Sprintf("%s/%s", p.Owner, p.Name) } -func getCertByApplication(application *Application) *Cert { +func getCertByApplication(application *Application) (*Cert, error) { if application.Cert != "" { return getCertByName(application.Cert) } else { @@ -213,7 +209,7 @@ func getCertByApplication(application *Application) *Cert { } } -func GetDefaultCert() *Cert { +func GetDefaultCert() (*Cert, error) { return getCert("admin", "cert-built-in") } diff --git a/object/chat.go b/object/chat.go index 87e2c50a..875b0391 100644 --- a/object/chat.go +++ b/object/chat.go @@ -37,92 +37,104 @@ type Chat struct { MessageCount int `json:"messageCount"` } -func GetMaskedChat(chat *Chat) *Chat { +func GetMaskedChat(chat *Chat, err ...error) (*Chat, error) { + if len(err) > 0 && err[0] != nil { + return nil, err[0] + } + if chat == nil { - return nil + return nil, nil } - return chat + return chat, nil } -func GetMaskedChats(chats []*Chat) []*Chat { +func GetMaskedChats(chats []*Chat, errs ...error) ([]*Chat, error) { + if len(errs) > 0 && errs[0] != nil { + return nil, errs[0] + } + var err error for _, chat := range chats { - chat = GetMaskedChat(chat) + chat, err = GetMaskedChat(chat) + if err != nil { + return nil, err + } } - return chats + return chats, nil } -func GetChatCount(owner, field, value string) int { +func GetChatCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Chat{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Chat{}) } -func GetChats(owner string) []*Chat { +func GetChats(owner string) ([]*Chat, error) { chats := []*Chat{} err := adapter.Engine.Desc("created_time").Find(&chats, &Chat{Owner: owner}) if err != nil { - panic(err) + return chats, err } - return chats + return chats, nil } -func GetPaginationChats(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Chat { +func GetPaginationChats(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Chat, error) { chats := []*Chat{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&chats) if err != nil { - panic(err) + return chats, err } - return chats + return chats, nil } -func getChat(owner string, name string) *Chat { +func getChat(owner string, name string) (*Chat, error) { if owner == "" || name == "" { - return nil + return nil, nil } chat := Chat{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&chat) if err != nil { - panic(err) + return &chat, err } if existed { - return &chat + return &chat, nil } else { - return nil + return nil, nil } } -func GetChat(id string) *Chat { +func GetChat(id string) (*Chat, error) { owner, name := util.GetOwnerAndNameFromId(id) return getChat(owner, name) } -func UpdateChat(id string, chat *Chat) bool { +func UpdateChat(id string, chat *Chat) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getChat(owner, name) == nil { - return false + if c, err := getChat(owner, name); err != nil { + return false, err + } else if c == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(chat) if err != nil { - panic(err) + return false, nil } - return affected != 0 + return affected != 0, nil } -func AddChat(chat *Chat) bool { +func AddChat(chat *Chat) (bool, error) { if chat.Type == "AI" && chat.User2 == "" { - provider := getDefaultAiProvider() + provider, err := getDefaultAiProvider() + if err != nil { + return false, err + } + if provider != nil { chat.User2 = provider.Name } @@ -130,23 +142,23 @@ func AddChat(chat *Chat) bool { affected, err := adapter.Engine.Insert(chat) if err != nil { - panic(err) + return false, nil } - return affected != 0 + return affected != 0, nil } -func DeleteChat(chat *Chat) bool { +func DeleteChat(chat *Chat) (bool, error) { affected, err := adapter.Engine.ID(core.PK{chat.Owner, chat.Name}).Delete(&Chat{}) if err != nil { - panic(err) + return false, err } if affected != 0 { return DeleteChatMessages(chat.Name) } - return affected != 0 + return affected != 0, nil } func (p *Chat) GetId() string { diff --git a/object/check.go b/object/check.go index 07cbf46a..bf134525 100644 --- a/object/check.go +++ b/object/check.go @@ -170,7 +170,11 @@ func CheckPassword(user *User, password string, lang string, options ...bool) st } } - organization := GetOrganizationByUser(user) + organization, err := GetOrganizationByUser(user) + if err != nil { + panic(err) + } + if organization == nil { return i18n.Translate(lang, "check:Organization does not exist") } @@ -200,7 +204,11 @@ func CheckPassword(user *User, password string, lang string, options ...bool) st } func checkLdapUserPassword(user *User, password string, lang string) string { - ldaps := GetLdaps(user.Owner) + ldaps, err := GetLdaps(user.Owner) + if err != nil { + return err.Error() + } + ldapLoginSuccess := false hit := false @@ -247,7 +255,11 @@ func CheckUserPassword(organization string, username string, password string, la if len(options) > 0 { enableCaptcha = options[0] } - user := GetUserByFields(organization, username) + user, err := GetUserByFields(organization, username) + if err != nil { + panic(err) + } + if user == nil || user.IsDeleted { return nil, fmt.Sprintf(i18n.Translate(lang, "general:The user: %s doesn't exist"), util.GetId(organization, username)) } @@ -284,7 +296,11 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string) userOwner := util.GetOwnerFromId(userId) if userId != "" { - targetUser := GetUser(userId) + targetUser, err := GetUser(userId) + if err != nil { + panic(err) + } + if targetUser == nil { if strings.HasPrefix(requestUserId, "built-in/") { return true, nil @@ -300,7 +316,11 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string) if strings.HasPrefix(requestUserId, "app/") { hasPermission = true } else { - requestUser := GetUser(requestUserId) + requestUser, err := GetUser(requestUserId) + if err != nil { + return false, err + } + if requestUser == nil { return false, fmt.Errorf(i18n.Translate(lang, "check:Session outdated, please login again")) } @@ -321,13 +341,17 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string) } func CheckAccessPermission(userId string, application *Application) (bool, error) { + var err error if userId == "built-in/admin" { return true, nil } - permissions := GetPermissions(application.Organization) + permissions, err := GetPermissions(application.Organization) + if err != nil { + return false, err + } + allowed := true - var err error for _, permission := range permissions { if !permission.IsEnabled || len(permission.Users) == 0 { continue @@ -403,9 +427,9 @@ func CheckUpdateUser(oldUser, user *User, lang string) string { return "" } -func CheckToEnableCaptcha(application *Application, organization, username string) bool { +func CheckToEnableCaptcha(application *Application, organization, username string) (bool, error) { if len(application.Providers) == 0 { - return false + return false, nil } for _, providerItem := range application.Providers { @@ -414,12 +438,15 @@ func CheckToEnableCaptcha(application *Application, organization, username strin } if providerItem.Provider.Category == "Captcha" { if providerItem.Rule == "Dynamic" { - user := GetUserByFields(organization, username) - return user != nil && user.SigninWrongTimes >= SigninWrongTimesLimit + user, err := GetUserByFields(organization, username) + if err != nil { + return false, err + } + return user != nil && user.SigninWrongTimes >= SigninWrongTimesLimit, nil } - return providerItem.Rule == "Always" + return providerItem.Rule == "Always", nil } } - return false + return false, nil } diff --git a/object/init.go b/object/init.go index 4fac38c9..33474e4a 100644 --- a/object/init.go +++ b/object/init.go @@ -74,7 +74,11 @@ func getBuiltInAccountItems() []*AccountItem { } func initBuiltInOrganization() bool { - organization := getOrganization("admin", "built-in") + organization, err := getOrganization("admin", "built-in") + if err != nil { + panic(err) + } + if organization != nil { return true } @@ -96,12 +100,19 @@ func initBuiltInOrganization() bool { EnableSoftDeletion: false, IsProfilePublic: false, } - AddOrganization(organization) + _, err = AddOrganization(organization) + if err != nil { + panic(err) + } + return false } func initBuiltInUser() { - user := getUser("built-in", "admin") + user, err := getUser("built-in", "admin") + if err != nil { + panic(err) + } if user != nil { return } @@ -131,11 +142,18 @@ func initBuiltInUser() { CreatedIp: "127.0.0.1", Properties: make(map[string]string), } - AddUser(user) + _, err = AddUser(user) + if err != nil { + panic(err) + } } func initBuiltInApplication() { - application := getApplication("admin", "app-built-in") + application, err := getApplication("admin", "app-built-in") + if err != nil { + panic(err) + } + if application != nil { return } @@ -168,7 +186,10 @@ func initBuiltInApplication() { ExpireInHours: 168, FormOffset: 2, } - AddApplication(application) + _, err = AddApplication(application) + if err != nil { + panic(err) + } } func readTokenFromFile() (string, string) { @@ -187,7 +208,11 @@ func readTokenFromFile() (string, string) { func initBuiltInCert() { tokenJwtCertificate, tokenJwtPrivateKey := readTokenFromFile() - cert := getCert("admin", "cert-built-in") + cert, err := getCert("admin", "cert-built-in") + if err != nil { + panic(err) + } + if cert != nil { return } @@ -205,11 +230,18 @@ func initBuiltInCert() { Certificate: tokenJwtCertificate, PrivateKey: tokenJwtPrivateKey, } - AddCert(cert) + _, err = AddCert(cert) + if err != nil { + panic(err) + } } func initBuiltInLdap() { - ldap := GetLdap("ldap-built-in") + ldap, err := GetLdap("ldap-built-in") + if err != nil { + panic(err) + } + if ldap != nil { return } @@ -226,11 +258,18 @@ func initBuiltInLdap() { AutoSync: 0, LastSync: "", } - AddLdap(ldap) + _, err = AddLdap(ldap) + if err != nil { + panic(err) + } } func initBuiltInProvider() { - provider := GetProvider(util.GetId("admin", "provider_captcha_default")) + provider, err := GetProvider(util.GetId("admin", "provider_captcha_default")) + if err != nil { + panic(err) + } + if provider != nil { return } @@ -243,7 +282,10 @@ func initBuiltInProvider() { Category: "Captcha", Type: "Default", } - AddProvider(provider) + _, err = AddProvider(provider) + if err != nil { + panic(err) + } } func initWebAuthn() { @@ -251,7 +293,11 @@ func initWebAuthn() { } func initBuiltInModel() { - model := GetModel("built-in/model-built-in") + model, err := GetModel("built-in/model-built-in") + if err != nil { + panic(err) + } + if model != nil { return } @@ -274,11 +320,17 @@ e = some(where (p.eft == allow)) [matchers] m = r.sub == p.sub && r.obj == p.obj && r.act == p.act`, } - AddModel(model) + _, err = AddModel(model) + if err != nil { + panic(err) + } } func initBuiltInPermission() { - permission := GetPermission("built-in/permission-built-in") + permission, err := GetPermission("built-in/permission-built-in") + if err != nil { + panic(err) + } if permission != nil { return } @@ -298,5 +350,8 @@ func initBuiltInPermission() { Effect: "Allow", IsEnabled: true, } - AddPermission(permission) + _, err = AddPermission(permission) + if err != nil { + panic(err) + } } diff --git a/object/init_data.go b/object/init_data.go index 09b98ca3..fefd1af6 100644 --- a/object/init_data.go +++ b/object/init_data.go @@ -35,7 +35,11 @@ type InitData struct { } func InitFromFile() { - initData := readInitDataFromFile("./init_data.json") + initData, err := readInitDataFromFile("./init_data.json") + if err != nil { + panic(err) + } + if initData != nil { for _, organization := range initData.Organizations { initDefinedOrganization(organization) @@ -85,9 +89,9 @@ func InitFromFile() { } } -func readInitDataFromFile(filePath string) *InitData { +func readInitDataFromFile(filePath string) (*InitData, error) { if !util.FileExist(filePath) { - return nil + return nil, nil } s := util.ReadStringFromPath(filePath) @@ -111,7 +115,7 @@ func readInitDataFromFile(filePath string) *InitData { } err := util.JsonToStruct(s, data) if err != nil { - panic(err) + return nil, err } // transform nil slice to empty slice @@ -170,142 +174,246 @@ func readInitDataFromFile(filePath string) *InitData { } } - return data + return data, nil } func initDefinedOrganization(organization *Organization) { - existed := getOrganization(organization.Owner, organization.Name) + existed, err := getOrganization(organization.Owner, organization.Name) + if err != nil { + panic(err) + } + if existed != nil { return } organization.CreatedTime = util.GetCurrentTime() organization.AccountItems = getBuiltInAccountItems() - AddOrganization(organization) + _, err = AddOrganization(organization) + if err != nil { + panic(err) + } } func initDefinedApplication(application *Application) { - existed := getApplication(application.Owner, application.Name) + existed, err := getApplication(application.Owner, application.Name) + if err != nil { + panic(err) + } + if existed != nil { return } application.CreatedTime = util.GetCurrentTime() - AddApplication(application) + _, err = AddApplication(application) + if err != nil { + panic(err) + } } func initDefinedUser(user *User) { - existed := getUser(user.Owner, user.Name) + existed, err := getUser(user.Owner, user.Name) + if err != nil { + panic(err) + } if existed != nil { return } user.CreatedTime = util.GetCurrentTime() user.Id = util.GenerateId() user.Properties = make(map[string]string) - AddUser(user) + _, err = AddUser(user) + if err != nil { + panic(err) + } } func initDefinedCert(cert *Cert) { - existed := getCert(cert.Owner, cert.Name) + existed, err := getCert(cert.Owner, cert.Name) + if err != nil { + panic(err) + } + if existed != nil { return } cert.CreatedTime = util.GetCurrentTime() - AddCert(cert) + _, err = AddCert(cert) + if err != nil { + panic(err) + } } func initDefinedLdap(ldap *Ldap) { - existed := GetLdap(ldap.Id) + existed, err := GetLdap(ldap.Id) + if err != nil { + panic(err) + } + if existed != nil { return } - AddLdap(ldap) + _, err = AddLdap(ldap) + if err != nil { + panic(err) + } } func initDefinedProvider(provider *Provider) { - existed := GetProvider(util.GetId("admin", provider.Name)) + existed, err := GetProvider(util.GetId("admin", provider.Name)) + if err != nil { + panic(err) + } + if existed != nil { return } - AddProvider(provider) + _, err = AddProvider(provider) + if err != nil { + panic(err) + } } func initDefinedModel(model *Model) { - existed := GetModel(model.GetId()) + existed, err := GetModel(model.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } model.CreatedTime = util.GetCurrentTime() - AddModel(model) + _, err = AddModel(model) + if err != nil { + panic(err) + } } func initDefinedPermission(permission *Permission) { - existed := GetPermission(permission.GetId()) + existed, err := GetPermission(permission.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } permission.CreatedTime = util.GetCurrentTime() - AddPermission(permission) + _, err = AddPermission(permission) + if err != nil { + panic(err) + } } func initDefinedPayment(payment *Payment) { - existed := GetPayment(payment.GetId()) + existed, err := GetPayment(payment.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } payment.CreatedTime = util.GetCurrentTime() - AddPayment(payment) + _, err = AddPayment(payment) + if err != nil { + panic(err) + } } func initDefinedProduct(product *Product) { - existed := GetProduct(product.GetId()) + existed, err := GetProduct(product.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } product.CreatedTime = util.GetCurrentTime() - AddProduct(product) + _, err = AddProduct(product) + if err != nil { + panic(err) + } } func initDefinedResource(resource *Resource) { - existed := GetResource(resource.GetId()) + existed, err := GetResource(resource.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } resource.CreatedTime = util.GetCurrentTime() - AddResource(resource) + _, err = AddResource(resource) + if err != nil { + panic(err) + } } func initDefinedRole(role *Role) { - existed := GetRole(role.GetId()) + existed, err := GetRole(role.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } role.CreatedTime = util.GetCurrentTime() - AddRole(role) + _, err = AddRole(role) + if err != nil { + panic(err) + } } func initDefinedSyncer(syncer *Syncer) { - existed := GetSyncer(syncer.GetId()) + existed, err := GetSyncer(syncer.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } syncer.CreatedTime = util.GetCurrentTime() - AddSyncer(syncer) + _, err = AddSyncer(syncer) + if err != nil { + panic(err) + } } func initDefinedToken(token *Token) { - existed := GetToken(token.GetId()) + existed, err := GetToken(token.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } token.CreatedTime = util.GetCurrentTime() - AddToken(token) + _, err = AddToken(token) + if err != nil { + panic(err) + } } func initDefinedWebhook(webhook *Webhook) { - existed := GetWebhook(webhook.GetId()) + existed, err := GetWebhook(webhook.GetId()) + if err != nil { + panic(err) + } + if existed != nil { return } webhook.CreatedTime = util.GetCurrentTime() - AddWebhook(webhook) + _, err = AddWebhook(webhook) + if err != nil { + panic(err) + } } diff --git a/object/ldap.go b/object/ldap.go index 49c4b0b8..d76c0e63 100644 --- a/object/ldap.go +++ b/object/ldap.go @@ -37,7 +37,7 @@ type Ldap struct { LastSync string `xorm:"varchar(100)" json:"lastSync"` } -func AddLdap(ldap *Ldap) bool { +func AddLdap(ldap *Ldap) (bool, error) { if len(ldap.Id) == 0 { ldap.Id = util.GenerateId() } @@ -48,13 +48,13 @@ func AddLdap(ldap *Ldap) bool { affected, err := adapter.Engine.Insert(ldap) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func CheckLdapExist(ldap *Ldap) bool { +func CheckLdapExist(ldap *Ldap) (bool, error) { var result []*Ldap err := adapter.Engine.Find(&result, &Ldap{ Owner: ldap.Owner, @@ -65,63 +65,65 @@ func CheckLdapExist(ldap *Ldap) bool { BaseDn: ldap.BaseDn, }) if err != nil { - panic(err) + return false, err } if len(result) > 0 { - return true + return true, nil } - return false + return false, nil } -func GetLdaps(owner string) []*Ldap { +func GetLdaps(owner string) ([]*Ldap, error) { var ldaps []*Ldap err := adapter.Engine.Desc("created_time").Find(&ldaps, &Ldap{Owner: owner}) if err != nil { - panic(err) + return ldaps, err } - return ldaps + return ldaps, nil } -func GetLdap(id string) *Ldap { +func GetLdap(id string) (*Ldap, error) { if util.IsStringsEmpty(id) { - return nil + return nil, nil } ldap := Ldap{Id: id} existed, err := adapter.Engine.Get(&ldap) if err != nil { - panic(err) + return &ldap, nil } if existed { - return &ldap + return &ldap, nil } else { - return nil + return nil, nil } } -func UpdateLdap(ldap *Ldap) bool { - if GetLdap(ldap.Id) == nil { - return false +func UpdateLdap(ldap *Ldap) (bool, error) { + if l, err := GetLdap(ldap.Id); err != nil { + return false, nil + } else if l == nil { + return false, nil } affected, err := adapter.Engine.ID(ldap.Id).Cols("owner", "server_name", "host", "port", "enable_ssl", "username", "password", "base_dn", "filter", "filter_fields", "auto_sync").Update(ldap) if err != nil { - panic(err) + return false, nil } - return affected != 0 + return affected != 0, nil } -func DeleteLdap(ldap *Ldap) bool { +func DeleteLdap(ldap *Ldap) (bool, error) { affected, err := adapter.Engine.ID(ldap.Id).Delete(&Ldap{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } diff --git a/object/ldap_autosync.go b/object/ldap_autosync.go index 8eb22a8b..8e9951d5 100644 --- a/object/ldap_autosync.go +++ b/object/ldap_autosync.go @@ -18,7 +18,10 @@ var globalLdapAutoSynchronizer *LdapAutoSynchronizer func InitLdapAutoSynchronizer() { globalLdapAutoSynchronizer = NewLdapAutoSynchronizer() - globalLdapAutoSynchronizer.LdapAutoSynchronizerStartUpAll() + err := globalLdapAutoSynchronizer.LdapAutoSynchronizerStartUpAll() + if err != nil { + panic(err) + } } func NewLdapAutoSynchronizer() *LdapAutoSynchronizer { @@ -37,7 +40,11 @@ func (l *LdapAutoSynchronizer) StartAutoSync(ldapId string) error { l.Lock() defer l.Unlock() - ldap := GetLdap(ldapId) + ldap, err := GetLdap(ldapId) + if err != nil { + return err + } + if ldap == nil { return fmt.Errorf("ldap %s doesn't exist", ldapId) } @@ -49,7 +56,12 @@ func (l *LdapAutoSynchronizer) StartAutoSync(ldapId string) error { stopChan := make(chan struct{}) l.ldapIdToStopChan[ldapId] = stopChan logs.Info(fmt.Sprintf("autoSync started for %s", ldap.Id)) - util.SafeGoroutine(func() { l.syncRoutine(ldap, stopChan) }) + util.SafeGoroutine(func() { + err := l.syncRoutine(ldap, stopChan) + if err != nil { + panic(err) + } + }) return nil } @@ -63,18 +75,22 @@ func (l *LdapAutoSynchronizer) StopAutoSync(ldapId string) { } // autosync goroutine -func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) { +func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) error { ticker := time.NewTicker(time.Duration(ldap.AutoSync) * time.Minute) defer ticker.Stop() for { select { case <-stopChan: logs.Info(fmt.Sprintf("autoSync goroutine for %s stopped", ldap.Id)) - return + return nil case <-ticker.C: } - UpdateLdapSyncTime(ldap.Id) + err := UpdateLdapSyncTime(ldap.Id) + if err != nil { + return err + } + // fetch all users conn, err := ldap.GetLdapConn() if err != nil { @@ -100,24 +116,35 @@ func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) { // LdapAutoSynchronizerStartUpAll // start all autosync goroutine for existing ldap servers in each organizations -func (l *LdapAutoSynchronizer) LdapAutoSynchronizerStartUpAll() { +func (l *LdapAutoSynchronizer) LdapAutoSynchronizerStartUpAll() error { organizations := []*Organization{} err := adapter.Engine.Desc("created_time").Find(&organizations) if err != nil { logs.Info("failed to Star up LdapAutoSynchronizer; ") } for _, org := range organizations { - for _, ldap := range GetLdaps(org.Name) { + ldaps, err := GetLdaps(org.Name) + if err != nil { + return err + } + + for _, ldap := range ldaps { if ldap.AutoSync != 0 { - l.StartAutoSync(ldap.Id) + err = l.StartAutoSync(ldap.Id) + if err != nil { + return err + } } } } + return nil } -func UpdateLdapSyncTime(ldapId string) { +func UpdateLdapSyncTime(ldapId string) error { _, err := adapter.Engine.ID(ldapId).Update(&Ldap{LastSync: util.GetCurrentTime()}) if err != nil { - panic(err) + return err } + + return nil } diff --git a/object/ldap_conn.go b/object/ldap_conn.go index 055812d9..219c9ac6 100644 --- a/object/ldap_conn.go +++ b/object/ldap_conn.go @@ -255,8 +255,12 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser uuids = append(uuids, user.Uuid) } - organization := getOrganization("admin", owner) - ldap := GetLdap(ldapId) + organization, err := getOrganization("admin", owner) + if err != nil { + panic(err) + } + + ldap, err := GetLdap(ldapId) var dc []string for _, basedn := range strings.Split(ldap.BaseDn, ",") { @@ -275,7 +279,11 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser tag := strings.Join(ou, ".") for _, syncUser := range syncUsers { - existUuids := GetExistUuids(owner, uuids) + existUuids, err := GetExistUuids(owner, uuids) + if err != nil { + return nil, nil, err + } + found := false if len(existUuids) > 0 { for _, existUuid := range existUuids { @@ -287,10 +295,19 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser } if !found { - score, _ := organization.GetInitScore() + score, err := organization.GetInitScore() + if err != nil { + return nil, nil, err + } + + name, err := syncUser.buildLdapUserName() + if err != nil { + return nil, nil, err + } + newUser := &User{ Owner: owner, - Name: syncUser.buildLdapUserName(), + Name: name, CreatedTime: util.GetCurrentTime(), DisplayName: syncUser.buildLdapDisplayName(), Avatar: organization.DefaultAvatar, @@ -303,7 +320,11 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser Ldap: syncUser.Uuid, } - affected := AddUser(newUser) + affected, err := AddUser(newUser) + if err != nil { + return nil, nil, err + } + if !affected { failedUsers = append(failedUsers, syncUser) continue @@ -314,38 +335,38 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser return existUsers, failedUsers, err } -func GetExistUuids(owner string, uuids []string) []string { +func GetExistUuids(owner string, uuids []string) ([]string, error) { var existUuids []string err := adapter.Engine.Table("user").Where("owner = ?", owner).Cols("ldap"). In("ldap", uuids).Select("DISTINCT ldap").Find(&existUuids) if err != nil { - panic(err) + return existUuids, err } - return existUuids + return existUuids, nil } -func (ldapUser *LdapUser) buildLdapUserName() string { +func (ldapUser *LdapUser) buildLdapUserName() (string, error) { user := User{} uidWithNumber := fmt.Sprintf("%s_%s", ldapUser.Uid, ldapUser.UidNumber) has, err := adapter.Engine.Where("name = ? or name = ?", ldapUser.Uid, uidWithNumber).Get(&user) if err != nil { - panic(err) + return "", err } if has { if user.Name == ldapUser.Uid { - return uidWithNumber + return uidWithNumber, nil } - return fmt.Sprintf("%s_%s", uidWithNumber, randstr.Hex(6)) + return fmt.Sprintf("%s_%s", uidWithNumber, randstr.Hex(6)), nil } if ldapUser.Uid != "" { - return ldapUser.Uid + return ldapUser.Uid, nil } - return ldapUser.Cn + return ldapUser.Cn, nil } func (ldapUser *LdapUser) buildLdapDisplayName() string { diff --git a/object/message.go b/object/message.go index 84affd8f..dfa84208 100644 --- a/object/message.go +++ b/object/message.go @@ -48,109 +48,94 @@ func GetMaskedMessages(messages []*Message) []*Message { return messages } -func GetMessageCount(owner, organization, field, value string) int { +func GetMessageCount(owner, organization, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Message{Organization: organization}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Message{Organization: organization}) } -func GetMessages(owner string) []*Message { +func GetMessages(owner string) ([]*Message, error) { messages := []*Message{} err := adapter.Engine.Desc("created_time").Find(&messages, &Message{Owner: owner}) - if err != nil { - panic(err) - } - - return messages + return messages, err } -func GetChatMessages(chat string) []*Message { +func GetChatMessages(chat string) ([]*Message, error) { messages := []*Message{} err := adapter.Engine.Asc("created_time").Find(&messages, &Message{Chat: chat}) - if err != nil { - panic(err) - } - - return messages + return messages, err } -func GetPaginationMessages(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Message { +func GetPaginationMessages(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Message, error) { messages := []*Message{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&messages, &Message{Organization: organization}) - if err != nil { - panic(err) - } - - return messages + return messages, err } -func getMessage(owner string, name string) *Message { +func getMessage(owner string, name string) (*Message, error) { if owner == "" || name == "" { - return nil + return nil, nil } message := Message{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&message) if err != nil { - panic(err) + return nil, err } if existed { - return &message + return &message, nil } else { - return nil + return nil, nil } } -func GetMessage(id string) *Message { +func GetMessage(id string) (*Message, error) { owner, name := util.GetOwnerAndNameFromId(id) return getMessage(owner, name) } -func UpdateMessage(id string, message *Message) bool { +func UpdateMessage(id string, message *Message) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getMessage(owner, name) == nil { - return false + if m, err := getMessage(owner, name); err != nil { + return false, err + } else if m == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(message) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddMessage(message *Message) bool { +func AddMessage(message *Message) (bool, error) { affected, err := adapter.Engine.Insert(message) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteMessage(message *Message) bool { +func DeleteMessage(message *Message) (bool, error) { affected, err := adapter.Engine.ID(core.PK{message.Owner, message.Name}).Delete(&Message{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteChatMessages(chat string) bool { +func DeleteChatMessages(chat string) (bool, error) { affected, err := adapter.Engine.Delete(&Message{Chat: chat}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (p *Message) GetId() string { diff --git a/object/mfa.go b/object/mfa.go index 98759ec5..d4e74385 100644 --- a/object/mfa.go +++ b/object/mfa.go @@ -83,7 +83,11 @@ func RecoverTfs(user *User, recoveryCode string) error { return fmt.Errorf("recovery code not found") } - affected := UpdateUser(user.GetId(), user, []string{"two_factor_auth"}, user.IsAdminUser()) + affected, err := UpdateUser(user.GetId(), user, []string{"two_factor_auth"}, user.IsAdminUser()) + if err != nil { + return err + } + if !affected { return fmt.Errorf("") } diff --git a/object/mfa_sms.go b/object/mfa_sms.go index 6ac7dce7..d298d2e2 100644 --- a/object/mfa_sms.go +++ b/object/mfa_sms.go @@ -100,7 +100,11 @@ func (mfa *SmsMfa) Enable(ctx *context.Context, user *User) error { } user.MultiFactorAuths = append(user.MultiFactorAuths, mfa.Config) - affected := UpdateUser(user.GetId(), user, []string{"multi_factor_auths"}, user.IsAdminUser()) + affected, err := UpdateUser(user.GetId(), user, []string{"multi_factor_auths"}, user.IsAdminUser()) + if err != nil { + return err + } + if !affected { return fmt.Errorf("failed to enable two factor authentication") } diff --git a/object/migrator.go b/object/migrator.go index 5e0e8dcf..d86aff5d 100644 --- a/object/migrator.go +++ b/object/migrator.go @@ -44,5 +44,8 @@ func DoMigration() { } m := migrate.New(adapter.Engine, options, migrations) - m.Migrate() + err := m.Migrate() + if err != nil { + panic(err) + } } diff --git a/object/model.go b/object/model.go index 43e86680..4484b33f 100644 --- a/object/model.go +++ b/object/model.go @@ -32,56 +32,51 @@ type Model struct { IsEnabled bool `json:"isEnabled"` } -func GetModelCount(owner, field, value string) int { +func GetModelCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Model{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Model{}) } -func GetModels(owner string) []*Model { +func GetModels(owner string) ([]*Model, error) { models := []*Model{} err := adapter.Engine.Desc("created_time").Find(&models, &Model{Owner: owner}) if err != nil { - panic(err) + return models, err } - return models + return models, nil } -func GetPaginationModels(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Model { +func GetPaginationModels(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Model, error) { models := []*Model{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&models) if err != nil { - panic(err) + return models, err } - return models + return models, nil } -func getModel(owner string, name string) *Model { +func getModel(owner string, name string) (*Model, error) { if owner == "" || name == "" { - return nil + return nil, nil } m := Model{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&m) if err != nil { - panic(err) + return &m, err } if existed { - return &m + return &m, nil } else { - return nil + return nil, nil } } -func GetModel(id string) *Model { +func GetModel(id string) (*Model, error) { owner, name := util.GetOwnerAndNameFromId(id) return getModel(owner, name) } @@ -92,48 +87,53 @@ func UpdateModelWithCheck(id string, modelObj *Model) error { if err != nil { return err } - UpdateModel(id, modelObj) + _, err = UpdateModel(id, modelObj) + if err != nil { + return err + } return nil } -func UpdateModel(id string, modelObj *Model) bool { +func UpdateModel(id string, modelObj *Model) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getModel(owner, name) == nil { - return false + if m, err := getModel(owner, name); err != nil { + return false, err + } else if m == nil { + return false, nil } if name != modelObj.Name { err := modelChangeTrigger(name, modelObj.Name) if err != nil { - return false + return false, err } } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(modelObj) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, err } -func AddModel(model *Model) bool { +func AddModel(model *Model) (bool, error) { affected, err := adapter.Engine.Insert(model) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteModel(model *Model) bool { +func DeleteModel(model *Model) (bool, error) { affected, err := adapter.Engine.ID(core.PK{model.Owner, model.Name}).Delete(&Model{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (model *Model) GetId() string { diff --git a/object/oidc_discovery.go b/object/oidc_discovery.go index d99986c8..cdf4b6ce 100644 --- a/object/oidc_discovery.go +++ b/object/oidc_discovery.go @@ -110,8 +110,12 @@ func GetOidcDiscovery(host string) OidcDiscovery { } func GetJsonWebKeySet() (jose.JSONWebKeySet, error) { - certs := GetCerts("admin") jwks := jose.JSONWebKeySet{} + certs, err := GetCerts("admin") + if err != nil { + return jwks, err + } + // follows the protocol rfc 7517(draft) // link here: https://self-issued.info/docs/draft-ietf-jose-json-web-key.html // or https://datatracker.ietf.org/doc/html/draft-ietf-jose-json-web-key diff --git a/object/organization.go b/object/organization.go index e5206b44..fed8d584 100644 --- a/object/organization.go +++ b/object/organization.go @@ -70,92 +70,102 @@ type Organization struct { AccountItems []*AccountItem `xorm:"varchar(3000)" json:"accountItems"` } -func GetOrganizationCount(owner, field, value string) int { +func GetOrganizationCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Organization{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Organization{}) } -func GetOrganizations(owner string) []*Organization { +func GetOrganizations(owner string) ([]*Organization, error) { organizations := []*Organization{} err := adapter.Engine.Desc("created_time").Find(&organizations, &Organization{Owner: owner}) if err != nil { - panic(err) + return nil, err } - return organizations + return organizations, nil } -func GetOrganizationsByFields(owner string, fields ...string) []*Organization { +func GetOrganizationsByFields(owner string, fields ...string) ([]*Organization, error) { organizations := []*Organization{} err := adapter.Engine.Desc("created_time").Cols(fields...).Find(&organizations, &Organization{Owner: owner}) if err != nil { - panic(err) + return nil, err } - return organizations + return organizations, nil } -func GetPaginationOrganizations(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Organization { +func GetPaginationOrganizations(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Organization, error) { organizations := []*Organization{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&organizations) if err != nil { - panic(err) + return nil, err } - return organizations + return organizations, nil } -func getOrganization(owner string, name string) *Organization { +func getOrganization(owner string, name string) (*Organization, error) { if owner == "" || name == "" { - return nil + return nil, nil } organization := Organization{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&organization) if err != nil { - panic(err) + return nil, err } if existed { - return &organization + return &organization, nil } - return nil + return nil, nil } -func GetOrganization(id string) *Organization { +func GetOrganization(id string) (*Organization, error) { owner, name := util.GetOwnerAndNameFromId(id) return getOrganization(owner, name) } -func GetMaskedOrganization(organization *Organization) *Organization { +func GetMaskedOrganization(organization *Organization, errs ...error) (*Organization, error) { + if len(errs) > 0 && errs[0] != nil { + return nil, errs[0] + } + if organization == nil { - return nil + return nil, nil } if organization.MasterPassword != "" { organization.MasterPassword = "***" } - return organization + return organization, nil } -func GetMaskedOrganizations(organizations []*Organization) []*Organization { - for _, organization := range organizations { - organization = GetMaskedOrganization(organization) +func GetMaskedOrganizations(organizations []*Organization, errs ...error) ([]*Organization, error) { + if len(errs) > 0 && errs[0] != nil { + return nil, errs[0] } - return organizations + + var err error + for _, organization := range organizations { + organization, err = GetMaskedOrganization(organization) + if err != nil { + return nil, err + } + } + + return organizations, nil } -func UpdateOrganization(id string, organization *Organization) bool { +func UpdateOrganization(id string, organization *Organization) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getOrganization(owner, name) == nil { - return false + if org, err := getOrganization(owner, name); err != nil { + return false, err + } else if org == nil { + return false, nil } if name == "built-in" { @@ -165,7 +175,7 @@ func UpdateOrganization(id string, organization *Organization) bool { if name != organization.Name { err := organizationChangeTrigger(name, organization.Name) if err != nil { - return false + return false, nil } } @@ -183,35 +193,35 @@ func UpdateOrganization(id string, organization *Organization) bool { } affected, err := session.Update(organization) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddOrganization(organization *Organization) bool { +func AddOrganization(organization *Organization) (bool, error) { affected, err := adapter.Engine.Insert(organization) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteOrganization(organization *Organization) bool { +func DeleteOrganization(organization *Organization) (bool, error) { if organization.Name == "built-in" { - return false + return false, nil } affected, err := adapter.Engine.ID(core.PK{organization.Owner, organization.Name}).Delete(&Organization{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func GetOrganizationByUser(user *User) *Organization { +func GetOrganizationByUser(user *User) (*Organization, error) { return getOrganization("admin", user.Owner) } @@ -248,13 +258,21 @@ func CheckAccountItemModifyRule(accountItem *AccountItem, isAdmin bool, lang str } func GetDefaultApplication(id string) (*Application, error) { - organization := GetOrganization(id) + organization, err := GetOrganization(id) + if err != nil { + return nil, err + } + if organization == nil { return nil, fmt.Errorf("The organization: %s does not exist", id) } if organization.DefaultApplication != "" { - defaultApplication := getApplication("admin", organization.DefaultApplication) + defaultApplication, err := getApplication("admin", organization.DefaultApplication) + if err != nil { + return nil, err + } + if defaultApplication == nil { return nil, fmt.Errorf("The default application: %s does not exist", organization.DefaultApplication) } else { @@ -263,9 +281,9 @@ func GetDefaultApplication(id string) (*Application, error) { } applications := []*Application{} - err := adapter.Engine.Asc("created_time").Find(&applications, &Application{Organization: organization.Name}) + err = adapter.Engine.Asc("created_time").Find(&applications, &Application{Organization: organization.Name}) if err != nil { - panic(err) + return nil, err } if len(applications) == 0 { @@ -280,8 +298,15 @@ func GetDefaultApplication(id string) (*Application, error) { } } - extendApplicationWithProviders(defaultApplication) - extendApplicationWithOrg(defaultApplication) + err = extendApplicationWithProviders(defaultApplication) + if err != nil { + return nil, err + } + + err = extendApplicationWithOrg(defaultApplication) + if err != nil { + return nil, err + } return defaultApplication, nil } diff --git a/object/payment.go b/object/payment.go index 4f3f7d78..34798f06 100644 --- a/object/payment.go +++ b/object/payment.go @@ -56,74 +56,71 @@ type Payment struct { InvoiceUrl string `xorm:"varchar(255)" json:"invoiceUrl"` } -func GetPaymentCount(owner, field, value string) int { +func GetPaymentCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Payment{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Payment{}) } -func GetPayments(owner string) []*Payment { +func GetPayments(owner string) ([]*Payment, error) { payments := []*Payment{} err := adapter.Engine.Desc("created_time").Find(&payments, &Payment{Owner: owner}) if err != nil { - panic(err) + return nil, err } - return payments + return payments, nil } -func GetUserPayments(owner string, organization string, user string) []*Payment { +func GetUserPayments(owner string, organization string, user string) ([]*Payment, error) { payments := []*Payment{} err := adapter.Engine.Desc("created_time").Find(&payments, &Payment{Owner: owner, Organization: organization, User: user}) if err != nil { - panic(err) + return nil, err } - return payments + return payments, nil } -func GetPaginationPayments(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Payment { +func GetPaginationPayments(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Payment, error) { payments := []*Payment{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&payments) if err != nil { - panic(err) + return nil, err } - return payments + return payments, nil } -func getPayment(owner string, name string) *Payment { +func getPayment(owner string, name string) (*Payment, error) { if owner == "" || name == "" { - return nil + return nil, nil } payment := Payment{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&payment) if err != nil { - panic(err) + return nil, err } if existed { - return &payment + return &payment, nil } else { - return nil + return nil, nil } } -func GetPayment(id string) *Payment { +func GetPayment(id string) (*Payment, error) { owner, name := util.GetOwnerAndNameFromId(id) return getPayment(owner, name) } -func UpdatePayment(id string, payment *Payment) bool { +func UpdatePayment(id string, payment *Payment) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getPayment(owner, name) == nil { - return false + if p, err := getPayment(owner, name); err != nil { + return false, err + } else if p == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(payment) @@ -131,42 +128,53 @@ func UpdatePayment(id string, payment *Payment) bool { panic(err) } - return affected != 0 + return affected != 0, nil } -func AddPayment(payment *Payment) bool { +func AddPayment(payment *Payment) (bool, error) { affected, err := adapter.Engine.Insert(payment) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeletePayment(payment *Payment) bool { +func DeletePayment(payment *Payment) (bool, error) { affected, err := adapter.Engine.ID(core.PK{payment.Owner, payment.Name}).Delete(&Payment{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func notifyPayment(request *http.Request, body []byte, owner string, providerName string, productName string, paymentName string) (*Payment, error, string) { - provider := getProvider(owner, providerName) + provider, err := getProvider(owner, providerName) + if err != nil { + panic(err) + } pProvider, cert, err := provider.getPaymentProvider() if err != nil { panic(err) } - payment := getPayment(owner, paymentName) + payment, err := getPayment(owner, paymentName) + if err != nil { + panic(err) + } + if payment == nil { err = fmt.Errorf("the payment: %s does not exist", paymentName) return nil, err, pProvider.GetResponseError(err) } - product := getProduct(owner, productName) + product, err := getProduct(owner, productName) + if err != nil { + panic(err) + } + if product == nil { err = fmt.Errorf("the product: %s does not exist", productName) return payment, err, pProvider.GetResponseError(err) @@ -201,14 +209,21 @@ func NotifyPayment(request *http.Request, body []byte, owner string, providerNam payment.State = "Paid" } - UpdatePayment(payment.GetId(), payment) + _, err = UpdatePayment(payment.GetId(), payment) + if err != nil { + panic(err) + } } return err, errorResponse } func invoicePayment(payment *Payment) (string, error) { - provider := getProvider(payment.Owner, payment.Provider) + provider, err := getProvider(payment.Owner, payment.Provider) + if err != nil { + panic(err) + } + if provider == nil { return "", fmt.Errorf("the payment provider: %s does not exist", payment.Provider) } @@ -237,7 +252,11 @@ func InvoicePayment(payment *Payment) (string, error) { } payment.InvoiceUrl = invoiceUrl - affected := UpdatePayment(payment.GetId(), payment) + affected, err := UpdatePayment(payment.GetId(), payment) + if err != nil { + return "", err + } + if !affected { return "", fmt.Errorf("failed to update the payment: %s", payment.Name) } diff --git a/object/permission.go b/object/permission.go index 006f9d55..8cfd6ea8 100644 --- a/object/permission.go +++ b/object/permission.go @@ -66,96 +66,97 @@ func (p *Permission) GetId() string { return util.GetId(p.Owner, p.Name) } -func GetPermissionCount(owner, field, value string) int { +func GetPermissionCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Permission{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Permission{}) } -func GetPermissions(owner string) []*Permission { +func GetPermissions(owner string) ([]*Permission, error) { permissions := []*Permission{} err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner}) if err != nil { - panic(err) + return permissions, err } - return permissions + return permissions, nil } -func GetPaginationPermissions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Permission { +func GetPaginationPermissions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Permission, error) { permissions := []*Permission{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&permissions) if err != nil { - panic(err) + return permissions, err } - return permissions + return permissions, nil } -func getPermission(owner string, name string) *Permission { +func getPermission(owner string, name string) (*Permission, error) { if owner == "" || name == "" { - return nil + return nil, nil } permission := Permission{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&permission) if err != nil { - panic(err) + return &permission, err } if existed { - return &permission + return &permission, nil } else { - return nil + return nil, nil } } -func GetPermission(id string) *Permission { +func GetPermission(id string) (*Permission, error) { owner, name := util.GetOwnerAndNameFromId(id) return getPermission(owner, name) } // checkPermissionValid verifies if the permission is valid -func checkPermissionValid(permission *Permission) { +func checkPermissionValid(permission *Permission) error { enforcer := getEnforcer(permission) enforcer.EnableAutoSave(false) policies := getPolicies(permission) _, err := enforcer.AddPolicies(policies) if err != nil { - panic(err) + return err } if !HasRoleDefinition(enforcer.GetModel()) { permission.Roles = []string{} - return + return nil } groupingPolicies := getGroupingPolicies(permission) if len(groupingPolicies) > 0 { _, err := enforcer.AddGroupingPolicies(groupingPolicies) if err != nil { - panic(err) + return err } } + + return nil } -func UpdatePermission(id string, permission *Permission) bool { - checkPermissionValid(permission) +func UpdatePermission(id string, permission *Permission) (bool, error) { + err := checkPermissionValid(permission) + if err != nil { + return false, err + } + owner, name := util.GetOwnerAndNameFromId(id) - oldPermission := getPermission(owner, name) + oldPermission, err := getPermission(owner, name) if oldPermission == nil { - return false + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(permission) if err != nil { - panic(err) + return false, err } if affected != 0 { @@ -166,7 +167,7 @@ func UpdatePermission(id string, permission *Permission) bool { if isEmpty { err = adapter.Engine.DropTables(oldPermission.Adapter) if err != nil { - panic(err) + return false, err } } } @@ -174,13 +175,13 @@ func UpdatePermission(id string, permission *Permission) bool { addPolicies(permission) } - return affected != 0 + return affected != 0, nil } -func AddPermission(permission *Permission) bool { +func AddPermission(permission *Permission) (bool, error) { affected, err := adapter.Engine.Insert(permission) if err != nil { - panic(err) + return false, err } if affected != 0 { @@ -188,7 +189,7 @@ func AddPermission(permission *Permission) bool { addPolicies(permission) } - return affected != 0 + return affected != 0, nil } func AddPermissions(permissions []*Permission) bool { @@ -239,10 +240,10 @@ func AddPermissionsInBatch(permissions []*Permission) bool { return affected } -func DeletePermission(permission *Permission) bool { +func DeletePermission(permission *Permission) (bool, error) { affected, err := adapter.Engine.ID(core.PK{permission.Owner, permission.Name}).Delete(&Permission{}) if err != nil { - panic(err) + return false, err } if affected != 0 { @@ -253,67 +254,67 @@ func DeletePermission(permission *Permission) bool { if isEmpty { err = adapter.Engine.DropTables(permission.Adapter) if err != nil { - panic(err) + return false, err } } } } - return affected != 0 + return affected != 0, nil } -func GetPermissionsByUser(userId string) []*Permission { +func GetPermissionsByUser(userId string) ([]*Permission, error) { permissions := []*Permission{} err := adapter.Engine.Where("users like ?", "%"+userId+"\"%").Find(&permissions) if err != nil { - panic(err) + return permissions, err } for i := range permissions { permissions[i].Users = nil } - return permissions + return permissions, nil } -func GetPermissionsByRole(roleId string) []*Permission { +func GetPermissionsByRole(roleId string) ([]*Permission, error) { permissions := []*Permission{} err := adapter.Engine.Where("roles like ?", "%"+roleId+"\"%").Find(&permissions) if err != nil { - panic(err) + return permissions, err } - return permissions + return permissions, nil } -func GetPermissionsByResource(resourceId string) []*Permission { +func GetPermissionsByResource(resourceId string) ([]*Permission, error) { permissions := []*Permission{} err := adapter.Engine.Where("resources like ?", "%"+resourceId+"\"%").Find(&permissions) if err != nil { - panic(err) + return permissions, err } - return permissions + return permissions, nil } -func GetPermissionsBySubmitter(owner string, submitter string) []*Permission { +func GetPermissionsBySubmitter(owner string, submitter string) ([]*Permission, error) { permissions := []*Permission{} err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner, Submitter: submitter}) if err != nil { - panic(err) + return permissions, err } - return permissions + return permissions, nil } -func GetPermissionsByModel(owner string, model string) []*Permission { +func GetPermissionsByModel(owner string, model string) ([]*Permission, error) { permissions := []*Permission{} err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner, Model: model}) if err != nil { - panic(err) + return permissions, err } - return permissions + return permissions, nil } func ContainsAsterisk(userId string, users []string) bool { diff --git a/object/permission_enforcer.go b/object/permission_enforcer.go index 5a53afb2..7f7e5af5 100644 --- a/object/permission_enforcer.go +++ b/object/permission_enforcer.go @@ -29,7 +29,11 @@ import ( func getEnforcer(permission *Permission) *casbin.Enforcer { tableName := "permission_rule" if len(permission.Adapter) != 0 { - adapterObj := getCasbinAdapter(permission.Owner, permission.Adapter) + adapterObj, err := getCasbinAdapter(permission.Owner, permission.Adapter) + if err != nil { + panic(err) + } + if adapterObj != nil && adapterObj.Table != "" { tableName = adapterObj.Table } @@ -42,7 +46,11 @@ func getEnforcer(permission *Permission) *casbin.Enforcer { panic(err) } - permissionModel := getModel(permission.Owner, permission.Model) + permissionModel, err := getModel(permission.Owner, permission.Model) + if err != nil { + panic(err) + } + m := model.Model{} if permissionModel != nil { m, err = GetBuiltInModel(permissionModel.ModelText) @@ -122,21 +130,30 @@ func getPolicies(permission *Permission) [][]string { return policies } -func getRolesInRole(roleId string, visited map[string]struct{}) []*Role { - role := GetRole(roleId) +func getRolesInRole(roleId string, visited map[string]struct{}) ([]*Role, error) { + role, err := GetRole(roleId) + if err != nil { + return []*Role{}, err + } + if role == nil { - return []*Role{} + return []*Role{}, nil } visited[roleId] = struct{}{} roles := []*Role{role} for _, subRole := range role.Roles { if _, ok := visited[subRole]; !ok { - roles = append(roles, getRolesInRole(subRole, visited)...) + r, err := getRolesInRole(subRole, visited) + if err != nil { + return []*Role{}, err + } + + roles = append(roles, r...) } } - return roles + return roles, nil } func getGroupingPolicies(permission *Permission) [][]string { @@ -147,8 +164,10 @@ func getGroupingPolicies(permission *Permission) [][]string { for _, roleId := range permission.Roles { visited := map[string]struct{}{} - rolesInRole := getRolesInRole(roleId, visited) - + rolesInRole, err := getRolesInRole(roleId, visited) + if err != nil { + panic(err) + } for _, role := range rolesInRole { roleId := role.GetId() for _, subUser := range role.Users { @@ -223,7 +242,11 @@ func removePolicies(permission *Permission) { type CasbinRequest = []interface{} func Enforce(permissionId string, request *CasbinRequest) bool { - permission := GetPermission(permissionId) + permission, err := GetPermission(permissionId) + if err != nil { + panic(err) + } + enforcer := getEnforcer(permission) allow, err := enforcer.Enforce(*request...) @@ -234,7 +257,11 @@ func Enforce(permissionId string, request *CasbinRequest) bool { } func BatchEnforce(permissionId string, requests *[]CasbinRequest) []bool { - permission := GetPermission(permissionId) + permission, err := GetPermission(permissionId) + if err != nil { + panic(err) + } + enforcer := getEnforcer(permission) allow, err := enforcer.BatchEnforce(*requests) if err != nil { @@ -244,9 +271,18 @@ func BatchEnforce(permissionId string, requests *[]CasbinRequest) []bool { } func getAllValues(userId string, fn func(enforcer *casbin.Enforcer) []string) []string { - permissions := GetPermissionsByUser(userId) + permissions, err := GetPermissionsByUser(userId) + if err != nil { + panic(err) + } + for _, role := range GetAllRoles(userId) { - permissions = append(permissions, GetPermissionsByRole(role)...) + permissionsByRole, err := GetPermissionsByRole(role) + if err != nil { + panic(err) + } + + permissions = append(permissions, permissionsByRole...) } var values []string @@ -270,7 +306,11 @@ func GetAllActions(userId string) []string { } func GetAllRoles(userId string) []string { - roles := GetRolesByUser(userId) + roles, err := GetRolesByUser(userId) + if err != nil { + panic(err) + } + var res []string for _, role := range roles { res = append(res, role.Name) diff --git a/object/permission_upload.go b/object/permission_upload.go index 15e94801..c3533a9d 100644 --- a/object/permission_upload.go +++ b/object/permission_upload.go @@ -18,21 +18,29 @@ import ( "github.com/casdoor/casdoor/xlsx" ) -func getPermissionMap(owner string) map[string]*Permission { +func getPermissionMap(owner string) (map[string]*Permission, error) { m := map[string]*Permission{} - permissions := GetPermissions(owner) + permissions, err := GetPermissions(owner) + if err != nil { + return nil, err + } + for _, permission := range permissions { m[permission.GetId()] = permission } - return m + return m, err } -func UploadPermissions(owner string, fileId string) bool { +func UploadPermissions(owner string, fileId string) (bool, error) { table := xlsx.ReadXlsxFile(fileId) - oldUserMap := getPermissionMap(owner) + oldUserMap, err := getPermissionMap(owner) + if err != nil { + return false, err + } + newPermissions := []*Permission{} for index, line := range table { if index == 0 || parseLineItem(&line, 0) == "" { @@ -71,7 +79,7 @@ func UploadPermissions(owner string, fileId string) bool { } if len(newPermissions) == 0 { - return false + return false, nil } - return AddPermissionsInBatch(newPermissions) + return AddPermissionsInBatch(newPermissions), nil } diff --git a/object/plan.go b/object/plan.go index 1c86d1a6..48d482cf 100644 --- a/object/plan.go +++ b/object/plan.go @@ -37,109 +37,115 @@ type Plan struct { Options []string `xorm:"-" json:"options"` } -func GetPlanCount(owner, field, value string) int { +func GetPlanCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Plan{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Plan{}) } -func GetPlans(owner string) []*Plan { +func GetPlans(owner string) ([]*Plan, error) { plans := []*Plan{} err := adapter.Engine.Desc("created_time").Find(&plans, &Plan{Owner: owner}) if err != nil { - panic(err) + return plans, err } - return plans + return plans, nil } -func GetPaginatedPlans(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Plan { +func GetPaginatedPlans(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Plan, error) { plans := []*Plan{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&plans) if err != nil { - panic(err) + return plans, err } - return plans + return plans, nil } -func getPlan(owner, name string) *Plan { +func getPlan(owner, name string) (*Plan, error) { if owner == "" || name == "" { - return nil + return nil, nil } plan := Plan{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&plan) if err != nil { - panic(err) + return &plan, err } if existed { - return &plan + return &plan, nil } else { - return nil + return nil, nil } } -func GetPlan(id string) *Plan { +func GetPlan(id string) (*Plan, error) { owner, name := util.GetOwnerAndNameFromId(id) return getPlan(owner, name) } -func UpdatePlan(id string, plan *Plan) bool { +func UpdatePlan(id string, plan *Plan) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getPlan(owner, name) == nil { - return false + if p, err := getPlan(owner, name); err != nil { + return false, err + } else if p == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(plan) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddPlan(plan *Plan) bool { +func AddPlan(plan *Plan) (bool, error) { affected, err := adapter.Engine.Insert(plan) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeletePlan(plan *Plan) bool { +func DeletePlan(plan *Plan) (bool, error) { affected, err := adapter.Engine.ID(core.PK{plan.Owner, plan.Name}).Delete(plan) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (plan *Plan) GetId() string { return fmt.Sprintf("%s/%s", plan.Owner, plan.Name) } -func Subscribe(owner string, user string, plan string, pricing string) *Subscription { - selectedPricing := GetPricing(fmt.Sprintf("%s/%s", owner, pricing)) +func Subscribe(owner string, user string, plan string, pricing string) (*Subscription, error) { + selectedPricing, err := GetPricing(fmt.Sprintf("%s/%s", owner, pricing)) + if err != nil { + return nil, err + } valid := selectedPricing != nil && selectedPricing.IsEnabled if !valid { - return nil + return nil, nil } - planBelongToPricing := selectedPricing.HasPlan(owner, plan) + planBelongToPricing, err := selectedPricing.HasPlan(owner, plan) + if err != nil { + return nil, err + } if planBelongToPricing { newSubscription := NewSubscription(owner, user, plan, selectedPricing.TrialDuration) - affected := AddSubscription(newSubscription) + affected, err := AddSubscription(newSubscription) + if err != nil { + return nil, err + } if affected { - return newSubscription + return newSubscription, nil } } - return nil + return nil, nil } diff --git a/object/pricing.go b/object/pricing.go index 24b20dec..f45667aa 100644 --- a/object/pricing.go +++ b/object/pricing.go @@ -42,96 +42,97 @@ type Pricing struct { State string `xorm:"varchar(100)" json:"state"` } -func GetPricingCount(owner, field, value string) int { +func GetPricingCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Pricing{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Pricing{}) } -func GetPricings(owner string) []*Pricing { +func GetPricings(owner string) ([]*Pricing, error) { pricings := []*Pricing{} err := adapter.Engine.Desc("created_time").Find(&pricings, &Pricing{Owner: owner}) if err != nil { - panic(err) + return pricings, err } - return pricings + + return pricings, nil } -func GetPaginatedPricings(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Pricing { +func GetPaginatedPricings(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Pricing, error) { pricings := []*Pricing{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&pricings) if err != nil { - panic(err) + return pricings, err } - return pricings + return pricings, nil } -func getPricing(owner, name string) *Pricing { +func getPricing(owner, name string) (*Pricing, error) { if owner == "" || name == "" { - return nil + return nil, nil } pricing := Pricing{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&pricing) if err != nil { - panic(err) + return &pricing, err } if existed { - return &pricing + return &pricing, nil } else { - return nil + return nil, nil } } -func GetPricing(id string) *Pricing { +func GetPricing(id string) (*Pricing, error) { owner, name := util.GetOwnerAndNameFromId(id) return getPricing(owner, name) } -func UpdatePricing(id string, pricing *Pricing) bool { +func UpdatePricing(id string, pricing *Pricing) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getPricing(owner, name) == nil { - return false + if p, err := getPricing(owner, name); err != nil { + return false, err + } else if p == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(pricing) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddPricing(pricing *Pricing) bool { +func AddPricing(pricing *Pricing) (bool, error) { affected, err := adapter.Engine.Insert(pricing) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeletePricing(pricing *Pricing) bool { +func DeletePricing(pricing *Pricing) (bool, error) { affected, err := adapter.Engine.ID(core.PK{pricing.Owner, pricing.Name}).Delete(pricing) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (pricing *Pricing) GetId() string { return fmt.Sprintf("%s/%s", pricing.Owner, pricing.Name) } -func (pricing *Pricing) HasPlan(owner string, plan string) bool { - selectedPlan := GetPlan(fmt.Sprintf("%s/%s", owner, plan)) +func (pricing *Pricing) HasPlan(owner string, plan string) (bool, error) { + selectedPlan, err := GetPlan(fmt.Sprintf("%s/%s", owner, plan)) + if err != nil { + return false, err + } if selectedPlan == nil { - return false + return false, nil } result := false @@ -143,5 +144,5 @@ func (pricing *Pricing) HasPlan(owner string, plan string) bool { } } - return result + return result, nil } diff --git a/object/product.go b/object/product.go index d1034adc..564b7bdb 100644 --- a/object/product.go +++ b/object/product.go @@ -43,90 +43,87 @@ type Product struct { ProviderObjs []*Provider `xorm:"-" json:"providerObjs"` } -func GetProductCount(owner, field, value string) int { +func GetProductCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Product{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Product{}) } -func GetProducts(owner string) []*Product { +func GetProducts(owner string) ([]*Product, error) { products := []*Product{} err := adapter.Engine.Desc("created_time").Find(&products, &Product{Owner: owner}) if err != nil { - panic(err) + return products, err } - return products + return products, nil } -func GetPaginationProducts(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Product { +func GetPaginationProducts(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Product, error) { products := []*Product{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&products) if err != nil { - panic(err) + return products, err } - return products + return products, nil } -func getProduct(owner string, name string) *Product { +func getProduct(owner string, name string) (*Product, error) { if owner == "" || name == "" { - return nil + return nil, nil } product := Product{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&product) if err != nil { - panic(err) + return &product, nil } if existed { - return &product + return &product, nil } else { - return nil + return nil, nil } } -func GetProduct(id string) *Product { +func GetProduct(id string) (*Product, error) { owner, name := util.GetOwnerAndNameFromId(id) return getProduct(owner, name) } -func UpdateProduct(id string, product *Product) bool { +func UpdateProduct(id string, product *Product) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getProduct(owner, name) == nil { - return false + if p, err := getProduct(owner, name); err != nil { + return false, err + } else if p == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(product) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddProduct(product *Product) bool { +func AddProduct(product *Product) (bool, error) { affected, err := adapter.Engine.Insert(product) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteProduct(product *Product) bool { +func DeleteProduct(product *Product) (bool, error) { affected, err := adapter.Engine.ID(core.PK{product.Owner, product.Name}).Delete(&Product{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (product *Product) GetId() string { @@ -143,7 +140,11 @@ func (product *Product) isValidProvider(provider *Provider) bool { } func (product *Product) getProvider(providerId string) (*Provider, error) { - provider := getProvider(product.Owner, providerId) + provider, err := getProvider(product.Owner, providerId) + if err != nil { + return nil, err + } + if provider == nil { return nil, fmt.Errorf("the payment provider: %s does not exist", providerId) } @@ -156,7 +157,11 @@ func (product *Product) getProvider(providerId string) (*Provider, error) { } func BuyProduct(id string, providerName string, user *User, host string) (string, error) { - product := GetProduct(id) + product, err := GetProduct(id) + if err != nil { + return "", err + } + if product == nil { return "", fmt.Errorf("the product: %s does not exist", id) } @@ -205,7 +210,11 @@ func BuyProduct(id string, providerName string, user *User, host string) (string ReturnUrl: product.ReturnUrl, State: "Created", } - affected := AddPayment(&payment) + affected, err := AddPayment(&payment) + if err != nil { + return "", err + } + if !affected { return "", fmt.Errorf("failed to add payment: %s", util.StructToJson(payment)) } @@ -213,17 +222,23 @@ func BuyProduct(id string, providerName string, user *User, host string) (string return payUrl, err } -func ExtendProductWithProviders(product *Product) { +func ExtendProductWithProviders(product *Product) error { if product == nil { - return + return nil } product.ProviderObjs = []*Provider{} - m := getProviderMap(product.Owner) + m, err := getProviderMap(product.Owner) + if err != nil { + return err + } + for _, providerItem := range product.Providers { if provider, ok := m[providerItem]; ok { product.ProviderObjs = append(product.ProviderObjs, provider) } } + + return nil } diff --git a/object/product_test.go b/object/product_test.go index 682777ca..43ecb9ed 100644 --- a/object/product_test.go +++ b/object/product_test.go @@ -27,9 +27,9 @@ import ( func TestProduct(t *testing.T) { InitConfig() - product := GetProduct("admin/product_123") - provider := getProvider(product.Owner, "provider_pay_alipay") - cert := getCert(product.Owner, "cert-pay-alipay") + product, _ := GetProduct("admin/product_123") + provider, _ := getProvider(product.Owner, "provider_pay_alipay") + cert, _ := getCert(product.Owner, "cert-pay-alipay") pProvider, err := pp.GetPaymentProvider(provider.Type, provider.ClientId, provider.ClientSecret, provider.Host, cert.Certificate, cert.PrivateKey, cert.AuthorityPublicKey, cert.AuthorityRootPublicKey, provider.ClientId2) if err != nil { panic(err) diff --git a/object/provider.go b/object/provider.go index 9382687d..449822ea 100644 --- a/object/provider.go +++ b/object/provider.go @@ -103,103 +103,93 @@ func GetMaskedProviders(providers []*Provider, isMaskEnabled bool) []*Provider { return providers } -func GetProviderCount(owner, field, value string) int { +func GetProviderCount(owner, field, value string) (int64, error) { session := GetSession("", -1, -1, field, value, "", "") - count, err := session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Provider{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Provider{}) } -func GetGlobalProviderCount(field, value string) int { +func GetGlobalProviderCount(field, value string) (int64, error) { session := GetSession("", -1, -1, field, value, "", "") - count, err := session.Count(&Provider{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Provider{}) } -func GetProviders(owner string) []*Provider { +func GetProviders(owner string) ([]*Provider, error) { providers := []*Provider{} err := adapter.Engine.Where("owner = ? or owner = ? ", "admin", owner).Desc("created_time").Find(&providers, &Provider{}) if err != nil { - panic(err) + return providers, err } - return providers + return providers, nil } -func GetGlobalProviders() []*Provider { +func GetGlobalProviders() ([]*Provider, error) { providers := []*Provider{} err := adapter.Engine.Desc("created_time").Find(&providers) if err != nil { - panic(err) + return providers, err } - return providers + return providers, nil } -func GetPaginationProviders(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Provider { +func GetPaginationProviders(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Provider, error) { providers := []*Provider{} session := GetSession("", offset, limit, field, value, sortField, sortOrder) err := session.Where("owner = ? or owner = ? ", "admin", owner).Find(&providers) if err != nil { - panic(err) + return providers, err } - return providers + return providers, nil } -func GetPaginationGlobalProviders(offset, limit int, field, value, sortField, sortOrder string) []*Provider { +func GetPaginationGlobalProviders(offset, limit int, field, value, sortField, sortOrder string) ([]*Provider, error) { providers := []*Provider{} session := GetSession("", offset, limit, field, value, sortField, sortOrder) err := session.Find(&providers) if err != nil { - panic(err) + return providers, err } - return providers + return providers, nil } -func getProvider(owner string, name string) *Provider { +func getProvider(owner string, name string) (*Provider, error) { if owner == "" || name == "" { - return nil + return nil, nil } provider := Provider{Name: name} existed, err := adapter.Engine.Get(&provider) if err != nil { - panic(err) + return &provider, err } if existed { - return &provider + return &provider, nil } else { - return nil + return nil, nil } } -func GetProvider(id string) *Provider { +func GetProvider(id string) (*Provider, error) { owner, name := util.GetOwnerAndNameFromId(id) return getProvider(owner, name) } -func getDefaultAiProvider() *Provider { +func getDefaultAiProvider() (*Provider, error) { provider := Provider{Owner: "admin", Category: "AI"} existed, err := adapter.Engine.Get(&provider) if err != nil { - panic(err) + return &provider, err } if !existed { - return nil + return nil, nil } - return &provider + return &provider, nil } func GetWechatMiniProgramProvider(application *Application) *Provider { @@ -212,16 +202,18 @@ func GetWechatMiniProgramProvider(application *Application) *Provider { return nil } -func UpdateProvider(id string, provider *Provider) bool { +func UpdateProvider(id string, provider *Provider) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getProvider(owner, name) == nil { - return false + if p, err := getProvider(owner, name); err != nil { + return false, err + } else if p == nil { + return false, nil } if name != provider.Name { err := providerChangeTrigger(name, provider.Name) if err != nil { - return false + return false, nil } } @@ -238,37 +230,41 @@ func UpdateProvider(id string, provider *Provider) bool { affected, err := session.Update(provider) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddProvider(provider *Provider) bool { +func AddProvider(provider *Provider) (bool, error) { provider.Endpoint = util.GetEndPoint(provider.Endpoint) provider.IntranetEndpoint = util.GetEndPoint(provider.IntranetEndpoint) affected, err := adapter.Engine.Insert(provider) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteProvider(provider *Provider) bool { +func DeleteProvider(provider *Provider) (bool, error) { affected, err := adapter.Engine.ID(core.PK{provider.Owner, provider.Name}).Delete(&Provider{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (p *Provider) getPaymentProvider() (pp.PaymentProvider, *Cert, error) { cert := &Cert{} if p.Cert != "" { - cert = getCert(p.Owner, p.Cert) + cert, err := getCert(p.Owner, p.Cert) + if err != nil { + return nil, nil, err + } + if cert == nil { return nil, nil, fmt.Errorf("the cert: %s does not exist", p.Cert) } @@ -309,7 +305,11 @@ func GetCaptchaProviderByApplication(applicationId, isCurrentProvider, lang stri if isCurrentProvider == "true" { return GetCaptchaProviderByOwnerName(applicationId, lang) } - application := GetApplication(applicationId) + application, err := GetApplication(applicationId) + if err != nil { + return nil, err + } + if application == nil || len(application.Providers) == 0 { return nil, fmt.Errorf(i18n.Translate(lang, "provider:Invalid application id")) } diff --git a/object/record.go b/object/record.go index e54fcb6d..f2ad432d 100644 --- a/object/record.go +++ b/object/record.go @@ -26,11 +26,7 @@ import ( var logPostOnly bool func init() { - var err error - logPostOnly, err = conf.GetConfigBool("logPostOnly") - if err != nil { - // panic(err) - } + logPostOnly = conf.GetConfigBool("logPostOnly") } type Record struct { @@ -108,49 +104,48 @@ func AddRecord(record *Record) bool { return affected != 0 } -func GetRecordCount(field, value string, filterRecord *Record) int { +func GetRecordCount(field, value string, filterRecord *Record) (int64, error) { session := GetSession("", -1, -1, field, value, "", "") - count, err := session.Count(filterRecord) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(filterRecord) } -func GetRecords() []*Record { +func GetRecords() ([]*Record, error) { records := []*Record{} err := adapter.Engine.Desc("id").Find(&records) if err != nil { - panic(err) + return records, err } - return records + return records, nil } -func GetPaginationRecords(offset, limit int, field, value, sortField, sortOrder string, filterRecord *Record) []*Record { +func GetPaginationRecords(offset, limit int, field, value, sortField, sortOrder string, filterRecord *Record) ([]*Record, error) { records := []*Record{} session := GetSession("", offset, limit, field, value, sortField, sortOrder) err := session.Find(&records, filterRecord) if err != nil { - panic(err) + return records, err } - return records + return records, nil } -func GetRecordsByField(record *Record) []*Record { +func GetRecordsByField(record *Record) ([]*Record, error) { records := []*Record{} err := adapter.Engine.Find(&records, record) if err != nil { - panic(err) + return records, err } - return records + return records, nil } func SendWebhooks(record *Record) error { - webhooks := getWebhooksByOrganization(record.Organization) + webhooks, err := getWebhooksByOrganization(record.Organization) + if err != nil { + return err + } + for _, webhook := range webhooks { if !webhook.IsEnabled { continue @@ -166,7 +161,11 @@ func SendWebhooks(record *Record) error { if matched { if webhook.IsUserExtended { - user := GetMaskedUser(getUser(record.Organization, record.User)) + user, err := GetMaskedUser(getUser(record.Organization, record.User)) + if err != nil { + return err + } + record.ExtendedUser = user } diff --git a/object/resource.go b/object/resource.go index f72dbc10..5cd64913 100644 --- a/object/resource.go +++ b/object/resource.go @@ -39,17 +39,12 @@ type Resource struct { Description string `xorm:"varchar(255)" json:"description"` } -func GetResourceCount(owner, user, field, value string) int { +func GetResourceCount(owner, user, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Resource{User: user}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Resource{User: user}) } -func GetResources(owner string, user string) []*Resource { +func GetResources(owner string, user string) ([]*Resource, error) { if owner == "built-in" { owner = "" user = "" @@ -58,13 +53,13 @@ func GetResources(owner string, user string) []*Resource { resources := []*Resource{} err := adapter.Engine.Desc("created_time").Find(&resources, &Resource{Owner: owner, User: user}) if err != nil { - panic(err) + return resources, err } - return resources + return resources, err } -func GetPaginationResources(owner, user string, offset, limit int, field, value, sortField, sortOrder string) []*Resource { +func GetPaginationResources(owner, user string, offset, limit int, field, value, sortField, sortOrder string) ([]*Resource, error) { if owner == "built-in" { owner = "" user = "" @@ -74,70 +69,74 @@ func GetPaginationResources(owner, user string, offset, limit int, field, value, session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&resources, &Resource{User: user}) if err != nil { - panic(err) + return resources, err } - return resources + return resources, nil } -func getResource(owner string, name string) *Resource { +func getResource(owner string, name string) (*Resource, error) { resource := Resource{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&resource) if err != nil { - panic(err) + return &resource, err } if existed { - return &resource + return &resource, nil } - return nil + return nil, nil } -func GetResource(id string) *Resource { +func GetResource(id string) (*Resource, error) { owner, name := util.GetOwnerAndNameFromIdNoCheck(id) return getResource(owner, name) } -func UpdateResource(id string, resource *Resource) bool { +func UpdateResource(id string, resource *Resource) (bool, error) { owner, name := util.GetOwnerAndNameFromIdNoCheck(id) - if getResource(owner, name) == nil { - return false + if r, err := getResource(owner, name); err != nil { + return false, err + } else if r == nil { + return false, nil } _, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(resource) if err != nil { - panic(err) + return false, err } // return affected != 0 - return true + return true, nil } -func AddResource(resource *Resource) bool { +func AddResource(resource *Resource) (bool, error) { affected, err := adapter.Engine.Insert(resource) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteResource(resource *Resource) bool { +func DeleteResource(resource *Resource) (bool, error) { affected, err := adapter.Engine.ID(core.PK{resource.Owner, resource.Name}).Delete(&Resource{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (resource *Resource) GetId() string { return fmt.Sprintf("%s/%s", resource.Owner, resource.Name) } -func AddOrUpdateResource(resource *Resource) bool { - if getResource(resource.Owner, resource.Name) == nil { +func AddOrUpdateResource(resource *Resource) (bool, error) { + if r, err := getResource(resource.Owner, resource.Name); err != nil { + return false, err + } else if r == nil { return AddResource(resource) } else { return UpdateResource(resource.GetId(), resource) diff --git a/object/role.go b/object/role.go index ec1c9493..8a3276c0 100644 --- a/object/role.go +++ b/object/role.go @@ -36,79 +36,90 @@ type Role struct { IsEnabled bool `json:"isEnabled"` } -func GetRoleCount(owner, field, value string) int { +func GetRoleCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Role{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Role{}) } -func GetRoles(owner string) []*Role { +func GetRoles(owner string) ([]*Role, error) { roles := []*Role{} err := adapter.Engine.Desc("created_time").Find(&roles, &Role{Owner: owner}) if err != nil { - panic(err) + return roles, err } - return roles + return roles, nil } -func GetPaginationRoles(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Role { +func GetPaginationRoles(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Role, error) { roles := []*Role{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&roles) if err != nil { - panic(err) + return roles, err } - return roles + return roles, nil } -func getRole(owner string, name string) *Role { +func getRole(owner string, name string) (*Role, error) { if owner == "" || name == "" { - return nil + return nil, nil } role := Role{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&role) if err != nil { - panic(err) + return &role, err } if existed { - return &role + return &role, nil } else { - return nil + return nil, nil } } -func GetRole(id string) *Role { +func GetRole(id string) (*Role, error) { owner, name := util.GetOwnerAndNameFromId(id) return getRole(owner, name) } -func UpdateRole(id string, role *Role) bool { +func UpdateRole(id string, role *Role) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - oldRole := getRole(owner, name) + oldRole, err := getRole(owner, name) + if err != nil { + return false, err + } + if oldRole == nil { - return false + return false, nil } visited := map[string]struct{}{} - permissions := GetPermissionsByRole(id) + permissions, err := GetPermissionsByRole(id) + if err != nil { + return false, err + } + for _, permission := range permissions { removeGroupingPolicies(permission) removePolicies(permission) visited[permission.GetId()] = struct{}{} } - ancestorRoles := GetAncestorRoles(id) + ancestorRoles, err := GetAncestorRoles(id) + if err != nil { + return false, err + } + for _, r := range ancestorRoles { - permissions := GetPermissionsByRole(r.GetId()) + permissions, err := GetPermissionsByRole(r.GetId()) + if err != nil { + return false, err + } + for _, permission := range permissions { permissionId := permission.GetId() if _, ok := visited[permissionId]; !ok { @@ -121,27 +132,38 @@ func UpdateRole(id string, role *Role) bool { if name != role.Name { err := roleChangeTrigger(name, role.Name) if err != nil { - return false + return false, nil } } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(role) if err != nil { - panic(err) + return false, err } visited = map[string]struct{}{} newRoleID := role.GetId() - permissions = GetPermissionsByRole(newRoleID) + permissions, err = GetPermissionsByRole(newRoleID) + if err != nil { + return false, err + } + for _, permission := range permissions { addGroupingPolicies(permission) addPolicies(permission) visited[permission.GetId()] = struct{}{} } - ancestorRoles = GetAncestorRoles(newRoleID) + ancestorRoles, err = GetAncestorRoles(newRoleID) + if err != nil { + return false, err + } + for _, r := range ancestorRoles { - permissions := GetPermissionsByRole(r.GetId()) + permissions, err := GetPermissionsByRole(r.GetId()) + if err != nil { + return false, err + } for _, permission := range permissions { permissionId := permission.GetId() if _, ok := visited[permissionId]; !ok { @@ -151,16 +173,16 @@ func UpdateRole(id string, role *Role) bool { } } - return affected != 0 + return affected != 0, nil } -func AddRole(role *Role) bool { +func AddRole(role *Role) (bool, error) { affected, err := adapter.Engine.Insert(role) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func AddRoles(roles []*Role) bool { @@ -202,38 +224,45 @@ func AddRolesInBatch(roles []*Role) bool { return affected } -func DeleteRole(role *Role) bool { +func DeleteRole(role *Role) (bool, error) { roleId := role.GetId() - permissions := GetPermissionsByRole(roleId) + permissions, err := GetPermissionsByRole(roleId) + if err != nil { + return false, err + } + for _, permission := range permissions { permission.Roles = util.DeleteVal(permission.Roles, roleId) - UpdatePermission(permission.GetId(), permission) + _, err := UpdatePermission(permission.GetId(), permission) + if err != nil { + return false, err + } } affected, err := adapter.Engine.ID(core.PK{role.Owner, role.Name}).Delete(&Role{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (role *Role) GetId() string { return fmt.Sprintf("%s/%s", role.Owner, role.Name) } -func GetRolesByUser(userId string) []*Role { +func GetRolesByUser(userId string) ([]*Role, error) { roles := []*Role{} err := adapter.Engine.Where("users like ?", "%"+userId+"\"%").Find(&roles) if err != nil { - panic(err) + return roles, err } for i := range roles { roles[i].Users = nil } - return roles + return roles, nil } func roleChangeTrigger(oldName string, newName string) error { @@ -250,6 +279,7 @@ func roleChangeTrigger(oldName string, newName string) error { if err != nil { return err } + for _, role := range roles { for j, u := range role.Roles { owner, name := util.GetOwnerAndNameFromId(u) @@ -268,6 +298,7 @@ func roleChangeTrigger(oldName string, newName string) error { if err != nil { return err } + for _, permission := range permissions { for j, u := range permission.Roles { // u = organization/username @@ -293,17 +324,17 @@ func GetMaskedRoles(roles []*Role) []*Role { return roles } -func GetRolesByNamePrefix(owner string, prefix string) []*Role { +func GetRolesByNamePrefix(owner string, prefix string) ([]*Role, error) { roles := []*Role{} err := adapter.Engine.Where("owner=? and name like ?", owner, prefix+"%").Find(&roles) if err != nil { - panic(err) + return roles, err } - return roles + return roles, nil } -func GetAncestorRoles(roleId string) []*Role { +func GetAncestorRoles(roleId string) ([]*Role, error) { var ( result []*Role roleMap = make(map[string]*Role) @@ -312,7 +343,11 @@ func GetAncestorRoles(roleId string) []*Role { owner, _ := util.GetOwnerAndNameFromIdNoCheck(roleId) - allRoles := GetRoles(owner) + allRoles, err := GetRoles(owner) + if err != nil { + return nil, err + } + for _, r := range allRoles { roleMap[r.GetId()] = r } @@ -331,7 +366,7 @@ func GetAncestorRoles(roleId string) []*Role { } } - return result + return result, nil } // containsRole is a helper function to check if a slice of roles contains a specific roleId diff --git a/object/role_upload.go b/object/role_upload.go index a91908ab..22786a28 100644 --- a/object/role_upload.go +++ b/object/role_upload.go @@ -18,21 +18,29 @@ import ( "github.com/casdoor/casdoor/xlsx" ) -func getRoleMap(owner string) map[string]*Role { +func getRoleMap(owner string) (map[string]*Role, error) { m := map[string]*Role{} - roles := GetRoles(owner) + roles, err := GetRoles(owner) + if err != nil { + return nil, err + } + for _, role := range roles { m[role.GetId()] = role } - return m + return m, nil } -func UploadRoles(owner string, fileId string) bool { +func UploadRoles(owner string, fileId string) (bool, error) { table := xlsx.ReadXlsxFile(fileId) - oldUserMap := getRoleMap(owner) + oldUserMap, err := getRoleMap(owner) + if err != nil { + return false, err + } + newRoles := []*Role{} for index, line := range table { if index == 0 || parseLineItem(&line, 0) == "" { @@ -57,7 +65,7 @@ func UploadRoles(owner string, fileId string) bool { } if len(newRoles) == 0 { - return false + return false, nil } - return AddRolesInBatch(newRoles) + return AddRolesInBatch(newRoles), nil } diff --git a/object/saml_idp.go b/object/saml_idp.go index 94b98c28..9a933a47 100644 --- a/object/saml_idp.go +++ b/object/saml_idp.go @@ -105,7 +105,11 @@ func NewSamlResponse(user *User, host string, certificate string, destination st roles := attributes.CreateElement("saml:Attribute") roles.CreateAttr("Name", "Roles") roles.CreateAttr("NameFormat", "urn:oasis:names:tc:SAML:2.0:attrname-format:basic") - ExtendUserWithRolesAndPermissions(user) + err := ExtendUserWithRolesAndPermissions(user) + if err != nil { + return nil, err + } + for _, role := range user.Roles { roles.CreateElement("saml:AttributeValue").CreateAttr("xsi:type", "xs:string").Element().SetText(role.Name) } @@ -186,7 +190,11 @@ type Attribute struct { } func GetSamlMeta(application *Application, host string) (*IdpEntityDescriptor, error) { - cert := getCertByApplication(application) + cert, err := getCertByApplication(application) + if err != nil { + return nil, err + } + block, _ := pem.Decode([]byte(cert.Certificate)) certificate := base64.StdEncoding.EncodeToString(block.Bytes) @@ -263,7 +271,11 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h } // get certificate string - cert := getCertByApplication(application) + cert, err := getCertByApplication(application) + if err != nil { + return "", "", "", err + } + block, _ := pem.Decode([]byte(cert.Certificate)) certificate := base64.StdEncoding.EncodeToString(block.Bytes) diff --git a/object/saml_sp.go b/object/saml_sp.go index 9efd6b64..b9c3c622 100644 --- a/object/saml_sp.go +++ b/object/saml_sp.go @@ -43,7 +43,10 @@ func ParseSamlResponse(samlResponse string, provider *Provider, host string) (st } func GenerateSamlRequest(id, relayState, host, lang string) (auth string, method string, err error) { - provider := GetProvider(id) + provider, err := GetProvider(id) + if err != nil { + return "", "", err + } if provider.Category != "SAML" { return "", "", fmt.Errorf(i18n.Translate(lang, "saml_sp:provider %s's category is not SAML"), provider.Name) } @@ -92,27 +95,33 @@ func buildSp(provider *Provider, samlResponse string, host string) (*saml2.SAMLS } if provider.EnableSignAuthnRequest { sp.SignAuthnRequests = true - sp.SPKeyStore = buildSpKeyStore() + sp.SPKeyStore, err = buildSpKeyStore() + if err != nil { + return nil, err + } } return sp, nil } -func buildSpKeyStore() dsig.X509KeyStore { +func buildSpKeyStore() (dsig.X509KeyStore, error) { keyPair, err := tls.LoadX509KeyPair("object/token_jwt_key.pem", "object/token_jwt_key.key") if err != nil { - panic(err) + return nil, err } return &dsig.TLSCertKeyStore{ PrivateKey: keyPair.PrivateKey, Certificate: keyPair.Certificate, - } + }, nil } -func buildSpCertificateStore(provider *Provider, samlResponse string) (dsig.MemoryX509CertificateStore, error) { +func buildSpCertificateStore(provider *Provider, samlResponse string) (certStore dsig.MemoryX509CertificateStore, err error) { certEncodedData := "" if samlResponse != "" { - certEncodedData = getCertificateFromSamlResponse(samlResponse, provider.Type) + certEncodedData, err = getCertificateFromSamlResponse(samlResponse, provider.Type) + if err != nil { + return + } } else if provider.IdP != "" { certEncodedData = provider.IdP } @@ -126,17 +135,18 @@ func buildSpCertificateStore(provider *Provider, samlResponse string) (dsig.Memo return dsig.MemoryX509CertificateStore{}, err } - certStore := dsig.MemoryX509CertificateStore{ + certStore = dsig.MemoryX509CertificateStore{ Roots: []*x509.Certificate{idpCert}, } return certStore, nil } -func getCertificateFromSamlResponse(samlResponse string, providerType string) string { +func getCertificateFromSamlResponse(samlResponse string, providerType string) (string, error) { de, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { - panic(err) + return "", err } + deStr := strings.Replace(string(de), "\n", "", -1) tagMap := map[string]string{ "Aliyun IDaaS": "ds", @@ -145,5 +155,5 @@ func getCertificateFromSamlResponse(samlResponse string, providerType string) st tag := tagMap[providerType] expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)", tag, tag) res := regexp.MustCompile(expression).FindStringSubmatch(deStr) - return res[1] + return res[1], nil } diff --git a/object/session.go b/object/session.go index 765fc473..3a38dea7 100644 --- a/object/session.go +++ b/object/session.go @@ -36,7 +36,7 @@ type Session struct { SessionId []string `json:"sessionId"` } -func GetSessions(owner string) []*Session { +func GetSessions(owner string) ([]*Session, error) { sessions := []*Session{} var err error if owner != "" { @@ -45,61 +45,58 @@ func GetSessions(owner string) []*Session { err = adapter.Engine.Desc("created_time").Find(&sessions) } if err != nil { - panic(err) + return sessions, err } - return sessions + return sessions, nil } -func GetPaginationSessions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Session { +func GetPaginationSessions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Session, error) { sessions := []*Session{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&sessions) if err != nil { - panic(err) + return sessions, err } - return sessions + return sessions, nil } -func GetSessionCount(owner, field, value string) int { +func GetSessionCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Session{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Session{}) } -func GetSingleSession(id string) *Session { +func GetSingleSession(id string) (*Session, error) { owner, name, application := util.GetOwnerAndNameAndOtherFromId(id) session := Session{Owner: owner, Name: name, Application: application} get, err := adapter.Engine.Get(&session) if err != nil { - panic(err) + return &session, err } if !get { - return nil + return nil, nil } - return &session + return &session, nil } -func UpdateSession(id string, session *Session) bool { +func UpdateSession(id string, session *Session) (bool, error) { owner, name, application := util.GetOwnerAndNameAndOtherFromId(id) - if GetSingleSession(id) == nil { - return false + if ss, err := GetSingleSession(id); err != nil { + return false, err + } else if ss == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name, application}).Update(session) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func removeExtraSessionIds(session *Session) { @@ -108,17 +105,21 @@ func removeExtraSessionIds(session *Session) { } } -func AddSession(session *Session) bool { - dbSession := GetSingleSession(session.GetId()) +func AddSession(session *Session) (bool, error) { + dbSession, err := GetSingleSession(session.GetId()) + if err != nil { + return false, err + } + if dbSession == nil { session.CreatedTime = util.GetCurrentTime() affected, err := adapter.Engine.Insert(session) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } else { m := make(map[string]struct{}) for _, v := range dbSession.SessionId { @@ -136,10 +137,14 @@ func AddSession(session *Session) bool { } } -func DeleteSession(id string) bool { +func DeleteSession(id string) (bool, error) { owner, name, application := util.GetOwnerAndNameAndOtherFromId(id) if owner == CasdoorOrganization && application == CasdoorApplication { - session := GetSingleSession(id) + session, err := GetSingleSession(id) + if err != nil { + return false, err + } + if session != nil { DeleteBeegoSession(session.SessionId) } @@ -147,16 +152,19 @@ func DeleteSession(id string) bool { affected, err := adapter.Engine.ID(core.PK{owner, name, application}).Delete(&Session{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteSessionId(id string, sessionId string) bool { - session := GetSingleSession(id) +func DeleteSessionId(id string, sessionId string) (bool, error) { + session, err := GetSingleSession(id) + if err != nil { + return false, err + } if session == nil { - return false + return false, nil } owner, _, application := util.GetOwnerAndNameAndOtherFromId(id) @@ -185,17 +193,21 @@ func (session *Session) GetId() string { return fmt.Sprintf("%s/%s/%s", session.Owner, session.Name, session.Application) } -func IsSessionDuplicated(id string, sessionId string) bool { - session := GetSingleSession(id) +func IsSessionDuplicated(id string, sessionId string) (bool, error) { + session, err := GetSingleSession(id) + if err != nil { + return false, err + } + if session == nil { - return false + return false, nil } else { if len(session.SessionId) > 1 { - return true + return true, nil } else if len(session.SessionId) < 1 { - return false + return false, nil } else { - return session.SessionId[0] != sessionId + return session.SessionId[0] != sessionId, nil } } } diff --git a/object/storage.go b/object/storage.go index 1917457f..ccc4ddb1 100644 --- a/object/storage.go +++ b/object/storage.go @@ -30,11 +30,7 @@ import ( var isCloudIntranet bool func init() { - var err error - isCloudIntranet, err = conf.GetConfigBool("isCloudIntranet") - if err != nil { - // panic(err) - } + isCloudIntranet = conf.GetConfigBool("isCloudIntranet") } func getProviderEndpoint(provider *Provider) string { diff --git a/object/subscription.go b/object/subscription.go index 5724fb29..d944f63b 100644 --- a/object/subscription.go +++ b/object/subscription.go @@ -63,90 +63,87 @@ func NewSubscription(owner string, user string, plan string, duration int) *Subs } } -func GetSubscriptionCount(owner, field, value string) int { +func GetSubscriptionCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Subscription{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Subscription{}) } -func GetSubscriptions(owner string) []*Subscription { +func GetSubscriptions(owner string) ([]*Subscription, error) { subscriptions := []*Subscription{} err := adapter.Engine.Desc("created_time").Find(&subscriptions, &Subscription{Owner: owner}) if err != nil { - panic(err) + return subscriptions, err } - return subscriptions + return subscriptions, nil } -func GetPaginationSubscriptions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Subscription { +func GetPaginationSubscriptions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Subscription, error) { subscriptions := []*Subscription{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&subscriptions) if err != nil { - panic(err) + return subscriptions, err } - return subscriptions + return subscriptions, nil } -func getSubscription(owner string, name string) *Subscription { +func getSubscription(owner string, name string) (*Subscription, error) { if owner == "" || name == "" { - return nil + return nil, nil } subscription := Subscription{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&subscription) if err != nil { - panic(err) + return nil, err } if existed { - return &subscription + return &subscription, nil } else { - return nil + return nil, nil } } -func GetSubscription(id string) *Subscription { +func GetSubscription(id string) (*Subscription, error) { owner, name := util.GetOwnerAndNameFromId(id) return getSubscription(owner, name) } -func UpdateSubscription(id string, subscription *Subscription) bool { +func UpdateSubscription(id string, subscription *Subscription) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getSubscription(owner, name) == nil { - return false + if s, err := getSubscription(owner, name); err != nil { + return false, err + } else if s == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(subscription) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddSubscription(subscription *Subscription) bool { +func AddSubscription(subscription *Subscription) (bool, error) { affected, err := adapter.Engine.Insert(subscription) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteSubscription(subscription *Subscription) bool { +func DeleteSubscription(subscription *Subscription) (bool, error) { affected, err := adapter.Engine.ID(core.PK{subscription.Owner, subscription.Name}).Delete(&Subscription{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (subscription *Subscription) GetId() string { diff --git a/object/syncer.go b/object/syncer.go index c77c3241..7e836a14 100644 --- a/object/syncer.go +++ b/object/syncer.go @@ -55,66 +55,61 @@ type Syncer struct { Adapter *Adapter `xorm:"-" json:"-"` } -func GetSyncerCount(owner, organization, field, value string) int { +func GetSyncerCount(owner, organization, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Syncer{Organization: organization}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Syncer{Organization: organization}) } -func GetSyncers(owner string) []*Syncer { +func GetSyncers(owner string) ([]*Syncer, error) { syncers := []*Syncer{} err := adapter.Engine.Desc("created_time").Find(&syncers, &Syncer{Owner: owner}) if err != nil { - panic(err) + return syncers, err } - return syncers + return syncers, nil } -func GetOrganizationSyncers(owner, organization string) []*Syncer { +func GetOrganizationSyncers(owner, organization string) ([]*Syncer, error) { syncers := []*Syncer{} err := adapter.Engine.Desc("created_time").Find(&syncers, &Syncer{Owner: owner, Organization: organization}) if err != nil { - panic(err) + return syncers, err } - return syncers + return syncers, nil } -func GetPaginationSyncers(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Syncer { +func GetPaginationSyncers(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Syncer, error) { syncers := []*Syncer{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&syncers, &Syncer{Organization: organization}) if err != nil { - panic(err) + return syncers, err } - return syncers + return syncers, nil } -func getSyncer(owner string, name string) *Syncer { +func getSyncer(owner string, name string) (*Syncer, error) { if owner == "" || name == "" { - return nil + return nil, nil } syncer := Syncer{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&syncer) if err != nil { - panic(err) + return &syncer, err } if existed { - return &syncer + return &syncer, nil } else { - return nil + return nil, nil } } -func GetSyncer(id string) *Syncer { +func GetSyncer(id string) (*Syncer, error) { owner, name := util.GetOwnerAndNameFromId(id) return getSyncer(owner, name) } @@ -137,10 +132,12 @@ func GetMaskedSyncers(syncers []*Syncer) []*Syncer { return syncers } -func UpdateSyncer(id string, syncer *Syncer) bool { +func UpdateSyncer(id string, syncer *Syncer) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getSyncer(owner, name) == nil { - return false + if s, err := getSyncer(owner, name); err != nil { + return false, err + } else if s == nil { + return false, nil } session := adapter.Engine.ID(core.PK{owner, name}).AllCols() @@ -149,56 +146,66 @@ func UpdateSyncer(id string, syncer *Syncer) bool { } affected, err := session.Update(syncer) if err != nil { - panic(err) + return false, err } if affected == 1 { - addSyncerJob(syncer) + err = addSyncerJob(syncer) + if err != nil { + return false, err + } } - return affected != 0 + return affected != 0, nil } -func updateSyncerErrorText(syncer *Syncer, line string) bool { - s := getSyncer(syncer.Owner, syncer.Name) +func updateSyncerErrorText(syncer *Syncer, line string) (bool, error) { + s, err := getSyncer(syncer.Owner, syncer.Name) + if err != nil { + return false, err + } + if s == nil { - return false + return false, nil } s.ErrorText = s.ErrorText + line affected, err := adapter.Engine.ID(core.PK{s.Owner, s.Name}).Cols("error_text").Update(s) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddSyncer(syncer *Syncer) bool { +func AddSyncer(syncer *Syncer) (bool, error) { affected, err := adapter.Engine.Insert(syncer) if err != nil { - panic(err) + return false, err } if affected == 1 { - addSyncerJob(syncer) + err = addSyncerJob(syncer) + if err != nil { + return false, err + } } - return affected != 0 + return affected != 0, nil } -func DeleteSyncer(syncer *Syncer) bool { +func DeleteSyncer(syncer *Syncer) (bool, error) { affected, err := adapter.Engine.ID(core.PK{syncer.Owner, syncer.Name}).Delete(&Syncer{}) if err != nil { - panic(err) + return false, err } if affected == 1 { deleteSyncerJob(syncer) } - return affected != 0 + return affected != 0, nil } func (syncer *Syncer) GetId() string { diff --git a/object/syncer_affiliation.go b/object/syncer_affiliation.go index a64e2a65..8b0d86b2 100644 --- a/object/syncer_affiliation.go +++ b/object/syncer_affiliation.go @@ -19,22 +19,25 @@ type Affiliation struct { Name string `xorm:"varchar(128)" json:"name"` } -func (syncer *Syncer) getAffiliations() []*Affiliation { +func (syncer *Syncer) getAffiliations() ([]*Affiliation, error) { affiliations := []*Affiliation{} err := syncer.Adapter.Engine.Table(syncer.AffiliationTable).Asc("id").Find(&affiliations) if err != nil { - panic(err) + return nil, err } - return affiliations + return affiliations, nil } -func (syncer *Syncer) getAffiliationMap() ([]*Affiliation, map[int]string) { - affiliations := syncer.getAffiliations() +func (syncer *Syncer) getAffiliationMap() ([]*Affiliation, map[int]string, error) { + affiliations, err := syncer.getAffiliations() + if err != nil { + return nil, nil, err + } m := map[int]string{} for _, affiliation := range affiliations { m[affiliation.Id] = affiliation.Name } - return affiliations, m + return affiliations, m, nil } diff --git a/object/syncer_cron.go b/object/syncer_cron.go index a15f1dba..dc67dcc6 100644 --- a/object/syncer_cron.go +++ b/object/syncer_cron.go @@ -43,11 +43,11 @@ func clearCron(name string) { } } -func addSyncerJob(syncer *Syncer) { +func addSyncerJob(syncer *Syncer) error { deleteSyncerJob(syncer) if !syncer.IsEnabled { - return + return nil } syncer.initAdapter() @@ -58,10 +58,11 @@ func addSyncerJob(syncer *Syncer) { cron := getCronMap(syncer.Name) _, err := cron.AddFunc(schedule, syncer.syncUsers) if err != nil { - panic(err) + return err } cron.Start() + return nil } func deleteSyncerJob(syncer *Syncer) { diff --git a/object/syncer_public_api.go b/object/syncer_public_api.go index b016107a..ac65c03e 100644 --- a/object/syncer_public_api.go +++ b/object/syncer_public_api.go @@ -16,46 +16,74 @@ package object import "fmt" -func getDbSyncerForUser(user *User) *Syncer { - syncers := GetSyncers("admin") +func getDbSyncerForUser(user *User) (*Syncer, error) { + syncers, err := GetSyncers("admin") + if err != nil { + return nil, err + } + for _, syncer := range syncers { if syncer.Organization == user.Owner && syncer.IsEnabled && syncer.Type == "Database" { - return syncer + return syncer, nil } } - return nil + return nil, nil } -func getEnabledSyncerForOrganization(organization string) *Syncer { - syncers := GetSyncers("admin") +func getEnabledSyncerForOrganization(organization string) (*Syncer, error) { + syncers, err := GetSyncers("admin") + if err != nil { + return nil, err + } + for _, syncer := range syncers { if syncer.Organization == organization && syncer.IsEnabled { - return syncer + return syncer, nil } } - return nil + return nil, nil } -func AddUserToOriginalDatabase(user *User) { - syncer := getEnabledSyncerForOrganization(user.Owner) +func AddUserToOriginalDatabase(user *User) error { + syncer, err := getEnabledSyncerForOrganization(user.Owner) + if err != nil { + return err + } + if syncer == nil { - return + return nil } updatedOUser := syncer.createOriginalUserFromUser(user) - syncer.addUser(updatedOUser) - fmt.Printf("Add from user to oUser: %v\n", updatedOUser) -} - -func UpdateUserToOriginalDatabase(user *User) { - syncer := getEnabledSyncerForOrganization(user.Owner) - if syncer == nil { - return + _, err = syncer.addUser(updatedOUser) + if err != nil { + return err } - newUser := GetUser(user.GetId()) + fmt.Printf("Add from user to oUser: %v\n", updatedOUser) + return nil +} + +func UpdateUserToOriginalDatabase(user *User) error { + syncer, err := getEnabledSyncerForOrganization(user.Owner) + if err != nil { + return err + } + if syncer == nil { + return nil + } + + newUser, err := GetUser(user.GetId()) + if err != nil { + return err + } updatedOUser := syncer.createOriginalUserFromUser(newUser) - syncer.updateUser(updatedOUser) + _, err = syncer.updateUser(updatedOUser) + if err != nil { + return err + } + fmt.Printf("Update from user to oUser: %v\n", updatedOUser) + return nil } diff --git a/object/syncer_sync.go b/object/syncer_sync.go index ec20a417..8df8182e 100644 --- a/object/syncer_sync.go +++ b/object/syncer_sync.go @@ -37,7 +37,7 @@ func (syncer *Syncer) syncUsers() { var affiliationMap map[int]string if syncer.AffiliationTable != "" { - _, affiliationMap = syncer.getAffiliationMap() + _, affiliationMap, err = syncer.getAffiliationMap() } newUsers := []*User{} @@ -86,13 +86,19 @@ func (syncer *Syncer) syncUsers() { } } } - AddUsersInBatch(newUsers) + _, err = AddUsersInBatch(newUsers) + if err != nil { + panic(err) + } for _, user := range users { id := user.Id if _, ok := oUserMap[id]; !ok { newOUser := syncer.createOriginalUserFromUser(user) - syncer.addUser(newOUser) + _, err = syncer.addUser(newOUser) + if err != nil { + panic(err) + } fmt.Printf("New oUser: %v\n", newOUser) } } diff --git a/object/syncer_user.go b/object/syncer_user.go index d991157f..a592e3f8 100644 --- a/object/syncer_user.go +++ b/object/syncer_user.go @@ -122,14 +122,18 @@ func (syncer *Syncer) updateUser(user *OriginalUser) (bool, error) { } func (syncer *Syncer) updateUserForOriginalFields(user *User) (bool, error) { + var err error owner, name := util.GetOwnerAndNameFromId(user.GetId()) - oldUser := getUserById(owner, name) - if oldUser == nil { - return false, nil + oldUser, err := getUserById(owner, name) + if oldUser == nil || err != nil { + return false, err } if user.Avatar != oldUser.Avatar && user.Avatar != "" { - user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) + user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) + if err != nil { + return false, err + } } columns := syncer.getCasdoorColumns() @@ -175,7 +179,11 @@ func (syncer *Syncer) initAdapter() { } func RunSyncUsersJob() { - syncers := GetSyncers("admin") + syncers, err := GetSyncers("admin") + if err != nil { + panic(err) + } + for _, syncer := range syncers { addSyncerJob(syncer) } diff --git a/object/syncer_user_test.go b/object/syncer_user_test.go index 023d5371..67ecc2f6 100644 --- a/object/syncer_user_test.go +++ b/object/syncer_user_test.go @@ -23,7 +23,7 @@ import ( func TestGetUsers(t *testing.T) { InitConfig() - syncers := GetSyncers("admin") + syncers, _ := GetSyncers("admin") syncer := syncers[0] syncer.initAdapter() users, _ := syncer.getOriginalUsers() diff --git a/object/syner_db_user.go b/object/syner_db_user.go index a34ba7aa..debbe3ef 100644 --- a/object/syner_db_user.go +++ b/object/syner_db_user.go @@ -15,7 +15,11 @@ package object func (syncer *Syncer) getUsers() []*User { - users := GetUsers(syncer.Organization) + users, err := GetUsers(syncer.Organization) + if err != nil { + panic(err) + } + return users } diff --git a/object/token.go b/object/token.go index af33f1d4..b58d321c 100644 --- a/object/token.go +++ b/object/token.go @@ -91,67 +91,54 @@ type IntrospectionResponse struct { Jti string `json:"jti,omitempty"` } -func GetTokenCount(owner, organization, field, value string) int { +func GetTokenCount(owner, organization, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Token{Organization: organization}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Token{Organization: organization}) } -func GetTokens(owner string, organization string) []*Token { +func GetTokens(owner string, organization string) ([]*Token, error) { tokens := []*Token{} err := adapter.Engine.Desc("created_time").Find(&tokens, &Token{Owner: owner, Organization: organization}) - if err != nil { - panic(err) - } - - return tokens + return tokens, err } -func GetPaginationTokens(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Token { +func GetPaginationTokens(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Token, error) { tokens := []*Token{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&tokens, &Token{Organization: organization}) - if err != nil { - panic(err) - } - - return tokens + return tokens, err } -func getToken(owner string, name string) *Token { +func getToken(owner string, name string) (*Token, error) { if owner == "" || name == "" { - return nil + return nil, nil } token := Token{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&token) if err != nil { - panic(err) + return nil, err } if existed { - return &token + return &token, nil } - return nil + return nil, nil } -func getTokenByCode(code string) *Token { +func getTokenByCode(code string) (*Token, error) { token := Token{Code: code} existed, err := adapter.Engine.Get(&token) if err != nil { - panic(err) + return nil, err } if existed { - return &token + return &token, nil } - return nil + return nil, nil } func updateUsedByCode(token *Token) bool { @@ -163,7 +150,7 @@ func updateUsedByCode(token *Token) bool { return affected != 0 } -func GetToken(id string) *Token { +func GetToken(id string) (*Token, error) { owner, name := util.GetOwnerAndNameFromId(id) return getToken(owner, name) } @@ -172,124 +159,155 @@ func (token *Token) GetId() string { return fmt.Sprintf("%s/%s", token.Owner, token.Name) } -func UpdateToken(id string, token *Token) bool { +func UpdateToken(id string, token *Token) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getToken(owner, name) == nil { - return false + if t, err := getToken(owner, name); err != nil { + return false, err + } else if t == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(token) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddToken(token *Token) bool { +func AddToken(token *Token) (bool, error) { affected, err := adapter.Engine.Insert(token) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteToken(token *Token) bool { +func DeleteToken(token *Token) (bool, error) { affected, err := adapter.Engine.ID(core.PK{token.Owner, token.Name}).Delete(&Token{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func ExpireTokenByAccessToken(accessToken string) (bool, *Application, *Token) { +func ExpireTokenByAccessToken(accessToken string) (bool, *Application, *Token, error) { token := Token{AccessToken: accessToken} existed, err := adapter.Engine.Get(&token) if err != nil { - panic(err) + return false, nil, nil, err } if !existed { - return false, nil, nil + return false, nil, nil, nil } token.ExpiresIn = 0 affected, err := adapter.Engine.ID(core.PK{token.Owner, token.Name}).Cols("expires_in").Update(&token) if err != nil { - panic(err) + return false, nil, nil, err } - application := getApplication(token.Owner, token.Application) - return affected != 0, application, &token + application, err := getApplication(token.Owner, token.Application) + if err != nil { + return false, nil, nil, err + } + + return affected != 0, application, &token, nil } -func GetTokenByAccessToken(accessToken string) *Token { +func GetTokenByAccessToken(accessToken string) (*Token, error) { // Check if the accessToken is in the database token := Token{AccessToken: accessToken} existed, err := adapter.Engine.Get(&token) - if err != nil || !existed { - return nil + if err != nil { + return nil, err } - return &token + + if !existed { + return nil, nil + } + + return &token, nil } -func GetTokenByTokenAndApplication(token string, application string) *Token { +func GetTokenByTokenAndApplication(token string, application string) (*Token, error) { tokenResult := Token{} existed, err := adapter.Engine.Where("(refresh_token = ? or access_token = ? ) and application = ?", token, token, application).Get(&tokenResult) - if err != nil || !existed { - return nil + if err != nil { + return nil, err } - return &tokenResult + + if !existed { + return nil, nil + } + + return &tokenResult, nil } -func CheckOAuthLogin(clientId string, responseType string, redirectUri string, scope string, state string, lang string) (string, *Application) { +func CheckOAuthLogin(clientId string, responseType string, redirectUri string, scope string, state string, lang string) (string, *Application, error) { if responseType != "code" && responseType != "token" && responseType != "id_token" { - return fmt.Sprintf(i18n.Translate(lang, "token:Grant_type: %s is not supported in this application"), responseType), nil + return fmt.Sprintf(i18n.Translate(lang, "token:Grant_type: %s is not supported in this application"), responseType), nil, nil + } + + application, err := GetApplicationByClientId(clientId) + if err != nil { + return "", nil, err } - application := GetApplicationByClientId(clientId) if application == nil { - return i18n.Translate(lang, "token:Invalid client_id"), nil + return i18n.Translate(lang, "token:Invalid client_id"), nil, nil } if !application.IsRedirectUriValid(redirectUri) { - return fmt.Sprintf(i18n.Translate(lang, "token:Redirect URI: %s doesn't exist in the allowed Redirect URI list"), redirectUri), application + return fmt.Sprintf(i18n.Translate(lang, "token:Redirect URI: %s doesn't exist in the allowed Redirect URI list"), redirectUri), application, nil } // Mask application for /api/get-app-login application.ClientSecret = "" - return "", application + return "", application, nil } -func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string, host string, lang string) *Code { - user := GetUser(userId) +func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string, host string, lang string) (*Code, error) { + user, err := GetUser(userId) + if err != nil { + return nil, err + } + if user == nil { return &Code{ Message: fmt.Sprintf("general:The user: %s doesn't exist", userId), Code: "", - } + }, nil } if user.IsForbidden { return &Code{ Message: "error: the user is forbidden to sign in, please contact the administrator", Code: "", - } + }, nil + } + + msg, application, err := CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, lang) + if err != nil { + return nil, err } - msg, application := CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, lang) if msg != "" { return &Code{ Message: msg, Code: "", - } + }, nil } - ExtendUserWithRolesAndPermissions(user) + err = ExtendUserWithRolesAndPermissions(user) + if err != nil { + return nil, err + } accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, nonce, scope, host) if err != nil { - panic(err) + return nil, err } if challenge == "null" { @@ -313,21 +331,28 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU CodeIsUsed: false, CodeExpireIn: time.Now().Add(time.Minute * 5).Unix(), } - AddToken(token) + _, err = AddToken(token) + if err != nil { + return nil, err + } return &Code{ Message: "", Code: token.Code, - } + }, nil } -func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string, scope string, username string, password string, host string, refreshToken string, tag string, avatar string, lang string) interface{} { - application := GetApplicationByClientId(clientId) +func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string, scope string, username string, password string, host string, refreshToken string, tag string, avatar string, lang string) (interface{}, error) { + application, err := GetApplicationByClientId(clientId) + if err != nil { + return nil, err + } + if application == nil { return &TokenError{ Error: InvalidClient, ErrorDescription: "client_id is invalid", - } + }, nil } // Check if grantType is allowed in the current application @@ -336,32 +361,44 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code return &TokenError{ Error: UnsupportedGrantType, ErrorDescription: fmt.Sprintf("grant_type: %s is not supported in this application", grantType), - } + }, nil } var token *Token var tokenError *TokenError switch grantType { case "authorization_code": // Authorization Code Grant - token, tokenError = GetAuthorizationCodeToken(application, clientSecret, code, verifier) + token, tokenError, err = GetAuthorizationCodeToken(application, clientSecret, code, verifier) case "password": // Resource Owner Password Credentials Grant - token, tokenError = GetPasswordToken(application, username, password, scope, host) + token, tokenError, err = GetPasswordToken(application, username, password, scope, host) case "client_credentials": // Client Credentials Grant - token, tokenError = GetClientCredentialsToken(application, clientSecret, scope, host) + token, tokenError, err = GetClientCredentialsToken(application, clientSecret, scope, host) case "refresh_token": - return RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host) + refreshToken2, err := RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host) + if err != nil { + return nil, err + } + return refreshToken2, nil + } + + if err != nil { + return nil, err } if tag == "wechat_miniprogram" { // Wechat Mini Program - token, tokenError = GetWechatMiniProgramToken(application, code, host, username, avatar, lang) + token, tokenError, err = GetWechatMiniProgramToken(application, code, host, username, avatar, lang) + if err != nil { + return nil, err + } } if tokenError != nil { - return tokenError + return tokenError, nil } token.CodeIsUsed = true + go updateUsedByCode(token) tokenWrapper := &TokenWrapper{ @@ -373,29 +410,33 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code Scope: token.Scope, } - return tokenWrapper + return tokenWrapper, nil } -func RefreshToken(grantType string, refreshToken string, scope string, clientId string, clientSecret string, host string) interface{} { +func RefreshToken(grantType string, refreshToken string, scope string, clientId string, clientSecret string, host string) (interface{}, error) { // check parameters if grantType != "refresh_token" { return &TokenError{ Error: UnsupportedGrantType, ErrorDescription: "grant_type should be refresh_token", - } + }, nil } - application := GetApplicationByClientId(clientId) + application, err := GetApplicationByClientId(clientId) + if err != nil { + return nil, err + } + if application == nil { return &TokenError{ Error: InvalidClient, ErrorDescription: "client_id is invalid", - } + }, nil } if clientSecret != "" && application.ClientSecret != clientSecret { return &TokenError{ Error: InvalidClient, ErrorDescription: "client_secret is invalid", - } + }, nil } // check whether the refresh token is valid, and has not expired. token := Token{RefreshToken: refreshToken} @@ -404,33 +445,44 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId return &TokenError{ Error: InvalidGrant, ErrorDescription: "refresh token is invalid, expired or revoked", - } + }, nil + } + + cert, err := getCertByApplication(application) + if err != nil { + return nil, err } - cert := getCertByApplication(application) _, err = ParseJwtToken(refreshToken, cert) if err != nil { return &TokenError{ Error: InvalidGrant, ErrorDescription: fmt.Sprintf("parse refresh token error: %s", err.Error()), - } + }, nil } // generate a new token - user := getUser(application.Organization, token.User) + user, err := getUser(application.Organization, token.User) + if err != nil { + return nil, err + } + if user.IsForbidden { return &TokenError{ Error: InvalidGrant, ErrorDescription: "the user is forbidden to sign in, please contact the administrator", - } + }, nil } - ExtendUserWithRolesAndPermissions(user) + err = ExtendUserWithRolesAndPermissions(user) + if err != nil { + return nil, err + } newAccessToken, newRefreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host) if err != nil { return &TokenError{ Error: EndpointError, ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()), - } + }, nil } newToken := &Token{ @@ -447,8 +499,15 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId Scope: scope, TokenType: "Bearer", } - AddToken(newToken) - DeleteToken(&token) + _, err = AddToken(newToken) + if err != nil { + return nil, err + } + + _, err = DeleteToken(&token) + if err != nil { + return nil, err + } tokenWrapper := &TokenWrapper{ AccessToken: newToken.AccessToken, @@ -459,7 +518,7 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId Scope: newToken.Scope, } - return tokenWrapper + return tokenWrapper, nil } // PkceChallenge: base64-URL-encoded SHA256 hash of verifier, per rfc 7636 @@ -486,34 +545,38 @@ func IsGrantTypeValid(method string, grantTypes []string) bool { // GetAuthorizationCodeToken // Authorization code flow -func GetAuthorizationCodeToken(application *Application, clientSecret string, code string, verifier string) (*Token, *TokenError) { +func GetAuthorizationCodeToken(application *Application, clientSecret string, code string, verifier string) (*Token, *TokenError, error) { if code == "" { return nil, &TokenError{ Error: InvalidRequest, ErrorDescription: "authorization code should not be empty", - } + }, nil + } + + token, err := getTokenByCode(code) + if err != nil { + return nil, nil, err } - token := getTokenByCode(code) if token == nil { return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "authorization code is invalid", - } + }, nil } if token.CodeIsUsed { // anti replay attacks return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "authorization code has been used", - } + }, nil } if token.CodeChallenge != "" && pkceChallenge(verifier) != token.CodeChallenge { return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "verifier is invalid", - } + }, nil } if application.ClientSecret != clientSecret { @@ -523,13 +586,13 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co return nil, &TokenError{ Error: InvalidClient, ErrorDescription: "client_secret is invalid", - } + }, nil } else { if clientSecret != "" { return nil, &TokenError{ Error: InvalidClient, ErrorDescription: "client_secret is invalid", - } + }, nil } } } @@ -538,7 +601,7 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "the token is for wrong application (client_id)", - } + }, nil } if time.Now().Unix() > token.CodeExpireIn { @@ -546,42 +609,50 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "authorization code has expired", - } + }, nil } - return token, nil + return token, nil, nil } // GetPasswordToken // Resource Owner Password Credentials flow -func GetPasswordToken(application *Application, username string, password string, scope string, host string) (*Token, *TokenError) { - user := getUser(application.Organization, username) +func GetPasswordToken(application *Application, username string, password string, scope string, host string) (*Token, *TokenError, error) { + user, err := getUser(application.Organization, username) + if err != nil { + return nil, nil, err + } + if user == nil { return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "the user does not exist", - } + }, nil } msg := CheckPassword(user, password, "en") if msg != "" { return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "invalid username or password", - } + }, nil } if user.IsForbidden { return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "the user is forbidden to sign in, please contact the administrator", - } + }, nil + } + + err = ExtendUserWithRolesAndPermissions(user) + if err != nil { + return nil, nil, err } - ExtendUserWithRolesAndPermissions(user) accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host) if err != nil { return nil, &TokenError{ Error: EndpointError, ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()), - } + }, nil } token := &Token{ Owner: application.Owner, @@ -598,18 +669,22 @@ func GetPasswordToken(application *Application, username string, password string TokenType: "Bearer", CodeIsUsed: true, } - AddToken(token) - return token, nil + _, err = AddToken(token) + if err != nil { + return nil, nil, err + } + + return token, nil, nil } // GetClientCredentialsToken // Client Credentials flow -func GetClientCredentialsToken(application *Application, clientSecret string, scope string, host string) (*Token, *TokenError) { +func GetClientCredentialsToken(application *Application, clientSecret string, scope string, host string) (*Token, *TokenError, error) { if application.ClientSecret != clientSecret { return nil, &TokenError{ Error: InvalidClient, ErrorDescription: "client_secret is invalid", - } + }, nil } nullUser := &User{ Owner: application.Owner, @@ -623,7 +698,7 @@ func GetClientCredentialsToken(application *Application, clientSecret string, sc return nil, &TokenError{ Error: EndpointError, ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()), - } + }, nil } token := &Token{ Owner: application.Owner, @@ -639,18 +714,27 @@ func GetClientCredentialsToken(application *Application, clientSecret string, sc TokenType: "Bearer", CodeIsUsed: true, } - AddToken(token) - return token, nil + _, err = AddToken(token) + if err != nil { + return nil, nil, err + } + + return token, nil, nil } // GetTokenByUser // Implicit flow func GetTokenByUser(application *Application, user *User, scope string, host string) (*Token, error) { - ExtendUserWithRolesAndPermissions(user) + err := ExtendUserWithRolesAndPermissions(user) + if err != nil { + return nil, err + } + accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host) if err != nil { return nil, err } + token := &Token{ Owner: application.Owner, Name: tokenName, @@ -666,43 +750,56 @@ func GetTokenByUser(application *Application, user *User, scope string, host str TokenType: "Bearer", CodeIsUsed: true, } - AddToken(token) + _, err = AddToken(token) + if err != nil { + return nil, err + } + return token, nil } // GetWechatMiniProgramToken // Wechat Mini Program flow -func GetWechatMiniProgramToken(application *Application, code string, host string, username string, avatar string, lang string) (*Token, *TokenError) { +func GetWechatMiniProgramToken(application *Application, code string, host string, username string, avatar string, lang string) (*Token, *TokenError, error) { mpProvider := GetWechatMiniProgramProvider(application) if mpProvider == nil { return nil, &TokenError{ Error: InvalidClient, ErrorDescription: "the application does not support wechat mini program", - } + }, nil } - provider := GetProvider(util.GetId("admin", mpProvider.Name)) + provider, err := GetProvider(util.GetId("admin", mpProvider.Name)) + if err != nil { + return nil, nil, err + } + mpIdp := idp.NewWeChatMiniProgramIdProvider(provider.ClientId, provider.ClientSecret) session, err := mpIdp.GetSessionByCode(code) if err != nil { return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: fmt.Sprintf("get wechat mini program session error: %s", err.Error()), - } + }, nil } + openId, unionId := session.Openid, session.Unionid if openId == "" && unionId == "" { return nil, &TokenError{ Error: InvalidRequest, ErrorDescription: "the wechat mini program session is invalid", - } + }, nil } - user := getUserByWechatId(openId, unionId) + user, err := getUserByWechatId(openId, unionId) + if err != nil { + return nil, nil, err + } + if user == nil { if !application.EnableSignUp { return nil, &TokenError{ Error: InvalidGrant, ErrorDescription: "the application does not allow to sign up new account", - } + }, nil } // Add new user var name string @@ -730,16 +827,23 @@ func GetWechatMiniProgramToken(application *Application, code string, host strin UserPropertiesWechatUnionId: unionId, }, } - AddUser(user) + _, err = AddUser(user) + if err != nil { + return nil, nil, err + } + } + + err = ExtendUserWithRolesAndPermissions(user) + if err != nil { + return nil, nil, err } - ExtendUserWithRolesAndPermissions(user) accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", "", host) if err != nil { return nil, &TokenError{ Error: EndpointError, ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()), - } + }, nil } token := &Token{ @@ -757,6 +861,9 @@ func GetWechatMiniProgramToken(application *Application, code string, host strin TokenType: "Bearer", CodeIsUsed: true, } - AddToken(token) - return token, nil + _, err = AddToken(token) + if err != nil { + return nil, nil, err + } + return token, nil, nil } diff --git a/object/token_cas.go b/object/token_cas.go index 6e17dadf..ab9ab6e0 100644 --- a/object/token_cas.go +++ b/object/token_cas.go @@ -177,7 +177,9 @@ func StoreCasTokenForProxyTicket(token *CasAuthenticationSuccess, targetService, } func GenerateCasToken(userId string, service string) (string, error) { - if user := GetUser(userId); user != nil { + if user, err := GetUser(userId); err != nil { + return "", err + } else if user != nil { authenticationSuccess := CasAuthenticationSuccess{ User: user.Name, Attributes: &CasAttributes{ @@ -232,18 +234,31 @@ func GetValidationBySaml(samlRequest string, host string) (string, string, error return "", "", fmt.Errorf("ticket %s found", ticket) } - user := GetUser(userId) + user, err := GetUser(userId) + if err != nil { + return "", "", err + } + if user == nil { return "", "", fmt.Errorf("user %s found", userId) } - application := GetApplicationByUser(user) + + application, err := GetApplicationByUser(user) + if err != nil { + return "", "", err + } + if application == nil { return "", "", fmt.Errorf("application for user %s found", userId) } samlResponse := NewSamlResponse11(user, request.RequestID, host) - cert := getCertByApplication(application) + cert, err := getCertByApplication(application) + if err != nil { + return "", "", err + } + block, _ := pem.Decode([]byte(cert.Certificate)) certificate := base64.StdEncoding.EncodeToString(block.Bytes) randomKeyStore := &X509Key{ diff --git a/object/token_jwt.go b/object/token_jwt.go index 31a0dd7c..7dfd4ba6 100644 --- a/object/token_jwt.go +++ b/object/token_jwt.go @@ -273,7 +273,10 @@ func generateJwtToken(application *Application, user *User, nonce string, scope refreshToken = jwt.NewWithClaims(jwt.SigningMethodRS256, claimsWithoutThirdIdp) } - cert := getCertByApplication(application) + cert, err := getCertByApplication(application) + if err != nil { + return "", "", "", err + } // RSA private key key, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(cert.PrivateKey)) @@ -316,5 +319,10 @@ func ParseJwtToken(token string, cert *Cert) (*Claims, error) { } func ParseJwtTokenByApplication(token string, application *Application) (*Claims, error) { - return ParseJwtToken(token, getCertByApplication(application)) + cert, err := getCertByApplication(application) + if err != nil { + return nil, err + } + + return ParseJwtToken(token, cert) } diff --git a/object/user.go b/object/user.go index b646dc03..367af464 100644 --- a/object/user.go +++ b/object/user.go @@ -191,217 +191,206 @@ type ManagedAccount struct { SigninUrl string `xorm:"varchar(200)" json:"signinUrl"` } -func GetGlobalUserCount(field, value string) int { +func GetGlobalUserCount(field, value string) (int64, error) { session := GetSession("", -1, -1, field, value, "", "") - count, err := session.Count(&User{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&User{}) } -func GetGlobalUsers() []*User { +func GetGlobalUsers() ([]*User, error) { users := []*User{} err := adapter.Engine.Desc("created_time").Find(&users) if err != nil { - panic(err) + return nil, err } - return users + return users, nil } -func GetPaginationGlobalUsers(offset, limit int, field, value, sortField, sortOrder string) []*User { +func GetPaginationGlobalUsers(offset, limit int, field, value, sortField, sortOrder string) ([]*User, error) { users := []*User{} session := GetSession("", offset, limit, field, value, sortField, sortOrder) err := session.Find(&users) if err != nil { - panic(err) + return nil, err } - return users + return users, nil } -func GetUserCount(owner, field, value string) int { +func GetUserCount(owner, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&User{}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&User{}) } -func GetOnlineUserCount(owner string, isOnline int) int { - count, err := adapter.Engine.Where("is_online = ?", isOnline).Count(&User{Owner: owner}) - if err != nil { - panic(err) - } - - return int(count) +func GetOnlineUserCount(owner string, isOnline int) (int64, error) { + return adapter.Engine.Where("is_online = ?", isOnline).Count(&User{Owner: owner}) } -func GetUsers(owner string) []*User { +func GetUsers(owner string) ([]*User, error) { users := []*User{} err := adapter.Engine.Desc("created_time").Find(&users, &User{Owner: owner}) if err != nil { - panic(err) + return nil, err } - return users + return users, nil } -func GetUsersByTag(owner string, tag string) []*User { +func GetUsersByTag(owner string, tag string) ([]*User, error) { users := []*User{} err := adapter.Engine.Desc("created_time").Find(&users, &User{Owner: owner, Tag: tag}) if err != nil { - panic(err) + return nil, err } - return users + return users, nil } -func GetSortedUsers(owner string, sorter string, limit int) []*User { +func GetSortedUsers(owner string, sorter string, limit int) ([]*User, error) { users := []*User{} err := adapter.Engine.Desc(sorter).Limit(limit, 0).Find(&users, &User{Owner: owner}) if err != nil { - panic(err) + return nil, err } - return users + return users, nil } -func GetPaginationUsers(owner string, offset, limit int, field, value, sortField, sortOrder string) []*User { +func GetPaginationUsers(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*User, error) { users := []*User{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&users) if err != nil { - panic(err) + return nil, err } - return users + return users, nil } -func getUser(owner string, name string) *User { +func getUser(owner string, name string) (*User, error) { if owner == "" || name == "" { - return nil + return nil, nil } user := User{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&user) if err != nil { - panic(err) + return nil, err } if existed { - return &user + return &user, nil } else { - return nil + return nil, nil } } -func getUserById(owner string, id string) *User { +func getUserById(owner string, id string) (*User, error) { if owner == "" || id == "" { - return nil + return nil, nil } user := User{Owner: owner, Id: id} existed, err := adapter.Engine.Get(&user) if err != nil { - panic(err) + return nil, err } if existed { - return &user + return &user, nil } else { - return nil + return nil, nil } } -func getUserByWechatId(wechatOpenId string, wechatUnionId string) *User { +func getUserByWechatId(wechatOpenId string, wechatUnionId string) (*User, error) { if wechatUnionId == "" { wechatUnionId = wechatOpenId } user := &User{} existed, err := adapter.Engine.Where("wechat = ? OR wechat = ?", wechatOpenId, wechatUnionId).Get(user) if err != nil { - panic(err) + return nil, err } if existed { - return user + return user, nil } else { - return nil + return nil, nil } } -func GetUserByEmail(owner string, email string) *User { +func GetUserByEmail(owner string, email string) (*User, error) { if owner == "" || email == "" { - return nil + return nil, nil } user := User{Owner: owner, Email: email} existed, err := adapter.Engine.Get(&user) if err != nil { - panic(err) + return nil, err } if existed { - return &user + return &user, nil } else { - return nil + return nil, nil } } -func GetUserByPhone(owner string, phone string) *User { +func GetUserByPhone(owner string, phone string) (*User, error) { if owner == "" || phone == "" { - return nil + return nil, nil } user := User{Owner: owner, Phone: phone} existed, err := adapter.Engine.Get(&user) if err != nil { - panic(err) + return nil, err } if existed { - return &user + return &user, nil } else { - return nil + return nil, nil } } -func GetUserByUserId(owner string, userId string) *User { +func GetUserByUserId(owner string, userId string) (*User, error) { if owner == "" || userId == "" { - return nil + return nil, nil } user := User{Owner: owner, Id: userId} existed, err := adapter.Engine.Get(&user) if err != nil { - panic(err) + return nil, err } if existed { - return &user + return &user, nil } else { - return nil + return nil, nil } } -func GetUser(id string) *User { +func GetUser(id string) (*User, error) { owner, name := util.GetOwnerAndNameFromId(id) return getUser(owner, name) } -func GetUserNoCheck(id string) *User { +func GetUserNoCheck(id string) (*User, error) { owner, name := util.GetOwnerAndNameFromIdNoCheck(id) return getUser(owner, name) } -func GetMaskedUser(user *User) *User { +func GetMaskedUser(user *User, errs ...error) (*User, error) { + if len(errs) > 0 && errs[0] != nil { + return nil, errs[0] + } + if user == nil { - return nil + return nil, nil } if user.Password != "" { @@ -419,51 +408,69 @@ func GetMaskedUser(user *User) *User { user.MultiFactorAuths[i] = GetMaskedProps(props) } } - return user + return user, nil } -func GetMaskedUsers(users []*User) []*User { - for _, user := range users { - user = GetMaskedUser(user) +func GetMaskedUsers(users []*User, errs ...error) ([]*User, error) { + if len(errs) > 0 && errs[0] != nil { + return nil, errs[0] } - return users + + var err error + for _, user := range users { + user, err = GetMaskedUser(user) + if err != nil { + return nil, err + } + } + return users, nil } -func GetLastUser(owner string) *User { +func GetLastUser(owner string) (*User, error) { user := User{Owner: owner} existed, err := adapter.Engine.Desc("created_time", "id").Get(&user) if err != nil { - panic(err) + return nil, err } if existed { - return &user + return &user, nil } - return nil + return nil, nil } -func UpdateUser(id string, user *User, columns []string, isAdmin bool) bool { +func UpdateUser(id string, user *User, columns []string, isAdmin bool) (bool, error) { + var err error owner, name := util.GetOwnerAndNameFromIdNoCheck(id) - oldUser := getUser(owner, name) + oldUser, err := getUser(owner, name) + if err != nil { + return false, err + } if oldUser == nil { - return false + return false, nil } if name != user.Name { err := userChangeTrigger(name, user.Name) if err != nil { - return false + return false, nil } } if user.Password == "***" { user.Password = oldUser.Password } - user.UpdateUserHash() + err = user.UpdateUserHash() + if err != nil { + panic(err) + } if user.Avatar != oldUser.Avatar && user.Avatar != "" && user.PermanentAvatar != "*" { - user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false) + user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false) + if err != nil { + return false, err + } } if len(columns) == 0 { @@ -487,77 +494,105 @@ func UpdateUser(id string, user *User, columns []string, isAdmin bool) bool { affected, err := adapter.Engine.ID(core.PK{owner, name}).Cols(columns...).Update(user) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func UpdateUserForAllFields(id string, user *User) bool { +func UpdateUserForAllFields(id string, user *User) (bool, error) { + var err error owner, name := util.GetOwnerAndNameFromId(id) - oldUser := getUser(owner, name) + oldUser, err := getUser(owner, name) + if err != nil { + return false, err + } + if oldUser == nil { - return false + return false, nil } if name != user.Name { err := userChangeTrigger(name, user.Name) if err != nil { - return false + return false, nil } } - user.UpdateUserHash() + err = user.UpdateUserHash() + if err != nil { + return false, err + } if user.Avatar != oldUser.Avatar && user.Avatar != "" { - user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false) + user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false) + if err != nil { + return false, err + } } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(user) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddUser(user *User) bool { +func AddUser(user *User) (bool, error) { + var err error if user.Id == "" { user.Id = util.GenerateId() } if user.Owner == "" || user.Name == "" { - return false + return false, nil } - organization := GetOrganizationByUser(user) + organization, _ := GetOrganizationByUser(user) if organization == nil { - return false + return false, nil } user.UpdateUserPassword(organization) - user.UpdateUserHash() - user.PreHash = user.Hash - - updated := user.refreshAvatar() - if updated && user.PermanentAvatar != "*" { - user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false) + err = user.UpdateUserHash() + if err != nil { + return false, err } - user.Ranking = GetUserCount(user.Owner, "", "") + 1 + user.PreHash = user.Hash + + updated, err := user.refreshAvatar() + if err != nil { + return false, err + } + + if updated && user.PermanentAvatar != "*" { + user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false) + if err != nil { + return false, err + } + } + + count, err := GetUserCount(user.Owner, "", "") + if err != nil { + return false, err + } + user.Ranking = int(count + 1) affected, err := adapter.Engine.Insert(user) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddUsers(users []*User) bool { +func AddUsers(users []*User) (bool, error) { + var err error if len(users) == 0 { - return false + return false, nil } // organization := GetOrganizationByUser(users[0]) @@ -565,27 +600,34 @@ func AddUsers(users []*User) bool { // this function is only used for syncer or batch upload, so no need to encrypt the password // user.UpdateUserPassword(organization) - user.UpdateUserHash() + err = user.UpdateUserHash() + if err != nil { + return false, err + } + user.PreHash = user.Hash - user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) + user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) + if err != nil { + return false, err + } } affected, err := adapter.Engine.Insert(users) if err != nil { if !strings.Contains(err.Error(), "Duplicate entry") { - panic(err) + return false, err } } - return affected != 0 + return affected != 0, nil } -func AddUsersInBatch(users []*User) bool { +func AddUsersInBatch(users []*User) (bool, error) { batchSize := conf.GetConfigBatchSize() if len(users) == 0 { - return false + return false, nil } affected := false @@ -599,24 +641,29 @@ func AddUsersInBatch(users []*User) bool { tmp := users[start:end] // TODO: save to log instead of standard output // fmt.Printf("Add users: [%d - %d].\n", start, end) - if AddUsers(tmp) { + if ok, err := AddUsers(tmp); err != nil { + return false, err + } else if ok { affected = true } } - return affected + return affected, nil } -func DeleteUser(user *User) bool { +func DeleteUser(user *User) (bool, error) { // Forced offline the user first - DeleteSession(util.GetSessionId(user.Owner, user.Name, CasdoorApplication)) + _, err := DeleteSession(util.GetSessionId(user.Owner, user.Name, CasdoorApplication)) + if err != nil { + return false, err + } affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Delete(&User{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func GetUserInfo(user *User, scope string, aud string, host string) *Userinfo { @@ -644,7 +691,7 @@ func GetUserInfo(user *User, scope string, aud string, host string) *Userinfo { return &resp } -func LinkUserAccount(user *User, field string, value string) bool { +func LinkUserAccount(user *User, field string, value string) (bool, error) { return SetUserField(user, field, value) } @@ -656,13 +703,18 @@ func isUserIdGlobalAdmin(userId string) bool { return strings.HasPrefix(userId, "built-in/") } -func ExtendUserWithRolesAndPermissions(user *User) { +func ExtendUserWithRolesAndPermissions(user *User) (err error) { if user == nil { return } - user.Roles = GetRolesByUser(user.GetId()) - user.Permissions = GetPermissionsByUser(user.GetId()) + user.Roles, err = GetRolesByUser(user.GetId()) + if err != nil { + return + } + + user.Permissions, err = GetPermissionsByUser(user.GetId()) + return } func userChangeTrigger(oldName string, newName string) error { @@ -679,6 +731,7 @@ func userChangeTrigger(oldName string, newName string) error { if err != nil { return err } + for _, role := range roles { for j, u := range role.Users { // u = organization/username @@ -722,7 +775,7 @@ func userChangeTrigger(oldName string, newName string) error { return session.Commit() } -func (user *User) refreshAvatar() bool { +func (user *User) refreshAvatar() (bool, error) { var err error var fileBuffer *bytes.Buffer var ext string @@ -732,13 +785,13 @@ func (user *User) refreshAvatar() bool { client := proxy.ProxyHttpClient has, err := hasGravatar(client, user.Email) if err != nil { - panic(err) + return false, err } if has { fileBuffer, ext, err = getGravatarFileBuffer(client, user.Email) if err != nil { - panic(err) + return false, err } } } @@ -746,17 +799,20 @@ func (user *User) refreshAvatar() bool { if fileBuffer == nil && strings.Contains(user.Avatar, "Identicon") { fileBuffer, ext, err = getIdenticonFileBuffer(user.Name) if err != nil { - panic(err) + return false, err } } if fileBuffer != nil { - avatarUrl := getPermanentAvatarUrlFromBuffer(user.Owner, user.Name, fileBuffer, ext, true) + avatarUrl, err := getPermanentAvatarUrlFromBuffer(user.Owner, user.Name, fileBuffer, ext, true) + if err != nil { + return false, err + } user.Avatar = avatarUrl - return true + return true, nil } - return false + return false, nil } func (user *User) IsMfaEnabled() bool { diff --git a/object/user_cred.go b/object/user_cred.go index c1db566b..240c4144 100644 --- a/object/user_cred.go +++ b/object/user_cred.go @@ -16,18 +16,27 @@ package object import "github.com/casdoor/casdoor/cred" -func calculateHash(user *User) string { - syncer := getDbSyncerForUser(user) - if syncer == nil { - return "" +func calculateHash(user *User) (string, error) { + syncer, err := getDbSyncerForUser(user) + if err != nil { + return "", err } - return syncer.calculateHash(user) + if syncer == nil { + return "", nil + } + + return syncer.calculateHash(user), nil } -func (user *User) UpdateUserHash() { - hash := calculateHash(user) +func (user *User) UpdateUserHash() error { + hash, err := calculateHash(user) + if err != nil { + return err + } + user.Hash = hash + return nil } func (user *User) UpdateUserPassword(organization *Organization) { diff --git a/object/user_test.go b/object/user_test.go index 66a14c69..858f360a 100644 --- a/object/user_test.go +++ b/object/user_test.go @@ -36,7 +36,7 @@ func updateUserColumn(column string, user *User) bool { func TestSyncAvatarsFromGitHub(t *testing.T) { InitConfig() - users := GetGlobalUsers() + users, _ := GetGlobalUsers() for _, user := range users { if user.GitHub == "" { continue @@ -50,7 +50,7 @@ func TestSyncAvatarsFromGitHub(t *testing.T) { func TestSyncIds(t *testing.T) { InitConfig() - users := GetGlobalUsers() + users, _ := GetGlobalUsers() for _, user := range users { if user.Id != "" { continue @@ -64,13 +64,16 @@ func TestSyncIds(t *testing.T) { func TestSyncHashes(t *testing.T) { InitConfig() - users := GetGlobalUsers() + users, _ := GetGlobalUsers() for _, user := range users { if user.Hash != "" { continue } - user.UpdateUserHash() + err := user.UpdateUserHash() + if err != nil { + panic(err) + } updateUserColumn("hash", user) } } @@ -92,7 +95,7 @@ func TestGetMaskedUsers(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := GetMaskedUsers(tt.args.users); !reflect.DeepEqual(got, tt.want) { + if got, _ := GetMaskedUsers(tt.args.users); !reflect.DeepEqual(got, tt.want) { t.Errorf("GetMaskedUsers() = %v, want %v", got, tt.want) } }) @@ -102,7 +105,7 @@ func TestGetMaskedUsers(t *testing.T) { func TestGetUserByField(t *testing.T) { InitConfig() - user := GetUserByField("built-in", "DingTalk", "test") + user, _ := GetUserByField("built-in", "DingTalk", "test") if user != nil { t.Logf("%+v", user) } else { @@ -115,7 +118,7 @@ func TestGetEmailsForUsers(t *testing.T) { emailMap := map[string]int{} emails := []string{} - users := GetUsers("built-in") + users, _ := GetUsers("built-in") for _, user := range users { if user.Email == "" { continue diff --git a/object/user_upload.go b/object/user_upload.go index e19ef701..48804c72 100644 --- a/object/user_upload.go +++ b/object/user_upload.go @@ -22,15 +22,18 @@ import ( "github.com/casdoor/casdoor/xlsx" ) -func getUserMap(owner string) map[string]*User { +func getUserMap(owner string) (map[string]*User, error) { m := map[string]*User{} - users := GetUsers(owner) + users, err := GetUsers(owner) + if err != nil { + return m, err + } for _, user := range users { m[user.GetId()] = user } - return m + return m, nil } func parseLineItem(line *[]string, i int) string { @@ -70,10 +73,14 @@ func parseListItem(lines *[]string, i int) []string { return trimmedItems } -func UploadUsers(owner string, fileId string) bool { +func UploadUsers(owner string, fileId string) (bool, error) { table := xlsx.ReadXlsxFile(fileId) - oldUserMap := getUserMap(owner) + oldUserMap, err := getUserMap(owner) + if err != nil { + return false, err + } + newUsers := []*User{} for index, line := range table { if index == 0 || parseLineItem(&line, 0) == "" { @@ -135,7 +142,7 @@ func UploadUsers(owner string, fileId string) bool { } if len(newUsers) == 0 { - return false + return false, nil } return AddUsersInBatch(newUsers) } diff --git a/object/user_util.go b/object/user_util.go index e38b6c59..34ce1a56 100644 --- a/object/user_util.go +++ b/object/user_util.go @@ -24,62 +24,70 @@ import ( "github.com/xorm-io/core" ) -func GetUserByField(organizationName string, field string, value string) *User { +func GetUserByField(organizationName string, field string, value string) (*User, error) { if field == "" || value == "" { - return nil + return nil, nil } user := User{Owner: organizationName} existed, err := adapter.Engine.Where(fmt.Sprintf("%s=?", strings.ToLower(field)), value).Get(&user) if err != nil { - panic(err) + return nil, err } if existed { - return &user + return &user, nil } else { - return nil + return nil, nil } } func HasUserByField(organizationName string, field string, value string) bool { - return GetUserByField(organizationName, field, value) != nil + user, err := GetUserByField(organizationName, field, value) + if err != nil { + panic(err) + } + return user != nil } -func GetUserByFields(organization string, field string) *User { +func GetUserByFields(organization string, field string) (*User, error) { // check username - user := GetUserByField(organization, "name", field) - if user != nil { - return user + user, err := GetUserByField(organization, "name", field) + if err != nil || user != nil { + return user, err } // check email if strings.Contains(field, "@") { - user = GetUserByField(organization, "email", field) - if user != nil { - return user + user, err = GetUserByField(organization, "email", field) + if user != nil || err != nil { + return user, err } } // check phone - user = GetUserByField(organization, "phone", field) - if user != nil { - return user + user, err = GetUserByField(organization, "phone", field) + if user != nil || err != nil { + return user, err } // check ID card - user = GetUserByField(organization, "id_card", field) - if user != nil { - return user + user, err = GetUserByField(organization, "id_card", field) + if user != nil || err != nil { + return user, err } - return nil + return nil, nil } -func SetUserField(user *User, field string, value string) bool { +func SetUserField(user *User, field string, value string) (bool, error) { bean := make(map[string]interface{}) if field == "password" { - organization := GetOrganizationByUser(user) + organization, err := GetOrganizationByUser(user) + if err != nil { + return false, err + } + user.UpdateUserPassword(organization) bean[strings.ToLower(field)] = user.Password bean["password_type"] = user.PasswordType @@ -89,17 +97,25 @@ func SetUserField(user *User, field string, value string) bool { affected, err := adapter.Engine.Table(user).ID(core.PK{user.Owner, user.Name}).Update(bean) if err != nil { - panic(err) + return false, err + } + + user, err = getUser(user.Owner, user.Name) + if err != nil { + return false, err + } + + err = user.UpdateUserHash() + if err != nil { + return false, err } - user = getUser(user.Owner, user.Name) - user.UpdateUserHash() _, err = adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("hash").Update(user) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func GetUserField(user *User, field string) string { @@ -121,7 +137,7 @@ func setUserProperty(user *User, field string, value string) { } } -func SetUserOAuthProperties(organization *Organization, user *User, providerType string, userInfo *idp.UserInfo) bool { +func SetUserOAuthProperties(organization *Organization, user *User, providerType string, userInfo *idp.UserInfo) (bool, error) { if userInfo.Id != "" { propertyName := fmt.Sprintf("oauth_%s_id", providerType) setUserProperty(user, propertyName, userInfo.Id) @@ -164,11 +180,10 @@ func SetUserOAuthProperties(organization *Organization, user *User, providerType } } - affected := UpdateUserForAllFields(user.GetId(), user) - return affected + return UpdateUserForAllFields(user.GetId(), user) } -func ClearUserOAuthProperties(user *User, providerType string) bool { +func ClearUserOAuthProperties(user *User, providerType string) (bool, error) { for k := range user.Properties { prefix := fmt.Sprintf("oauth_%s_", providerType) if strings.HasPrefix(k, prefix) { @@ -178,14 +193,18 @@ func ClearUserOAuthProperties(user *User, providerType string) bool { affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("properties").Update(user) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func CheckPermissionForUpdateUser(oldUser, newUser *User, isAdmin bool, lang string) (bool, string) { - organization := GetOrganizationByUser(oldUser) + organization, err := GetOrganizationByUser(oldUser) + if err != nil { + return false, err.Error() + } + var itemsChanged []*AccountItem if oldUser.Owner != newUser.Owner { @@ -310,7 +329,7 @@ func (user *User) GetCountryCode(countryCode string) string { return user.CountryCode } - if org := GetOrganizationByUser(user); org != nil && len(org.CountryCodes) > 0 { + if org, _ := GetOrganizationByUser(user); org != nil && len(org.CountryCodes) > 0 { return org.CountryCodes[0] } return "" diff --git a/object/user_webauthn.go b/object/user_webauthn.go index 9382f166..124f23ad 100644 --- a/object/user_webauthn.go +++ b/object/user_webauthn.go @@ -16,6 +16,7 @@ package object import ( "encoding/base64" + "fmt" "net/url" "strings" @@ -24,14 +25,14 @@ import ( "github.com/go-webauthn/webauthn/webauthn" ) -func GetWebAuthnObject(host string) *webauthn.WebAuthn { +func GetWebAuthnObject(host string) (*webauthn.WebAuthn, error) { var err error _, originBackend := getOriginFromHost(host) localUrl, err := url.Parse(originBackend) if err != nil { - panic("error when parsing origin:" + err.Error()) + return nil, fmt.Errorf("error when parsing origin:" + err.Error()) } webAuthn, err := webauthn.New(&webauthn.Config{ @@ -41,10 +42,10 @@ func GetWebAuthnObject(host string) *webauthn.WebAuthn { // RPIcon: "https://duo.com/logo.png", // Optional icon URL for your site }) if err != nil { - panic(err) + return nil, err } - return webAuthn + return webAuthn, nil } // WebAuthnID @@ -84,17 +85,17 @@ func (user *User) CredentialExcludeList() []protocol.CredentialDescriptor { return credentialExcludeList } -func (user *User) AddCredentials(credential webauthn.Credential, isGlobalAdmin bool) bool { +func (user *User) AddCredentials(credential webauthn.Credential, isGlobalAdmin bool) (bool, error) { user.WebauthnCredentials = append(user.WebauthnCredentials, credential) return UpdateUser(user.GetId(), user, []string{"webauthnCredentials"}, isGlobalAdmin) } -func (user *User) DeleteCredentials(credentialIdBase64 string) bool { +func (user *User) DeleteCredentials(credentialIdBase64 string) (bool, error) { for i, credential := range user.WebauthnCredentials { if base64.StdEncoding.EncodeToString(credential.ID) == credentialIdBase64 { user.WebauthnCredentials = append(user.WebauthnCredentials[0:i], user.WebauthnCredentials[i+1:]...) return UpdateUserForAllFields(user.GetId(), user) } } - return false + return false, nil } diff --git a/object/verification.go b/object/verification.go index 8225f823..cb22905e 100644 --- a/object/verification.go +++ b/object/verification.go @@ -151,21 +151,24 @@ func AddToVerificationRecord(user *User, provider *Provider, remoteAddr, recordT return nil } -func getVerificationRecord(dest string) *VerificationRecord { +func getVerificationRecord(dest string) (*VerificationRecord, error) { var record VerificationRecord record.Receiver = dest has, err := adapter.Engine.Desc("time").Where("is_used = false").Get(&record) if err != nil { - panic(err) + return nil, err } if !has { - return nil + return nil, nil } - return &record + return &record, nil } func CheckVerificationCode(dest, code, lang string) *VerifyResult { - record := getVerificationRecord(dest) + record, err := getVerificationRecord(dest) + if err != nil { + panic(err) + } if record == nil { return &VerifyResult{noRecordError, i18n.Translate(lang, "verification:Code has not been sent yet!")} @@ -188,17 +191,15 @@ func CheckVerificationCode(dest, code, lang string) *VerifyResult { return &VerifyResult{VerificationSuccess, ""} } -func DisableVerificationCode(dest string) { - record := getVerificationRecord(dest) - if record == nil { +func DisableVerificationCode(dest string) (err error) { + record, err := getVerificationRecord(dest) + if record == nil || err != nil { return } record.IsUsed = true - _, err := adapter.Engine.ID(core.PK{record.Owner, record.Name}).AllCols().Update(record) - if err != nil { - panic(err) - } + _, err = adapter.Engine.ID(core.PK{record.Owner, record.Name}).AllCols().Update(record) + return } func CheckSigninCode(user *User, dest, code, lang string) string { diff --git a/object/webhook.go b/object/webhook.go index 15f8804d..9b4e8a88 100644 --- a/object/webhook.go +++ b/object/webhook.go @@ -42,100 +42,97 @@ type Webhook struct { IsEnabled bool `json:"isEnabled"` } -func GetWebhookCount(owner, organization, field, value string) int { +func GetWebhookCount(owner, organization, field, value string) (int64, error) { session := GetSession(owner, -1, -1, field, value, "", "") - count, err := session.Count(&Webhook{Organization: organization}) - if err != nil { - panic(err) - } - - return int(count) + return session.Count(&Webhook{Organization: organization}) } -func GetWebhooks(owner string, organization string) []*Webhook { +func GetWebhooks(owner string, organization string) ([]*Webhook, error) { webhooks := []*Webhook{} err := adapter.Engine.Desc("created_time").Find(&webhooks, &Webhook{Owner: owner, Organization: organization}) if err != nil { - panic(err) + return webhooks, err } - return webhooks + return webhooks, nil } -func GetPaginationWebhooks(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Webhook { +func GetPaginationWebhooks(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Webhook, error) { webhooks := []*Webhook{} session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) err := session.Find(&webhooks, &Webhook{Organization: organization}) if err != nil { - panic(err) + return nil, err } - return webhooks + return webhooks, nil } -func getWebhooksByOrganization(organization string) []*Webhook { +func getWebhooksByOrganization(organization string) ([]*Webhook, error) { webhooks := []*Webhook{} err := adapter.Engine.Desc("created_time").Find(&webhooks, &Webhook{Organization: organization}) if err != nil { - panic(err) + return webhooks, err } - return webhooks + return webhooks, nil } -func getWebhook(owner string, name string) *Webhook { +func getWebhook(owner string, name string) (*Webhook, error) { if owner == "" || name == "" { - return nil + return nil, nil } webhook := Webhook{Owner: owner, Name: name} existed, err := adapter.Engine.Get(&webhook) if err != nil { - panic(err) + return &webhook, err } if existed { - return &webhook + return &webhook, nil } else { - return nil + return nil, nil } } -func GetWebhook(id string) *Webhook { +func GetWebhook(id string) (*Webhook, error) { owner, name := util.GetOwnerAndNameFromId(id) return getWebhook(owner, name) } -func UpdateWebhook(id string, webhook *Webhook) bool { +func UpdateWebhook(id string, webhook *Webhook) (bool, error) { owner, name := util.GetOwnerAndNameFromId(id) - if getWebhook(owner, name) == nil { - return false + if w, err := getWebhook(owner, name); err != nil { + return false, err + } else if w == nil { + return false, nil } affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(webhook) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func AddWebhook(webhook *Webhook) bool { +func AddWebhook(webhook *Webhook) (bool, error) { affected, err := adapter.Engine.Insert(webhook) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } -func DeleteWebhook(webhook *Webhook) bool { +func DeleteWebhook(webhook *Webhook) (bool, error) { affected, err := adapter.Engine.ID(core.PK{webhook.Owner, webhook.Name}).Delete(&Webhook{}) if err != nil { - panic(err) + return false, err } - return affected != 0 + return affected != 0, nil } func (p *Webhook) GetId() string { diff --git a/routers/auto_signin_filter.go b/routers/auto_signin_filter.go index 2ce85b82..fea0d6b0 100644 --- a/routers/auto_signin_filter.go +++ b/routers/auto_signin_filter.go @@ -32,7 +32,12 @@ func AutoSigninFilter(ctx *context.Context) { accessToken := util.GetMaxLenStr(ctx.Input.Query("accessToken"), ctx.Input.Query("access_token"), parseBearerToken(ctx)) if accessToken != "" { - token := object.GetTokenByAccessToken(accessToken) + token, err := object.GetTokenByAccessToken(accessToken) + if err != nil { + responseError(ctx, err.Error()) + return + } + if token == nil { responseError(ctx, "Access token doesn't exist") return @@ -44,7 +49,11 @@ func AutoSigninFilter(ctx *context.Context) { } userId := util.GetId(token.Organization, token.User) - application, _ := object.GetApplicationByUserId(fmt.Sprintf("app/%s", token.Application)) + application, _, err := object.GetApplicationByUserId(fmt.Sprintf("app/%s", token.Application)) + if err != nil { + panic(err) + } + setSessionUser(ctx, userId) setSessionOidc(ctx, token.Scope, application.ClientId) return diff --git a/routers/base.go b/routers/base.go index 2cab9792..716ba356 100644 --- a/routers/base.go +++ b/routers/base.go @@ -72,7 +72,11 @@ func getUsernameByClientIdSecret(ctx *context.Context) string { return "" } - application := object.GetApplicationByClientId(clientId) + application, err := object.GetApplicationByClientId(clientId) + if err != nil { + panic(err) + } + if application == nil || application.ClientSecret != clientSecret { return "" } diff --git a/routers/cors_filter.go b/routers/cors_filter.go index be773f25..00f8431c 100644 --- a/routers/cors_filter.go +++ b/routers/cors_filter.go @@ -34,7 +34,12 @@ func CorsFilter(ctx *context.Context) { originConf := conf.GetConfigString("origin") if origin != "" && originConf != "" && origin != originConf { - if object.IsOriginAllowed(origin) { + ok, err := object.IsOriginAllowed(origin) + if err != nil { + panic(err) + } + + if ok { ctx.Output.Header(headerAllowOrigin, origin) ctx.Output.Header(headerAllowMethods, "POST, GET, OPTIONS, DELETE") ctx.Output.Header(headerAllowHeaders, "Content-Type, Authorization") diff --git a/routers/record.go b/routers/record.go index 203e1146..6bac98a7 100644 --- a/routers/record.go +++ b/routers/record.go @@ -43,7 +43,11 @@ func getUserByClientIdSecret(ctx *context.Context) string { return "" } - application := object.GetApplicationByClientId(clientId) + application, err := object.GetApplicationByClientId(clientId) + if err != nil { + panic(err) + } + if application == nil || application.ClientSecret != clientSecret { return "" } diff --git a/routers/static_filter.go b/routers/static_filter.go index 0ef7047a..68ed6cf2 100644 --- a/routers/static_filter.go +++ b/routers/static_filter.go @@ -30,7 +30,7 @@ import ( var ( oldStaticBaseUrl = "https://cdn.casbin.org" newStaticBaseUrl = conf.GetConfigString("staticBaseUrl") - enableGzip, _ = conf.GetConfigBool("enableGzip") + enableGzip = conf.GetConfigBool("enableGzip") ) func StaticFilter(ctx *context.Context) {