feat: return most backend API errors to frontend (#1836)

* feat: return most backend API errros to frontend

Signed-off-by: yehong <239859435@qq.com>

* refactor: reduce int type change

Signed-off-by: yehong <239859435@qq.com>

* feat: return err backend in token.go

Signed-off-by: yehong <239859435@qq.com>

---------

Signed-off-by: yehong <239859435@qq.com>
This commit is contained in:
yehong
2023-05-30 15:49:39 +08:00
committed by GitHub
parent 34151c0095
commit 02e692a300
105 changed files with 3788 additions and 1734 deletions

View File

@ -154,7 +154,11 @@ func IsAllowed(subOwner string, subName string, method string, urlPath string, o
}
}
user := object.GetUser(util.GetId(subOwner, subName))
user, err := object.GetUser(util.GetId(subOwner, subName))
if err != nil {
panic(err)
}
if user != nil && user.IsAdmin && (subOwner == objOwner || (objOwner == "admin")) {
return true
}

View File

@ -16,7 +16,6 @@ package conf
import (
"encoding/json"
"fmt"
"os"
"runtime"
"strconv"
@ -73,14 +72,13 @@ func GetConfigString(key string) string {
return res
}
func GetConfigBool(key string) (bool, error) {
func GetConfigBool(key string) bool {
value := GetConfigString(key)
if value == "true" {
return true, nil
} else if value == "false" {
return false, nil
return true
} else {
return false
}
return false, fmt.Errorf("value %s cannot be converted into bool", value)
}
func GetConfigInt64(key string) (int64, error) {

View File

@ -87,7 +87,7 @@ func TestGetConfBool(t *testing.T) {
assert.Nil(t, err)
for _, scenery := range scenarios {
t.Run(scenery.description, func(t *testing.T) {
actual, err := GetConfigBool(scenery.input)
actual := GetConfigBool(scenery.input)
assert.Nil(t, err)
assert.Equal(t, scenery.expected, actual)
})

View File

@ -78,13 +78,23 @@ func (c *ApiController) Signup() {
return
}
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
if !application.EnableSignUp {
c.ResponseError(c.T("account:The application does not allow to sign up new account"))
return
}
organization := object.GetOrganization(util.GetId("admin", authForm.Organization))
organization, err := object.GetOrganization(util.GetId("admin", authForm.Organization))
if err != nil {
c.ResponseError(c.T(err.Error()))
return
}
msg := object.CheckUserSignup(application, organization, &authForm, c.GetAcceptLanguage())
if msg != "" {
c.ResponseError(msg)
@ -111,7 +121,11 @@ func (c *ApiController) Signup() {
id := util.GenerateId()
if application.GetSignupItemRule("ID") == "Incremental" {
lastUser := object.GetLastUser(authForm.Organization)
lastUser, err := object.GetLastUser(authForm.Organization)
if err != nil {
c.ResponseError(err.Error())
return
}
lastIdInt := -1
if lastUser != nil {
@ -173,25 +187,47 @@ func (c *ApiController) Signup() {
}
}
affected := object.AddUser(user)
affected, err := object.AddUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
if !affected {
c.ResponseError(c.T("account:Failed to add user"), util.StructToJson(user))
return
}
object.AddUserToOriginalDatabase(user)
err = object.AddUserToOriginalDatabase(user)
if err != nil {
c.ResponseError(err.Error())
return
}
if application.HasPromptPage() {
// The prompt page needs the user to be signed in
c.SetSessionUsername(user.GetId())
}
object.DisableVerificationCode(authForm.Email)
object.DisableVerificationCode(checkPhone)
err = object.DisableVerificationCode(authForm.Email)
if err != nil {
c.ResponseError(err.Error())
return
}
err = object.DisableVerificationCode(checkPhone)
if err != nil {
c.ResponseError(err.Error())
return
}
isSignupFromPricing := authForm.Plan != "" && authForm.Pricing != ""
if isSignupFromPricing {
object.Subscribe(organization.Name, user.Name, authForm.Plan, authForm.Pricing)
_, err = object.Subscribe(organization.Name, user.Name, authForm.Plan, authForm.Pricing)
if err != nil {
c.ResponseError(err.Error())
return
}
}
record := object.NewRecord(c.Ctx)
@ -231,7 +267,11 @@ func (c *ApiController) Logout() {
c.ClearUserSession()
owner, username := util.GetOwnerAndNameFromId(user)
object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID())
_, err := object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID())
if err != nil {
c.ResponseError(err.Error())
return
}
util.LogInfo(c.Ctx, "API: [%s] logged out", user)
@ -252,7 +292,12 @@ func (c *ApiController) Logout() {
return
}
affected, application, token := object.ExpireTokenByAccessToken(accessToken)
affected, application, token, err := object.ExpireTokenByAccessToken(accessToken)
if err != nil {
c.ResponseError(err.Error())
return
}
if !affected {
c.ResponseError(c.T("token:Token not found, invalid accessToken"))
return
@ -272,7 +317,12 @@ func (c *ApiController) Logout() {
// TODO https://github.com/casdoor/casdoor/pull/1494#discussion_r1095675265
owner, username := util.GetOwnerAndNameFromId(user)
object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID())
_, err := object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID())
if err != nil {
c.ResponseError(err.Error())
return
}
util.LogInfo(c.Ctx, "API: [%s] logged out", user)
c.Ctx.Redirect(http.StatusFound, fmt.Sprintf("%s?state=%s", strings.TrimRight(redirectUri, "/"), state))
@ -290,6 +340,7 @@ func (c *ApiController) Logout() {
// @Success 200 {object} controllers.Response The Response object
// @router /get-account [get]
func (c *ApiController) GetAccount() {
var err error
user, ok := c.RequireSignedInUser()
if !ok {
return
@ -297,20 +348,39 @@ func (c *ApiController) GetAccount() {
managedAccounts := c.Input().Get("managedAccounts")
if managedAccounts == "1" {
user = object.ExtendManagedAccountsWithUser(user)
user, err = object.ExtendManagedAccountsWithUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
}
object.ExtendUserWithRolesAndPermissions(user)
err = object.ExtendUserWithRolesAndPermissions(user)
if err != nil {
c.ResponseError(err.Error())
return
}
user.Permissions = object.GetMaskedPermissions(user.Permissions)
user.Roles = object.GetMaskedRoles(user.Roles)
organization := object.GetMaskedOrganization(object.GetOrganizationByUser(user))
organization, err := object.GetMaskedOrganization(object.GetOrganizationByUser(user))
if err != nil {
c.ResponseError(err.Error())
return
}
u, err := object.GetMaskedUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
resp := Response{
Status: "ok",
Sub: user.Id,
Name: user.Name,
Data: object.GetMaskedUser(user),
Data: u,
Data2: organization,
}
c.Data["json"] = resp
@ -391,7 +461,12 @@ func (c *ApiController) GetCaptcha() {
if captchaProvider != nil {
if captchaProvider.Type == "Default" {
id, img := object.GetCaptcha()
id, img, err := object.GetCaptcha()
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(Captcha{Type: captchaProvider.Type, CaptchaId: id, CaptchaImage: img})
return
} else if captchaProvider.Type != "" {

View File

@ -40,21 +40,35 @@ func (c *ApiController) GetApplications() {
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization")
var err error
if limit == "" || page == "" {
var applications []*object.Application
if organization == "" {
applications = object.GetApplications(owner)
applications, err = object.GetApplications(owner)
} else {
applications = object.GetOrganizationApplications(owner, organization)
applications, err = object.GetOrganizationApplications(owner, organization)
}
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedApplications(applications, userId)
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetApplicationCount(owner, field, value)))
applications := object.GetMaskedApplications(object.GetPaginationApplications(owner, paginator.Offset(), limit, field, value, sortField, sortOrder), userId)
count, err := object.GetApplicationCount(owner, field, value)
if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
app, err := object.GetPaginationApplications(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
panic(err)
}
applications := object.GetMaskedApplications(app, userId)
c.ResponseOk(applications, paginator.Nums())
}
}
@ -69,8 +83,12 @@ func (c *ApiController) GetApplications() {
func (c *ApiController) GetApplication() {
userId := c.GetSessionUsername()
id := c.Input().Get("id")
app, err := object.GetApplication(id)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedApplication(object.GetApplication(id), userId)
c.Data["json"] = object.GetMaskedApplication(app, userId)
c.ServeJSON()
}
@ -84,13 +102,22 @@ func (c *ApiController) GetApplication() {
func (c *ApiController) GetUserApplication() {
userId := c.GetSessionUsername()
id := c.Input().Get("id")
user := object.GetUser(id)
user, err := object.GetUser(id)
if err != nil {
panic(err)
}
if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), id))
return
}
c.Data["json"] = object.GetMaskedApplication(object.GetApplicationByUser(user), userId)
app, err := object.GetApplicationByUser(user)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedApplication(app, userId)
c.ServeJSON()
}
@ -118,13 +145,30 @@ func (c *ApiController) GetOrganizationApplications() {
}
if limit == "" || page == "" {
applications := object.GetOrganizationApplications(owner, organization)
applications, err := object.GetOrganizationApplications(owner, organization)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedApplications(applications, userId)
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetOrganizationApplicationCount(owner, organization, field, value)))
applications := object.GetMaskedApplications(object.GetPaginationOrganizationApplications(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder), userId)
count, err := object.GetOrganizationApplicationCount(owner, organization, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
app, err := object.GetPaginationOrganizationApplications(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
applications := object.GetMaskedApplications(app, userId)
c.ResponseOk(applications, paginator.Nums())
}
}
@ -166,8 +210,13 @@ func (c *ApiController) AddApplication() {
return
}
count := object.GetApplicationCount("", "", "")
if err := checkQuotaForApplication(count); err != nil {
count, err := object.GetApplicationCount("", "", "")
if err != nil {
c.ResponseError(err.Error())
return
}
if err := checkQuotaForApplication(int(count)); err != nil {
c.ResponseError(err.Error())
return
}

View File

@ -93,7 +93,12 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob
c.ResponseError(c.T("auth:Challenge method should be S256"))
return
}
code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge, c.Ctx.Request.Host, c.GetAcceptLanguage())
code, err := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge, c.Ctx.Request.Host, c.GetAcceptLanguage())
if err != nil {
c.ResponseError(err.Error(), nil)
return
}
resp = codeToResponse(code)
if application.EnableSigninSession || application.HasPromptPage() {
@ -142,12 +147,16 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob
}
if resp.Status == "ok" {
object.AddSession(&object.Session{
_, err = object.AddSession(&object.Session{
Owner: user.Owner,
Name: user.Name,
Application: application.Name,
SessionId: []string{c.Ctx.Input.CruSession.SessionID()},
})
if err != nil {
c.ResponseError(err.Error(), nil)
return
}
}
return resp
@ -171,7 +180,12 @@ func (c *ApiController) GetApplicationLogin() {
scope := c.Input().Get("scope")
state := c.Input().Get("state")
msg, application := object.CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, c.GetAcceptLanguage())
msg, application, err := object.CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, c.GetAcceptLanguage())
if err != nil {
c.ResponseError(err.Error())
return
}
application = object.GetMaskedApplication(application, "")
if msg != "" {
c.ResponseError(msg, application)
@ -248,7 +262,10 @@ func (c *ApiController) Login() {
var msg string
if authForm.Password == "" {
if user = object.GetUserByFields(authForm.Organization, authForm.Username); user == nil {
if user, err = object.GetUserByFields(authForm.Organization, authForm.Username); err != nil {
c.ResponseError(err.Error(), nil)
return
} else if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(authForm.Organization, authForm.Username)))
return
}
@ -272,9 +289,18 @@ func (c *ApiController) Login() {
}
// disable the verification code
object.DisableVerificationCode(checkDest)
err := object.DisableVerificationCode(checkDest)
if err != nil {
c.ResponseError(err.Error(), nil)
return
}
} else {
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error(), nil)
return
}
if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return
@ -284,7 +310,10 @@ func (c *ApiController) Login() {
return
}
var enableCaptcha bool
if enableCaptcha = object.CheckToEnableCaptcha(application, authForm.Organization, authForm.Username); enableCaptcha {
if enableCaptcha, err = object.CheckToEnableCaptcha(application, authForm.Organization, authForm.Username); err != nil {
c.ResponseError(err.Error())
return
} else if enableCaptcha {
isHuman, err := captcha.VerifyCaptchaByCaptchaType(authForm.CaptchaType, authForm.CaptchaToken, authForm.ClientSecret)
if err != nil {
c.ResponseError(err.Error())
@ -304,7 +333,12 @@ func (c *ApiController) Login() {
if msg != "" {
resp = &Response{Status: "error", Msg: msg}
} else {
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return
@ -312,7 +346,11 @@ func (c *ApiController) Login() {
resp = c.HandleLoggedIn(application, user, &authForm)
organization := object.GetOrganizationByUser(user)
organization, err := object.GetOrganizationByUser(user)
if err != nil {
c.ResponseError(err.Error())
}
if user != nil && organization.HasRequiredMfa() && !user.IsMfaEnabled() {
resp.Msg = object.RequiredMfa
}
@ -325,18 +363,34 @@ func (c *ApiController) Login() {
} else if authForm.Provider != "" {
var application *object.Application
if authForm.ClientId != "" {
application = object.GetApplicationByClientId(authForm.ClientId)
application, err = object.GetApplicationByClientId(authForm.ClientId)
if err != nil {
c.ResponseError(err.Error())
return
}
} else {
application = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
application, err = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
}
if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return
}
organization, err := object.GetOrganization(util.GetId("admin", application.Organization))
if err != nil {
c.ResponseError(c.T(err.Error()))
}
provider, err := object.GetProvider(util.GetId("admin", authForm.Provider))
if err != nil {
c.ResponseError(err.Error())
return
}
organization := object.GetOrganization(util.GetId("admin", application.Organization))
provider := object.GetProvider(util.GetId("admin", authForm.Provider))
providerItem := application.GetProviderItem(provider.Name)
if !providerItem.IsProviderVisible() {
c.ResponseError(fmt.Sprintf(c.T("auth:The provider: %s is not enabled for the application"), provider.Name))
@ -396,9 +450,17 @@ func (c *ApiController) Login() {
if authForm.Method == "signup" {
user := &object.User{}
if provider.Category == "SAML" {
user = object.GetUser(util.GetId(application.Organization, userInfo.Id))
user, err = object.GetUser(util.GetId(application.Organization, userInfo.Id))
if err != nil {
c.ResponseError(err.Error())
return
}
} else if provider.Category == "OAuth" {
user = object.GetUserByField(application.Organization, provider.Type, userInfo.Id)
user, err = object.GetUserByField(application.Organization, provider.Type, userInfo.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
}
if user != nil && !user.IsDeleted {
@ -419,12 +481,20 @@ func (c *ApiController) Login() {
if application.EnableLinkWithEmail {
if userInfo.Email != "" {
// Find existing user with Email
user = object.GetUserByField(application.Organization, "email", userInfo.Email)
user, err = object.GetUserByField(application.Organization, "email", userInfo.Email)
if err != nil {
c.ResponseError(err.Error())
return
}
}
if user == nil && userInfo.Phone != "" {
// Find existing user with phone number
user = object.GetUserByField(application.Organization, "phone", userInfo.Phone)
user, err = object.GetUserByField(application.Organization, "phone", userInfo.Phone)
if err != nil {
c.ResponseError(err.Error())
return
}
}
}
@ -440,7 +510,12 @@ func (c *ApiController) Login() {
}
// Handle username conflicts
tmpUser := object.GetUser(util.GetId(application.Organization, userInfo.Username))
tmpUser, err := object.GetUser(util.GetId(application.Organization, userInfo.Username))
if err != nil {
c.ResponseError(err.Error())
return
}
if tmpUser != nil {
uid, err := uuid.NewRandom()
if err != nil {
@ -453,7 +528,13 @@ func (c *ApiController) Login() {
}
properties := map[string]string{}
properties["no"] = strconv.Itoa(object.GetUserCount(application.Organization, "", "") + 2)
count, err := object.GetUserCount(application.Organization, "", "")
if err != nil {
c.ResponseError(err.Error())
return
}
properties["no"] = strconv.Itoa(int(count + 2))
initScore, err := organization.GetInitScore()
if err != nil {
c.ResponseError(fmt.Errorf(c.T("account:Get init score failed, error: %w"), err).Error())
@ -482,7 +563,12 @@ func (c *ApiController) Login() {
Properties: properties,
}
affected := object.AddUser(user)
affected, err := object.AddUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
if !affected {
c.ResponseError(fmt.Sprintf(c.T("auth:Failed to create user, user information is invalid: %s"), util.StructToJson(user)))
return
@ -490,8 +576,17 @@ func (c *ApiController) Login() {
}
// sync info from 3rd-party if possible
object.SetUserOAuthProperties(organization, user, provider.Type, userInfo)
object.LinkUserAccount(user, provider.Type, userInfo.Id)
_, err := object.SetUserOAuthProperties(organization, user, provider.Type, userInfo)
if err != nil {
c.ResponseError(err.Error())
return
}
_, err = object.LinkUserAccount(user, provider.Type, userInfo.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
resp = c.HandleLoggedIn(application, user, &authForm)
@ -516,18 +611,36 @@ func (c *ApiController) Login() {
return
}
oldUser := object.GetUserByField(application.Organization, provider.Type, userInfo.Id)
oldUser, err := object.GetUserByField(application.Organization, provider.Type, userInfo.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
if oldUser != nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The account for provider: %s and username: %s (%s) is already linked to another account: %s (%s)"), provider.Type, userInfo.Username, userInfo.DisplayName, oldUser.Name, oldUser.DisplayName))
return
}
user := object.GetUser(userId)
user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
// sync info from 3rd-party if possible
object.SetUserOAuthProperties(organization, user, provider.Type, userInfo)
_, err = object.SetUserOAuthProperties(organization, user, provider.Type, userInfo)
if err != nil {
c.ResponseError(err.Error())
return
}
isLinked, err := object.LinkUserAccount(user, provider.Type, userInfo.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
isLinked := object.LinkUserAccount(user, provider.Type, userInfo.Id)
if isLinked {
resp = &Response{Status: "ok", Msg: "", Data: isLinked}
} else {
@ -536,7 +649,11 @@ func (c *ApiController) Login() {
}
} else if c.getMfaSessionData() != nil {
mfaSession := c.getMfaSessionData()
user := object.GetUser(mfaSession.UserId)
user, err := object.GetUser(mfaSession.UserId)
if err != nil {
c.ResponseError(err.Error())
return
}
if authForm.Passcode != "" {
MfaUtil := object.GetMfaUtil(authForm.MfaType, user.GetPreferMfa(false))
@ -554,7 +671,12 @@ func (c *ApiController) Login() {
}
}
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return
@ -569,7 +691,12 @@ func (c *ApiController) Login() {
} else {
if c.GetSessionUsername() != "" {
// user already signed in to Casdoor, so let the user click the avatar button to do the quick sign-in
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return
@ -624,8 +751,9 @@ func (c *ApiController) HandleSamlLogin() {
func (c *ApiController) HandleOfficialAccountEvent() {
respBytes, err := ioutil.ReadAll(c.Ctx.Request.Body)
if err != nil {
c.ResponseError(err.Error())
panic(err)
}
var data struct {
MsgType string `xml:"MsgType"`
Event string `xml:"Event"`
@ -633,8 +761,9 @@ func (c *ApiController) HandleOfficialAccountEvent() {
}
err = xml.Unmarshal(respBytes, &data)
if err != nil {
c.ResponseError(err.Error())
panic(err)
}
lock.Lock()
defer lock.Unlock()
if data.EventKey != "" {
@ -670,7 +799,12 @@ func (c *ApiController) GetWebhookEventType() {
func (c *ApiController) GetCaptchaStatus() {
organization := c.Input().Get("organization")
userId := c.Input().Get("user_id")
user := object.GetUserByFields(organization, userId)
user, err := object.GetUserByFields(organization, userId)
if err != nil {
c.ResponseError(err.Error())
return
}
var captchaEnabled bool
if user != nil && user.SigninWrongTimes >= object.SigninWrongTimesLimit {
captchaEnabled = true

View File

@ -72,11 +72,15 @@ func (c *ApiController) isGlobalAdmin() (bool, *object.User) {
func (c *ApiController) getCurrentUser() *object.User {
var user *object.User
var err error
userId := c.GetSessionUsername()
if userId == "" {
user = nil
} else {
user = object.GetUser(userId)
user, err = object.GetUser(userId)
if err != nil {
panic(err)
}
}
return user
}
@ -106,7 +110,11 @@ func (c *ApiController) GetSessionApplication() *object.Application {
if clientId == nil {
return nil
}
application := object.GetApplicationByClientId(clientId.(string))
application, err := object.GetApplicationByClientId(clientId.(string))
if err != nil {
panic(err)
}
return application
}
@ -192,8 +200,10 @@ func (c *ApiController) setExpireForSession() {
})
}
func wrapActionResponse(affected bool) *Response {
if affected {
func wrapActionResponse(affected bool, e ...error) *Response {
if len(e) != 0 && e[0] != nil {
return &Response{Status: "error", Msg: e[0].Error()}
} else if affected {
return &Response{Status: "ok", Msg: "", Data: "Affected"}
} else {
return &Response{Status: "ok", Msg: "", Data: "Unaffected"}

View File

@ -33,19 +33,40 @@ func (c *ApiController) GetCasbinAdapters() {
sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization")
if limit == "" || page == "" {
adapters := object.GetCasbinAdapters(owner, organization)
adapters, err := object.GetCasbinAdapters(owner, organization)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(adapters)
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetCasbinAdapterCount(owner, organization, field, value)))
adapters := object.GetPaginationCasbinAdapters(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetCasbinAdapterCount(owner, organization, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
adapters, err := object.GetPaginationCasbinAdapters(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(adapters, paginator.Nums())
}
}
func (c *ApiController) GetCasbinAdapter() {
id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id)
adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(adapter)
}
@ -89,7 +110,11 @@ func (c *ApiController) DeleteCasbinAdapter() {
func (c *ApiController) SyncPolicies() {
id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id)
adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
policies, err := object.SyncPolicies(adapter)
if err != nil {
@ -102,9 +127,14 @@ func (c *ApiController) SyncPolicies() {
func (c *ApiController) UpdatePolicy() {
id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id)
adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
var policies []xormadapter.CasbinRule
err := json.Unmarshal(c.Ctx.Input.RequestBody, &policies)
err = json.Unmarshal(c.Ctx.Input.RequestBody, &policies)
if err != nil {
c.ResponseError(err.Error())
return
@ -121,9 +151,14 @@ func (c *ApiController) UpdatePolicy() {
func (c *ApiController) AddPolicy() {
id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id)
adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
var policy xormadapter.CasbinRule
err := json.Unmarshal(c.Ctx.Input.RequestBody, &policy)
err = json.Unmarshal(c.Ctx.Input.RequestBody, &policy)
if err != nil {
c.ResponseError(err.Error())
return
@ -140,9 +175,14 @@ func (c *ApiController) AddPolicy() {
func (c *ApiController) RemovePolicy() {
id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id)
adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
var policy xormadapter.CasbinRule
err := json.Unmarshal(c.Ctx.Input.RequestBody, &policy)
err = json.Unmarshal(c.Ctx.Input.RequestBody, &policy)
if err != nil {
c.ResponseError(err.Error())
return

View File

@ -37,13 +37,28 @@ func (c *ApiController) GetCerts() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedCerts(object.GetCerts(owner))
maskedCerts, err := object.GetMaskedCerts(object.GetCerts(owner))
if err != nil {
panic(err)
}
c.Data["json"] = maskedCerts
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetCertCount(owner, field, value)))
certs := object.GetMaskedCerts(object.GetPaginationCerts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
count, err := object.GetCertCount(owner, field, value)
if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
certs, err := object.GetMaskedCerts(object.GetPaginationCerts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
if err != nil {
panic(err)
}
c.ResponseOk(certs, paginator.Nums())
}
}
@ -61,13 +76,28 @@ func (c *ApiController) GetGlobleCerts() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedCerts(object.GetGlobleCerts())
maskedCerts, err := object.GetMaskedCerts(object.GetGlobleCerts())
if err != nil {
panic(err)
}
c.Data["json"] = maskedCerts
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalCertsCount(field, value)))
certs := object.GetMaskedCerts(object.GetPaginationGlobalCerts(paginator.Offset(), limit, field, value, sortField, sortOrder))
count, err := object.GetGlobalCertsCount(field, value)
if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
certs, err := object.GetMaskedCerts(object.GetPaginationGlobalCerts(paginator.Offset(), limit, field, value, sortField, sortOrder))
if err != nil {
panic(err)
}
c.ResponseOk(certs, paginator.Nums())
}
}
@ -81,8 +111,12 @@ func (c *ApiController) GetGlobleCerts() {
// @router /get-cert [get]
func (c *ApiController) GetCert() {
id := c.Input().Get("id")
cert, err := object.GetCert(id)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedCert(object.GetCert(id))
c.Data["json"] = object.GetMaskedCert(cert)
c.ServeJSON()
}

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetChats() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedChats(object.GetChats(owner))
maskedChats, err := object.GetMaskedChats(object.GetChats(owner))
if err != nil {
panic(err)
}
c.Data["json"] = maskedChats
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetChatCount(owner, field, value)))
chats := object.GetMaskedChats(object.GetPaginationChats(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
count, err := object.GetChatCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
chats, err := object.GetMaskedChats(object.GetPaginationChats(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(chats, paginator.Nums())
}
}
@ -58,7 +75,12 @@ func (c *ApiController) GetChats() {
func (c *ApiController) GetChat() {
id := c.Input().Get("id")
c.Data["json"] = object.GetMaskedChat(object.GetChat(id))
maskedChat, err := object.GetMaskedChat(object.GetChat(id))
if err != nil {
panic(err)
}
c.Data["json"] = maskedChat
c.ServeJSON()
}

View File

@ -34,8 +34,7 @@ func (c *ApiController) Enforce() {
}
if permissionId != "" {
c.Data["json"] = object.Enforce(permissionId, &request)
c.ServeJSON()
c.ResponseOk(object.Enforce(permissionId, &request))
return
}
@ -44,9 +43,15 @@ func (c *ApiController) Enforce() {
if modelId != "" {
owner, modelName := util.GetOwnerAndNameFromId(modelId)
permissions = object.GetPermissionsByModel(owner, modelName)
permissions, err = object.GetPermissionsByModel(owner, modelName)
if err != nil {
panic(err)
}
} else {
permissions = object.GetPermissionsByResource(resourceId)
permissions, err = object.GetPermissionsByResource(resourceId)
if err != nil {
panic(err)
}
}
for _, permission := range permissions {
@ -63,8 +68,7 @@ func (c *ApiController) BatchEnforce() {
var requests []object.CasbinRequest
err := json.Unmarshal(c.Ctx.Input.RequestBody, &requests)
if err != nil {
c.ResponseError(err.Error())
return
panic(err)
}
if permissionId != "" {
@ -72,14 +76,17 @@ func (c *ApiController) BatchEnforce() {
c.ServeJSON()
} else {
owner, modelName := util.GetOwnerAndNameFromId(modelId)
permissions := object.GetPermissionsByModel(owner, modelName)
permissions, err := object.GetPermissionsByModel(owner, modelName)
if err != nil {
panic(err)
}
res := [][]bool{}
for _, permission := range permissions {
res = append(res, object.BatchEnforce(permission.GetId(), &requests))
}
c.Data["json"] = res
c.ServeJSON()
c.ResponseOk(res)
}
}
@ -90,8 +97,7 @@ func (c *ApiController) GetAllObjects() {
return
}
c.Data["json"] = object.GetAllObjects(userId)
c.ServeJSON()
c.ResponseOk(object.GetAllObjects(userId))
}
func (c *ApiController) GetAllActions() {
@ -101,8 +107,7 @@ func (c *ApiController) GetAllActions() {
return
}
c.Data["json"] = object.GetAllActions(userId)
c.ServeJSON()
c.ResponseOk(object.GetAllActions(userId))
}
func (c *ApiController) GetAllRoles() {
@ -112,6 +117,5 @@ func (c *ApiController) GetAllRoles() {
return
}
c.Data["json"] = object.GetAllRoles(userId)
c.ServeJSON()
c.ResponseOk(object.GetAllRoles(userId))
}

View File

@ -45,7 +45,11 @@ func (c *ApiController) GetLdapUsers() {
id := c.Input().Get("id")
_, ldapId := util.GetOwnerAndNameFromId(id)
ldapServer := object.GetLdap(ldapId)
ldapServer, err := object.GetLdap(ldapId)
if err != nil {
c.ResponseError(err.Error())
return
}
conn, err := ldapServer.GetLdapConn()
if err != nil {
@ -76,7 +80,11 @@ func (c *ApiController) GetLdapUsers() {
for i, user := range users {
uuids[i] = user.GetLdapUuid()
}
existUuids := object.GetExistUuids(ldapServer.Owner, uuids)
existUuids, err := object.GetExistUuids(ldapServer.Owner, uuids)
if err != nil {
c.ResponseError(err.Error())
return
}
resp := LdapResp{
Users: object.AutoAdjustLdapUser(users),
@ -128,17 +136,23 @@ func (c *ApiController) AddLdap() {
return
}
if object.CheckLdapExist(&ldap) {
if ok, err := object.CheckLdapExist(&ldap); err != nil {
c.ResponseError(err.Error())
return
} else if ok {
c.ResponseError(c.T("ldap:Ldap server exist"))
return
}
affected := object.AddLdap(&ldap)
resp := wrapActionResponse(affected)
resp := wrapActionResponse(object.AddLdap(&ldap))
resp.Data2 = ldap
if ldap.AutoSync != 0 {
object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id)
err = object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
}
c.Data["json"] = resp
@ -157,11 +171,24 @@ func (c *ApiController) UpdateLdap() {
return
}
prevLdap := object.GetLdap(ldap.Id)
affected := object.UpdateLdap(&ldap)
prevLdap, err := object.GetLdap(ldap.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
affected, err := object.UpdateLdap(&ldap)
if err != nil {
c.ResponseError(err.Error())
return
}
if ldap.AutoSync != 0 {
object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id)
err := object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
} else if ldap.AutoSync == 0 && prevLdap.AutoSync != 0 {
object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id)
}
@ -182,7 +209,11 @@ func (c *ApiController) DeleteLdap() {
return
}
affected := object.DeleteLdap(&ldap)
affected, err := object.DeleteLdap(&ldap)
if err != nil {
c.ResponseError(err.Error())
return
}
object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id)
@ -204,7 +235,11 @@ func (c *ApiController) SyncLdapUsers() {
return
}
object.UpdateLdapSyncTime(ldapId)
err = object.UpdateLdapSyncTime(ldapId)
if err != nil {
c.ResponseError(err.Error())
return
}
exist, failed, _ := object.SyncLdapUsers(owner, users, ldapId)

View File

@ -53,7 +53,12 @@ func (c *ApiController) Unlink() {
if user.Id == unlinkedUser.Id && !user.IsGlobalAdmin {
// if the user is unlinking themselves, should check the provider can be unlinked, if not, we should return an error.
application := object.GetApplicationByUser(user)
application, err := object.GetApplicationByUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil {
c.ResponseError(c.T("link:You can't unlink yourself, you are not a member of any application"))
return
@ -88,8 +93,17 @@ func (c *ApiController) Unlink() {
return
}
object.ClearUserOAuthProperties(&unlinkedUser, providerType)
_, err = object.ClearUserOAuthProperties(&unlinkedUser, providerType)
if err != nil {
c.ResponseError(err.Error())
return
}
_, err = object.LinkUserAccount(&unlinkedUser, providerType, "")
if err != nil {
c.ResponseError(err.Error())
return
}
object.LinkUserAccount(&unlinkedUser, providerType, "")
c.ResponseOk()
}

View File

@ -44,18 +44,35 @@ func (c *ApiController) GetMessages() {
organization := c.Input().Get("organization")
if limit == "" || page == "" {
var messages []*object.Message
var err error
if chat == "" {
messages = object.GetMessages(owner)
messages, err = object.GetMessages(owner)
} else {
messages = object.GetChatMessages(chat)
messages, err = object.GetChatMessages(chat)
}
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedMessages(messages)
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetMessageCount(owner, organization, field, value)))
messages := object.GetMaskedMessages(object.GetPaginationMessages(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder))
count, err := object.GetMessageCount(owner, organization, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
paginationMessages, err := object.GetPaginationMessages(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
messages := object.GetMaskedMessages(paginationMessages)
c.ResponseOk(messages, paginator.Nums())
}
}
@ -69,8 +86,12 @@ func (c *ApiController) GetMessages() {
// @router /get-message [get]
func (c *ApiController) GetMessage() {
id := c.Input().Get("id")
message, err := object.GetMessage(id)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedMessage(object.GetMessage(id))
c.Data["json"] = object.GetMaskedMessage(message)
c.ServeJSON()
}
@ -96,7 +117,12 @@ func (c *ApiController) GetMessageAnswer() {
c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
message := object.GetMessage(id)
message, err := object.GetMessage(id)
if err != nil {
c.ResponseError(err.Error())
return
}
if message == nil {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id))
return
@ -108,7 +134,12 @@ func (c *ApiController) GetMessageAnswer() {
}
chatId := util.GetId("admin", message.Chat)
chat := object.GetChat(chatId)
chat, err := object.GetChat(chatId)
if err != nil {
c.ResponseError(err.Error())
return
}
if chat == nil || chat.Organization != message.Organization {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId))
return
@ -119,14 +150,19 @@ func (c *ApiController) GetMessageAnswer() {
return
}
questionMessage := object.GetMessage(message.ReplyTo)
questionMessage, err := object.GetMessage(message.ReplyTo)
if questionMessage == nil {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id))
return
}
providerId := util.GetId(chat.Owner, chat.User2)
provider := object.GetProvider(providerId)
provider, err := object.GetProvider(providerId)
if err != nil {
c.ResponseError(err.Error())
return
}
if provider == nil {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId))
return
@ -148,7 +184,7 @@ func (c *ApiController) GetMessageAnswer() {
fmt.Printf("Question: [%s]\n", questionMessage.Text)
fmt.Printf("Answer: [")
err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder)
err = ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder)
if err != nil {
c.ResponseErrorStream(err.Error())
return
@ -165,7 +201,10 @@ func (c *ApiController) GetMessageAnswer() {
answer := stringBuilder.String()
message.Text = answer
object.UpdateMessage(message.GetId(), message)
_, err = object.UpdateMessage(message.GetId(), message)
if err != nil {
panic(err)
}
}
// UpdateMessage
@ -208,14 +247,24 @@ func (c *ApiController) AddMessage() {
var chat *object.Chat
if message.Chat != "" {
chatId := util.GetId("admin", message.Chat)
chat = object.GetChat(chatId)
chat, err = object.GetChat(chatId)
if err != nil {
c.ResponseError(err.Error())
return
}
if chat == nil || chat.Organization != message.Organization {
c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId))
return
}
}
affected := object.AddMessage(&message)
affected, err := object.AddMessage(&message)
if err != nil {
c.ResponseError(err.Error())
return
}
if affected {
if chat != nil && chat.Type == "AI" {
answerMessage := &object.Message{
@ -228,7 +277,11 @@ func (c *ApiController) AddMessage() {
Author: "AI",
Text: "",
}
object.AddMessage(answerMessage)
_, err = object.AddMessage(answerMessage)
if err != nil {
c.ResponseError(err.Error())
return
}
}
}

View File

@ -46,7 +46,12 @@ func (c *ApiController) MfaSetupInitiate() {
if MfaUtil == nil {
c.ResponseError("Invalid auth type")
}
user := object.GetUser(userId)
user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError("User doesn't exist")
return
@ -105,14 +110,19 @@ func (c *ApiController) MfaSetupEnable() {
name := c.Ctx.Request.Form.Get("name")
authType := c.Ctx.Request.Form.Get("type")
user := object.GetUser(util.GetId(owner, name))
user, err := object.GetUser(util.GetId(owner, name))
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError("User doesn't exist")
return
}
twoFactor := object.GetMfaUtil(authType, nil)
err := twoFactor.Enable(c.Ctx, user)
err = twoFactor.Enable(c.Ctx, user)
if err != nil {
c.ResponseError(err.Error())
return
@ -136,7 +146,12 @@ func (c *ApiController) DeleteMfa() {
name := c.Ctx.Request.Form.Get("name")
userId := util.GetId(owner, name)
user := object.GetUser(userId)
user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError("User doesn't exist")
return
@ -151,7 +166,12 @@ func (c *ApiController) DeleteMfa() {
}
}
user.MultiFactorAuths = mfaProps
object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser())
_, err = object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser())
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(user.MultiFactorAuths)
}
@ -170,7 +190,12 @@ func (c *ApiController) SetPreferredMfa() {
name := c.Ctx.Request.Form.Get("name")
userId := util.GetId(owner, name)
user := object.GetUser(userId)
user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError("User doesn't exist")
return
@ -185,7 +210,11 @@ func (c *ApiController) SetPreferredMfa() {
}
}
object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser())
_, err = object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser())
if err != nil {
c.ResponseError(err.Error())
return
}
for i, mfaProp := range mfaProps {
mfaProps[i] = object.GetMaskedProps(mfaProp)

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetModels() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetModels(owner)
models, err := object.GetModels(owner)
if err != nil {
panic(err)
}
c.Data["json"] = models
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetModelCount(owner, field, value)))
models := object.GetPaginationModels(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetModelCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
models, err := object.GetPaginationModels(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(models, paginator.Nums())
}
}
@ -58,7 +75,12 @@ func (c *ApiController) GetModels() {
func (c *ApiController) GetModel() {
id := c.Input().Get("id")
c.Data["json"] = object.GetModel(id)
model, err := object.GetModel(id)
if err != nil {
panic(err)
}
c.Data["json"] = model
c.ServeJSON()
}

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetOrganizations() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedOrganizations(object.GetOrganizations(owner))
maskedOrganizations, err := object.GetMaskedOrganizations(object.GetOrganizations(owner))
if err != nil {
panic(err)
}
c.Data["json"] = maskedOrganizations
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetOrganizationCount(owner, field, value)))
organizations := object.GetMaskedOrganizations(object.GetPaginationOrganizations(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
count, err := object.GetOrganizationCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
organizations, err := object.GetMaskedOrganizations(object.GetPaginationOrganizations(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(organizations, paginator.Nums())
}
}
@ -58,7 +75,12 @@ func (c *ApiController) GetOrganizations() {
func (c *ApiController) GetOrganization() {
id := c.Input().Get("id")
c.Data["json"] = object.GetMaskedOrganization(object.GetOrganization(id))
maskedOrganization, err := object.GetMaskedOrganization(object.GetOrganization(id))
if err != nil {
panic(err)
}
c.Data["json"] = maskedOrganization
c.ServeJSON()
}
@ -99,8 +121,13 @@ func (c *ApiController) AddOrganization() {
return
}
count := object.GetOrganizationCount("", "", "")
if err := checkQuotaForOrganization(count); err != nil {
count, err := object.GetOrganizationCount("", "", "")
if err != nil {
c.ResponseError(err.Error())
return
}
if err = checkQuotaForOrganization(int(count)); err != nil {
c.ResponseError(err.Error())
return
}
@ -158,6 +185,11 @@ func (c *ApiController) GetDefaultApplication() {
// @router /get-organization-names [get]
func (c *ApiController) GetOrganizationNames() {
owner := c.Input().Get("owner")
organizationNames := object.GetOrganizationsByFields(owner, "name")
organizationNames, err := object.GetOrganizationsByFields(owner, "name")
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(organizationNames)
}

View File

@ -37,13 +37,28 @@ func (c *ApiController) GetPayments() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetPayments(owner)
payments, err := object.GetPayments(owner)
if err != nil {
panic(err)
}
c.Data["json"] = payments
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPaymentCount(owner, field, value)))
payments := object.GetPaginationPayments(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetPaymentCount(owner, field, value)
if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
payments, err := object.GetPaginationPayments(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
panic(err)
}
c.ResponseOk(payments, paginator.Nums())
}
}
@ -62,7 +77,12 @@ func (c *ApiController) GetUserPayments() {
organization := c.Input().Get("organization")
user := c.Input().Get("user")
payments := object.GetUserPayments(owner, organization, user)
payments, err := object.GetUserPayments(owner, organization, user)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(payments)
}
@ -76,7 +96,12 @@ func (c *ApiController) GetUserPayments() {
func (c *ApiController) GetPayment() {
id := c.Input().Get("id")
c.Data["json"] = object.GetPayment(id)
payment, err := object.GetPayment(id)
if err != nil {
panic(err)
}
c.Data["json"] = payment
c.ServeJSON()
}
@ -177,7 +202,12 @@ func (c *ApiController) NotifyPayment() {
func (c *ApiController) InvoicePayment() {
id := c.Input().Get("id")
payment := object.GetPayment(id)
payment, err := object.GetPayment(id)
if err != nil {
c.ResponseError(err.Error())
return
}
invoiceUrl, err := object.InvoicePayment(payment)
if err != nil {
c.ResponseError(err.Error())

View File

@ -37,13 +37,28 @@ func (c *ApiController) GetPermissions() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetPermissions(owner)
permissions, err := object.GetPermissions(owner)
if err != nil {
panic(err)
}
c.Data["json"] = permissions
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPermissionCount(owner, field, value)))
permissions := object.GetPaginationPermissions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetPermissionCount(owner, field, value)
if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
permissions, err := object.GetPaginationPermissions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
panic(err)
}
c.ResponseOk(permissions, paginator.Nums())
}
}
@ -60,7 +75,12 @@ func (c *ApiController) GetPermissionsBySubmitter() {
return
}
permissions := object.GetPermissionsBySubmitter(user.Owner, user.Name)
permissions, err := object.GetPermissionsBySubmitter(user.Owner, user.Name)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(permissions, len(permissions))
return
}
@ -74,7 +94,12 @@ func (c *ApiController) GetPermissionsBySubmitter() {
// @router /get-permissions-by-role [get]
func (c *ApiController) GetPermissionsByRole() {
id := c.Input().Get("id")
permissions := object.GetPermissionsByRole(id)
permissions, err := object.GetPermissionsByRole(id)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(permissions, len(permissions))
return
}
@ -89,7 +114,12 @@ func (c *ApiController) GetPermissionsByRole() {
func (c *ApiController) GetPermission() {
id := c.Input().Get("id")
c.Data["json"] = object.GetPermission(id)
permission, err := object.GetPermission(id)
if err != nil {
panic(err)
}
c.Data["json"] = permission
c.ServeJSON()
}

View File

@ -41,7 +41,11 @@ func (c *ApiController) UploadPermissions() {
return
}
affected := object.UploadPermissions(owner, fileId)
affected, err := object.UploadPermissions(owner, fileId)
if err != nil {
c.ResponseError(err.Error())
}
if affected {
c.ResponseOk()
} else {

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetPlans() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetPlans(owner)
plans, err := object.GetPlans(owner)
if err != nil {
panic(err)
}
c.Data["json"] = plans
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPlanCount(owner, field, value)))
plan := object.GetPaginatedPlans(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetPlanCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
plan, err := object.GetPaginatedPlans(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(plan, paginator.Nums())
}
}
@ -60,10 +77,16 @@ func (c *ApiController) GetPlan() {
id := c.Input().Get("id")
includeOption := c.Input().Get("includeOption") == "true"
plan := object.GetPlan(id)
plan, err := object.GetPlan(id)
if err != nil {
panic(err)
}
if includeOption {
options := object.GetPermissionsByRole(plan.Role)
options, err := object.GetPermissionsByRole(plan.Role)
if err != nil {
panic(err)
}
for _, option := range options {
plan.Options = append(plan.Options, option.DisplayName)

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetPricings() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetPricings(owner)
pricings, err := object.GetPricings(owner)
if err != nil {
panic(err)
}
c.Data["json"] = pricings
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPricingCount(owner, field, value)))
pricing := object.GetPaginatedPricings(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetPricingCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
pricing, err := object.GetPaginatedPricings(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(pricing, paginator.Nums())
}
}
@ -58,7 +75,10 @@ func (c *ApiController) GetPricings() {
func (c *ApiController) GetPricing() {
id := c.Input().Get("id")
pricing := object.GetPricing(id)
pricing, err := object.GetPricing(id)
if err != nil {
panic(err)
}
c.Data["json"] = pricing
c.ServeJSON()

View File

@ -38,13 +38,30 @@ func (c *ApiController) GetProducts() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetProducts(owner)
products, err := object.GetProducts(owner)
if err != nil {
panic(err)
}
c.Data["json"] = products
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetProductCount(owner, field, value)))
products := object.GetPaginationProducts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetProductCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
products, err := object.GetPaginationProducts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(products, paginator.Nums())
}
}
@ -59,8 +76,15 @@ func (c *ApiController) GetProducts() {
func (c *ApiController) GetProduct() {
id := c.Input().Get("id")
product := object.GetProduct(id)
object.ExtendProductWithProviders(product)
product, err := object.GetProduct(id)
if err != nil {
panic(err)
}
err = object.ExtendProductWithProviders(product)
if err != nil {
panic(err)
}
c.Data["json"] = product
c.ServeJSON()
@ -145,7 +169,12 @@ func (c *ApiController) BuyProduct() {
return
}
user := object.GetUser(userId)
user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), userId))
return

View File

@ -44,12 +44,29 @@ func (c *ApiController) GetProviders() {
}
if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedProviders(object.GetProviders(owner), isMaskEnabled)
providers, err := object.GetProviders(owner)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedProviders(providers, isMaskEnabled)
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetProviderCount(owner, field, value)))
providers := object.GetMaskedProviders(object.GetPaginationProviders(owner, paginator.Offset(), limit, field, value, sortField, sortOrder), isMaskEnabled)
count, err := object.GetProviderCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
paginationProviders, err := object.GetPaginationProviders(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
providers := object.GetMaskedProviders(paginationProviders, isMaskEnabled)
c.ResponseOk(providers, paginator.Nums())
}
}
@ -74,12 +91,29 @@ func (c *ApiController) GetGlobalProviders() {
}
if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedProviders(object.GetGlobalProviders(), isMaskEnabled)
globalProviders, err := object.GetGlobalProviders()
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedProviders(globalProviders, isMaskEnabled)
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalProviderCount(field, value)))
providers := object.GetMaskedProviders(object.GetPaginationGlobalProviders(paginator.Offset(), limit, field, value, sortField, sortOrder), isMaskEnabled)
count, err := object.GetGlobalProviderCount(field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
paginationGlobalProviders, err := object.GetPaginationGlobalProviders(paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
providers := object.GetMaskedProviders(paginationGlobalProviders, isMaskEnabled)
c.ResponseOk(providers, paginator.Nums())
}
}
@ -98,8 +132,13 @@ func (c *ApiController) GetProvider() {
if !ok {
return
}
provider, err := object.GetProvider(id)
if err != nil {
c.ResponseError(err.Error())
return
}
c.Data["json"] = object.GetMaskedProvider(object.GetProvider(id), isMaskEnabled)
c.Data["json"] = object.GetMaskedProvider(provider, isMaskEnabled)
c.ServeJSON()
}
@ -140,8 +179,13 @@ func (c *ApiController) AddProvider() {
return
}
count := object.GetProviderCount("", "", "")
if err := checkQuotaForProvider(count); err != nil {
count, err := object.GetProviderCount("", "", "")
if err != nil {
c.ResponseError(err.Error())
return
}
if err := checkQuotaForProvider(int(count)); err != nil {
c.ResponseError(err.Error())
return
}

View File

@ -42,14 +42,31 @@ func (c *ApiController) GetRecords() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetRecords()
records, err := object.GetRecords()
if err != nil {
panic(err)
}
c.Data["json"] = records
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
filterRecord := &object.Record{Organization: organization}
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetRecordCount(field, value, filterRecord)))
records := object.GetPaginationRecords(paginator.Offset(), limit, field, value, sortField, sortOrder, filterRecord)
count, err := object.GetRecordCount(field, value, filterRecord)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
records, err := object.GetPaginationRecords(paginator.Offset(), limit, field, value, sortField, sortOrder, filterRecord)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(records, paginator.Nums())
}
}
@ -67,11 +84,15 @@ func (c *ApiController) GetRecordsByFilter() {
record := &object.Record{}
err := util.JsonToStruct(body, record)
if err != nil {
c.ResponseError(err.Error())
return
panic(err)
}
c.Data["json"] = object.GetRecordsByField(record)
records, err := object.GetRecordsByField(record)
if err != nil {
panic(err)
}
c.Data["json"] = records
c.ServeJSON()
}

View File

@ -51,12 +51,28 @@ func (c *ApiController) GetResources() {
}
if limit == "" || page == "" {
c.Data["json"] = object.GetResources(owner, user)
resources, err := object.GetResources(owner, user)
if err != nil {
panic(err)
}
c.Data["json"] = resources
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetResourceCount(owner, user, field, value)))
resources := object.GetPaginationResources(owner, user, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetResourceCount(owner, user, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
resources, err := object.GetPaginationResources(owner, user, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(resources, paginator.Nums())
}
}
@ -68,7 +84,12 @@ func (c *ApiController) GetResources() {
func (c *ApiController) GetResource() {
id := c.Input().Get("id")
c.Data["json"] = object.GetResource(id)
resource, err := object.GetResource(id)
if err != nil {
panic(err)
}
c.Data["json"] = resource
c.ServeJSON()
}
@ -187,7 +208,10 @@ func (c *ApiController) UploadResource() {
index := len(fullFilePath) - len(ext)
for i := 1; ; i++ {
_, objectKey := object.GetUploadFileUrl(provider, fullFilePath, true)
if object.GetResourceCount(owner, username, "name", objectKey) == 0 {
if count, err := object.GetResourceCount(owner, username, "name", objectKey); err != nil {
c.ResponseError(err.Error())
return
} else if count == 0 {
break
}
@ -223,20 +247,39 @@ func (c *ApiController) UploadResource() {
Url: fileUrl,
Description: description,
}
object.AddOrUpdateResource(resource)
_, err = object.AddOrUpdateResource(resource)
if err != nil {
c.ResponseError(err.Error())
return
}
switch tag {
case "avatar":
user := object.GetUserNoCheck(util.GetId(owner, username))
user, err := object.GetUserNoCheck(util.GetId(owner, username))
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError(c.T("resource:User is nil for tag: avatar"))
return
}
user.Avatar = fileUrl
object.UpdateUser(user.GetId(), user, []string{"avatar"}, false)
_, err = object.UpdateUser(user.GetId(), user, []string{"avatar"}, false)
if err != nil {
c.ResponseError(err.Error())
return
}
case "termsOfUse":
user := object.GetUserNoCheck(util.GetId(owner, username))
user, err := object.GetUserNoCheck(util.GetId(owner, username))
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(owner, username)))
return
@ -248,9 +291,18 @@ func (c *ApiController) UploadResource() {
}
_, applicationId := util.GetOwnerAndNameFromIdNoCheck(strings.TrimRight(fullFilePath, ".html"))
applicationObj := object.GetApplication(applicationId)
applicationObj, err := object.GetApplication(applicationId)
if err != nil {
c.ResponseError(err.Error())
return
}
applicationObj.TermsOfUse = fileUrl
object.UpdateApplication(applicationId, applicationObj)
_, err = object.UpdateApplication(applicationId, applicationObj)
if err != nil {
c.ResponseError(err.Error())
return
}
}
c.ResponseOk(fileUrl, objectKey)

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetRoles() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetRoles(owner)
roles, err := object.GetRoles(owner)
if err != nil {
panic(err)
}
c.Data["json"] = roles
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetRoleCount(owner, field, value)))
roles := object.GetPaginationRoles(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetRoleCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
roles, err := object.GetPaginationRoles(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(roles, paginator.Nums())
}
}
@ -58,7 +75,12 @@ func (c *ApiController) GetRoles() {
func (c *ApiController) GetRole() {
id := c.Input().Get("id")
c.Data["json"] = object.GetRole(id)
role, err := object.GetRole(id)
if err != nil {
panic(err)
}
c.Data["json"] = role
c.ServeJSON()
}

View File

@ -41,7 +41,11 @@ func (c *ApiController) UploadRoles() {
return
}
affected := object.UploadRoles(owner, fileId)
affected, err := object.UploadRoles(owner, fileId)
if err != nil {
c.ResponseError(err.Error())
}
if affected {
c.ResponseOk()
} else {

View File

@ -23,7 +23,12 @@ import (
func (c *ApiController) GetSamlMeta() {
host := c.Ctx.Request.Host
paramApp := c.Input().Get("application")
application := object.GetApplication(paramApp)
application, err := object.GetApplication(paramApp)
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil {
c.ResponseError(fmt.Sprintf(c.T("saml:Application %s not found"), paramApp))
return

View File

@ -61,7 +61,12 @@ func (c *ApiController) SendEmail() {
var provider *object.Provider
if emailForm.Provider != "" {
// called by frontend's TestEmailWidget, provider name is set by frontend
provider = object.GetProvider(util.GetId("admin", emailForm.Provider))
provider, err = object.GetProvider(util.GetId("admin", emailForm.Provider))
if err != nil {
c.ResponseError(err.Error())
return
}
} else {
// called by Casdoor SDK via Client ID & Client Secret, so the used Email provider will be the application' Email provider or the default Email provider
var ok bool

View File

@ -37,13 +37,29 @@ func (c *ApiController) GetSessions() {
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
owner := c.Input().Get("owner")
if limit == "" || page == "" {
c.Data["json"] = object.GetSessions(owner)
sessions, err := object.GetSessions(owner)
if err != nil {
panic(err)
}
c.Data["json"] = sessions
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSessionCount(owner, field, value)))
sessions := object.GetPaginationSessions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetSessionCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
sessions, err := object.GetPaginationSessions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(sessions, paginator.Nums())
}
}
@ -58,7 +74,12 @@ func (c *ApiController) GetSessions() {
func (c *ApiController) GetSingleSession() {
id := c.Input().Get("sessionPkId")
c.Data["json"] = object.GetSingleSession(id)
session, err := object.GetSingleSession(id)
if err != nil {
panic(err)
}
c.Data["json"] = session
c.ServeJSON()
}
@ -132,7 +153,11 @@ func (c *ApiController) IsSessionDuplicated() {
id := c.Input().Get("sessionPkId")
sessionId := c.Input().Get("sessionId")
isUserSessionDuplicated := object.IsSessionDuplicated(id, sessionId)
isUserSessionDuplicated, err := object.IsSessionDuplicated(id, sessionId)
if err != nil {
panic(err)
}
c.Data["json"] = &Response{Status: "ok", Msg: "", Data: isUserSessionDuplicated}
c.ServeJSON()

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetSubscriptions() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetSubscriptions(owner)
subscriptions, err := object.GetSubscriptions(owner)
if err != nil {
panic(err)
}
c.Data["json"] = subscriptions
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSubscriptionCount(owner, field, value)))
subscription := object.GetPaginationSubscriptions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetSubscriptionCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
subscription, err := object.GetPaginationSubscriptions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(subscription, paginator.Nums())
}
}
@ -58,7 +75,10 @@ func (c *ApiController) GetSubscriptions() {
func (c *ApiController) GetSubscription() {
id := c.Input().Get("id")
subscription := object.GetSubscription(id)
subscription, err := object.GetSubscription(id)
if err != nil {
panic(err)
}
c.Data["json"] = subscription
c.ServeJSON()

View File

@ -38,13 +38,30 @@ func (c *ApiController) GetSyncers() {
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization")
if limit == "" || page == "" {
c.Data["json"] = object.GetOrganizationSyncers(owner, organization)
organizationSyncers, err := object.GetOrganizationSyncers(owner, organization)
if err != nil {
panic(err)
}
c.Data["json"] = organizationSyncers
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSyncerCount(owner, organization, field, value)))
syncers := object.GetPaginationSyncers(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetSyncerCount(owner, organization, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
syncers, err := object.GetPaginationSyncers(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(syncers, paginator.Nums())
}
}
@ -59,7 +76,12 @@ func (c *ApiController) GetSyncers() {
func (c *ApiController) GetSyncer() {
id := c.Input().Get("id")
c.Data["json"] = object.GetSyncer(id)
syncer, err := object.GetSyncer(id)
if err != nil {
panic(err)
}
c.Data["json"] = syncer
c.ServeJSON()
}
@ -132,7 +154,11 @@ func (c *ApiController) DeleteSyncer() {
// @router /run-syncer [get]
func (c *ApiController) RunSyncer() {
id := c.Input().Get("id")
syncer := object.GetSyncer(id)
syncer, err := object.GetSyncer(id)
if err != nil {
c.ResponseError(err.Error())
return
}
object.RunSyncer(syncer)

View File

@ -41,12 +41,28 @@ func (c *ApiController) GetTokens() {
sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization")
if limit == "" || page == "" {
c.Data["json"] = object.GetTokens(owner, organization)
token, err := object.GetTokens(owner, organization)
if err != nil {
panic(err)
}
c.Data["json"] = token
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetTokenCount(owner, organization, field, value)))
tokens := object.GetPaginationTokens(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetTokenCount(owner, organization, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
tokens, err := object.GetPaginationTokens(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(tokens, paginator.Nums())
}
}
@ -60,8 +76,12 @@ func (c *ApiController) GetTokens() {
// @router /get-token [get]
func (c *ApiController) GetToken() {
id := c.Input().Get("id")
token, err := object.GetToken(id)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetToken(id)
c.Data["json"] = token
c.ServeJSON()
}
@ -171,8 +191,12 @@ func (c *ApiController) GetOAuthToken() {
}
}
host := c.Ctx.Request.Host
oAuthtoken, err := object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier, scope, username, password, host, refreshToken, tag, avatar, c.GetAcceptLanguage())
if err != nil {
panic(err)
}
c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier, scope, username, password, host, refreshToken, tag, avatar, c.GetAcceptLanguage())
c.Data["json"] = oAuthtoken
c.SetTokenErrorHttpStatus()
c.ServeJSON()
}
@ -210,7 +234,12 @@ func (c *ApiController) RefreshToken() {
}
}
c.Data["json"] = object.RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host)
refreshToken2, err := object.RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host)
if err != nil {
panic(err)
}
c.Data["json"] = refreshToken2
c.SetTokenErrorHttpStatus()
c.ServeJSON()
}
@ -245,7 +274,11 @@ func (c *ApiController) IntrospectToken() {
return
}
}
application := object.GetApplicationByClientId(clientId)
application, err := object.GetApplicationByClientId(clientId)
if err != nil {
panic(err)
}
if application == nil || application.ClientSecret != clientSecret {
c.ResponseError(c.T("token:Invalid application or wrong clientSecret"))
c.Data["json"] = &object.TokenError{
@ -254,7 +287,11 @@ func (c *ApiController) IntrospectToken() {
c.SetTokenErrorHttpStatus()
return
}
token := object.GetTokenByTokenAndApplication(tokenValue, application.Name)
token, err := object.GetTokenByTokenAndApplication(tokenValue, application.Name)
if err != nil {
panic(err)
}
if token == nil {
c.Data["json"] = &object.IntrospectionResponse{Active: false}
c.ServeJSON()

View File

@ -37,14 +37,36 @@ func (c *ApiController) GetGlobalUsers() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedUsers(object.GetGlobalUsers())
maskedUsers, err := object.GetMaskedUsers(object.GetGlobalUsers())
if err != nil {
panic(err)
}
c.Data["json"] = maskedUsers
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalUserCount(field, value)))
users := object.GetPaginationGlobalUsers(paginator.Offset(), limit, field, value, sortField, sortOrder)
users = object.GetMaskedUsers(users)
count, err := object.GetGlobalUserCount(field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
users, err := object.GetPaginationGlobalUsers(paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
users, err = object.GetMaskedUsers(users)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(users, paginator.Nums())
}
}
@ -64,14 +86,36 @@ func (c *ApiController) GetUsers() {
value := c.Input().Get("value")
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedUsers(object.GetUsers(owner))
maskedUsers, err := object.GetMaskedUsers(object.GetUsers(owner))
if err != nil {
panic(err)
}
c.Data["json"] = maskedUsers
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetUserCount(owner, field, value)))
users := object.GetPaginationUsers(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
users = object.GetMaskedUsers(users)
count, err := object.GetUserCount(owner, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
users, err := object.GetPaginationUsers(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
users, err = object.GetMaskedUsers(users)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(users, paginator.Nums())
}
}
@ -93,10 +137,14 @@ func (c *ApiController) GetUser() {
phone := c.Input().Get("phone")
userId := c.Input().Get("userId")
owner := c.Input().Get("owner")
var err error
var userFromUserId *object.User
if userId != "" && owner != "" {
userFromUserId = object.GetUserByUserId(owner, userId)
userFromUserId, err = object.GetUserByUserId(owner, userId)
if err != nil {
panic(err)
}
id = util.GetId(userFromUserId.Owner, userFromUserId.Name)
}
@ -104,7 +152,11 @@ func (c *ApiController) GetUser() {
owner = util.GetOwnerFromId(id)
}
organization := object.GetOrganization(util.GetId("admin", owner))
organization, err := object.GetOrganization(util.GetId("admin", owner))
if err != nil {
panic(err)
}
if !organization.IsProfilePublic {
requestUserId := c.GetSessionUsername()
hasPermission, err := object.CheckUserPermission(requestUserId, id, false, c.GetAcceptLanguage())
@ -117,18 +169,30 @@ func (c *ApiController) GetUser() {
var user *object.User
switch {
case email != "":
user = object.GetUserByEmail(owner, email)
user, err = object.GetUserByEmail(owner, email)
case phone != "":
user = object.GetUserByPhone(owner, phone)
user, err = object.GetUserByPhone(owner, phone)
case userId != "":
user = userFromUserId
default:
user = object.GetUser(id)
user, err = object.GetUser(id)
}
object.ExtendUserWithRolesAndPermissions(user)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedUser(user)
err = object.ExtendUserWithRolesAndPermissions(user)
if err != nil {
panic(err)
}
maskedUser, err := object.GetMaskedUser(user)
if err != nil {
panic(err)
}
c.Data["json"] = maskedUser
c.ServeJSON()
}
@ -158,7 +222,12 @@ func (c *ApiController) UpdateUser() {
return
}
}
oldUser := object.GetUser(id)
oldUser, err := object.GetUser(id)
if err != nil {
c.ResponseError(err.Error())
return
}
if oldUser == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), id))
return
@ -185,9 +254,18 @@ func (c *ApiController) UpdateUser() {
columns = strings.Split(columnsStr, ",")
}
affected := object.UpdateUser(id, &user, columns, isAdmin)
affected, err := object.UpdateUser(id, &user, columns, isAdmin)
if err != nil {
c.ResponseError(err.Error())
return
}
if affected {
object.UpdateUserToOriginalDatabase(&user)
err = object.UpdateUserToOriginalDatabase(&user)
if err != nil {
c.ResponseError(err.Error())
return
}
}
c.Data["json"] = wrapActionResponse(affected)
@ -209,8 +287,13 @@ func (c *ApiController) AddUser() {
return
}
count := object.GetUserCount("", "", "")
if err := checkQuotaForUser(count); err != nil {
count, err := object.GetUserCount("", "", "")
if err != nil {
c.ResponseError(err.Error())
return
}
if err := checkQuotaForUser(int(count)); err != nil {
c.ResponseError(err.Error())
return
}
@ -261,7 +344,12 @@ func (c *ApiController) GetEmailAndPhone() {
organization := c.Ctx.Request.Form.Get("organization")
username := c.Ctx.Request.Form.Get("username")
user := object.GetUserByFields(organization, username)
user, err := object.GetUserByFields(organization, username)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(organization, username)))
return
@ -335,7 +423,11 @@ func (c *ApiController) SetPassword() {
c.SetSession("verifiedCode", "")
}
targetUser := object.GetUser(userId)
targetUser, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if oldPassword != "" {
msg := object.CheckPassword(targetUser, oldPassword, c.GetAcceptLanguage())
@ -346,7 +438,12 @@ func (c *ApiController) SetPassword() {
}
targetUser.Password = newPassword
object.SetUserField(targetUser, "password", targetUser.Password)
_, err = object.SetUserField(targetUser, "password", targetUser.Password)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk()
}
@ -384,7 +481,12 @@ func (c *ApiController) GetSortedUsers() {
sorter := c.Input().Get("sorter")
limit := util.ParseInt(c.Input().Get("limit"))
c.Data["json"] = object.GetMaskedUsers(object.GetSortedUsers(owner, sorter, limit))
maskedUsers, err := object.GetMaskedUsers(object.GetSortedUsers(owner, sorter, limit))
if err != nil {
panic(err)
}
c.Data["json"] = maskedUsers
c.ServeJSON()
}
@ -400,11 +502,16 @@ func (c *ApiController) GetUserCount() {
owner := c.Input().Get("owner")
isOnline := c.Input().Get("isOnline")
count := 0
var count int64
var err error
if isOnline == "" {
count = object.GetUserCount(owner, "", "")
count, err = object.GetUserCount(owner, "", "")
} else {
count = object.GetOnlineUserCount(owner, util.ParseInt(isOnline))
count, err = object.GetOnlineUserCount(owner, util.ParseInt(isOnline))
}
if err != nil {
c.ResponseError(err.Error())
return
}
c.Data["json"] = count

View File

@ -57,7 +57,12 @@ func (c *ApiController) UploadUsers() {
return
}
affected := object.UploadUsers(owner, fileId)
affected, err := object.UploadUsers(owner, fileId)
if err != nil {
c.ResponseError(err.Error())
return
}
if affected {
c.ResponseOk()
} else {

View File

@ -92,7 +92,11 @@ func (c *ApiController) RequireSignedInUser() (*object.User, bool) {
return nil, false
}
user := object.GetUser(userId)
user, err := object.GetUser(userId)
if err != nil {
panic(err)
}
if user == nil {
c.ClearUserSession()
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), userId))
@ -138,7 +142,11 @@ func (c *ApiController) IsMaskedEnabled() (bool, bool) {
func (c *ApiController) GetProviderFromContext(category string) (*object.Provider, *object.User, bool) {
providerName := c.Input().Get("provider")
if providerName != "" {
provider := object.GetProvider(util.GetId("admin", providerName))
provider, err := object.GetProvider(util.GetId("admin", providerName))
if err != nil {
panic(err)
}
if provider == nil {
c.ResponseError(fmt.Sprintf(c.T("util:The provider: %s is not found"), providerName))
return nil, nil, false
@ -151,13 +159,21 @@ func (c *ApiController) GetProviderFromContext(category string) (*object.Provide
return nil, nil, false
}
application, user := object.GetApplicationByUserId(userId)
application, user, err := object.GetApplicationByUserId(userId)
if err != nil {
panic(err)
}
if application == nil {
c.ResponseError(fmt.Sprintf(c.T("util:No application is found for userId: %s"), userId))
return nil, nil, false
}
provider := application.GetProviderByCategory(category)
provider, err := application.GetProviderByCategory(category)
if err != nil {
panic(err)
}
if provider == nil {
c.ResponseError(fmt.Sprintf(c.T("util:No provider for category: %s is found for application: %s"), category, application.Name))
return nil, nil, false

View File

@ -66,8 +66,17 @@ func (c *ApiController) SendVerificationCode() {
}
}
application := object.GetApplication(vform.ApplicationId)
organization := object.GetOrganization(util.GetId(application.Owner, application.Organization))
application, err := object.GetApplication(vform.ApplicationId)
if err != nil {
c.ResponseError(err.Error())
return
}
organization, err := object.GetOrganization(util.GetId(application.Owner, application.Organization))
if err != nil {
c.ResponseError(c.T(err.Error()))
}
if organization == nil {
c.ResponseError(c.T("check:Organization does not exist"))
return
@ -77,12 +86,20 @@ func (c *ApiController) SendVerificationCode() {
// checkUser != "", means method is ForgetVerification
if vform.CheckUser != "" {
owner := application.Organization
user = object.GetUser(util.GetId(owner, vform.CheckUser))
user, err = object.GetUser(util.GetId(owner, vform.CheckUser))
if err != nil {
c.ResponseError(err.Error())
return
}
}
// mfaSessionData != nil, means method is MfaSetupVerification
if mfaSessionData := c.getMfaSessionData(); mfaSessionData != nil {
user = object.GetUser(mfaSessionData.UserId)
user, err = object.GetUser(mfaSessionData.UserId)
if err != nil {
c.ResponseError(err.Error())
return
}
}
sendResp := errors.New("invalid dest type")
@ -99,7 +116,12 @@ func (c *ApiController) SendVerificationCode() {
vform.Dest = user.Email
}
user = object.GetUserByEmail(organization.Name, vform.Dest)
user, err = object.GetUserByEmail(organization.Name, vform.Dest)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError(c.T("verification:the user does not exist, please sign up first"))
return
@ -113,7 +135,12 @@ func (c *ApiController) SendVerificationCode() {
}
}
provider := application.GetEmailProvider()
provider, err := application.GetEmailProvider()
if err != nil {
c.ResponseError(err.Error())
return
}
sendResp = object.SendVerificationCodeToEmail(organization, user, provider, remoteAddr, vform.Dest)
case object.VerifyTypePhone:
if vform.Method == LoginVerification || vform.Method == ForgetVerification {
@ -121,7 +148,10 @@ func (c *ApiController) SendVerificationCode() {
vform.Dest = user.Phone
}
if user = object.GetUserByPhone(organization.Name, vform.Dest); user == nil {
if user, err = object.GetUserByPhone(organization.Name, vform.Dest); err != nil {
c.ResponseError(err.Error())
return
} else if user == nil {
c.ResponseError(c.T("verification:the user does not exist, please sign up first"))
return
}
@ -140,7 +170,12 @@ func (c *ApiController) SendVerificationCode() {
vform.CountryCode = mfaProps.CountryCode
}
provider := application.GetSmsProvider()
provider, err := application.GetSmsProvider()
if err != nil {
c.ResponseError(err.Error())
return
}
if phone, ok := util.GetE164Number(vform.Dest, vform.CountryCode); !ok {
c.ResponseError(fmt.Sprintf(c.T("verification:Phone number is invalid in your region %s"), vform.CountryCode))
return
@ -213,7 +248,12 @@ func (c *ApiController) ResetEmailOrPhone() {
}
checkDest := dest
organization := object.GetOrganizationByUser(user)
organization, err := object.GetOrganizationByUser(user)
if err != nil {
c.ResponseError(c.T(err.Error()))
return
}
if destType == object.VerifyTypePhone {
if object.HasUserByField(user.Owner, "phone", dest) {
c.ResponseError(c.T("check:Phone already exists"))
@ -260,16 +300,25 @@ func (c *ApiController) ResetEmailOrPhone() {
switch destType {
case object.VerifyTypeEmail:
user.Email = dest
object.SetUserField(user, "email", user.Email)
_, err = object.SetUserField(user, "email", user.Email)
case object.VerifyTypePhone:
user.Phone = dest
object.SetUserField(user, "phone", user.Phone)
_, err = object.SetUserField(user, "phone", user.Phone)
default:
c.ResponseError(c.T("verification:Unknown type"))
return
}
if err != nil {
c.ResponseError(err.Error())
return
}
err = object.DisableVerificationCode(checkDest)
if err != nil {
c.ResponseError(err.Error())
return
}
object.DisableVerificationCode(checkDest)
c.ResponseOk()
}
@ -287,7 +336,11 @@ func (c *ApiController) VerifyCode() {
var user *object.User
if authForm.Name != "" {
user = object.GetUserByFields(authForm.Organization, authForm.Name)
user, err = object.GetUserByFields(authForm.Organization, authForm.Name)
if err != nil {
c.ResponseError(err.Error())
return
}
}
var checkDest string
@ -302,7 +355,10 @@ func (c *ApiController) VerifyCode() {
}
}
if user = object.GetUserByFields(authForm.Organization, authForm.Username); user == nil {
if user, err = object.GetUserByFields(authForm.Organization, authForm.Username); err != nil {
c.ResponseError(err.Error())
return
} else if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(authForm.Organization, authForm.Username)))
return
}
@ -321,7 +377,11 @@ func (c *ApiController) VerifyCode() {
c.ResponseError(result.Msg)
return
}
object.DisableVerificationCode(checkDest)
err = object.DisableVerificationCode(checkDest)
if err != nil {
c.ResponseError(err.Error())
return
}
c.SetSession("verifiedCode", authForm.Code)
c.ResponseOk()

View File

@ -33,7 +33,12 @@ import (
// @Success 200 {object} protocol.CredentialCreation The CredentialCreationOptions object
// @router /webauthn/signup/begin [get]
func (c *ApiController) WebAuthnSignupBegin() {
webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host)
webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}
user := c.getCurrentUser()
if user == nil {
c.ResponseError(c.T("general:Please login first"))
@ -64,7 +69,12 @@ func (c *ApiController) WebAuthnSignupBegin() {
// @Success 200 {object} Response "The Response object"
// @router /webauthn/signup/finish [post]
func (c *ApiController) WebAuthnSignupFinish() {
webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host)
webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}
user := c.getCurrentUser()
if user == nil {
c.ResponseError(c.T("general:Please login first"))
@ -84,7 +94,12 @@ func (c *ApiController) WebAuthnSignupFinish() {
return
}
isGlobalAdmin := c.IsGlobalAdmin()
user.AddCredentials(*credential, isGlobalAdmin)
_, err = user.AddCredentials(*credential, isGlobalAdmin)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk()
}
@ -97,10 +112,20 @@ func (c *ApiController) WebAuthnSignupFinish() {
// @Success 200 {object} protocol.CredentialAssertion The CredentialAssertion object
// @router /webauthn/signin/begin [get]
func (c *ApiController) WebAuthnSigninBegin() {
webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host)
webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}
userOwner := c.Input().Get("owner")
userName := c.Input().Get("name")
user := object.GetUserByFields(userOwner, userName)
user, err := object.GetUserByFields(userOwner, userName)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(userOwner, userName)))
return
@ -129,7 +154,12 @@ func (c *ApiController) WebAuthnSigninBegin() {
// @router /webauthn/signin/finish [post]
func (c *ApiController) WebAuthnSigninFinish() {
responseType := c.Input().Get("responseType")
webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host)
webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}
sessionObj := c.GetSession("authentication")
sessionData, ok := sessionObj.(webauthn.SessionData)
if !ok {
@ -138,8 +168,13 @@ func (c *ApiController) WebAuthnSigninFinish() {
}
c.Ctx.Request.Body = io.NopCloser(bytes.NewBuffer(c.Ctx.Input.RequestBody))
userId := string(sessionData.UserID)
user := object.GetUser(userId)
_, err := webauthnObj.FinishLogin(user, sessionData, c.Ctx.Request)
user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
_, err = webauthnObj.FinishLogin(user, sessionData, c.Ctx.Request)
if err != nil {
c.ResponseError(err.Error())
return
@ -147,7 +182,12 @@ func (c *ApiController) WebAuthnSigninFinish() {
c.SetSessionUsername(userId)
util.LogInfo(c.Ctx, "API: [%s] signed in", userId)
application := object.GetApplicationByUser(user)
application, err := object.GetApplicationByUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
var authForm form.AuthForm
authForm.Type = responseType
resp := c.HandleLoggedIn(application, user, &authForm)

View File

@ -38,13 +38,31 @@ func (c *ApiController) GetWebhooks() {
sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization")
if limit == "" || page == "" {
c.Data["json"] = object.GetWebhooks(owner, organization)
webhooks, err := object.GetWebhooks(owner, organization)
if err != nil {
panic(err)
}
c.Data["json"] = webhooks
c.ServeJSON()
} else {
limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetWebhookCount(owner, organization, field, value)))
webhooks := object.GetPaginationWebhooks(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
count, err := object.GetWebhookCount(owner, organization, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
webhooks, err := object.GetPaginationWebhooks(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(webhooks, paginator.Nums())
}
}
@ -59,7 +77,12 @@ func (c *ApiController) GetWebhooks() {
func (c *ApiController) GetWebhook() {
id := c.Input().Get("id")
c.Data["json"] = object.GetWebhook(id)
webhook, err := object.GetWebhook(id)
if err != nil {
panic(err)
}
c.Data["json"] = webhook
c.ServeJSON()
}

View File

@ -84,6 +84,7 @@ func stringInSlice(value string, list []string) bool {
}
func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int) {
var err error
r := m.GetSearchRequest()
name, org, code := getNameAndOrgFromFilter(string(r.BaseObject()), r.FilterString())
@ -93,11 +94,19 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int)
if name == "*" && m.Client.IsOrgAdmin { // get all users from organization 'org'
if m.Client.IsGlobalAdmin && org == "*" {
filteredUsers = object.GetGlobalUsers()
filteredUsers, err = object.GetGlobalUsers()
if err != nil {
panic(err)
}
return filteredUsers, ldap.LDAPResultSuccess
}
if m.Client.IsGlobalAdmin || org == m.Client.OrgName {
filteredUsers = object.GetUsers(org)
filteredUsers, err = object.GetUsers(org)
if err != nil {
panic(err)
}
return filteredUsers, ldap.LDAPResultSuccess
} else {
return nil, ldap.LDAPResultInsufficientAccessRights
@ -112,13 +121,21 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int)
return nil, ldap.LDAPResultInsufficientAccessRights
}
user := object.GetUser(userId)
user, err := object.GetUser(userId)
if err != nil {
panic(err)
}
if user != nil {
filteredUsers = append(filteredUsers, user)
return filteredUsers, ldap.LDAPResultSuccess
}
organization := object.GetOrganization(util.GetId("admin", org))
organization, err := object.GetOrganization(util.GetId("admin", org))
if err != nil {
panic(err)
}
if organization == nil {
return nil, ldap.LDAPResultNoSuchObject
}
@ -127,7 +144,11 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int)
return nil, ldap.LDAPResultNoSuchObject
}
users := object.GetUsersByTag(org, name)
users, err := object.GetUsersByTag(org, name)
if err != nil {
panic(err)
}
filteredUsers = append(filteredUsers, users...)
return filteredUsers, ldap.LDAPResultSuccess
}
@ -137,7 +158,11 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int)
// TODO not handle salt yet
// @return {md5}5f4dcc3b5aa765d61d8327deb882cf99
func getUserPasswordWithType(user *object.User) string {
org := object.GetOrganizationByUser(user)
org, err := object.GetOrganizationByUser(user)
if err != nil {
panic(err)
}
if org.PasswordType == "" || org.PasswordType == "plain" {
return user.Password
}

View File

@ -119,7 +119,7 @@ func (a *Adapter) close() {
}
func (a *Adapter) createTable() {
showSql, _ := conf.GetConfigBool("showSql")
showSql := conf.GetConfigBool("showSql")
a.Engine.ShowSQL(showSql)
tableNamePrefix := conf.GetConfigString("tableNamePrefix")

View File

@ -79,134 +79,155 @@ type Application struct {
FormBackgroundUrl string `xorm:"varchar(200)" json:"formBackgroundUrl"`
}
func GetApplicationCount(owner, field, value string) int {
func GetApplicationCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Application{})
if err != nil {
panic(err)
return session.Count(&Application{})
}
return int(count)
}
func GetOrganizationApplicationCount(owner, Organization, field, value string) int {
func GetOrganizationApplicationCount(owner, Organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Application{Organization: Organization})
if err != nil {
panic(err)
return session.Count(&Application{Organization: Organization})
}
return int(count)
}
func GetApplications(owner string) []*Application {
func GetApplications(owner string) ([]*Application, error) {
applications := []*Application{}
err := adapter.Engine.Desc("created_time").Find(&applications, &Application{Owner: owner})
if err != nil {
panic(err)
return applications, err
}
return applications
return applications, nil
}
func GetOrganizationApplications(owner string, organization string) []*Application {
func GetOrganizationApplications(owner string, organization string) ([]*Application, error) {
applications := []*Application{}
err := adapter.Engine.Desc("created_time").Find(&applications, &Application{Organization: organization})
if err != nil {
panic(err)
return applications, err
}
return applications
return applications, nil
}
func GetPaginationApplications(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Application {
func GetPaginationApplications(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Application, error) {
var applications []*Application
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&applications)
if err != nil {
panic(err)
return applications, err
}
return applications
return applications, nil
}
func GetPaginationOrganizationApplications(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Application {
func GetPaginationOrganizationApplications(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Application, error) {
applications := []*Application{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&applications, &Application{Organization: organization})
if err != nil {
panic(err)
return applications, err
}
return applications
return applications, nil
}
func getProviderMap(owner string) map[string]*Provider {
providers := GetProviders(owner)
m := map[string]*Provider{}
func getProviderMap(owner string) (m map[string]*Provider, err error) {
providers, err := GetProviders(owner)
if err != nil {
return nil, err
}
m = map[string]*Provider{}
for _, provider := range providers {
// Get QRCode only once
if provider.Type == "WeChat" && provider.DisableSsl && provider.Content == "" {
provider.Content, _ = idp.GetWechatOfficialAccountQRCode(provider.ClientId2, provider.ClientSecret2)
provider.Content, err = idp.GetWechatOfficialAccountQRCode(provider.ClientId2, provider.ClientSecret2)
if err != nil {
return
}
UpdateProvider(provider.Owner+"/"+provider.Name, provider)
}
m[provider.Name] = GetMaskedProvider(provider, true)
}
return m
return m, err
}
func extendApplicationWithProviders(application *Application) (err error) {
m, err := getProviderMap(application.Organization)
if err != nil {
return err
}
func extendApplicationWithProviders(application *Application) {
m := getProviderMap(application.Organization)
for _, providerItem := range application.Providers {
if provider, ok := m[providerItem.Name]; ok {
providerItem.Provider = provider
}
}
return
}
func extendApplicationWithOrg(application *Application) {
organization := getOrganization(application.Owner, application.Organization)
func extendApplicationWithOrg(application *Application) (err error) {
organization, err := getOrganization(application.Owner, application.Organization)
application.OrganizationObj = organization
return
}
func getApplication(owner string, name string) *Application {
func getApplication(owner string, name string) (*Application, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
application := Application{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&application)
if err != nil {
panic(err)
return nil, err
}
if existed {
extendApplicationWithProviders(&application)
extendApplicationWithOrg(&application)
return &application
err = extendApplicationWithProviders(&application)
if err != nil {
return nil, err
}
err = extendApplicationWithOrg(&application)
if err != nil {
return nil, err
}
return &application, nil
} else {
return nil
return nil, nil
}
}
func GetApplicationByOrganizationName(organization string) *Application {
func GetApplicationByOrganizationName(organization string) (*Application, error) {
application := Application{}
existed, err := adapter.Engine.Where("organization=?", organization).Get(&application)
if err != nil {
panic(err)
return nil, nil
}
if existed {
extendApplicationWithProviders(&application)
extendApplicationWithOrg(&application)
return &application
err = extendApplicationWithProviders(&application)
if err != nil {
return nil, err
}
err = extendApplicationWithOrg(&application)
if err != nil {
return nil, err
}
return &application, nil
} else {
return nil
return nil, nil
}
}
func GetApplicationByUser(user *User) *Application {
func GetApplicationByUser(user *User) (*Application, error) {
if user.SignupApplication != "" {
return getApplication("admin", user.SignupApplication)
} else {
@ -214,38 +235,46 @@ func GetApplicationByUser(user *User) *Application {
}
}
func GetApplicationByUserId(userId string) (*Application, *User) {
var application *Application
func GetApplicationByUserId(userId string) (application *Application, user *User, err error) {
owner, name := util.GetOwnerAndNameFromId(userId)
if owner == "app" {
application = getApplication("admin", name)
return application, nil
application, err = getApplication("admin", name)
return
}
user := GetUser(userId)
application = GetApplicationByUser(user)
return application, user
user, err = GetUser(userId)
if err != nil {
return nil, nil, err
}
application, err = GetApplicationByUser(user)
return
}
func GetApplicationByClientId(clientId string) *Application {
func GetApplicationByClientId(clientId string) (*Application, error) {
application := Application{}
existed, err := adapter.Engine.Where("client_id=?", clientId).Get(&application)
if err != nil {
panic(err)
return nil, err
}
if existed {
extendApplicationWithProviders(&application)
extendApplicationWithOrg(&application)
return &application
err = extendApplicationWithProviders(&application)
if err != nil {
return nil, err
}
err = extendApplicationWithOrg(&application)
if err != nil {
return nil, err
}
return &application, nil
} else {
return nil
return nil, nil
}
}
func GetApplication(id string) *Application {
func GetApplication(id string) (*Application, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getApplication(owner, name)
}
@ -288,11 +317,11 @@ func GetMaskedApplications(applications []*Application, userId string) []*Applic
return applications
}
func UpdateApplication(id string, application *Application) bool {
func UpdateApplication(id string, application *Application) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
oldApplication := getApplication(owner, name)
oldApplication, err := getApplication(owner, name)
if oldApplication == nil {
return false
return false, err
}
if name == "app-built-in" {
@ -300,14 +329,19 @@ func UpdateApplication(id string, application *Application) bool {
}
if name != application.Name {
err := applicationChangeTrigger(name, application.Name)
err = applicationChangeTrigger(name, application.Name)
if err != nil {
return false
return false, err
}
}
if oldApplication.ClientId != application.ClientId && GetApplicationByClientId(application.ClientId) != nil {
return false
applicationByClientId, err := GetApplicationByClientId(application.ClientId)
if err != nil {
return false, err
}
if oldApplication.ClientId != application.ClientId && applicationByClientId != nil {
return false, err
}
for _, providerItem := range application.Providers {
@ -320,13 +354,13 @@ func UpdateApplication(id string, application *Application) bool {
}
affected, err := session.Update(application)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddApplication(application *Application) bool {
func AddApplication(application *Application) (bool, error) {
if application.Owner == "" {
application.Owner = "admin"
}
@ -339,32 +373,39 @@ func AddApplication(application *Application) bool {
if application.ClientSecret == "" {
application.ClientSecret = util.GenerateClientSecret()
}
if GetApplicationByClientId(application.ClientId) != nil {
return false
app, err := GetApplicationByClientId(application.ClientId)
if err != nil {
return false, err
}
if app != nil {
return false, nil
}
for _, providerItem := range application.Providers {
providerItem.Provider = nil
}
affected, err := adapter.Engine.Insert(application)
if err != nil {
panic(err)
return false, nil
}
return affected != 0
return affected != 0, nil
}
func DeleteApplication(application *Application) bool {
func DeleteApplication(application *Application) (bool, error) {
if application.Name == "app-built-in" {
return false
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{application.Owner, application.Name}).Delete(&Application{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (application *Application) GetId() string {
@ -383,33 +424,43 @@ func (application *Application) IsRedirectUriValid(redirectUri string) bool {
return isValid
}
func IsOriginAllowed(origin string) bool {
applications := GetApplications("")
func IsOriginAllowed(origin string) (bool, error) {
applications, err := GetApplications("")
if err != nil {
return false, err
}
for _, application := range applications {
if application.IsRedirectUriValid(origin) {
return true
return true, nil
}
}
return false
return false, nil
}
func getApplicationMap(organization string) map[string]*Application {
applications := GetOrganizationApplications("admin", organization)
func getApplicationMap(organization string) (map[string]*Application, error) {
applicationMap := make(map[string]*Application)
applications, err := GetOrganizationApplications("admin", organization)
if err != nil {
return applicationMap, err
}
for _, application := range applications {
applicationMap[application.Name] = application
}
return applicationMap
return applicationMap, nil
}
func ExtendManagedAccountsWithUser(user *User) *User {
func ExtendManagedAccountsWithUser(user *User) (*User, error) {
if user.ManagedAccounts == nil || len(user.ManagedAccounts) == 0 {
return user
return user, nil
}
applicationMap := getApplicationMap(user.Owner)
applicationMap, err := getApplicationMap(user.Owner)
if err != nil {
return user, err
}
var managedAccounts []ManagedAccount
for _, managedAccount := range user.ManagedAccounts {
@ -421,7 +472,7 @@ func ExtendManagedAccountsWithUser(user *User) *User {
}
user.ManagedAccounts = managedAccounts
return user
return user, nil
}
func applicationChangeTrigger(oldName string, newName string) error {

View File

@ -14,8 +14,12 @@
package object
func (application *Application) GetProviderByCategory(category string) *Provider {
providers := GetProviders(application.Organization)
func (application *Application) GetProviderByCategory(category string) (*Provider, error) {
providers, err := GetProviders(application.Organization)
if err != nil {
return nil, err
}
m := map[string]*Provider{}
for _, provider := range providers {
if provider.Category != category {
@ -27,22 +31,22 @@ func (application *Application) GetProviderByCategory(category string) *Provider
for _, providerItem := range application.Providers {
if provider, ok := m[providerItem.Name]; ok {
return provider
return provider, nil
}
}
return nil
return nil, nil
}
func (application *Application) GetEmailProvider() *Provider {
func (application *Application) GetEmailProvider() (*Provider, error) {
return application.GetProviderByCategory("Email")
}
func (application *Application) GetSmsProvider() *Provider {
func (application *Application) GetSmsProvider() (*Provider, error) {
return application.GetProviderByCategory("SMS")
}
func (application *Application) GetStorageProvider() *Provider {
func (application *Application) GetStorageProvider() (*Provider, error) {
return application.GetProviderByCategory("Storage")
}

View File

@ -28,7 +28,11 @@ var defaultStorageProvider *Provider = nil
func InitDefaultStorageProvider() {
defaultStorageProviderStr := conf.GetConfigString("defaultStorageProvider")
if defaultStorageProviderStr != "" {
defaultStorageProvider = getProvider("admin", defaultStorageProviderStr)
var err error
defaultStorageProvider, err = getProvider("admin", defaultStorageProviderStr)
if err != nil {
panic(err)
}
}
}
@ -50,40 +54,44 @@ func downloadFile(url string) (*bytes.Buffer, error) {
return fileBuffer, nil
}
func getPermanentAvatarUrl(organization string, username string, url string, upload bool) string {
func getPermanentAvatarUrl(organization string, username string, url string, upload bool) (string, error) {
if url == "" {
return ""
return "", nil
}
if defaultStorageProvider == nil {
return ""
return "", nil
}
fullFilePath := fmt.Sprintf("/avatar/%s/%s.png", organization, username)
uploadedFileUrl, _ := GetUploadFileUrl(defaultStorageProvider, fullFilePath, false)
if upload {
DownloadAndUpload(url, fullFilePath, "en")
if err := DownloadAndUpload(url, fullFilePath, "en"); err != nil {
return "", err
}
}
return uploadedFileUrl
return uploadedFileUrl, nil
}
func DownloadAndUpload(url string, fullFilePath string, lang string) {
func DownloadAndUpload(url string, fullFilePath string, lang string) (err error) {
fileBuffer, err := downloadFile(url)
if err != nil {
panic(err)
return
}
_, _, err = UploadFileSafe(defaultStorageProvider, fullFilePath, fileBuffer, lang)
if err != nil {
panic(err)
}
return
}
func getPermanentAvatarUrlFromBuffer(organization string, username string, fileBuffer *bytes.Buffer, ext string, upload bool) string {
return
}
func getPermanentAvatarUrlFromBuffer(organization string, username string, fileBuffer *bytes.Buffer, ext string, upload bool) (string, error) {
if defaultStorageProvider == nil {
return ""
return "", nil
}
fullFilePath := fmt.Sprintf("/avatar/%s/%s%s", organization, username, ext)
@ -92,9 +100,9 @@ func getPermanentAvatarUrlFromBuffer(organization string, username string, fileB
if upload {
_, _, err := UploadFileSafe(defaultStorageProvider, fullFilePath, fileBuffer, "en")
if err != nil {
panic(err)
return "", err
}
}
return uploadedFileUrl
return uploadedFileUrl, nil
}

View File

@ -27,13 +27,21 @@ func TestSyncPermanentAvatars(t *testing.T) {
InitDefaultStorageProvider()
proxy.InitHttpClient()
users := GetGlobalUsers()
users, err := GetGlobalUsers()
if err != nil {
panic(err)
}
for i, user := range users {
if user.Avatar == "" {
continue
}
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
if err != nil {
panic(err)
}
updateUserColumn("permanent_avatar", user)
fmt.Printf("[%d/%d]: Update user: [%s]'s permanent avatar: %s\n", i, len(users), user.GetId(), user.PermanentAvatar)
}
@ -44,16 +52,27 @@ func TestUpdateAvatars(t *testing.T) {
InitDefaultStorageProvider()
proxy.InitHttpClient()
users := GetUsers("casdoor")
users, err := GetUsers("casdoor")
if err != nil {
panic(err)
}
for _, user := range users {
if strings.HasPrefix(user.Avatar, "http") {
continue
}
updated := user.refreshAvatar()
updated, err := user.refreshAvatar()
if err != nil {
panic(err)
}
if updated {
user.PermanentAvatar = "*"
UpdateUser(user.GetId(), user, []string{"avatar"}, true)
_, err = UpdateUser(user.GetId(), user, []string{"avatar"}, true)
if err != nil {
panic(err)
}
}
}
}

View File

@ -20,17 +20,17 @@ import (
"github.com/dchest/captcha"
)
func GetCaptcha() (string, []byte) {
func GetCaptcha() (string, []byte, error) {
id := captcha.NewLen(5)
var buffer bytes.Buffer
err := captcha.WriteImage(&buffer, id, 200, 80)
if err != nil {
panic(err)
return "", nil, err
}
return id, buffer.Bytes()
return id, buffer.Bytes(), nil
}
func VerifyCaptcha(id string, digits string) bool {

View File

@ -46,64 +46,59 @@ type CasbinAdapter struct {
Adapter *xormadapter.Adapter `xorm:"-" json:"-"`
}
func GetCasbinAdapterCount(owner, organization, field, value string) int {
func GetCasbinAdapterCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&CasbinAdapter{Organization: organization})
if err != nil {
panic(err)
return session.Count(&CasbinAdapter{Organization: organization})
}
return int(count)
}
func GetCasbinAdapters(owner string, organization string) []*CasbinAdapter {
func GetCasbinAdapters(owner string, organization string) ([]*CasbinAdapter, error) {
adapters := []*CasbinAdapter{}
err := adapter.Engine.Where("owner = ? and organization = ?", owner, organization).Find(&adapters)
if err != nil {
panic(err)
return adapters, err
}
return adapters
return adapters, nil
}
func GetPaginationCasbinAdapters(owner, organization string, page, limit int, field, value, sort, order string) []*CasbinAdapter {
func GetPaginationCasbinAdapters(owner, organization string, page, limit int, field, value, sort, order string) ([]*CasbinAdapter, error) {
session := GetSession(owner, page, limit, field, value, sort, order)
adapters := []*CasbinAdapter{}
err := session.Find(&adapters, &CasbinAdapter{Organization: organization})
if err != nil {
panic(err)
return adapters, err
}
return adapters
return adapters, nil
}
func getCasbinAdapter(owner, name string) *CasbinAdapter {
func getCasbinAdapter(owner, name string) (*CasbinAdapter, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
casbinAdapter := CasbinAdapter{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&casbinAdapter)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &casbinAdapter
return &casbinAdapter, nil
} else {
return nil
return nil, nil
}
}
func GetCasbinAdapter(id string) *CasbinAdapter {
func GetCasbinAdapter(id string) (*CasbinAdapter, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getCasbinAdapter(owner, name)
}
func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) bool {
func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getCasbinAdapter(owner, name) == nil {
return false
if casbinAdapter, err := getCasbinAdapter(owner, name); casbinAdapter == nil {
return false, err
}
session := adapter.Engine.ID(core.PK{owner, name}).AllCols()
@ -112,28 +107,28 @@ func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) bool {
}
affected, err := session.Update(casbinAdapter)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddCasbinAdapter(casbinAdapter *CasbinAdapter) bool {
func AddCasbinAdapter(casbinAdapter *CasbinAdapter) (bool, error) {
affected, err := adapter.Engine.Insert(casbinAdapter)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteCasbinAdapter(casbinAdapter *CasbinAdapter) bool {
func DeleteCasbinAdapter(casbinAdapter *CasbinAdapter) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{casbinAdapter.Owner, casbinAdapter.Name}).Delete(&CasbinAdapter{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (casbinAdapter *CasbinAdapter) GetId() string {
@ -214,7 +209,11 @@ func matrixToCasbinRules(Ptype string, policies [][]string) []*xormadapter.Casbi
}
func SyncPolicies(casbinAdapter *CasbinAdapter) ([]*xormadapter.CasbinRule, error) {
modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model)
modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model)
if err != nil {
return nil, err
}
enforcer, err := initEnforcer(modelObj, casbinAdapter)
if err != nil {
return nil, err
@ -229,7 +228,11 @@ func SyncPolicies(casbinAdapter *CasbinAdapter) ([]*xormadapter.CasbinRule, erro
}
func UpdatePolicy(oldPolicy, newPolicy []string, casbinAdapter *CasbinAdapter) (bool, error) {
modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model)
modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model)
if err != nil {
return false, err
}
enforcer, err := initEnforcer(modelObj, casbinAdapter)
if err != nil {
return false, err
@ -243,7 +246,11 @@ func UpdatePolicy(oldPolicy, newPolicy []string, casbinAdapter *CasbinAdapter) (
}
func AddPolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) {
modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model)
modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model)
if err != nil {
return false, err
}
enforcer, err := initEnforcer(modelObj, casbinAdapter)
if err != nil {
return false, err
@ -257,7 +264,11 @@ func AddPolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) {
}
func RemovePolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) {
modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model)
modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model)
if err != nil {
return false, err
}
enforcer, err := initEnforcer(modelObj, casbinAdapter)
if err != nil {
return false, err

View File

@ -47,137 +47,133 @@ func GetMaskedCert(cert *Cert) *Cert {
return cert
}
func GetMaskedCerts(certs []*Cert) []*Cert {
func GetMaskedCerts(certs []*Cert, err error) ([]*Cert, error) {
if err != nil {
return nil, err
}
for _, cert := range certs {
cert = GetMaskedCert(cert)
}
return certs
return certs, nil
}
func GetCertCount(owner, field, value string) int {
func GetCertCount(owner, field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Cert{})
if err != nil {
panic(err)
return session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Cert{})
}
return int(count)
}
func GetCerts(owner string) []*Cert {
func GetCerts(owner string) ([]*Cert, error) {
certs := []*Cert{}
err := adapter.Engine.Where("owner = ? or owner = ? ", "admin", owner).Desc("created_time").Find(&certs, &Cert{})
if err != nil {
panic(err)
return certs, err
}
return certs
return certs, nil
}
func GetPaginationCerts(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Cert {
func GetPaginationCerts(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Cert, error) {
certs := []*Cert{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Where("owner = ? or owner = ? ", "admin", owner).Find(&certs)
if err != nil {
panic(err)
return certs, err
}
return certs
return certs, nil
}
func GetGlobalCertsCount(field, value string) int {
func GetGlobalCertsCount(field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Count(&Cert{})
if err != nil {
panic(err)
return session.Count(&Cert{})
}
return int(count)
}
func GetGlobleCerts() []*Cert {
func GetGlobleCerts() ([]*Cert, error) {
certs := []*Cert{}
err := adapter.Engine.Desc("created_time").Find(&certs)
if err != nil {
panic(err)
return certs, err
}
return certs
return certs, nil
}
func GetPaginationGlobalCerts(offset, limit int, field, value, sortField, sortOrder string) []*Cert {
func GetPaginationGlobalCerts(offset, limit int, field, value, sortField, sortOrder string) ([]*Cert, error) {
certs := []*Cert{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Find(&certs)
if err != nil {
panic(err)
return certs, err
}
return certs
return certs, nil
}
func getCert(owner string, name string) *Cert {
func getCert(owner string, name string) (*Cert, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
cert := Cert{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&cert)
if err != nil {
panic(err)
return &cert, err
}
if existed {
return &cert
return &cert, nil
} else {
return nil
return nil, nil
}
}
func getCertByName(name string) *Cert {
func getCertByName(name string) (*Cert, error) {
if name == "" {
return nil
return nil, nil
}
cert := Cert{Name: name}
existed, err := adapter.Engine.Get(&cert)
if err != nil {
panic(err)
return &cert, nil
}
if existed {
return &cert
return &cert, nil
} else {
return nil
return nil, nil
}
}
func GetCert(id string) *Cert {
func GetCert(id string) (*Cert, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getCert(owner, name)
}
func UpdateCert(id string, cert *Cert) bool {
func UpdateCert(id string, cert *Cert) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getCert(owner, name) == nil {
return false
if c, err := getCert(owner, name); err != nil {
return false, err
} else if c == nil {
return false, nil
}
if name != cert.Name {
err := certChangeTrigger(name, cert.Name)
if err != nil {
return false
return false, nil
}
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(cert)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddCert(cert *Cert) bool {
func AddCert(cert *Cert) (bool, error) {
if cert.Certificate == "" || cert.PrivateKey == "" {
certificate, privateKey := generateRsaKeys(cert.BitSize, cert.ExpireInYears, cert.Name, cert.Owner)
cert.Certificate = certificate
@ -186,26 +182,26 @@ func AddCert(cert *Cert) bool {
affected, err := adapter.Engine.Insert(cert)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteCert(cert *Cert) bool {
func DeleteCert(cert *Cert) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{cert.Owner, cert.Name}).Delete(&Cert{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (p *Cert) GetId() string {
return fmt.Sprintf("%s/%s", p.Owner, p.Name)
}
func getCertByApplication(application *Application) *Cert {
func getCertByApplication(application *Application) (*Cert, error) {
if application.Cert != "" {
return getCertByName(application.Cert)
} else {
@ -213,7 +209,7 @@ func getCertByApplication(application *Application) *Cert {
}
}
func GetDefaultCert() *Cert {
func GetDefaultCert() (*Cert, error) {
return getCert("admin", "cert-built-in")
}

View File

@ -37,92 +37,104 @@ type Chat struct {
MessageCount int `json:"messageCount"`
}
func GetMaskedChat(chat *Chat) *Chat {
func GetMaskedChat(chat *Chat, err ...error) (*Chat, error) {
if len(err) > 0 && err[0] != nil {
return nil, err[0]
}
if chat == nil {
return nil
return nil, nil
}
return chat
return chat, nil
}
func GetMaskedChats(chats []*Chat) []*Chat {
func GetMaskedChats(chats []*Chat, errs ...error) ([]*Chat, error) {
if len(errs) > 0 && errs[0] != nil {
return nil, errs[0]
}
var err error
for _, chat := range chats {
chat = GetMaskedChat(chat)
}
return chats
}
func GetChatCount(owner, field, value string) int {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Chat{})
chat, err = GetMaskedChat(chat)
if err != nil {
panic(err)
return nil, err
}
}
return chats, nil
}
return int(count)
func GetChatCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
return session.Count(&Chat{})
}
func GetChats(owner string) []*Chat {
func GetChats(owner string) ([]*Chat, error) {
chats := []*Chat{}
err := adapter.Engine.Desc("created_time").Find(&chats, &Chat{Owner: owner})
if err != nil {
panic(err)
return chats, err
}
return chats
return chats, nil
}
func GetPaginationChats(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Chat {
func GetPaginationChats(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Chat, error) {
chats := []*Chat{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&chats)
if err != nil {
panic(err)
return chats, err
}
return chats
return chats, nil
}
func getChat(owner string, name string) *Chat {
func getChat(owner string, name string) (*Chat, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
chat := Chat{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&chat)
if err != nil {
panic(err)
return &chat, err
}
if existed {
return &chat
return &chat, nil
} else {
return nil
return nil, nil
}
}
func GetChat(id string) *Chat {
func GetChat(id string) (*Chat, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getChat(owner, name)
}
func UpdateChat(id string, chat *Chat) bool {
func UpdateChat(id string, chat *Chat) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getChat(owner, name) == nil {
return false
if c, err := getChat(owner, name); err != nil {
return false, err
} else if c == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(chat)
if err != nil {
panic(err)
return false, nil
}
return affected != 0
return affected != 0, nil
}
func AddChat(chat *Chat) bool {
func AddChat(chat *Chat) (bool, error) {
if chat.Type == "AI" && chat.User2 == "" {
provider := getDefaultAiProvider()
provider, err := getDefaultAiProvider()
if err != nil {
return false, err
}
if provider != nil {
chat.User2 = provider.Name
}
@ -130,23 +142,23 @@ func AddChat(chat *Chat) bool {
affected, err := adapter.Engine.Insert(chat)
if err != nil {
panic(err)
return false, nil
}
return affected != 0
return affected != 0, nil
}
func DeleteChat(chat *Chat) bool {
func DeleteChat(chat *Chat) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{chat.Owner, chat.Name}).Delete(&Chat{})
if err != nil {
panic(err)
return false, err
}
if affected != 0 {
return DeleteChatMessages(chat.Name)
}
return affected != 0
return affected != 0, nil
}
func (p *Chat) GetId() string {

View File

@ -170,7 +170,11 @@ func CheckPassword(user *User, password string, lang string, options ...bool) st
}
}
organization := GetOrganizationByUser(user)
organization, err := GetOrganizationByUser(user)
if err != nil {
panic(err)
}
if organization == nil {
return i18n.Translate(lang, "check:Organization does not exist")
}
@ -200,7 +204,11 @@ func CheckPassword(user *User, password string, lang string, options ...bool) st
}
func checkLdapUserPassword(user *User, password string, lang string) string {
ldaps := GetLdaps(user.Owner)
ldaps, err := GetLdaps(user.Owner)
if err != nil {
return err.Error()
}
ldapLoginSuccess := false
hit := false
@ -247,7 +255,11 @@ func CheckUserPassword(organization string, username string, password string, la
if len(options) > 0 {
enableCaptcha = options[0]
}
user := GetUserByFields(organization, username)
user, err := GetUserByFields(organization, username)
if err != nil {
panic(err)
}
if user == nil || user.IsDeleted {
return nil, fmt.Sprintf(i18n.Translate(lang, "general:The user: %s doesn't exist"), util.GetId(organization, username))
}
@ -284,7 +296,11 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string)
userOwner := util.GetOwnerFromId(userId)
if userId != "" {
targetUser := GetUser(userId)
targetUser, err := GetUser(userId)
if err != nil {
panic(err)
}
if targetUser == nil {
if strings.HasPrefix(requestUserId, "built-in/") {
return true, nil
@ -300,7 +316,11 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string)
if strings.HasPrefix(requestUserId, "app/") {
hasPermission = true
} else {
requestUser := GetUser(requestUserId)
requestUser, err := GetUser(requestUserId)
if err != nil {
return false, err
}
if requestUser == nil {
return false, fmt.Errorf(i18n.Translate(lang, "check:Session outdated, please login again"))
}
@ -321,13 +341,17 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string)
}
func CheckAccessPermission(userId string, application *Application) (bool, error) {
var err error
if userId == "built-in/admin" {
return true, nil
}
permissions := GetPermissions(application.Organization)
permissions, err := GetPermissions(application.Organization)
if err != nil {
return false, err
}
allowed := true
var err error
for _, permission := range permissions {
if !permission.IsEnabled || len(permission.Users) == 0 {
continue
@ -403,9 +427,9 @@ func CheckUpdateUser(oldUser, user *User, lang string) string {
return ""
}
func CheckToEnableCaptcha(application *Application, organization, username string) bool {
func CheckToEnableCaptcha(application *Application, organization, username string) (bool, error) {
if len(application.Providers) == 0 {
return false
return false, nil
}
for _, providerItem := range application.Providers {
@ -414,12 +438,15 @@ func CheckToEnableCaptcha(application *Application, organization, username strin
}
if providerItem.Provider.Category == "Captcha" {
if providerItem.Rule == "Dynamic" {
user := GetUserByFields(organization, username)
return user != nil && user.SigninWrongTimes >= SigninWrongTimesLimit
user, err := GetUserByFields(organization, username)
if err != nil {
return false, err
}
return providerItem.Rule == "Always"
return user != nil && user.SigninWrongTimes >= SigninWrongTimesLimit, nil
}
return providerItem.Rule == "Always", nil
}
}
return false
return false, nil
}

View File

@ -74,7 +74,11 @@ func getBuiltInAccountItems() []*AccountItem {
}
func initBuiltInOrganization() bool {
organization := getOrganization("admin", "built-in")
organization, err := getOrganization("admin", "built-in")
if err != nil {
panic(err)
}
if organization != nil {
return true
}
@ -96,12 +100,19 @@ func initBuiltInOrganization() bool {
EnableSoftDeletion: false,
IsProfilePublic: false,
}
AddOrganization(organization)
_, err = AddOrganization(organization)
if err != nil {
panic(err)
}
return false
}
func initBuiltInUser() {
user := getUser("built-in", "admin")
user, err := getUser("built-in", "admin")
if err != nil {
panic(err)
}
if user != nil {
return
}
@ -131,11 +142,18 @@ func initBuiltInUser() {
CreatedIp: "127.0.0.1",
Properties: make(map[string]string),
}
AddUser(user)
_, err = AddUser(user)
if err != nil {
panic(err)
}
}
func initBuiltInApplication() {
application := getApplication("admin", "app-built-in")
application, err := getApplication("admin", "app-built-in")
if err != nil {
panic(err)
}
if application != nil {
return
}
@ -168,7 +186,10 @@ func initBuiltInApplication() {
ExpireInHours: 168,
FormOffset: 2,
}
AddApplication(application)
_, err = AddApplication(application)
if err != nil {
panic(err)
}
}
func readTokenFromFile() (string, string) {
@ -187,7 +208,11 @@ func readTokenFromFile() (string, string) {
func initBuiltInCert() {
tokenJwtCertificate, tokenJwtPrivateKey := readTokenFromFile()
cert := getCert("admin", "cert-built-in")
cert, err := getCert("admin", "cert-built-in")
if err != nil {
panic(err)
}
if cert != nil {
return
}
@ -205,11 +230,18 @@ func initBuiltInCert() {
Certificate: tokenJwtCertificate,
PrivateKey: tokenJwtPrivateKey,
}
AddCert(cert)
_, err = AddCert(cert)
if err != nil {
panic(err)
}
}
func initBuiltInLdap() {
ldap := GetLdap("ldap-built-in")
ldap, err := GetLdap("ldap-built-in")
if err != nil {
panic(err)
}
if ldap != nil {
return
}
@ -226,11 +258,18 @@ func initBuiltInLdap() {
AutoSync: 0,
LastSync: "",
}
AddLdap(ldap)
_, err = AddLdap(ldap)
if err != nil {
panic(err)
}
}
func initBuiltInProvider() {
provider := GetProvider(util.GetId("admin", "provider_captcha_default"))
provider, err := GetProvider(util.GetId("admin", "provider_captcha_default"))
if err != nil {
panic(err)
}
if provider != nil {
return
}
@ -243,7 +282,10 @@ func initBuiltInProvider() {
Category: "Captcha",
Type: "Default",
}
AddProvider(provider)
_, err = AddProvider(provider)
if err != nil {
panic(err)
}
}
func initWebAuthn() {
@ -251,7 +293,11 @@ func initWebAuthn() {
}
func initBuiltInModel() {
model := GetModel("built-in/model-built-in")
model, err := GetModel("built-in/model-built-in")
if err != nil {
panic(err)
}
if model != nil {
return
}
@ -274,11 +320,17 @@ e = some(where (p.eft == allow))
[matchers]
m = r.sub == p.sub && r.obj == p.obj && r.act == p.act`,
}
AddModel(model)
_, err = AddModel(model)
if err != nil {
panic(err)
}
}
func initBuiltInPermission() {
permission := GetPermission("built-in/permission-built-in")
permission, err := GetPermission("built-in/permission-built-in")
if err != nil {
panic(err)
}
if permission != nil {
return
}
@ -298,5 +350,8 @@ func initBuiltInPermission() {
Effect: "Allow",
IsEnabled: true,
}
AddPermission(permission)
_, err = AddPermission(permission)
if err != nil {
panic(err)
}
}

View File

@ -35,7 +35,11 @@ type InitData struct {
}
func InitFromFile() {
initData := readInitDataFromFile("./init_data.json")
initData, err := readInitDataFromFile("./init_data.json")
if err != nil {
panic(err)
}
if initData != nil {
for _, organization := range initData.Organizations {
initDefinedOrganization(organization)
@ -85,9 +89,9 @@ func InitFromFile() {
}
}
func readInitDataFromFile(filePath string) *InitData {
func readInitDataFromFile(filePath string) (*InitData, error) {
if !util.FileExist(filePath) {
return nil
return nil, nil
}
s := util.ReadStringFromPath(filePath)
@ -111,7 +115,7 @@ func readInitDataFromFile(filePath string) *InitData {
}
err := util.JsonToStruct(s, data)
if err != nil {
panic(err)
return nil, err
}
// transform nil slice to empty slice
@ -170,142 +174,246 @@ func readInitDataFromFile(filePath string) *InitData {
}
}
return data
return data, nil
}
func initDefinedOrganization(organization *Organization) {
existed := getOrganization(organization.Owner, organization.Name)
existed, err := getOrganization(organization.Owner, organization.Name)
if err != nil {
panic(err)
}
if existed != nil {
return
}
organization.CreatedTime = util.GetCurrentTime()
organization.AccountItems = getBuiltInAccountItems()
AddOrganization(organization)
_, err = AddOrganization(organization)
if err != nil {
panic(err)
}
}
func initDefinedApplication(application *Application) {
existed := getApplication(application.Owner, application.Name)
existed, err := getApplication(application.Owner, application.Name)
if err != nil {
panic(err)
}
if existed != nil {
return
}
application.CreatedTime = util.GetCurrentTime()
AddApplication(application)
_, err = AddApplication(application)
if err != nil {
panic(err)
}
}
func initDefinedUser(user *User) {
existed := getUser(user.Owner, user.Name)
existed, err := getUser(user.Owner, user.Name)
if err != nil {
panic(err)
}
if existed != nil {
return
}
user.CreatedTime = util.GetCurrentTime()
user.Id = util.GenerateId()
user.Properties = make(map[string]string)
AddUser(user)
_, err = AddUser(user)
if err != nil {
panic(err)
}
}
func initDefinedCert(cert *Cert) {
existed := getCert(cert.Owner, cert.Name)
existed, err := getCert(cert.Owner, cert.Name)
if err != nil {
panic(err)
}
if existed != nil {
return
}
cert.CreatedTime = util.GetCurrentTime()
AddCert(cert)
_, err = AddCert(cert)
if err != nil {
panic(err)
}
}
func initDefinedLdap(ldap *Ldap) {
existed := GetLdap(ldap.Id)
existed, err := GetLdap(ldap.Id)
if err != nil {
panic(err)
}
if existed != nil {
return
}
AddLdap(ldap)
_, err = AddLdap(ldap)
if err != nil {
panic(err)
}
}
func initDefinedProvider(provider *Provider) {
existed := GetProvider(util.GetId("admin", provider.Name))
existed, err := GetProvider(util.GetId("admin", provider.Name))
if err != nil {
panic(err)
}
if existed != nil {
return
}
AddProvider(provider)
_, err = AddProvider(provider)
if err != nil {
panic(err)
}
}
func initDefinedModel(model *Model) {
existed := GetModel(model.GetId())
existed, err := GetModel(model.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
model.CreatedTime = util.GetCurrentTime()
AddModel(model)
_, err = AddModel(model)
if err != nil {
panic(err)
}
}
func initDefinedPermission(permission *Permission) {
existed := GetPermission(permission.GetId())
existed, err := GetPermission(permission.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
permission.CreatedTime = util.GetCurrentTime()
AddPermission(permission)
_, err = AddPermission(permission)
if err != nil {
panic(err)
}
}
func initDefinedPayment(payment *Payment) {
existed := GetPayment(payment.GetId())
existed, err := GetPayment(payment.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
payment.CreatedTime = util.GetCurrentTime()
AddPayment(payment)
_, err = AddPayment(payment)
if err != nil {
panic(err)
}
}
func initDefinedProduct(product *Product) {
existed := GetProduct(product.GetId())
existed, err := GetProduct(product.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
product.CreatedTime = util.GetCurrentTime()
AddProduct(product)
_, err = AddProduct(product)
if err != nil {
panic(err)
}
}
func initDefinedResource(resource *Resource) {
existed := GetResource(resource.GetId())
existed, err := GetResource(resource.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
resource.CreatedTime = util.GetCurrentTime()
AddResource(resource)
_, err = AddResource(resource)
if err != nil {
panic(err)
}
}
func initDefinedRole(role *Role) {
existed := GetRole(role.GetId())
existed, err := GetRole(role.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
role.CreatedTime = util.GetCurrentTime()
AddRole(role)
_, err = AddRole(role)
if err != nil {
panic(err)
}
}
func initDefinedSyncer(syncer *Syncer) {
existed := GetSyncer(syncer.GetId())
existed, err := GetSyncer(syncer.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
syncer.CreatedTime = util.GetCurrentTime()
AddSyncer(syncer)
_, err = AddSyncer(syncer)
if err != nil {
panic(err)
}
}
func initDefinedToken(token *Token) {
existed := GetToken(token.GetId())
existed, err := GetToken(token.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
token.CreatedTime = util.GetCurrentTime()
AddToken(token)
_, err = AddToken(token)
if err != nil {
panic(err)
}
}
func initDefinedWebhook(webhook *Webhook) {
existed := GetWebhook(webhook.GetId())
existed, err := GetWebhook(webhook.GetId())
if err != nil {
panic(err)
}
if existed != nil {
return
}
webhook.CreatedTime = util.GetCurrentTime()
AddWebhook(webhook)
_, err = AddWebhook(webhook)
if err != nil {
panic(err)
}
}

View File

@ -37,7 +37,7 @@ type Ldap struct {
LastSync string `xorm:"varchar(100)" json:"lastSync"`
}
func AddLdap(ldap *Ldap) bool {
func AddLdap(ldap *Ldap) (bool, error) {
if len(ldap.Id) == 0 {
ldap.Id = util.GenerateId()
}
@ -48,13 +48,13 @@ func AddLdap(ldap *Ldap) bool {
affected, err := adapter.Engine.Insert(ldap)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func CheckLdapExist(ldap *Ldap) bool {
func CheckLdapExist(ldap *Ldap) (bool, error) {
var result []*Ldap
err := adapter.Engine.Find(&result, &Ldap{
Owner: ldap.Owner,
@ -65,63 +65,65 @@ func CheckLdapExist(ldap *Ldap) bool {
BaseDn: ldap.BaseDn,
})
if err != nil {
panic(err)
return false, err
}
if len(result) > 0 {
return true
return true, nil
}
return false
return false, nil
}
func GetLdaps(owner string) []*Ldap {
func GetLdaps(owner string) ([]*Ldap, error) {
var ldaps []*Ldap
err := adapter.Engine.Desc("created_time").Find(&ldaps, &Ldap{Owner: owner})
if err != nil {
panic(err)
return ldaps, err
}
return ldaps
return ldaps, nil
}
func GetLdap(id string) *Ldap {
func GetLdap(id string) (*Ldap, error) {
if util.IsStringsEmpty(id) {
return nil
return nil, nil
}
ldap := Ldap{Id: id}
existed, err := adapter.Engine.Get(&ldap)
if err != nil {
panic(err)
return &ldap, nil
}
if existed {
return &ldap
return &ldap, nil
} else {
return nil
return nil, nil
}
}
func UpdateLdap(ldap *Ldap) bool {
if GetLdap(ldap.Id) == nil {
return false
func UpdateLdap(ldap *Ldap) (bool, error) {
if l, err := GetLdap(ldap.Id); err != nil {
return false, nil
} else if l == nil {
return false, nil
}
affected, err := adapter.Engine.ID(ldap.Id).Cols("owner", "server_name", "host",
"port", "enable_ssl", "username", "password", "base_dn", "filter", "filter_fields", "auto_sync").Update(ldap)
if err != nil {
panic(err)
return false, nil
}
return affected != 0
return affected != 0, nil
}
func DeleteLdap(ldap *Ldap) bool {
func DeleteLdap(ldap *Ldap) (bool, error) {
affected, err := adapter.Engine.ID(ldap.Id).Delete(&Ldap{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}

View File

@ -18,7 +18,10 @@ var globalLdapAutoSynchronizer *LdapAutoSynchronizer
func InitLdapAutoSynchronizer() {
globalLdapAutoSynchronizer = NewLdapAutoSynchronizer()
globalLdapAutoSynchronizer.LdapAutoSynchronizerStartUpAll()
err := globalLdapAutoSynchronizer.LdapAutoSynchronizerStartUpAll()
if err != nil {
panic(err)
}
}
func NewLdapAutoSynchronizer() *LdapAutoSynchronizer {
@ -37,7 +40,11 @@ func (l *LdapAutoSynchronizer) StartAutoSync(ldapId string) error {
l.Lock()
defer l.Unlock()
ldap := GetLdap(ldapId)
ldap, err := GetLdap(ldapId)
if err != nil {
return err
}
if ldap == nil {
return fmt.Errorf("ldap %s doesn't exist", ldapId)
}
@ -49,7 +56,12 @@ func (l *LdapAutoSynchronizer) StartAutoSync(ldapId string) error {
stopChan := make(chan struct{})
l.ldapIdToStopChan[ldapId] = stopChan
logs.Info(fmt.Sprintf("autoSync started for %s", ldap.Id))
util.SafeGoroutine(func() { l.syncRoutine(ldap, stopChan) })
util.SafeGoroutine(func() {
err := l.syncRoutine(ldap, stopChan)
if err != nil {
panic(err)
}
})
return nil
}
@ -63,18 +75,22 @@ func (l *LdapAutoSynchronizer) StopAutoSync(ldapId string) {
}
// autosync goroutine
func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) {
func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) error {
ticker := time.NewTicker(time.Duration(ldap.AutoSync) * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopChan:
logs.Info(fmt.Sprintf("autoSync goroutine for %s stopped", ldap.Id))
return
return nil
case <-ticker.C:
}
UpdateLdapSyncTime(ldap.Id)
err := UpdateLdapSyncTime(ldap.Id)
if err != nil {
return err
}
// fetch all users
conn, err := ldap.GetLdapConn()
if err != nil {
@ -100,24 +116,35 @@ func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) {
// LdapAutoSynchronizerStartUpAll
// start all autosync goroutine for existing ldap servers in each organizations
func (l *LdapAutoSynchronizer) LdapAutoSynchronizerStartUpAll() {
func (l *LdapAutoSynchronizer) LdapAutoSynchronizerStartUpAll() error {
organizations := []*Organization{}
err := adapter.Engine.Desc("created_time").Find(&organizations)
if err != nil {
logs.Info("failed to Star up LdapAutoSynchronizer; ")
}
for _, org := range organizations {
for _, ldap := range GetLdaps(org.Name) {
if ldap.AutoSync != 0 {
l.StartAutoSync(ldap.Id)
}
}
}
ldaps, err := GetLdaps(org.Name)
if err != nil {
return err
}
func UpdateLdapSyncTime(ldapId string) {
for _, ldap := range ldaps {
if ldap.AutoSync != 0 {
err = l.StartAutoSync(ldap.Id)
if err != nil {
return err
}
}
}
}
return nil
}
func UpdateLdapSyncTime(ldapId string) error {
_, err := adapter.Engine.ID(ldapId).Update(&Ldap{LastSync: util.GetCurrentTime()})
if err != nil {
panic(err)
return err
}
return nil
}

View File

@ -255,8 +255,12 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
uuids = append(uuids, user.Uuid)
}
organization := getOrganization("admin", owner)
ldap := GetLdap(ldapId)
organization, err := getOrganization("admin", owner)
if err != nil {
panic(err)
}
ldap, err := GetLdap(ldapId)
var dc []string
for _, basedn := range strings.Split(ldap.BaseDn, ",") {
@ -275,7 +279,11 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
tag := strings.Join(ou, ".")
for _, syncUser := range syncUsers {
existUuids := GetExistUuids(owner, uuids)
existUuids, err := GetExistUuids(owner, uuids)
if err != nil {
return nil, nil, err
}
found := false
if len(existUuids) > 0 {
for _, existUuid := range existUuids {
@ -287,10 +295,19 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
}
if !found {
score, _ := organization.GetInitScore()
score, err := organization.GetInitScore()
if err != nil {
return nil, nil, err
}
name, err := syncUser.buildLdapUserName()
if err != nil {
return nil, nil, err
}
newUser := &User{
Owner: owner,
Name: syncUser.buildLdapUserName(),
Name: name,
CreatedTime: util.GetCurrentTime(),
DisplayName: syncUser.buildLdapDisplayName(),
Avatar: organization.DefaultAvatar,
@ -303,7 +320,11 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
Ldap: syncUser.Uuid,
}
affected := AddUser(newUser)
affected, err := AddUser(newUser)
if err != nil {
return nil, nil, err
}
if !affected {
failedUsers = append(failedUsers, syncUser)
continue
@ -314,38 +335,38 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
return existUsers, failedUsers, err
}
func GetExistUuids(owner string, uuids []string) []string {
func GetExistUuids(owner string, uuids []string) ([]string, error) {
var existUuids []string
err := adapter.Engine.Table("user").Where("owner = ?", owner).Cols("ldap").
In("ldap", uuids).Select("DISTINCT ldap").Find(&existUuids)
if err != nil {
panic(err)
return existUuids, err
}
return existUuids
return existUuids, nil
}
func (ldapUser *LdapUser) buildLdapUserName() string {
func (ldapUser *LdapUser) buildLdapUserName() (string, error) {
user := User{}
uidWithNumber := fmt.Sprintf("%s_%s", ldapUser.Uid, ldapUser.UidNumber)
has, err := adapter.Engine.Where("name = ? or name = ?", ldapUser.Uid, uidWithNumber).Get(&user)
if err != nil {
panic(err)
return "", err
}
if has {
if user.Name == ldapUser.Uid {
return uidWithNumber
return uidWithNumber, nil
}
return fmt.Sprintf("%s_%s", uidWithNumber, randstr.Hex(6))
return fmt.Sprintf("%s_%s", uidWithNumber, randstr.Hex(6)), nil
}
if ldapUser.Uid != "" {
return ldapUser.Uid
return ldapUser.Uid, nil
}
return ldapUser.Cn
return ldapUser.Cn, nil
}
func (ldapUser *LdapUser) buildLdapDisplayName() string {

View File

@ -48,109 +48,94 @@ func GetMaskedMessages(messages []*Message) []*Message {
return messages
}
func GetMessageCount(owner, organization, field, value string) int {
func GetMessageCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Message{Organization: organization})
if err != nil {
panic(err)
return session.Count(&Message{Organization: organization})
}
return int(count)
}
func GetMessages(owner string) []*Message {
func GetMessages(owner string) ([]*Message, error) {
messages := []*Message{}
err := adapter.Engine.Desc("created_time").Find(&messages, &Message{Owner: owner})
if err != nil {
panic(err)
return messages, err
}
return messages
}
func GetChatMessages(chat string) []*Message {
func GetChatMessages(chat string) ([]*Message, error) {
messages := []*Message{}
err := adapter.Engine.Asc("created_time").Find(&messages, &Message{Chat: chat})
if err != nil {
panic(err)
return messages, err
}
return messages
}
func GetPaginationMessages(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Message {
func GetPaginationMessages(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Message, error) {
messages := []*Message{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&messages, &Message{Organization: organization})
if err != nil {
panic(err)
return messages, err
}
return messages
}
func getMessage(owner string, name string) *Message {
func getMessage(owner string, name string) (*Message, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
message := Message{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&message)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &message
return &message, nil
} else {
return nil
return nil, nil
}
}
func GetMessage(id string) *Message {
func GetMessage(id string) (*Message, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getMessage(owner, name)
}
func UpdateMessage(id string, message *Message) bool {
func UpdateMessage(id string, message *Message) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getMessage(owner, name) == nil {
return false
if m, err := getMessage(owner, name); err != nil {
return false, err
} else if m == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(message)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddMessage(message *Message) bool {
func AddMessage(message *Message) (bool, error) {
affected, err := adapter.Engine.Insert(message)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteMessage(message *Message) bool {
func DeleteMessage(message *Message) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{message.Owner, message.Name}).Delete(&Message{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteChatMessages(chat string) bool {
func DeleteChatMessages(chat string) (bool, error) {
affected, err := adapter.Engine.Delete(&Message{Chat: chat})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (p *Message) GetId() string {

View File

@ -83,7 +83,11 @@ func RecoverTfs(user *User, recoveryCode string) error {
return fmt.Errorf("recovery code not found")
}
affected := UpdateUser(user.GetId(), user, []string{"two_factor_auth"}, user.IsAdminUser())
affected, err := UpdateUser(user.GetId(), user, []string{"two_factor_auth"}, user.IsAdminUser())
if err != nil {
return err
}
if !affected {
return fmt.Errorf("")
}

View File

@ -100,7 +100,11 @@ func (mfa *SmsMfa) Enable(ctx *context.Context, user *User) error {
}
user.MultiFactorAuths = append(user.MultiFactorAuths, mfa.Config)
affected := UpdateUser(user.GetId(), user, []string{"multi_factor_auths"}, user.IsAdminUser())
affected, err := UpdateUser(user.GetId(), user, []string{"multi_factor_auths"}, user.IsAdminUser())
if err != nil {
return err
}
if !affected {
return fmt.Errorf("failed to enable two factor authentication")
}

View File

@ -44,5 +44,8 @@ func DoMigration() {
}
m := migrate.New(adapter.Engine, options, migrations)
m.Migrate()
err := m.Migrate()
if err != nil {
panic(err)
}
}

View File

@ -32,56 +32,51 @@ type Model struct {
IsEnabled bool `json:"isEnabled"`
}
func GetModelCount(owner, field, value string) int {
func GetModelCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Model{})
if err != nil {
panic(err)
return session.Count(&Model{})
}
return int(count)
}
func GetModels(owner string) []*Model {
func GetModels(owner string) ([]*Model, error) {
models := []*Model{}
err := adapter.Engine.Desc("created_time").Find(&models, &Model{Owner: owner})
if err != nil {
panic(err)
return models, err
}
return models
return models, nil
}
func GetPaginationModels(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Model {
func GetPaginationModels(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Model, error) {
models := []*Model{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&models)
if err != nil {
panic(err)
return models, err
}
return models
return models, nil
}
func getModel(owner string, name string) *Model {
func getModel(owner string, name string) (*Model, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
m := Model{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&m)
if err != nil {
panic(err)
return &m, err
}
if existed {
return &m
return &m, nil
} else {
return nil
return nil, nil
}
}
func GetModel(id string) *Model {
func GetModel(id string) (*Model, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getModel(owner, name)
}
@ -92,48 +87,53 @@ func UpdateModelWithCheck(id string, modelObj *Model) error {
if err != nil {
return err
}
UpdateModel(id, modelObj)
_, err = UpdateModel(id, modelObj)
if err != nil {
return err
}
return nil
}
func UpdateModel(id string, modelObj *Model) bool {
func UpdateModel(id string, modelObj *Model) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getModel(owner, name) == nil {
return false
if m, err := getModel(owner, name); err != nil {
return false, err
} else if m == nil {
return false, nil
}
if name != modelObj.Name {
err := modelChangeTrigger(name, modelObj.Name)
if err != nil {
return false
return false, err
}
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(modelObj)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, err
}
func AddModel(model *Model) bool {
func AddModel(model *Model) (bool, error) {
affected, err := adapter.Engine.Insert(model)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteModel(model *Model) bool {
func DeleteModel(model *Model) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{model.Owner, model.Name}).Delete(&Model{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (model *Model) GetId() string {

View File

@ -110,8 +110,12 @@ func GetOidcDiscovery(host string) OidcDiscovery {
}
func GetJsonWebKeySet() (jose.JSONWebKeySet, error) {
certs := GetCerts("admin")
jwks := jose.JSONWebKeySet{}
certs, err := GetCerts("admin")
if err != nil {
return jwks, err
}
// follows the protocol rfc 7517(draft)
// link here: https://self-issued.info/docs/draft-ietf-jose-json-web-key.html
// or https://datatracker.ietf.org/doc/html/draft-ietf-jose-json-web-key

View File

@ -70,92 +70,102 @@ type Organization struct {
AccountItems []*AccountItem `xorm:"varchar(3000)" json:"accountItems"`
}
func GetOrganizationCount(owner, field, value string) int {
func GetOrganizationCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Organization{})
if err != nil {
panic(err)
return session.Count(&Organization{})
}
return int(count)
}
func GetOrganizations(owner string) []*Organization {
func GetOrganizations(owner string) ([]*Organization, error) {
organizations := []*Organization{}
err := adapter.Engine.Desc("created_time").Find(&organizations, &Organization{Owner: owner})
if err != nil {
panic(err)
return nil, err
}
return organizations
return organizations, nil
}
func GetOrganizationsByFields(owner string, fields ...string) []*Organization {
func GetOrganizationsByFields(owner string, fields ...string) ([]*Organization, error) {
organizations := []*Organization{}
err := adapter.Engine.Desc("created_time").Cols(fields...).Find(&organizations, &Organization{Owner: owner})
if err != nil {
panic(err)
return nil, err
}
return organizations
return organizations, nil
}
func GetPaginationOrganizations(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Organization {
func GetPaginationOrganizations(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Organization, error) {
organizations := []*Organization{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&organizations)
if err != nil {
panic(err)
return nil, err
}
return organizations
return organizations, nil
}
func getOrganization(owner string, name string) *Organization {
func getOrganization(owner string, name string) (*Organization, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
organization := Organization{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&organization)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &organization
return &organization, nil
}
return nil
return nil, nil
}
func GetOrganization(id string) *Organization {
func GetOrganization(id string) (*Organization, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getOrganization(owner, name)
}
func GetMaskedOrganization(organization *Organization) *Organization {
func GetMaskedOrganization(organization *Organization, errs ...error) (*Organization, error) {
if len(errs) > 0 && errs[0] != nil {
return nil, errs[0]
}
if organization == nil {
return nil
return nil, nil
}
if organization.MasterPassword != "" {
organization.MasterPassword = "***"
}
return organization
return organization, nil
}
func GetMaskedOrganizations(organizations []*Organization) []*Organization {
func GetMaskedOrganizations(organizations []*Organization, errs ...error) ([]*Organization, error) {
if len(errs) > 0 && errs[0] != nil {
return nil, errs[0]
}
var err error
for _, organization := range organizations {
organization = GetMaskedOrganization(organization)
organization, err = GetMaskedOrganization(organization)
if err != nil {
return nil, err
}
return organizations
}
func UpdateOrganization(id string, organization *Organization) bool {
return organizations, nil
}
func UpdateOrganization(id string, organization *Organization) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getOrganization(owner, name) == nil {
return false
if org, err := getOrganization(owner, name); err != nil {
return false, err
} else if org == nil {
return false, nil
}
if name == "built-in" {
@ -165,7 +175,7 @@ func UpdateOrganization(id string, organization *Organization) bool {
if name != organization.Name {
err := organizationChangeTrigger(name, organization.Name)
if err != nil {
return false
return false, nil
}
}
@ -183,35 +193,35 @@ func UpdateOrganization(id string, organization *Organization) bool {
}
affected, err := session.Update(organization)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddOrganization(organization *Organization) bool {
func AddOrganization(organization *Organization) (bool, error) {
affected, err := adapter.Engine.Insert(organization)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteOrganization(organization *Organization) bool {
func DeleteOrganization(organization *Organization) (bool, error) {
if organization.Name == "built-in" {
return false
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{organization.Owner, organization.Name}).Delete(&Organization{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func GetOrganizationByUser(user *User) *Organization {
func GetOrganizationByUser(user *User) (*Organization, error) {
return getOrganization("admin", user.Owner)
}
@ -248,13 +258,21 @@ func CheckAccountItemModifyRule(accountItem *AccountItem, isAdmin bool, lang str
}
func GetDefaultApplication(id string) (*Application, error) {
organization := GetOrganization(id)
organization, err := GetOrganization(id)
if err != nil {
return nil, err
}
if organization == nil {
return nil, fmt.Errorf("The organization: %s does not exist", id)
}
if organization.DefaultApplication != "" {
defaultApplication := getApplication("admin", organization.DefaultApplication)
defaultApplication, err := getApplication("admin", organization.DefaultApplication)
if err != nil {
return nil, err
}
if defaultApplication == nil {
return nil, fmt.Errorf("The default application: %s does not exist", organization.DefaultApplication)
} else {
@ -263,9 +281,9 @@ func GetDefaultApplication(id string) (*Application, error) {
}
applications := []*Application{}
err := adapter.Engine.Asc("created_time").Find(&applications, &Application{Organization: organization.Name})
err = adapter.Engine.Asc("created_time").Find(&applications, &Application{Organization: organization.Name})
if err != nil {
panic(err)
return nil, err
}
if len(applications) == 0 {
@ -280,8 +298,15 @@ func GetDefaultApplication(id string) (*Application, error) {
}
}
extendApplicationWithProviders(defaultApplication)
extendApplicationWithOrg(defaultApplication)
err = extendApplicationWithProviders(defaultApplication)
if err != nil {
return nil, err
}
err = extendApplicationWithOrg(defaultApplication)
if err != nil {
return nil, err
}
return defaultApplication, nil
}

View File

@ -56,74 +56,71 @@ type Payment struct {
InvoiceUrl string `xorm:"varchar(255)" json:"invoiceUrl"`
}
func GetPaymentCount(owner, field, value string) int {
func GetPaymentCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Payment{})
if err != nil {
panic(err)
return session.Count(&Payment{})
}
return int(count)
}
func GetPayments(owner string) []*Payment {
func GetPayments(owner string) ([]*Payment, error) {
payments := []*Payment{}
err := adapter.Engine.Desc("created_time").Find(&payments, &Payment{Owner: owner})
if err != nil {
panic(err)
return nil, err
}
return payments
return payments, nil
}
func GetUserPayments(owner string, organization string, user string) []*Payment {
func GetUserPayments(owner string, organization string, user string) ([]*Payment, error) {
payments := []*Payment{}
err := adapter.Engine.Desc("created_time").Find(&payments, &Payment{Owner: owner, Organization: organization, User: user})
if err != nil {
panic(err)
return nil, err
}
return payments
return payments, nil
}
func GetPaginationPayments(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Payment {
func GetPaginationPayments(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Payment, error) {
payments := []*Payment{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&payments)
if err != nil {
panic(err)
return nil, err
}
return payments
return payments, nil
}
func getPayment(owner string, name string) *Payment {
func getPayment(owner string, name string) (*Payment, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
payment := Payment{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&payment)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &payment
return &payment, nil
} else {
return nil
return nil, nil
}
}
func GetPayment(id string) *Payment {
func GetPayment(id string) (*Payment, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getPayment(owner, name)
}
func UpdatePayment(id string, payment *Payment) bool {
func UpdatePayment(id string, payment *Payment) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getPayment(owner, name) == nil {
return false
if p, err := getPayment(owner, name); err != nil {
return false, err
} else if p == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(payment)
@ -131,42 +128,53 @@ func UpdatePayment(id string, payment *Payment) bool {
panic(err)
}
return affected != 0
return affected != 0, nil
}
func AddPayment(payment *Payment) bool {
func AddPayment(payment *Payment) (bool, error) {
affected, err := adapter.Engine.Insert(payment)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeletePayment(payment *Payment) bool {
func DeletePayment(payment *Payment) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{payment.Owner, payment.Name}).Delete(&Payment{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func notifyPayment(request *http.Request, body []byte, owner string, providerName string, productName string, paymentName string) (*Payment, error, string) {
provider := getProvider(owner, providerName)
provider, err := getProvider(owner, providerName)
if err != nil {
panic(err)
}
pProvider, cert, err := provider.getPaymentProvider()
if err != nil {
panic(err)
}
payment := getPayment(owner, paymentName)
payment, err := getPayment(owner, paymentName)
if err != nil {
panic(err)
}
if payment == nil {
err = fmt.Errorf("the payment: %s does not exist", paymentName)
return nil, err, pProvider.GetResponseError(err)
}
product := getProduct(owner, productName)
product, err := getProduct(owner, productName)
if err != nil {
panic(err)
}
if product == nil {
err = fmt.Errorf("the product: %s does not exist", productName)
return payment, err, pProvider.GetResponseError(err)
@ -201,14 +209,21 @@ func NotifyPayment(request *http.Request, body []byte, owner string, providerNam
payment.State = "Paid"
}
UpdatePayment(payment.GetId(), payment)
_, err = UpdatePayment(payment.GetId(), payment)
if err != nil {
panic(err)
}
}
return err, errorResponse
}
func invoicePayment(payment *Payment) (string, error) {
provider := getProvider(payment.Owner, payment.Provider)
provider, err := getProvider(payment.Owner, payment.Provider)
if err != nil {
panic(err)
}
if provider == nil {
return "", fmt.Errorf("the payment provider: %s does not exist", payment.Provider)
}
@ -237,7 +252,11 @@ func InvoicePayment(payment *Payment) (string, error) {
}
payment.InvoiceUrl = invoiceUrl
affected := UpdatePayment(payment.GetId(), payment)
affected, err := UpdatePayment(payment.GetId(), payment)
if err != nil {
return "", err
}
if !affected {
return "", fmt.Errorf("failed to update the payment: %s", payment.Name)
}

View File

@ -66,96 +66,97 @@ func (p *Permission) GetId() string {
return util.GetId(p.Owner, p.Name)
}
func GetPermissionCount(owner, field, value string) int {
func GetPermissionCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Permission{})
if err != nil {
panic(err)
return session.Count(&Permission{})
}
return int(count)
}
func GetPermissions(owner string) []*Permission {
func GetPermissions(owner string) ([]*Permission, error) {
permissions := []*Permission{}
err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner})
if err != nil {
panic(err)
return permissions, err
}
return permissions
return permissions, nil
}
func GetPaginationPermissions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Permission {
func GetPaginationPermissions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Permission, error) {
permissions := []*Permission{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&permissions)
if err != nil {
panic(err)
return permissions, err
}
return permissions
return permissions, nil
}
func getPermission(owner string, name string) *Permission {
func getPermission(owner string, name string) (*Permission, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
permission := Permission{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&permission)
if err != nil {
panic(err)
return &permission, err
}
if existed {
return &permission
return &permission, nil
} else {
return nil
return nil, nil
}
}
func GetPermission(id string) *Permission {
func GetPermission(id string) (*Permission, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getPermission(owner, name)
}
// checkPermissionValid verifies if the permission is valid
func checkPermissionValid(permission *Permission) {
func checkPermissionValid(permission *Permission) error {
enforcer := getEnforcer(permission)
enforcer.EnableAutoSave(false)
policies := getPolicies(permission)
_, err := enforcer.AddPolicies(policies)
if err != nil {
panic(err)
return err
}
if !HasRoleDefinition(enforcer.GetModel()) {
permission.Roles = []string{}
return
return nil
}
groupingPolicies := getGroupingPolicies(permission)
if len(groupingPolicies) > 0 {
_, err := enforcer.AddGroupingPolicies(groupingPolicies)
if err != nil {
panic(err)
}
return err
}
}
func UpdatePermission(id string, permission *Permission) bool {
checkPermissionValid(permission)
return nil
}
func UpdatePermission(id string, permission *Permission) (bool, error) {
err := checkPermissionValid(permission)
if err != nil {
return false, err
}
owner, name := util.GetOwnerAndNameFromId(id)
oldPermission := getPermission(owner, name)
oldPermission, err := getPermission(owner, name)
if oldPermission == nil {
return false
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(permission)
if err != nil {
panic(err)
return false, err
}
if affected != 0 {
@ -166,7 +167,7 @@ func UpdatePermission(id string, permission *Permission) bool {
if isEmpty {
err = adapter.Engine.DropTables(oldPermission.Adapter)
if err != nil {
panic(err)
return false, err
}
}
}
@ -174,13 +175,13 @@ func UpdatePermission(id string, permission *Permission) bool {
addPolicies(permission)
}
return affected != 0
return affected != 0, nil
}
func AddPermission(permission *Permission) bool {
func AddPermission(permission *Permission) (bool, error) {
affected, err := adapter.Engine.Insert(permission)
if err != nil {
panic(err)
return false, err
}
if affected != 0 {
@ -188,7 +189,7 @@ func AddPermission(permission *Permission) bool {
addPolicies(permission)
}
return affected != 0
return affected != 0, nil
}
func AddPermissions(permissions []*Permission) bool {
@ -239,10 +240,10 @@ func AddPermissionsInBatch(permissions []*Permission) bool {
return affected
}
func DeletePermission(permission *Permission) bool {
func DeletePermission(permission *Permission) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{permission.Owner, permission.Name}).Delete(&Permission{})
if err != nil {
panic(err)
return false, err
}
if affected != 0 {
@ -253,67 +254,67 @@ func DeletePermission(permission *Permission) bool {
if isEmpty {
err = adapter.Engine.DropTables(permission.Adapter)
if err != nil {
panic(err)
return false, err
}
}
}
}
return affected != 0
return affected != 0, nil
}
func GetPermissionsByUser(userId string) []*Permission {
func GetPermissionsByUser(userId string) ([]*Permission, error) {
permissions := []*Permission{}
err := adapter.Engine.Where("users like ?", "%"+userId+"\"%").Find(&permissions)
if err != nil {
panic(err)
return permissions, err
}
for i := range permissions {
permissions[i].Users = nil
}
return permissions
return permissions, nil
}
func GetPermissionsByRole(roleId string) []*Permission {
func GetPermissionsByRole(roleId string) ([]*Permission, error) {
permissions := []*Permission{}
err := adapter.Engine.Where("roles like ?", "%"+roleId+"\"%").Find(&permissions)
if err != nil {
panic(err)
return permissions, err
}
return permissions
return permissions, nil
}
func GetPermissionsByResource(resourceId string) []*Permission {
func GetPermissionsByResource(resourceId string) ([]*Permission, error) {
permissions := []*Permission{}
err := adapter.Engine.Where("resources like ?", "%"+resourceId+"\"%").Find(&permissions)
if err != nil {
panic(err)
return permissions, err
}
return permissions
return permissions, nil
}
func GetPermissionsBySubmitter(owner string, submitter string) []*Permission {
func GetPermissionsBySubmitter(owner string, submitter string) ([]*Permission, error) {
permissions := []*Permission{}
err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner, Submitter: submitter})
if err != nil {
panic(err)
return permissions, err
}
return permissions
return permissions, nil
}
func GetPermissionsByModel(owner string, model string) []*Permission {
func GetPermissionsByModel(owner string, model string) ([]*Permission, error) {
permissions := []*Permission{}
err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner, Model: model})
if err != nil {
panic(err)
return permissions, err
}
return permissions
return permissions, nil
}
func ContainsAsterisk(userId string, users []string) bool {

View File

@ -29,7 +29,11 @@ import (
func getEnforcer(permission *Permission) *casbin.Enforcer {
tableName := "permission_rule"
if len(permission.Adapter) != 0 {
adapterObj := getCasbinAdapter(permission.Owner, permission.Adapter)
adapterObj, err := getCasbinAdapter(permission.Owner, permission.Adapter)
if err != nil {
panic(err)
}
if adapterObj != nil && adapterObj.Table != "" {
tableName = adapterObj.Table
}
@ -42,7 +46,11 @@ func getEnforcer(permission *Permission) *casbin.Enforcer {
panic(err)
}
permissionModel := getModel(permission.Owner, permission.Model)
permissionModel, err := getModel(permission.Owner, permission.Model)
if err != nil {
panic(err)
}
m := model.Model{}
if permissionModel != nil {
m, err = GetBuiltInModel(permissionModel.ModelText)
@ -122,21 +130,30 @@ func getPolicies(permission *Permission) [][]string {
return policies
}
func getRolesInRole(roleId string, visited map[string]struct{}) []*Role {
role := GetRole(roleId)
func getRolesInRole(roleId string, visited map[string]struct{}) ([]*Role, error) {
role, err := GetRole(roleId)
if err != nil {
return []*Role{}, err
}
if role == nil {
return []*Role{}
return []*Role{}, nil
}
visited[roleId] = struct{}{}
roles := []*Role{role}
for _, subRole := range role.Roles {
if _, ok := visited[subRole]; !ok {
roles = append(roles, getRolesInRole(subRole, visited)...)
r, err := getRolesInRole(subRole, visited)
if err != nil {
return []*Role{}, err
}
roles = append(roles, r...)
}
}
return roles
return roles, nil
}
func getGroupingPolicies(permission *Permission) [][]string {
@ -147,8 +164,10 @@ func getGroupingPolicies(permission *Permission) [][]string {
for _, roleId := range permission.Roles {
visited := map[string]struct{}{}
rolesInRole := getRolesInRole(roleId, visited)
rolesInRole, err := getRolesInRole(roleId, visited)
if err != nil {
panic(err)
}
for _, role := range rolesInRole {
roleId := role.GetId()
for _, subUser := range role.Users {
@ -223,7 +242,11 @@ func removePolicies(permission *Permission) {
type CasbinRequest = []interface{}
func Enforce(permissionId string, request *CasbinRequest) bool {
permission := GetPermission(permissionId)
permission, err := GetPermission(permissionId)
if err != nil {
panic(err)
}
enforcer := getEnforcer(permission)
allow, err := enforcer.Enforce(*request...)
@ -234,7 +257,11 @@ func Enforce(permissionId string, request *CasbinRequest) bool {
}
func BatchEnforce(permissionId string, requests *[]CasbinRequest) []bool {
permission := GetPermission(permissionId)
permission, err := GetPermission(permissionId)
if err != nil {
panic(err)
}
enforcer := getEnforcer(permission)
allow, err := enforcer.BatchEnforce(*requests)
if err != nil {
@ -244,9 +271,18 @@ func BatchEnforce(permissionId string, requests *[]CasbinRequest) []bool {
}
func getAllValues(userId string, fn func(enforcer *casbin.Enforcer) []string) []string {
permissions := GetPermissionsByUser(userId)
permissions, err := GetPermissionsByUser(userId)
if err != nil {
panic(err)
}
for _, role := range GetAllRoles(userId) {
permissions = append(permissions, GetPermissionsByRole(role)...)
permissionsByRole, err := GetPermissionsByRole(role)
if err != nil {
panic(err)
}
permissions = append(permissions, permissionsByRole...)
}
var values []string
@ -270,7 +306,11 @@ func GetAllActions(userId string) []string {
}
func GetAllRoles(userId string) []string {
roles := GetRolesByUser(userId)
roles, err := GetRolesByUser(userId)
if err != nil {
panic(err)
}
var res []string
for _, role := range roles {
res = append(res, role.Name)

View File

@ -18,21 +18,29 @@ import (
"github.com/casdoor/casdoor/xlsx"
)
func getPermissionMap(owner string) map[string]*Permission {
func getPermissionMap(owner string) (map[string]*Permission, error) {
m := map[string]*Permission{}
permissions := GetPermissions(owner)
permissions, err := GetPermissions(owner)
if err != nil {
return nil, err
}
for _, permission := range permissions {
m[permission.GetId()] = permission
}
return m
return m, err
}
func UploadPermissions(owner string, fileId string) bool {
func UploadPermissions(owner string, fileId string) (bool, error) {
table := xlsx.ReadXlsxFile(fileId)
oldUserMap := getPermissionMap(owner)
oldUserMap, err := getPermissionMap(owner)
if err != nil {
return false, err
}
newPermissions := []*Permission{}
for index, line := range table {
if index == 0 || parseLineItem(&line, 0) == "" {
@ -71,7 +79,7 @@ func UploadPermissions(owner string, fileId string) bool {
}
if len(newPermissions) == 0 {
return false
return false, nil
}
return AddPermissionsInBatch(newPermissions)
return AddPermissionsInBatch(newPermissions), nil
}

View File

@ -37,109 +37,115 @@ type Plan struct {
Options []string `xorm:"-" json:"options"`
}
func GetPlanCount(owner, field, value string) int {
func GetPlanCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Plan{})
if err != nil {
panic(err)
return session.Count(&Plan{})
}
return int(count)
}
func GetPlans(owner string) []*Plan {
func GetPlans(owner string) ([]*Plan, error) {
plans := []*Plan{}
err := adapter.Engine.Desc("created_time").Find(&plans, &Plan{Owner: owner})
if err != nil {
panic(err)
return plans, err
}
return plans
return plans, nil
}
func GetPaginatedPlans(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Plan {
func GetPaginatedPlans(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Plan, error) {
plans := []*Plan{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&plans)
if err != nil {
panic(err)
return plans, err
}
return plans
return plans, nil
}
func getPlan(owner, name string) *Plan {
func getPlan(owner, name string) (*Plan, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
plan := Plan{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&plan)
if err != nil {
panic(err)
return &plan, err
}
if existed {
return &plan
return &plan, nil
} else {
return nil
return nil, nil
}
}
func GetPlan(id string) *Plan {
func GetPlan(id string) (*Plan, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getPlan(owner, name)
}
func UpdatePlan(id string, plan *Plan) bool {
func UpdatePlan(id string, plan *Plan) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getPlan(owner, name) == nil {
return false
if p, err := getPlan(owner, name); err != nil {
return false, err
} else if p == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(plan)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddPlan(plan *Plan) bool {
func AddPlan(plan *Plan) (bool, error) {
affected, err := adapter.Engine.Insert(plan)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeletePlan(plan *Plan) bool {
func DeletePlan(plan *Plan) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{plan.Owner, plan.Name}).Delete(plan)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (plan *Plan) GetId() string {
return fmt.Sprintf("%s/%s", plan.Owner, plan.Name)
}
func Subscribe(owner string, user string, plan string, pricing string) *Subscription {
selectedPricing := GetPricing(fmt.Sprintf("%s/%s", owner, pricing))
func Subscribe(owner string, user string, plan string, pricing string) (*Subscription, error) {
selectedPricing, err := GetPricing(fmt.Sprintf("%s/%s", owner, pricing))
if err != nil {
return nil, err
}
valid := selectedPricing != nil && selectedPricing.IsEnabled
if !valid {
return nil
return nil, nil
}
planBelongToPricing := selectedPricing.HasPlan(owner, plan)
planBelongToPricing, err := selectedPricing.HasPlan(owner, plan)
if err != nil {
return nil, err
}
if planBelongToPricing {
newSubscription := NewSubscription(owner, user, plan, selectedPricing.TrialDuration)
affected := AddSubscription(newSubscription)
affected, err := AddSubscription(newSubscription)
if err != nil {
return nil, err
}
if affected {
return newSubscription
return newSubscription, nil
}
}
return nil
return nil, nil
}

View File

@ -42,96 +42,97 @@ type Pricing struct {
State string `xorm:"varchar(100)" json:"state"`
}
func GetPricingCount(owner, field, value string) int {
func GetPricingCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Pricing{})
if err != nil {
panic(err)
return session.Count(&Pricing{})
}
return int(count)
}
func GetPricings(owner string) []*Pricing {
func GetPricings(owner string) ([]*Pricing, error) {
pricings := []*Pricing{}
err := adapter.Engine.Desc("created_time").Find(&pricings, &Pricing{Owner: owner})
if err != nil {
panic(err)
}
return pricings
return pricings, err
}
func GetPaginatedPricings(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Pricing {
return pricings, nil
}
func GetPaginatedPricings(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Pricing, error) {
pricings := []*Pricing{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&pricings)
if err != nil {
panic(err)
return pricings, err
}
return pricings
return pricings, nil
}
func getPricing(owner, name string) *Pricing {
func getPricing(owner, name string) (*Pricing, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
pricing := Pricing{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&pricing)
if err != nil {
panic(err)
return &pricing, err
}
if existed {
return &pricing
return &pricing, nil
} else {
return nil
return nil, nil
}
}
func GetPricing(id string) *Pricing {
func GetPricing(id string) (*Pricing, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getPricing(owner, name)
}
func UpdatePricing(id string, pricing *Pricing) bool {
func UpdatePricing(id string, pricing *Pricing) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getPricing(owner, name) == nil {
return false
if p, err := getPricing(owner, name); err != nil {
return false, err
} else if p == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(pricing)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddPricing(pricing *Pricing) bool {
func AddPricing(pricing *Pricing) (bool, error) {
affected, err := adapter.Engine.Insert(pricing)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeletePricing(pricing *Pricing) bool {
func DeletePricing(pricing *Pricing) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{pricing.Owner, pricing.Name}).Delete(pricing)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (pricing *Pricing) GetId() string {
return fmt.Sprintf("%s/%s", pricing.Owner, pricing.Name)
}
func (pricing *Pricing) HasPlan(owner string, plan string) bool {
selectedPlan := GetPlan(fmt.Sprintf("%s/%s", owner, plan))
func (pricing *Pricing) HasPlan(owner string, plan string) (bool, error) {
selectedPlan, err := GetPlan(fmt.Sprintf("%s/%s", owner, plan))
if err != nil {
return false, err
}
if selectedPlan == nil {
return false
return false, nil
}
result := false
@ -143,5 +144,5 @@ func (pricing *Pricing) HasPlan(owner string, plan string) bool {
}
}
return result
return result, nil
}

View File

@ -43,90 +43,87 @@ type Product struct {
ProviderObjs []*Provider `xorm:"-" json:"providerObjs"`
}
func GetProductCount(owner, field, value string) int {
func GetProductCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Product{})
if err != nil {
panic(err)
return session.Count(&Product{})
}
return int(count)
}
func GetProducts(owner string) []*Product {
func GetProducts(owner string) ([]*Product, error) {
products := []*Product{}
err := adapter.Engine.Desc("created_time").Find(&products, &Product{Owner: owner})
if err != nil {
panic(err)
return products, err
}
return products
return products, nil
}
func GetPaginationProducts(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Product {
func GetPaginationProducts(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Product, error) {
products := []*Product{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&products)
if err != nil {
panic(err)
return products, err
}
return products
return products, nil
}
func getProduct(owner string, name string) *Product {
func getProduct(owner string, name string) (*Product, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
product := Product{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&product)
if err != nil {
panic(err)
return &product, nil
}
if existed {
return &product
return &product, nil
} else {
return nil
return nil, nil
}
}
func GetProduct(id string) *Product {
func GetProduct(id string) (*Product, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getProduct(owner, name)
}
func UpdateProduct(id string, product *Product) bool {
func UpdateProduct(id string, product *Product) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getProduct(owner, name) == nil {
return false
if p, err := getProduct(owner, name); err != nil {
return false, err
} else if p == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(product)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddProduct(product *Product) bool {
func AddProduct(product *Product) (bool, error) {
affected, err := adapter.Engine.Insert(product)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteProduct(product *Product) bool {
func DeleteProduct(product *Product) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{product.Owner, product.Name}).Delete(&Product{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (product *Product) GetId() string {
@ -143,7 +140,11 @@ func (product *Product) isValidProvider(provider *Provider) bool {
}
func (product *Product) getProvider(providerId string) (*Provider, error) {
provider := getProvider(product.Owner, providerId)
provider, err := getProvider(product.Owner, providerId)
if err != nil {
return nil, err
}
if provider == nil {
return nil, fmt.Errorf("the payment provider: %s does not exist", providerId)
}
@ -156,7 +157,11 @@ func (product *Product) getProvider(providerId string) (*Provider, error) {
}
func BuyProduct(id string, providerName string, user *User, host string) (string, error) {
product := GetProduct(id)
product, err := GetProduct(id)
if err != nil {
return "", err
}
if product == nil {
return "", fmt.Errorf("the product: %s does not exist", id)
}
@ -205,7 +210,11 @@ func BuyProduct(id string, providerName string, user *User, host string) (string
ReturnUrl: product.ReturnUrl,
State: "Created",
}
affected := AddPayment(&payment)
affected, err := AddPayment(&payment)
if err != nil {
return "", err
}
if !affected {
return "", fmt.Errorf("failed to add payment: %s", util.StructToJson(payment))
}
@ -213,17 +222,23 @@ func BuyProduct(id string, providerName string, user *User, host string) (string
return payUrl, err
}
func ExtendProductWithProviders(product *Product) {
func ExtendProductWithProviders(product *Product) error {
if product == nil {
return
return nil
}
product.ProviderObjs = []*Provider{}
m := getProviderMap(product.Owner)
m, err := getProviderMap(product.Owner)
if err != nil {
return err
}
for _, providerItem := range product.Providers {
if provider, ok := m[providerItem]; ok {
product.ProviderObjs = append(product.ProviderObjs, provider)
}
}
return nil
}

View File

@ -27,9 +27,9 @@ import (
func TestProduct(t *testing.T) {
InitConfig()
product := GetProduct("admin/product_123")
provider := getProvider(product.Owner, "provider_pay_alipay")
cert := getCert(product.Owner, "cert-pay-alipay")
product, _ := GetProduct("admin/product_123")
provider, _ := getProvider(product.Owner, "provider_pay_alipay")
cert, _ := getCert(product.Owner, "cert-pay-alipay")
pProvider, err := pp.GetPaymentProvider(provider.Type, provider.ClientId, provider.ClientSecret, provider.Host, cert.Certificate, cert.PrivateKey, cert.AuthorityPublicKey, cert.AuthorityRootPublicKey, provider.ClientId2)
if err != nil {
panic(err)

View File

@ -103,103 +103,93 @@ func GetMaskedProviders(providers []*Provider, isMaskEnabled bool) []*Provider {
return providers
}
func GetProviderCount(owner, field, value string) int {
func GetProviderCount(owner, field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Provider{})
if err != nil {
panic(err)
return session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Provider{})
}
return int(count)
}
func GetGlobalProviderCount(field, value string) int {
func GetGlobalProviderCount(field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Count(&Provider{})
if err != nil {
panic(err)
return session.Count(&Provider{})
}
return int(count)
}
func GetProviders(owner string) []*Provider {
func GetProviders(owner string) ([]*Provider, error) {
providers := []*Provider{}
err := adapter.Engine.Where("owner = ? or owner = ? ", "admin", owner).Desc("created_time").Find(&providers, &Provider{})
if err != nil {
panic(err)
return providers, err
}
return providers
return providers, nil
}
func GetGlobalProviders() []*Provider {
func GetGlobalProviders() ([]*Provider, error) {
providers := []*Provider{}
err := adapter.Engine.Desc("created_time").Find(&providers)
if err != nil {
panic(err)
return providers, err
}
return providers
return providers, nil
}
func GetPaginationProviders(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Provider {
func GetPaginationProviders(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Provider, error) {
providers := []*Provider{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Where("owner = ? or owner = ? ", "admin", owner).Find(&providers)
if err != nil {
panic(err)
return providers, err
}
return providers
return providers, nil
}
func GetPaginationGlobalProviders(offset, limit int, field, value, sortField, sortOrder string) []*Provider {
func GetPaginationGlobalProviders(offset, limit int, field, value, sortField, sortOrder string) ([]*Provider, error) {
providers := []*Provider{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Find(&providers)
if err != nil {
panic(err)
return providers, err
}
return providers
return providers, nil
}
func getProvider(owner string, name string) *Provider {
func getProvider(owner string, name string) (*Provider, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
provider := Provider{Name: name}
existed, err := adapter.Engine.Get(&provider)
if err != nil {
panic(err)
return &provider, err
}
if existed {
return &provider
return &provider, nil
} else {
return nil
return nil, nil
}
}
func GetProvider(id string) *Provider {
func GetProvider(id string) (*Provider, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getProvider(owner, name)
}
func getDefaultAiProvider() *Provider {
func getDefaultAiProvider() (*Provider, error) {
provider := Provider{Owner: "admin", Category: "AI"}
existed, err := adapter.Engine.Get(&provider)
if err != nil {
panic(err)
return &provider, err
}
if !existed {
return nil
return nil, nil
}
return &provider
return &provider, nil
}
func GetWechatMiniProgramProvider(application *Application) *Provider {
@ -212,16 +202,18 @@ func GetWechatMiniProgramProvider(application *Application) *Provider {
return nil
}
func UpdateProvider(id string, provider *Provider) bool {
func UpdateProvider(id string, provider *Provider) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getProvider(owner, name) == nil {
return false
if p, err := getProvider(owner, name); err != nil {
return false, err
} else if p == nil {
return false, nil
}
if name != provider.Name {
err := providerChangeTrigger(name, provider.Name)
if err != nil {
return false
return false, nil
}
}
@ -238,37 +230,41 @@ func UpdateProvider(id string, provider *Provider) bool {
affected, err := session.Update(provider)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddProvider(provider *Provider) bool {
func AddProvider(provider *Provider) (bool, error) {
provider.Endpoint = util.GetEndPoint(provider.Endpoint)
provider.IntranetEndpoint = util.GetEndPoint(provider.IntranetEndpoint)
affected, err := adapter.Engine.Insert(provider)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteProvider(provider *Provider) bool {
func DeleteProvider(provider *Provider) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{provider.Owner, provider.Name}).Delete(&Provider{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (p *Provider) getPaymentProvider() (pp.PaymentProvider, *Cert, error) {
cert := &Cert{}
if p.Cert != "" {
cert = getCert(p.Owner, p.Cert)
cert, err := getCert(p.Owner, p.Cert)
if err != nil {
return nil, nil, err
}
if cert == nil {
return nil, nil, fmt.Errorf("the cert: %s does not exist", p.Cert)
}
@ -309,7 +305,11 @@ func GetCaptchaProviderByApplication(applicationId, isCurrentProvider, lang stri
if isCurrentProvider == "true" {
return GetCaptchaProviderByOwnerName(applicationId, lang)
}
application := GetApplication(applicationId)
application, err := GetApplication(applicationId)
if err != nil {
return nil, err
}
if application == nil || len(application.Providers) == 0 {
return nil, fmt.Errorf(i18n.Translate(lang, "provider:Invalid application id"))
}

View File

@ -26,11 +26,7 @@ import (
var logPostOnly bool
func init() {
var err error
logPostOnly, err = conf.GetConfigBool("logPostOnly")
if err != nil {
// panic(err)
}
logPostOnly = conf.GetConfigBool("logPostOnly")
}
type Record struct {
@ -108,49 +104,48 @@ func AddRecord(record *Record) bool {
return affected != 0
}
func GetRecordCount(field, value string, filterRecord *Record) int {
func GetRecordCount(field, value string, filterRecord *Record) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Count(filterRecord)
if err != nil {
panic(err)
return session.Count(filterRecord)
}
return int(count)
}
func GetRecords() []*Record {
func GetRecords() ([]*Record, error) {
records := []*Record{}
err := adapter.Engine.Desc("id").Find(&records)
if err != nil {
panic(err)
return records, err
}
return records
return records, nil
}
func GetPaginationRecords(offset, limit int, field, value, sortField, sortOrder string, filterRecord *Record) []*Record {
func GetPaginationRecords(offset, limit int, field, value, sortField, sortOrder string, filterRecord *Record) ([]*Record, error) {
records := []*Record{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Find(&records, filterRecord)
if err != nil {
panic(err)
return records, err
}
return records
return records, nil
}
func GetRecordsByField(record *Record) []*Record {
func GetRecordsByField(record *Record) ([]*Record, error) {
records := []*Record{}
err := adapter.Engine.Find(&records, record)
if err != nil {
panic(err)
return records, err
}
return records
return records, nil
}
func SendWebhooks(record *Record) error {
webhooks := getWebhooksByOrganization(record.Organization)
webhooks, err := getWebhooksByOrganization(record.Organization)
if err != nil {
return err
}
for _, webhook := range webhooks {
if !webhook.IsEnabled {
continue
@ -166,7 +161,11 @@ func SendWebhooks(record *Record) error {
if matched {
if webhook.IsUserExtended {
user := GetMaskedUser(getUser(record.Organization, record.User))
user, err := GetMaskedUser(getUser(record.Organization, record.User))
if err != nil {
return err
}
record.ExtendedUser = user
}

View File

@ -39,17 +39,12 @@ type Resource struct {
Description string `xorm:"varchar(255)" json:"description"`
}
func GetResourceCount(owner, user, field, value string) int {
func GetResourceCount(owner, user, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Resource{User: user})
if err != nil {
panic(err)
return session.Count(&Resource{User: user})
}
return int(count)
}
func GetResources(owner string, user string) []*Resource {
func GetResources(owner string, user string) ([]*Resource, error) {
if owner == "built-in" {
owner = ""
user = ""
@ -58,13 +53,13 @@ func GetResources(owner string, user string) []*Resource {
resources := []*Resource{}
err := adapter.Engine.Desc("created_time").Find(&resources, &Resource{Owner: owner, User: user})
if err != nil {
panic(err)
return resources, err
}
return resources
return resources, err
}
func GetPaginationResources(owner, user string, offset, limit int, field, value, sortField, sortOrder string) []*Resource {
func GetPaginationResources(owner, user string, offset, limit int, field, value, sortField, sortOrder string) ([]*Resource, error) {
if owner == "built-in" {
owner = ""
user = ""
@ -74,70 +69,74 @@ func GetPaginationResources(owner, user string, offset, limit int, field, value,
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&resources, &Resource{User: user})
if err != nil {
panic(err)
return resources, err
}
return resources
return resources, nil
}
func getResource(owner string, name string) *Resource {
func getResource(owner string, name string) (*Resource, error) {
resource := Resource{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&resource)
if err != nil {
panic(err)
return &resource, err
}
if existed {
return &resource
return &resource, nil
}
return nil
return nil, nil
}
func GetResource(id string) *Resource {
func GetResource(id string) (*Resource, error) {
owner, name := util.GetOwnerAndNameFromIdNoCheck(id)
return getResource(owner, name)
}
func UpdateResource(id string, resource *Resource) bool {
func UpdateResource(id string, resource *Resource) (bool, error) {
owner, name := util.GetOwnerAndNameFromIdNoCheck(id)
if getResource(owner, name) == nil {
return false
if r, err := getResource(owner, name); err != nil {
return false, err
} else if r == nil {
return false, nil
}
_, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(resource)
if err != nil {
panic(err)
return false, err
}
// return affected != 0
return true
return true, nil
}
func AddResource(resource *Resource) bool {
func AddResource(resource *Resource) (bool, error) {
affected, err := adapter.Engine.Insert(resource)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteResource(resource *Resource) bool {
func DeleteResource(resource *Resource) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{resource.Owner, resource.Name}).Delete(&Resource{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (resource *Resource) GetId() string {
return fmt.Sprintf("%s/%s", resource.Owner, resource.Name)
}
func AddOrUpdateResource(resource *Resource) bool {
if getResource(resource.Owner, resource.Name) == nil {
func AddOrUpdateResource(resource *Resource) (bool, error) {
if r, err := getResource(resource.Owner, resource.Name); err != nil {
return false, err
} else if r == nil {
return AddResource(resource)
} else {
return UpdateResource(resource.GetId(), resource)

View File

@ -36,79 +36,90 @@ type Role struct {
IsEnabled bool `json:"isEnabled"`
}
func GetRoleCount(owner, field, value string) int {
func GetRoleCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Role{})
if err != nil {
panic(err)
return session.Count(&Role{})
}
return int(count)
}
func GetRoles(owner string) []*Role {
func GetRoles(owner string) ([]*Role, error) {
roles := []*Role{}
err := adapter.Engine.Desc("created_time").Find(&roles, &Role{Owner: owner})
if err != nil {
panic(err)
return roles, err
}
return roles
return roles, nil
}
func GetPaginationRoles(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Role {
func GetPaginationRoles(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Role, error) {
roles := []*Role{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&roles)
if err != nil {
panic(err)
return roles, err
}
return roles
return roles, nil
}
func getRole(owner string, name string) *Role {
func getRole(owner string, name string) (*Role, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
role := Role{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&role)
if err != nil {
panic(err)
return &role, err
}
if existed {
return &role
return &role, nil
} else {
return nil
return nil, nil
}
}
func GetRole(id string) *Role {
func GetRole(id string) (*Role, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getRole(owner, name)
}
func UpdateRole(id string, role *Role) bool {
func UpdateRole(id string, role *Role) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
oldRole := getRole(owner, name)
oldRole, err := getRole(owner, name)
if err != nil {
return false, err
}
if oldRole == nil {
return false
return false, nil
}
visited := map[string]struct{}{}
permissions := GetPermissionsByRole(id)
permissions, err := GetPermissionsByRole(id)
if err != nil {
return false, err
}
for _, permission := range permissions {
removeGroupingPolicies(permission)
removePolicies(permission)
visited[permission.GetId()] = struct{}{}
}
ancestorRoles := GetAncestorRoles(id)
ancestorRoles, err := GetAncestorRoles(id)
if err != nil {
return false, err
}
for _, r := range ancestorRoles {
permissions := GetPermissionsByRole(r.GetId())
permissions, err := GetPermissionsByRole(r.GetId())
if err != nil {
return false, err
}
for _, permission := range permissions {
permissionId := permission.GetId()
if _, ok := visited[permissionId]; !ok {
@ -121,27 +132,38 @@ func UpdateRole(id string, role *Role) bool {
if name != role.Name {
err := roleChangeTrigger(name, role.Name)
if err != nil {
return false
return false, nil
}
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(role)
if err != nil {
panic(err)
return false, err
}
visited = map[string]struct{}{}
newRoleID := role.GetId()
permissions = GetPermissionsByRole(newRoleID)
permissions, err = GetPermissionsByRole(newRoleID)
if err != nil {
return false, err
}
for _, permission := range permissions {
addGroupingPolicies(permission)
addPolicies(permission)
visited[permission.GetId()] = struct{}{}
}
ancestorRoles = GetAncestorRoles(newRoleID)
ancestorRoles, err = GetAncestorRoles(newRoleID)
if err != nil {
return false, err
}
for _, r := range ancestorRoles {
permissions := GetPermissionsByRole(r.GetId())
permissions, err := GetPermissionsByRole(r.GetId())
if err != nil {
return false, err
}
for _, permission := range permissions {
permissionId := permission.GetId()
if _, ok := visited[permissionId]; !ok {
@ -151,16 +173,16 @@ func UpdateRole(id string, role *Role) bool {
}
}
return affected != 0
return affected != 0, nil
}
func AddRole(role *Role) bool {
func AddRole(role *Role) (bool, error) {
affected, err := adapter.Engine.Insert(role)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddRoles(roles []*Role) bool {
@ -202,38 +224,45 @@ func AddRolesInBatch(roles []*Role) bool {
return affected
}
func DeleteRole(role *Role) bool {
func DeleteRole(role *Role) (bool, error) {
roleId := role.GetId()
permissions := GetPermissionsByRole(roleId)
permissions, err := GetPermissionsByRole(roleId)
if err != nil {
return false, err
}
for _, permission := range permissions {
permission.Roles = util.DeleteVal(permission.Roles, roleId)
UpdatePermission(permission.GetId(), permission)
_, err := UpdatePermission(permission.GetId(), permission)
if err != nil {
return false, err
}
}
affected, err := adapter.Engine.ID(core.PK{role.Owner, role.Name}).Delete(&Role{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (role *Role) GetId() string {
return fmt.Sprintf("%s/%s", role.Owner, role.Name)
}
func GetRolesByUser(userId string) []*Role {
func GetRolesByUser(userId string) ([]*Role, error) {
roles := []*Role{}
err := adapter.Engine.Where("users like ?", "%"+userId+"\"%").Find(&roles)
if err != nil {
panic(err)
return roles, err
}
for i := range roles {
roles[i].Users = nil
}
return roles
return roles, nil
}
func roleChangeTrigger(oldName string, newName string) error {
@ -250,6 +279,7 @@ func roleChangeTrigger(oldName string, newName string) error {
if err != nil {
return err
}
for _, role := range roles {
for j, u := range role.Roles {
owner, name := util.GetOwnerAndNameFromId(u)
@ -268,6 +298,7 @@ func roleChangeTrigger(oldName string, newName string) error {
if err != nil {
return err
}
for _, permission := range permissions {
for j, u := range permission.Roles {
// u = organization/username
@ -293,17 +324,17 @@ func GetMaskedRoles(roles []*Role) []*Role {
return roles
}
func GetRolesByNamePrefix(owner string, prefix string) []*Role {
func GetRolesByNamePrefix(owner string, prefix string) ([]*Role, error) {
roles := []*Role{}
err := adapter.Engine.Where("owner=? and name like ?", owner, prefix+"%").Find(&roles)
if err != nil {
panic(err)
return roles, err
}
return roles
return roles, nil
}
func GetAncestorRoles(roleId string) []*Role {
func GetAncestorRoles(roleId string) ([]*Role, error) {
var (
result []*Role
roleMap = make(map[string]*Role)
@ -312,7 +343,11 @@ func GetAncestorRoles(roleId string) []*Role {
owner, _ := util.GetOwnerAndNameFromIdNoCheck(roleId)
allRoles := GetRoles(owner)
allRoles, err := GetRoles(owner)
if err != nil {
return nil, err
}
for _, r := range allRoles {
roleMap[r.GetId()] = r
}
@ -331,7 +366,7 @@ func GetAncestorRoles(roleId string) []*Role {
}
}
return result
return result, nil
}
// containsRole is a helper function to check if a slice of roles contains a specific roleId

View File

@ -18,21 +18,29 @@ import (
"github.com/casdoor/casdoor/xlsx"
)
func getRoleMap(owner string) map[string]*Role {
func getRoleMap(owner string) (map[string]*Role, error) {
m := map[string]*Role{}
roles := GetRoles(owner)
roles, err := GetRoles(owner)
if err != nil {
return nil, err
}
for _, role := range roles {
m[role.GetId()] = role
}
return m
return m, nil
}
func UploadRoles(owner string, fileId string) bool {
func UploadRoles(owner string, fileId string) (bool, error) {
table := xlsx.ReadXlsxFile(fileId)
oldUserMap := getRoleMap(owner)
oldUserMap, err := getRoleMap(owner)
if err != nil {
return false, err
}
newRoles := []*Role{}
for index, line := range table {
if index == 0 || parseLineItem(&line, 0) == "" {
@ -57,7 +65,7 @@ func UploadRoles(owner string, fileId string) bool {
}
if len(newRoles) == 0 {
return false
return false, nil
}
return AddRolesInBatch(newRoles)
return AddRolesInBatch(newRoles), nil
}

View File

@ -105,7 +105,11 @@ func NewSamlResponse(user *User, host string, certificate string, destination st
roles := attributes.CreateElement("saml:Attribute")
roles.CreateAttr("Name", "Roles")
roles.CreateAttr("NameFormat", "urn:oasis:names:tc:SAML:2.0:attrname-format:basic")
ExtendUserWithRolesAndPermissions(user)
err := ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, err
}
for _, role := range user.Roles {
roles.CreateElement("saml:AttributeValue").CreateAttr("xsi:type", "xs:string").Element().SetText(role.Name)
}
@ -186,7 +190,11 @@ type Attribute struct {
}
func GetSamlMeta(application *Application, host string) (*IdpEntityDescriptor, error) {
cert := getCertByApplication(application)
cert, err := getCertByApplication(application)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(cert.Certificate))
certificate := base64.StdEncoding.EncodeToString(block.Bytes)
@ -263,7 +271,11 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
}
// get certificate string
cert := getCertByApplication(application)
cert, err := getCertByApplication(application)
if err != nil {
return "", "", "", err
}
block, _ := pem.Decode([]byte(cert.Certificate))
certificate := base64.StdEncoding.EncodeToString(block.Bytes)

View File

@ -43,7 +43,10 @@ func ParseSamlResponse(samlResponse string, provider *Provider, host string) (st
}
func GenerateSamlRequest(id, relayState, host, lang string) (auth string, method string, err error) {
provider := GetProvider(id)
provider, err := GetProvider(id)
if err != nil {
return "", "", err
}
if provider.Category != "SAML" {
return "", "", fmt.Errorf(i18n.Translate(lang, "saml_sp:provider %s's category is not SAML"), provider.Name)
}
@ -92,27 +95,33 @@ func buildSp(provider *Provider, samlResponse string, host string) (*saml2.SAMLS
}
if provider.EnableSignAuthnRequest {
sp.SignAuthnRequests = true
sp.SPKeyStore = buildSpKeyStore()
sp.SPKeyStore, err = buildSpKeyStore()
if err != nil {
return nil, err
}
}
return sp, nil
}
func buildSpKeyStore() dsig.X509KeyStore {
func buildSpKeyStore() (dsig.X509KeyStore, error) {
keyPair, err := tls.LoadX509KeyPair("object/token_jwt_key.pem", "object/token_jwt_key.key")
if err != nil {
panic(err)
return nil, err
}
return &dsig.TLSCertKeyStore{
PrivateKey: keyPair.PrivateKey,
Certificate: keyPair.Certificate,
}
}, nil
}
func buildSpCertificateStore(provider *Provider, samlResponse string) (dsig.MemoryX509CertificateStore, error) {
func buildSpCertificateStore(provider *Provider, samlResponse string) (certStore dsig.MemoryX509CertificateStore, err error) {
certEncodedData := ""
if samlResponse != "" {
certEncodedData = getCertificateFromSamlResponse(samlResponse, provider.Type)
certEncodedData, err = getCertificateFromSamlResponse(samlResponse, provider.Type)
if err != nil {
return
}
} else if provider.IdP != "" {
certEncodedData = provider.IdP
}
@ -126,17 +135,18 @@ func buildSpCertificateStore(provider *Provider, samlResponse string) (dsig.Memo
return dsig.MemoryX509CertificateStore{}, err
}
certStore := dsig.MemoryX509CertificateStore{
certStore = dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{idpCert},
}
return certStore, nil
}
func getCertificateFromSamlResponse(samlResponse string, providerType string) string {
func getCertificateFromSamlResponse(samlResponse string, providerType string) (string, error) {
de, err := base64.StdEncoding.DecodeString(samlResponse)
if err != nil {
panic(err)
return "", err
}
deStr := strings.Replace(string(de), "\n", "", -1)
tagMap := map[string]string{
"Aliyun IDaaS": "ds",
@ -145,5 +155,5 @@ func getCertificateFromSamlResponse(samlResponse string, providerType string) st
tag := tagMap[providerType]
expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)</%s:X509Certificate>", tag, tag)
res := regexp.MustCompile(expression).FindStringSubmatch(deStr)
return res[1]
return res[1], nil
}

View File

@ -36,7 +36,7 @@ type Session struct {
SessionId []string `json:"sessionId"`
}
func GetSessions(owner string) []*Session {
func GetSessions(owner string) ([]*Session, error) {
sessions := []*Session{}
var err error
if owner != "" {
@ -45,61 +45,58 @@ func GetSessions(owner string) []*Session {
err = adapter.Engine.Desc("created_time").Find(&sessions)
}
if err != nil {
panic(err)
return sessions, err
}
return sessions
return sessions, nil
}
func GetPaginationSessions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Session {
func GetPaginationSessions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Session, error) {
sessions := []*Session{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&sessions)
if err != nil {
panic(err)
return sessions, err
}
return sessions
return sessions, nil
}
func GetSessionCount(owner, field, value string) int {
func GetSessionCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Session{})
if err != nil {
panic(err)
return session.Count(&Session{})
}
return int(count)
}
func GetSingleSession(id string) *Session {
func GetSingleSession(id string) (*Session, error) {
owner, name, application := util.GetOwnerAndNameAndOtherFromId(id)
session := Session{Owner: owner, Name: name, Application: application}
get, err := adapter.Engine.Get(&session)
if err != nil {
panic(err)
return &session, err
}
if !get {
return nil
return nil, nil
}
return &session
return &session, nil
}
func UpdateSession(id string, session *Session) bool {
func UpdateSession(id string, session *Session) (bool, error) {
owner, name, application := util.GetOwnerAndNameAndOtherFromId(id)
if GetSingleSession(id) == nil {
return false
if ss, err := GetSingleSession(id); err != nil {
return false, err
} else if ss == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name, application}).Update(session)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func removeExtraSessionIds(session *Session) {
@ -108,17 +105,21 @@ func removeExtraSessionIds(session *Session) {
}
}
func AddSession(session *Session) bool {
dbSession := GetSingleSession(session.GetId())
func AddSession(session *Session) (bool, error) {
dbSession, err := GetSingleSession(session.GetId())
if err != nil {
return false, err
}
if dbSession == nil {
session.CreatedTime = util.GetCurrentTime()
affected, err := adapter.Engine.Insert(session)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
} else {
m := make(map[string]struct{})
for _, v := range dbSession.SessionId {
@ -136,10 +137,14 @@ func AddSession(session *Session) bool {
}
}
func DeleteSession(id string) bool {
func DeleteSession(id string) (bool, error) {
owner, name, application := util.GetOwnerAndNameAndOtherFromId(id)
if owner == CasdoorOrganization && application == CasdoorApplication {
session := GetSingleSession(id)
session, err := GetSingleSession(id)
if err != nil {
return false, err
}
if session != nil {
DeleteBeegoSession(session.SessionId)
}
@ -147,16 +152,19 @@ func DeleteSession(id string) bool {
affected, err := adapter.Engine.ID(core.PK{owner, name, application}).Delete(&Session{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteSessionId(id string, sessionId string) bool {
session := GetSingleSession(id)
func DeleteSessionId(id string, sessionId string) (bool, error) {
session, err := GetSingleSession(id)
if err != nil {
return false, err
}
if session == nil {
return false
return false, nil
}
owner, _, application := util.GetOwnerAndNameAndOtherFromId(id)
@ -185,17 +193,21 @@ func (session *Session) GetId() string {
return fmt.Sprintf("%s/%s/%s", session.Owner, session.Name, session.Application)
}
func IsSessionDuplicated(id string, sessionId string) bool {
session := GetSingleSession(id)
func IsSessionDuplicated(id string, sessionId string) (bool, error) {
session, err := GetSingleSession(id)
if err != nil {
return false, err
}
if session == nil {
return false
return false, nil
} else {
if len(session.SessionId) > 1 {
return true
return true, nil
} else if len(session.SessionId) < 1 {
return false
return false, nil
} else {
return session.SessionId[0] != sessionId
return session.SessionId[0] != sessionId, nil
}
}
}

View File

@ -30,11 +30,7 @@ import (
var isCloudIntranet bool
func init() {
var err error
isCloudIntranet, err = conf.GetConfigBool("isCloudIntranet")
if err != nil {
// panic(err)
}
isCloudIntranet = conf.GetConfigBool("isCloudIntranet")
}
func getProviderEndpoint(provider *Provider) string {

View File

@ -63,90 +63,87 @@ func NewSubscription(owner string, user string, plan string, duration int) *Subs
}
}
func GetSubscriptionCount(owner, field, value string) int {
func GetSubscriptionCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Subscription{})
if err != nil {
panic(err)
return session.Count(&Subscription{})
}
return int(count)
}
func GetSubscriptions(owner string) []*Subscription {
func GetSubscriptions(owner string) ([]*Subscription, error) {
subscriptions := []*Subscription{}
err := adapter.Engine.Desc("created_time").Find(&subscriptions, &Subscription{Owner: owner})
if err != nil {
panic(err)
return subscriptions, err
}
return subscriptions
return subscriptions, nil
}
func GetPaginationSubscriptions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Subscription {
func GetPaginationSubscriptions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Subscription, error) {
subscriptions := []*Subscription{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&subscriptions)
if err != nil {
panic(err)
return subscriptions, err
}
return subscriptions
return subscriptions, nil
}
func getSubscription(owner string, name string) *Subscription {
func getSubscription(owner string, name string) (*Subscription, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
subscription := Subscription{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&subscription)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &subscription
return &subscription, nil
} else {
return nil
return nil, nil
}
}
func GetSubscription(id string) *Subscription {
func GetSubscription(id string) (*Subscription, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getSubscription(owner, name)
}
func UpdateSubscription(id string, subscription *Subscription) bool {
func UpdateSubscription(id string, subscription *Subscription) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getSubscription(owner, name) == nil {
return false
if s, err := getSubscription(owner, name); err != nil {
return false, err
} else if s == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(subscription)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddSubscription(subscription *Subscription) bool {
func AddSubscription(subscription *Subscription) (bool, error) {
affected, err := adapter.Engine.Insert(subscription)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteSubscription(subscription *Subscription) bool {
func DeleteSubscription(subscription *Subscription) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{subscription.Owner, subscription.Name}).Delete(&Subscription{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (subscription *Subscription) GetId() string {

View File

@ -55,66 +55,61 @@ type Syncer struct {
Adapter *Adapter `xorm:"-" json:"-"`
}
func GetSyncerCount(owner, organization, field, value string) int {
func GetSyncerCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Syncer{Organization: organization})
if err != nil {
panic(err)
return session.Count(&Syncer{Organization: organization})
}
return int(count)
}
func GetSyncers(owner string) []*Syncer {
func GetSyncers(owner string) ([]*Syncer, error) {
syncers := []*Syncer{}
err := adapter.Engine.Desc("created_time").Find(&syncers, &Syncer{Owner: owner})
if err != nil {
panic(err)
return syncers, err
}
return syncers
return syncers, nil
}
func GetOrganizationSyncers(owner, organization string) []*Syncer {
func GetOrganizationSyncers(owner, organization string) ([]*Syncer, error) {
syncers := []*Syncer{}
err := adapter.Engine.Desc("created_time").Find(&syncers, &Syncer{Owner: owner, Organization: organization})
if err != nil {
panic(err)
return syncers, err
}
return syncers
return syncers, nil
}
func GetPaginationSyncers(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Syncer {
func GetPaginationSyncers(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Syncer, error) {
syncers := []*Syncer{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&syncers, &Syncer{Organization: organization})
if err != nil {
panic(err)
return syncers, err
}
return syncers
return syncers, nil
}
func getSyncer(owner string, name string) *Syncer {
func getSyncer(owner string, name string) (*Syncer, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
syncer := Syncer{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&syncer)
if err != nil {
panic(err)
return &syncer, err
}
if existed {
return &syncer
return &syncer, nil
} else {
return nil
return nil, nil
}
}
func GetSyncer(id string) *Syncer {
func GetSyncer(id string) (*Syncer, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getSyncer(owner, name)
}
@ -137,10 +132,12 @@ func GetMaskedSyncers(syncers []*Syncer) []*Syncer {
return syncers
}
func UpdateSyncer(id string, syncer *Syncer) bool {
func UpdateSyncer(id string, syncer *Syncer) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getSyncer(owner, name) == nil {
return false
if s, err := getSyncer(owner, name); err != nil {
return false, err
} else if s == nil {
return false, nil
}
session := adapter.Engine.ID(core.PK{owner, name}).AllCols()
@ -149,56 +146,66 @@ func UpdateSyncer(id string, syncer *Syncer) bool {
}
affected, err := session.Update(syncer)
if err != nil {
panic(err)
return false, err
}
if affected == 1 {
addSyncerJob(syncer)
err = addSyncerJob(syncer)
if err != nil {
return false, err
}
}
return affected != 0
return affected != 0, nil
}
func updateSyncerErrorText(syncer *Syncer, line string) (bool, error) {
s, err := getSyncer(syncer.Owner, syncer.Name)
if err != nil {
return false, err
}
func updateSyncerErrorText(syncer *Syncer, line string) bool {
s := getSyncer(syncer.Owner, syncer.Name)
if s == nil {
return false
return false, nil
}
s.ErrorText = s.ErrorText + line
affected, err := adapter.Engine.ID(core.PK{s.Owner, s.Name}).Cols("error_text").Update(s)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddSyncer(syncer *Syncer) bool {
func AddSyncer(syncer *Syncer) (bool, error) {
affected, err := adapter.Engine.Insert(syncer)
if err != nil {
panic(err)
return false, err
}
if affected == 1 {
addSyncerJob(syncer)
err = addSyncerJob(syncer)
if err != nil {
return false, err
}
}
return affected != 0
return affected != 0, nil
}
func DeleteSyncer(syncer *Syncer) bool {
func DeleteSyncer(syncer *Syncer) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{syncer.Owner, syncer.Name}).Delete(&Syncer{})
if err != nil {
panic(err)
return false, err
}
if affected == 1 {
deleteSyncerJob(syncer)
}
return affected != 0
return affected != 0, nil
}
func (syncer *Syncer) GetId() string {

View File

@ -19,22 +19,25 @@ type Affiliation struct {
Name string `xorm:"varchar(128)" json:"name"`
}
func (syncer *Syncer) getAffiliations() []*Affiliation {
func (syncer *Syncer) getAffiliations() ([]*Affiliation, error) {
affiliations := []*Affiliation{}
err := syncer.Adapter.Engine.Table(syncer.AffiliationTable).Asc("id").Find(&affiliations)
if err != nil {
panic(err)
return nil, err
}
return affiliations
return affiliations, nil
}
func (syncer *Syncer) getAffiliationMap() ([]*Affiliation, map[int]string) {
affiliations := syncer.getAffiliations()
func (syncer *Syncer) getAffiliationMap() ([]*Affiliation, map[int]string, error) {
affiliations, err := syncer.getAffiliations()
if err != nil {
return nil, nil, err
}
m := map[int]string{}
for _, affiliation := range affiliations {
m[affiliation.Id] = affiliation.Name
}
return affiliations, m
return affiliations, m, nil
}

View File

@ -43,11 +43,11 @@ func clearCron(name string) {
}
}
func addSyncerJob(syncer *Syncer) {
func addSyncerJob(syncer *Syncer) error {
deleteSyncerJob(syncer)
if !syncer.IsEnabled {
return
return nil
}
syncer.initAdapter()
@ -58,10 +58,11 @@ func addSyncerJob(syncer *Syncer) {
cron := getCronMap(syncer.Name)
_, err := cron.AddFunc(schedule, syncer.syncUsers)
if err != nil {
panic(err)
return err
}
cron.Start()
return nil
}
func deleteSyncerJob(syncer *Syncer) {

View File

@ -16,46 +16,74 @@ package object
import "fmt"
func getDbSyncerForUser(user *User) *Syncer {
syncers := GetSyncers("admin")
func getDbSyncerForUser(user *User) (*Syncer, error) {
syncers, err := GetSyncers("admin")
if err != nil {
return nil, err
}
for _, syncer := range syncers {
if syncer.Organization == user.Owner && syncer.IsEnabled && syncer.Type == "Database" {
return syncer
return syncer, nil
}
}
return nil
return nil, nil
}
func getEnabledSyncerForOrganization(organization string) (*Syncer, error) {
syncers, err := GetSyncers("admin")
if err != nil {
return nil, err
}
func getEnabledSyncerForOrganization(organization string) *Syncer {
syncers := GetSyncers("admin")
for _, syncer := range syncers {
if syncer.Organization == organization && syncer.IsEnabled {
return syncer
return syncer, nil
}
}
return nil
return nil, nil
}
func AddUserToOriginalDatabase(user *User) error {
syncer, err := getEnabledSyncerForOrganization(user.Owner)
if err != nil {
return err
}
func AddUserToOriginalDatabase(user *User) {
syncer := getEnabledSyncerForOrganization(user.Owner)
if syncer == nil {
return
return nil
}
updatedOUser := syncer.createOriginalUserFromUser(user)
syncer.addUser(updatedOUser)
_, err = syncer.addUser(updatedOUser)
if err != nil {
return err
}
fmt.Printf("Add from user to oUser: %v\n", updatedOUser)
return nil
}
func UpdateUserToOriginalDatabase(user *User) {
syncer := getEnabledSyncerForOrganization(user.Owner)
func UpdateUserToOriginalDatabase(user *User) error {
syncer, err := getEnabledSyncerForOrganization(user.Owner)
if err != nil {
return err
}
if syncer == nil {
return
return nil
}
newUser := GetUser(user.GetId())
newUser, err := GetUser(user.GetId())
if err != nil {
return err
}
updatedOUser := syncer.createOriginalUserFromUser(newUser)
syncer.updateUser(updatedOUser)
fmt.Printf("Update from user to oUser: %v\n", updatedOUser)
_, err = syncer.updateUser(updatedOUser)
if err != nil {
return err
}
fmt.Printf("Update from user to oUser: %v\n", updatedOUser)
return nil
}

View File

@ -37,7 +37,7 @@ func (syncer *Syncer) syncUsers() {
var affiliationMap map[int]string
if syncer.AffiliationTable != "" {
_, affiliationMap = syncer.getAffiliationMap()
_, affiliationMap, err = syncer.getAffiliationMap()
}
newUsers := []*User{}
@ -86,13 +86,19 @@ func (syncer *Syncer) syncUsers() {
}
}
}
AddUsersInBatch(newUsers)
_, err = AddUsersInBatch(newUsers)
if err != nil {
panic(err)
}
for _, user := range users {
id := user.Id
if _, ok := oUserMap[id]; !ok {
newOUser := syncer.createOriginalUserFromUser(user)
syncer.addUser(newOUser)
_, err = syncer.addUser(newOUser)
if err != nil {
panic(err)
}
fmt.Printf("New oUser: %v\n", newOUser)
}
}

View File

@ -122,14 +122,18 @@ func (syncer *Syncer) updateUser(user *OriginalUser) (bool, error) {
}
func (syncer *Syncer) updateUserForOriginalFields(user *User) (bool, error) {
var err error
owner, name := util.GetOwnerAndNameFromId(user.GetId())
oldUser := getUserById(owner, name)
if oldUser == nil {
return false, nil
oldUser, err := getUserById(owner, name)
if oldUser == nil || err != nil {
return false, err
}
if user.Avatar != oldUser.Avatar && user.Avatar != "" {
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
if err != nil {
return false, err
}
}
columns := syncer.getCasdoorColumns()
@ -175,7 +179,11 @@ func (syncer *Syncer) initAdapter() {
}
func RunSyncUsersJob() {
syncers := GetSyncers("admin")
syncers, err := GetSyncers("admin")
if err != nil {
panic(err)
}
for _, syncer := range syncers {
addSyncerJob(syncer)
}

View File

@ -23,7 +23,7 @@ import (
func TestGetUsers(t *testing.T) {
InitConfig()
syncers := GetSyncers("admin")
syncers, _ := GetSyncers("admin")
syncer := syncers[0]
syncer.initAdapter()
users, _ := syncer.getOriginalUsers()

View File

@ -15,7 +15,11 @@
package object
func (syncer *Syncer) getUsers() []*User {
users := GetUsers(syncer.Organization)
users, err := GetUsers(syncer.Organization)
if err != nil {
panic(err)
}
return users
}

View File

@ -91,67 +91,54 @@ type IntrospectionResponse struct {
Jti string `json:"jti,omitempty"`
}
func GetTokenCount(owner, organization, field, value string) int {
func GetTokenCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Token{Organization: organization})
if err != nil {
panic(err)
return session.Count(&Token{Organization: organization})
}
return int(count)
}
func GetTokens(owner string, organization string) []*Token {
func GetTokens(owner string, organization string) ([]*Token, error) {
tokens := []*Token{}
err := adapter.Engine.Desc("created_time").Find(&tokens, &Token{Owner: owner, Organization: organization})
if err != nil {
panic(err)
return tokens, err
}
return tokens
}
func GetPaginationTokens(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Token {
func GetPaginationTokens(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Token, error) {
tokens := []*Token{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&tokens, &Token{Organization: organization})
if err != nil {
panic(err)
return tokens, err
}
return tokens
}
func getToken(owner string, name string) *Token {
func getToken(owner string, name string) (*Token, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
token := Token{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&token)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &token
return &token, nil
}
return nil
return nil, nil
}
func getTokenByCode(code string) *Token {
func getTokenByCode(code string) (*Token, error) {
token := Token{Code: code}
existed, err := adapter.Engine.Get(&token)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &token
return &token, nil
}
return nil
return nil, nil
}
func updateUsedByCode(token *Token) bool {
@ -163,7 +150,7 @@ func updateUsedByCode(token *Token) bool {
return affected != 0
}
func GetToken(id string) *Token {
func GetToken(id string) (*Token, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getToken(owner, name)
}
@ -172,124 +159,155 @@ func (token *Token) GetId() string {
return fmt.Sprintf("%s/%s", token.Owner, token.Name)
}
func UpdateToken(id string, token *Token) bool {
func UpdateToken(id string, token *Token) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getToken(owner, name) == nil {
return false
if t, err := getToken(owner, name); err != nil {
return false, err
} else if t == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(token)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddToken(token *Token) bool {
func AddToken(token *Token) (bool, error) {
affected, err := adapter.Engine.Insert(token)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteToken(token *Token) bool {
func DeleteToken(token *Token) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{token.Owner, token.Name}).Delete(&Token{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func ExpireTokenByAccessToken(accessToken string) (bool, *Application, *Token) {
func ExpireTokenByAccessToken(accessToken string) (bool, *Application, *Token, error) {
token := Token{AccessToken: accessToken}
existed, err := adapter.Engine.Get(&token)
if err != nil {
panic(err)
return false, nil, nil, err
}
if !existed {
return false, nil, nil
return false, nil, nil, nil
}
token.ExpiresIn = 0
affected, err := adapter.Engine.ID(core.PK{token.Owner, token.Name}).Cols("expires_in").Update(&token)
if err != nil {
panic(err)
return false, nil, nil, err
}
application := getApplication(token.Owner, token.Application)
return affected != 0, application, &token
application, err := getApplication(token.Owner, token.Application)
if err != nil {
return false, nil, nil, err
}
func GetTokenByAccessToken(accessToken string) *Token {
return affected != 0, application, &token, nil
}
func GetTokenByAccessToken(accessToken string) (*Token, error) {
// Check if the accessToken is in the database
token := Token{AccessToken: accessToken}
existed, err := adapter.Engine.Get(&token)
if err != nil || !existed {
return nil
}
return &token
if err != nil {
return nil, err
}
func GetTokenByTokenAndApplication(token string, application string) *Token {
if !existed {
return nil, nil
}
return &token, nil
}
func GetTokenByTokenAndApplication(token string, application string) (*Token, error) {
tokenResult := Token{}
existed, err := adapter.Engine.Where("(refresh_token = ? or access_token = ? ) and application = ?", token, token, application).Get(&tokenResult)
if err != nil || !existed {
return nil
}
return &tokenResult
if err != nil {
return nil, err
}
func CheckOAuthLogin(clientId string, responseType string, redirectUri string, scope string, state string, lang string) (string, *Application) {
if !existed {
return nil, nil
}
return &tokenResult, nil
}
func CheckOAuthLogin(clientId string, responseType string, redirectUri string, scope string, state string, lang string) (string, *Application, error) {
if responseType != "code" && responseType != "token" && responseType != "id_token" {
return fmt.Sprintf(i18n.Translate(lang, "token:Grant_type: %s is not supported in this application"), responseType), nil
return fmt.Sprintf(i18n.Translate(lang, "token:Grant_type: %s is not supported in this application"), responseType), nil, nil
}
application, err := GetApplicationByClientId(clientId)
if err != nil {
return "", nil, err
}
application := GetApplicationByClientId(clientId)
if application == nil {
return i18n.Translate(lang, "token:Invalid client_id"), nil
return i18n.Translate(lang, "token:Invalid client_id"), nil, nil
}
if !application.IsRedirectUriValid(redirectUri) {
return fmt.Sprintf(i18n.Translate(lang, "token:Redirect URI: %s doesn't exist in the allowed Redirect URI list"), redirectUri), application
return fmt.Sprintf(i18n.Translate(lang, "token:Redirect URI: %s doesn't exist in the allowed Redirect URI list"), redirectUri), application, nil
}
// Mask application for /api/get-app-login
application.ClientSecret = ""
return "", application
return "", application, nil
}
func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string, host string, lang string) (*Code, error) {
user, err := GetUser(userId)
if err != nil {
return nil, err
}
func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string, host string, lang string) *Code {
user := GetUser(userId)
if user == nil {
return &Code{
Message: fmt.Sprintf("general:The user: %s doesn't exist", userId),
Code: "",
}
}, nil
}
if user.IsForbidden {
return &Code{
Message: "error: the user is forbidden to sign in, please contact the administrator",
Code: "",
}
}, nil
}
msg, application, err := CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, lang)
if err != nil {
return nil, err
}
msg, application := CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, lang)
if msg != "" {
return &Code{
Message: msg,
Code: "",
}
}, nil
}
ExtendUserWithRolesAndPermissions(user)
err = ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, err
}
accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, nonce, scope, host)
if err != nil {
panic(err)
return nil, err
}
if challenge == "null" {
@ -313,21 +331,28 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU
CodeIsUsed: false,
CodeExpireIn: time.Now().Add(time.Minute * 5).Unix(),
}
AddToken(token)
_, err = AddToken(token)
if err != nil {
return nil, err
}
return &Code{
Message: "",
Code: token.Code,
}
}, nil
}
func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string, scope string, username string, password string, host string, refreshToken string, tag string, avatar string, lang string) (interface{}, error) {
application, err := GetApplicationByClientId(clientId)
if err != nil {
return nil, err
}
func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string, scope string, username string, password string, host string, refreshToken string, tag string, avatar string, lang string) interface{} {
application := GetApplicationByClientId(clientId)
if application == nil {
return &TokenError{
Error: InvalidClient,
ErrorDescription: "client_id is invalid",
}
}, nil
}
// Check if grantType is allowed in the current application
@ -336,32 +361,44 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code
return &TokenError{
Error: UnsupportedGrantType,
ErrorDescription: fmt.Sprintf("grant_type: %s is not supported in this application", grantType),
}
}, nil
}
var token *Token
var tokenError *TokenError
switch grantType {
case "authorization_code": // Authorization Code Grant
token, tokenError = GetAuthorizationCodeToken(application, clientSecret, code, verifier)
token, tokenError, err = GetAuthorizationCodeToken(application, clientSecret, code, verifier)
case "password": // Resource Owner Password Credentials Grant
token, tokenError = GetPasswordToken(application, username, password, scope, host)
token, tokenError, err = GetPasswordToken(application, username, password, scope, host)
case "client_credentials": // Client Credentials Grant
token, tokenError = GetClientCredentialsToken(application, clientSecret, scope, host)
token, tokenError, err = GetClientCredentialsToken(application, clientSecret, scope, host)
case "refresh_token":
return RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host)
refreshToken2, err := RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host)
if err != nil {
return nil, err
}
return refreshToken2, nil
}
if err != nil {
return nil, err
}
if tag == "wechat_miniprogram" {
// Wechat Mini Program
token, tokenError = GetWechatMiniProgramToken(application, code, host, username, avatar, lang)
token, tokenError, err = GetWechatMiniProgramToken(application, code, host, username, avatar, lang)
if err != nil {
return nil, err
}
}
if tokenError != nil {
return tokenError
return tokenError, nil
}
token.CodeIsUsed = true
go updateUsedByCode(token)
tokenWrapper := &TokenWrapper{
@ -373,29 +410,33 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code
Scope: token.Scope,
}
return tokenWrapper
return tokenWrapper, nil
}
func RefreshToken(grantType string, refreshToken string, scope string, clientId string, clientSecret string, host string) interface{} {
func RefreshToken(grantType string, refreshToken string, scope string, clientId string, clientSecret string, host string) (interface{}, error) {
// check parameters
if grantType != "refresh_token" {
return &TokenError{
Error: UnsupportedGrantType,
ErrorDescription: "grant_type should be refresh_token",
}, nil
}
application, err := GetApplicationByClientId(clientId)
if err != nil {
return nil, err
}
application := GetApplicationByClientId(clientId)
if application == nil {
return &TokenError{
Error: InvalidClient,
ErrorDescription: "client_id is invalid",
}
}, nil
}
if clientSecret != "" && application.ClientSecret != clientSecret {
return &TokenError{
Error: InvalidClient,
ErrorDescription: "client_secret is invalid",
}
}, nil
}
// check whether the refresh token is valid, and has not expired.
token := Token{RefreshToken: refreshToken}
@ -404,33 +445,44 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId
return &TokenError{
Error: InvalidGrant,
ErrorDescription: "refresh token is invalid, expired or revoked",
}
}, nil
}
cert, err := getCertByApplication(application)
if err != nil {
return nil, err
}
cert := getCertByApplication(application)
_, err = ParseJwtToken(refreshToken, cert)
if err != nil {
return &TokenError{
Error: InvalidGrant,
ErrorDescription: fmt.Sprintf("parse refresh token error: %s", err.Error()),
}
}, nil
}
// generate a new token
user := getUser(application.Organization, token.User)
user, err := getUser(application.Organization, token.User)
if err != nil {
return nil, err
}
if user.IsForbidden {
return &TokenError{
Error: InvalidGrant,
ErrorDescription: "the user is forbidden to sign in, please contact the administrator",
}
}, nil
}
ExtendUserWithRolesAndPermissions(user)
err = ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, err
}
newAccessToken, newRefreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host)
if err != nil {
return &TokenError{
Error: EndpointError,
ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()),
}
}, nil
}
newToken := &Token{
@ -447,8 +499,15 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId
Scope: scope,
TokenType: "Bearer",
}
AddToken(newToken)
DeleteToken(&token)
_, err = AddToken(newToken)
if err != nil {
return nil, err
}
_, err = DeleteToken(&token)
if err != nil {
return nil, err
}
tokenWrapper := &TokenWrapper{
AccessToken: newToken.AccessToken,
@ -459,7 +518,7 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId
Scope: newToken.Scope,
}
return tokenWrapper
return tokenWrapper, nil
}
// PkceChallenge: base64-URL-encoded SHA256 hash of verifier, per rfc 7636
@ -486,34 +545,38 @@ func IsGrantTypeValid(method string, grantTypes []string) bool {
// GetAuthorizationCodeToken
// Authorization code flow
func GetAuthorizationCodeToken(application *Application, clientSecret string, code string, verifier string) (*Token, *TokenError) {
func GetAuthorizationCodeToken(application *Application, clientSecret string, code string, verifier string) (*Token, *TokenError, error) {
if code == "" {
return nil, &TokenError{
Error: InvalidRequest,
ErrorDescription: "authorization code should not be empty",
}
}, nil
}
token, err := getTokenByCode(code)
if err != nil {
return nil, nil, err
}
token := getTokenByCode(code)
if token == nil {
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "authorization code is invalid",
}
}, nil
}
if token.CodeIsUsed {
// anti replay attacks
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "authorization code has been used",
}
}, nil
}
if token.CodeChallenge != "" && pkceChallenge(verifier) != token.CodeChallenge {
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "verifier is invalid",
}
}, nil
}
if application.ClientSecret != clientSecret {
@ -523,13 +586,13 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co
return nil, &TokenError{
Error: InvalidClient,
ErrorDescription: "client_secret is invalid",
}
}, nil
} else {
if clientSecret != "" {
return nil, &TokenError{
Error: InvalidClient,
ErrorDescription: "client_secret is invalid",
}
}, nil
}
}
}
@ -538,7 +601,7 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "the token is for wrong application (client_id)",
}
}, nil
}
if time.Now().Unix() > token.CodeExpireIn {
@ -546,42 +609,50 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "authorization code has expired",
}, nil
}
}
return token, nil
return token, nil, nil
}
// GetPasswordToken
// Resource Owner Password Credentials flow
func GetPasswordToken(application *Application, username string, password string, scope string, host string) (*Token, *TokenError) {
user := getUser(application.Organization, username)
func GetPasswordToken(application *Application, username string, password string, scope string, host string) (*Token, *TokenError, error) {
user, err := getUser(application.Organization, username)
if err != nil {
return nil, nil, err
}
if user == nil {
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "the user does not exist",
}
}, nil
}
msg := CheckPassword(user, password, "en")
if msg != "" {
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "invalid username or password",
}
}, nil
}
if user.IsForbidden {
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "the user is forbidden to sign in, please contact the administrator",
}
}, nil
}
err = ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, nil, err
}
ExtendUserWithRolesAndPermissions(user)
accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host)
if err != nil {
return nil, &TokenError{
Error: EndpointError,
ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()),
}
}, nil
}
token := &Token{
Owner: application.Owner,
@ -598,18 +669,22 @@ func GetPasswordToken(application *Application, username string, password string
TokenType: "Bearer",
CodeIsUsed: true,
}
AddToken(token)
return token, nil
_, err = AddToken(token)
if err != nil {
return nil, nil, err
}
return token, nil, nil
}
// GetClientCredentialsToken
// Client Credentials flow
func GetClientCredentialsToken(application *Application, clientSecret string, scope string, host string) (*Token, *TokenError) {
func GetClientCredentialsToken(application *Application, clientSecret string, scope string, host string) (*Token, *TokenError, error) {
if application.ClientSecret != clientSecret {
return nil, &TokenError{
Error: InvalidClient,
ErrorDescription: "client_secret is invalid",
}
}, nil
}
nullUser := &User{
Owner: application.Owner,
@ -623,7 +698,7 @@ func GetClientCredentialsToken(application *Application, clientSecret string, sc
return nil, &TokenError{
Error: EndpointError,
ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()),
}
}, nil
}
token := &Token{
Owner: application.Owner,
@ -639,18 +714,27 @@ func GetClientCredentialsToken(application *Application, clientSecret string, sc
TokenType: "Bearer",
CodeIsUsed: true,
}
AddToken(token)
return token, nil
_, err = AddToken(token)
if err != nil {
return nil, nil, err
}
return token, nil, nil
}
// GetTokenByUser
// Implicit flow
func GetTokenByUser(application *Application, user *User, scope string, host string) (*Token, error) {
ExtendUserWithRolesAndPermissions(user)
err := ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, err
}
accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host)
if err != nil {
return nil, err
}
token := &Token{
Owner: application.Owner,
Name: tokenName,
@ -666,43 +750,56 @@ func GetTokenByUser(application *Application, user *User, scope string, host str
TokenType: "Bearer",
CodeIsUsed: true,
}
AddToken(token)
_, err = AddToken(token)
if err != nil {
return nil, err
}
return token, nil
}
// GetWechatMiniProgramToken
// Wechat Mini Program flow
func GetWechatMiniProgramToken(application *Application, code string, host string, username string, avatar string, lang string) (*Token, *TokenError) {
func GetWechatMiniProgramToken(application *Application, code string, host string, username string, avatar string, lang string) (*Token, *TokenError, error) {
mpProvider := GetWechatMiniProgramProvider(application)
if mpProvider == nil {
return nil, &TokenError{
Error: InvalidClient,
ErrorDescription: "the application does not support wechat mini program",
}, nil
}
provider, err := GetProvider(util.GetId("admin", mpProvider.Name))
if err != nil {
return nil, nil, err
}
provider := GetProvider(util.GetId("admin", mpProvider.Name))
mpIdp := idp.NewWeChatMiniProgramIdProvider(provider.ClientId, provider.ClientSecret)
session, err := mpIdp.GetSessionByCode(code)
if err != nil {
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: fmt.Sprintf("get wechat mini program session error: %s", err.Error()),
}, nil
}
}
openId, unionId := session.Openid, session.Unionid
if openId == "" && unionId == "" {
return nil, &TokenError{
Error: InvalidRequest,
ErrorDescription: "the wechat mini program session is invalid",
}, nil
}
user, err := getUserByWechatId(openId, unionId)
if err != nil {
return nil, nil, err
}
user := getUserByWechatId(openId, unionId)
if user == nil {
if !application.EnableSignUp {
return nil, &TokenError{
Error: InvalidGrant,
ErrorDescription: "the application does not allow to sign up new account",
}
}, nil
}
// Add new user
var name string
@ -730,16 +827,23 @@ func GetWechatMiniProgramToken(application *Application, code string, host strin
UserPropertiesWechatUnionId: unionId,
},
}
AddUser(user)
_, err = AddUser(user)
if err != nil {
return nil, nil, err
}
}
err = ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, nil, err
}
ExtendUserWithRolesAndPermissions(user)
accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", "", host)
if err != nil {
return nil, &TokenError{
Error: EndpointError,
ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()),
}
}, nil
}
token := &Token{
@ -757,6 +861,9 @@ func GetWechatMiniProgramToken(application *Application, code string, host strin
TokenType: "Bearer",
CodeIsUsed: true,
}
AddToken(token)
return token, nil
_, err = AddToken(token)
if err != nil {
return nil, nil, err
}
return token, nil, nil
}

View File

@ -177,7 +177,9 @@ func StoreCasTokenForProxyTicket(token *CasAuthenticationSuccess, targetService,
}
func GenerateCasToken(userId string, service string) (string, error) {
if user := GetUser(userId); user != nil {
if user, err := GetUser(userId); err != nil {
return "", err
} else if user != nil {
authenticationSuccess := CasAuthenticationSuccess{
User: user.Name,
Attributes: &CasAttributes{
@ -232,18 +234,31 @@ func GetValidationBySaml(samlRequest string, host string) (string, string, error
return "", "", fmt.Errorf("ticket %s found", ticket)
}
user := GetUser(userId)
user, err := GetUser(userId)
if err != nil {
return "", "", err
}
if user == nil {
return "", "", fmt.Errorf("user %s found", userId)
}
application := GetApplicationByUser(user)
application, err := GetApplicationByUser(user)
if err != nil {
return "", "", err
}
if application == nil {
return "", "", fmt.Errorf("application for user %s found", userId)
}
samlResponse := NewSamlResponse11(user, request.RequestID, host)
cert := getCertByApplication(application)
cert, err := getCertByApplication(application)
if err != nil {
return "", "", err
}
block, _ := pem.Decode([]byte(cert.Certificate))
certificate := base64.StdEncoding.EncodeToString(block.Bytes)
randomKeyStore := &X509Key{

View File

@ -273,7 +273,10 @@ func generateJwtToken(application *Application, user *User, nonce string, scope
refreshToken = jwt.NewWithClaims(jwt.SigningMethodRS256, claimsWithoutThirdIdp)
}
cert := getCertByApplication(application)
cert, err := getCertByApplication(application)
if err != nil {
return "", "", "", err
}
// RSA private key
key, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(cert.PrivateKey))
@ -316,5 +319,10 @@ func ParseJwtToken(token string, cert *Cert) (*Claims, error) {
}
func ParseJwtTokenByApplication(token string, application *Application) (*Claims, error) {
return ParseJwtToken(token, getCertByApplication(application))
cert, err := getCertByApplication(application)
if err != nil {
return nil, err
}
return ParseJwtToken(token, cert)
}

View File

@ -191,217 +191,206 @@ type ManagedAccount struct {
SigninUrl string `xorm:"varchar(200)" json:"signinUrl"`
}
func GetGlobalUserCount(field, value string) int {
func GetGlobalUserCount(field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Count(&User{})
if err != nil {
panic(err)
return session.Count(&User{})
}
return int(count)
}
func GetGlobalUsers() []*User {
func GetGlobalUsers() ([]*User, error) {
users := []*User{}
err := adapter.Engine.Desc("created_time").Find(&users)
if err != nil {
panic(err)
return nil, err
}
return users
return users, nil
}
func GetPaginationGlobalUsers(offset, limit int, field, value, sortField, sortOrder string) []*User {
func GetPaginationGlobalUsers(offset, limit int, field, value, sortField, sortOrder string) ([]*User, error) {
users := []*User{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Find(&users)
if err != nil {
panic(err)
return nil, err
}
return users
return users, nil
}
func GetUserCount(owner, field, value string) int {
func GetUserCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&User{})
if err != nil {
panic(err)
return session.Count(&User{})
}
return int(count)
func GetOnlineUserCount(owner string, isOnline int) (int64, error) {
return adapter.Engine.Where("is_online = ?", isOnline).Count(&User{Owner: owner})
}
func GetOnlineUserCount(owner string, isOnline int) int {
count, err := adapter.Engine.Where("is_online = ?", isOnline).Count(&User{Owner: owner})
if err != nil {
panic(err)
}
return int(count)
}
func GetUsers(owner string) []*User {
func GetUsers(owner string) ([]*User, error) {
users := []*User{}
err := adapter.Engine.Desc("created_time").Find(&users, &User{Owner: owner})
if err != nil {
panic(err)
return nil, err
}
return users
return users, nil
}
func GetUsersByTag(owner string, tag string) []*User {
func GetUsersByTag(owner string, tag string) ([]*User, error) {
users := []*User{}
err := adapter.Engine.Desc("created_time").Find(&users, &User{Owner: owner, Tag: tag})
if err != nil {
panic(err)
return nil, err
}
return users
return users, nil
}
func GetSortedUsers(owner string, sorter string, limit int) []*User {
func GetSortedUsers(owner string, sorter string, limit int) ([]*User, error) {
users := []*User{}
err := adapter.Engine.Desc(sorter).Limit(limit, 0).Find(&users, &User{Owner: owner})
if err != nil {
panic(err)
return nil, err
}
return users
return users, nil
}
func GetPaginationUsers(owner string, offset, limit int, field, value, sortField, sortOrder string) []*User {
func GetPaginationUsers(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*User, error) {
users := []*User{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&users)
if err != nil {
panic(err)
return nil, err
}
return users
return users, nil
}
func getUser(owner string, name string) *User {
func getUser(owner string, name string) (*User, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
user := User{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&user)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &user
return &user, nil
} else {
return nil
return nil, nil
}
}
func getUserById(owner string, id string) *User {
func getUserById(owner string, id string) (*User, error) {
if owner == "" || id == "" {
return nil
return nil, nil
}
user := User{Owner: owner, Id: id}
existed, err := adapter.Engine.Get(&user)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &user
return &user, nil
} else {
return nil
return nil, nil
}
}
func getUserByWechatId(wechatOpenId string, wechatUnionId string) *User {
func getUserByWechatId(wechatOpenId string, wechatUnionId string) (*User, error) {
if wechatUnionId == "" {
wechatUnionId = wechatOpenId
}
user := &User{}
existed, err := adapter.Engine.Where("wechat = ? OR wechat = ?", wechatOpenId, wechatUnionId).Get(user)
if err != nil {
panic(err)
return nil, err
}
if existed {
return user
return user, nil
} else {
return nil
return nil, nil
}
}
func GetUserByEmail(owner string, email string) *User {
func GetUserByEmail(owner string, email string) (*User, error) {
if owner == "" || email == "" {
return nil
return nil, nil
}
user := User{Owner: owner, Email: email}
existed, err := adapter.Engine.Get(&user)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &user
return &user, nil
} else {
return nil
return nil, nil
}
}
func GetUserByPhone(owner string, phone string) *User {
func GetUserByPhone(owner string, phone string) (*User, error) {
if owner == "" || phone == "" {
return nil
return nil, nil
}
user := User{Owner: owner, Phone: phone}
existed, err := adapter.Engine.Get(&user)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &user
return &user, nil
} else {
return nil
return nil, nil
}
}
func GetUserByUserId(owner string, userId string) *User {
func GetUserByUserId(owner string, userId string) (*User, error) {
if owner == "" || userId == "" {
return nil
return nil, nil
}
user := User{Owner: owner, Id: userId}
existed, err := adapter.Engine.Get(&user)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &user
return &user, nil
} else {
return nil
return nil, nil
}
}
func GetUser(id string) *User {
func GetUser(id string) (*User, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getUser(owner, name)
}
func GetUserNoCheck(id string) *User {
func GetUserNoCheck(id string) (*User, error) {
owner, name := util.GetOwnerAndNameFromIdNoCheck(id)
return getUser(owner, name)
}
func GetMaskedUser(user *User) *User {
func GetMaskedUser(user *User, errs ...error) (*User, error) {
if len(errs) > 0 && errs[0] != nil {
return nil, errs[0]
}
if user == nil {
return nil
return nil, nil
}
if user.Password != "" {
@ -419,51 +408,69 @@ func GetMaskedUser(user *User) *User {
user.MultiFactorAuths[i] = GetMaskedProps(props)
}
}
return user
return user, nil
}
func GetMaskedUsers(users []*User) []*User {
func GetMaskedUsers(users []*User, errs ...error) ([]*User, error) {
if len(errs) > 0 && errs[0] != nil {
return nil, errs[0]
}
var err error
for _, user := range users {
user = GetMaskedUser(user)
user, err = GetMaskedUser(user)
if err != nil {
return nil, err
}
return users
}
return users, nil
}
func GetLastUser(owner string) *User {
func GetLastUser(owner string) (*User, error) {
user := User{Owner: owner}
existed, err := adapter.Engine.Desc("created_time", "id").Get(&user)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &user
return &user, nil
}
return nil
return nil, nil
}
func UpdateUser(id string, user *User, columns []string, isAdmin bool) bool {
func UpdateUser(id string, user *User, columns []string, isAdmin bool) (bool, error) {
var err error
owner, name := util.GetOwnerAndNameFromIdNoCheck(id)
oldUser := getUser(owner, name)
oldUser, err := getUser(owner, name)
if err != nil {
return false, err
}
if oldUser == nil {
return false
return false, nil
}
if name != user.Name {
err := userChangeTrigger(name, user.Name)
if err != nil {
return false
return false, nil
}
}
if user.Password == "***" {
user.Password = oldUser.Password
}
user.UpdateUserHash()
err = user.UpdateUserHash()
if err != nil {
panic(err)
}
if user.Avatar != oldUser.Avatar && user.Avatar != "" && user.PermanentAvatar != "*" {
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
if err != nil {
return false, err
}
}
if len(columns) == 0 {
@ -487,77 +494,105 @@ func UpdateUser(id string, user *User, columns []string, isAdmin bool) bool {
affected, err := adapter.Engine.ID(core.PK{owner, name}).Cols(columns...).Update(user)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func UpdateUserForAllFields(id string, user *User) bool {
func UpdateUserForAllFields(id string, user *User) (bool, error) {
var err error
owner, name := util.GetOwnerAndNameFromId(id)
oldUser := getUser(owner, name)
oldUser, err := getUser(owner, name)
if err != nil {
return false, err
}
if oldUser == nil {
return false
return false, nil
}
if name != user.Name {
err := userChangeTrigger(name, user.Name)
if err != nil {
return false
return false, nil
}
}
user.UpdateUserHash()
err = user.UpdateUserHash()
if err != nil {
return false, err
}
if user.Avatar != oldUser.Avatar && user.Avatar != "" {
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
if err != nil {
return false, err
}
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(user)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddUser(user *User) bool {
func AddUser(user *User) (bool, error) {
var err error
if user.Id == "" {
user.Id = util.GenerateId()
}
if user.Owner == "" || user.Name == "" {
return false
return false, nil
}
organization := GetOrganizationByUser(user)
organization, _ := GetOrganizationByUser(user)
if organization == nil {
return false
return false, nil
}
user.UpdateUserPassword(organization)
user.UpdateUserHash()
user.PreHash = user.Hash
updated := user.refreshAvatar()
if updated && user.PermanentAvatar != "*" {
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
err = user.UpdateUserHash()
if err != nil {
return false, err
}
user.Ranking = GetUserCount(user.Owner, "", "") + 1
user.PreHash = user.Hash
updated, err := user.refreshAvatar()
if err != nil {
return false, err
}
if updated && user.PermanentAvatar != "*" {
user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
if err != nil {
return false, err
}
}
count, err := GetUserCount(user.Owner, "", "")
if err != nil {
return false, err
}
user.Ranking = int(count + 1)
affected, err := adapter.Engine.Insert(user)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddUsers(users []*User) bool {
func AddUsers(users []*User) (bool, error) {
var err error
if len(users) == 0 {
return false
return false, nil
}
// organization := GetOrganizationByUser(users[0])
@ -565,27 +600,34 @@ func AddUsers(users []*User) bool {
// this function is only used for syncer or batch upload, so no need to encrypt the password
// user.UpdateUserPassword(organization)
user.UpdateUserHash()
err = user.UpdateUserHash()
if err != nil {
return false, err
}
user.PreHash = user.Hash
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
if err != nil {
return false, err
}
}
affected, err := adapter.Engine.Insert(users)
if err != nil {
if !strings.Contains(err.Error(), "Duplicate entry") {
panic(err)
return false, err
}
}
return affected != 0
return affected != 0, nil
}
func AddUsersInBatch(users []*User) bool {
func AddUsersInBatch(users []*User) (bool, error) {
batchSize := conf.GetConfigBatchSize()
if len(users) == 0 {
return false
return false, nil
}
affected := false
@ -599,24 +641,29 @@ func AddUsersInBatch(users []*User) bool {
tmp := users[start:end]
// TODO: save to log instead of standard output
// fmt.Printf("Add users: [%d - %d].\n", start, end)
if AddUsers(tmp) {
if ok, err := AddUsers(tmp); err != nil {
return false, err
} else if ok {
affected = true
}
}
return affected
return affected, nil
}
func DeleteUser(user *User) bool {
func DeleteUser(user *User) (bool, error) {
// Forced offline the user first
DeleteSession(util.GetSessionId(user.Owner, user.Name, CasdoorApplication))
_, err := DeleteSession(util.GetSessionId(user.Owner, user.Name, CasdoorApplication))
if err != nil {
return false, err
}
affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Delete(&User{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func GetUserInfo(user *User, scope string, aud string, host string) *Userinfo {
@ -644,7 +691,7 @@ func GetUserInfo(user *User, scope string, aud string, host string) *Userinfo {
return &resp
}
func LinkUserAccount(user *User, field string, value string) bool {
func LinkUserAccount(user *User, field string, value string) (bool, error) {
return SetUserField(user, field, value)
}
@ -656,13 +703,18 @@ func isUserIdGlobalAdmin(userId string) bool {
return strings.HasPrefix(userId, "built-in/")
}
func ExtendUserWithRolesAndPermissions(user *User) {
func ExtendUserWithRolesAndPermissions(user *User) (err error) {
if user == nil {
return
}
user.Roles = GetRolesByUser(user.GetId())
user.Permissions = GetPermissionsByUser(user.GetId())
user.Roles, err = GetRolesByUser(user.GetId())
if err != nil {
return
}
user.Permissions, err = GetPermissionsByUser(user.GetId())
return
}
func userChangeTrigger(oldName string, newName string) error {
@ -679,6 +731,7 @@ func userChangeTrigger(oldName string, newName string) error {
if err != nil {
return err
}
for _, role := range roles {
for j, u := range role.Users {
// u = organization/username
@ -722,7 +775,7 @@ func userChangeTrigger(oldName string, newName string) error {
return session.Commit()
}
func (user *User) refreshAvatar() bool {
func (user *User) refreshAvatar() (bool, error) {
var err error
var fileBuffer *bytes.Buffer
var ext string
@ -732,13 +785,13 @@ func (user *User) refreshAvatar() bool {
client := proxy.ProxyHttpClient
has, err := hasGravatar(client, user.Email)
if err != nil {
panic(err)
return false, err
}
if has {
fileBuffer, ext, err = getGravatarFileBuffer(client, user.Email)
if err != nil {
panic(err)
return false, err
}
}
}
@ -746,17 +799,20 @@ func (user *User) refreshAvatar() bool {
if fileBuffer == nil && strings.Contains(user.Avatar, "Identicon") {
fileBuffer, ext, err = getIdenticonFileBuffer(user.Name)
if err != nil {
panic(err)
return false, err
}
}
if fileBuffer != nil {
avatarUrl := getPermanentAvatarUrlFromBuffer(user.Owner, user.Name, fileBuffer, ext, true)
avatarUrl, err := getPermanentAvatarUrlFromBuffer(user.Owner, user.Name, fileBuffer, ext, true)
if err != nil {
return false, err
}
user.Avatar = avatarUrl
return true
return true, nil
}
return false
return false, nil
}
func (user *User) IsMfaEnabled() bool {

View File

@ -16,18 +16,27 @@ package object
import "github.com/casdoor/casdoor/cred"
func calculateHash(user *User) string {
syncer := getDbSyncerForUser(user)
func calculateHash(user *User) (string, error) {
syncer, err := getDbSyncerForUser(user)
if err != nil {
return "", err
}
if syncer == nil {
return ""
return "", nil
}
return syncer.calculateHash(user)
return syncer.calculateHash(user), nil
}
func (user *User) UpdateUserHash() error {
hash, err := calculateHash(user)
if err != nil {
return err
}
func (user *User) UpdateUserHash() {
hash := calculateHash(user)
user.Hash = hash
return nil
}
func (user *User) UpdateUserPassword(organization *Organization) {

View File

@ -36,7 +36,7 @@ func updateUserColumn(column string, user *User) bool {
func TestSyncAvatarsFromGitHub(t *testing.T) {
InitConfig()
users := GetGlobalUsers()
users, _ := GetGlobalUsers()
for _, user := range users {
if user.GitHub == "" {
continue
@ -50,7 +50,7 @@ func TestSyncAvatarsFromGitHub(t *testing.T) {
func TestSyncIds(t *testing.T) {
InitConfig()
users := GetGlobalUsers()
users, _ := GetGlobalUsers()
for _, user := range users {
if user.Id != "" {
continue
@ -64,13 +64,16 @@ func TestSyncIds(t *testing.T) {
func TestSyncHashes(t *testing.T) {
InitConfig()
users := GetGlobalUsers()
users, _ := GetGlobalUsers()
for _, user := range users {
if user.Hash != "" {
continue
}
user.UpdateUserHash()
err := user.UpdateUserHash()
if err != nil {
panic(err)
}
updateUserColumn("hash", user)
}
}
@ -92,7 +95,7 @@ func TestGetMaskedUsers(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetMaskedUsers(tt.args.users); !reflect.DeepEqual(got, tt.want) {
if got, _ := GetMaskedUsers(tt.args.users); !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetMaskedUsers() = %v, want %v", got, tt.want)
}
})
@ -102,7 +105,7 @@ func TestGetMaskedUsers(t *testing.T) {
func TestGetUserByField(t *testing.T) {
InitConfig()
user := GetUserByField("built-in", "DingTalk", "test")
user, _ := GetUserByField("built-in", "DingTalk", "test")
if user != nil {
t.Logf("%+v", user)
} else {
@ -115,7 +118,7 @@ func TestGetEmailsForUsers(t *testing.T) {
emailMap := map[string]int{}
emails := []string{}
users := GetUsers("built-in")
users, _ := GetUsers("built-in")
for _, user := range users {
if user.Email == "" {
continue

View File

@ -22,15 +22,18 @@ import (
"github.com/casdoor/casdoor/xlsx"
)
func getUserMap(owner string) map[string]*User {
func getUserMap(owner string) (map[string]*User, error) {
m := map[string]*User{}
users := GetUsers(owner)
users, err := GetUsers(owner)
if err != nil {
return m, err
}
for _, user := range users {
m[user.GetId()] = user
}
return m
return m, nil
}
func parseLineItem(line *[]string, i int) string {
@ -70,10 +73,14 @@ func parseListItem(lines *[]string, i int) []string {
return trimmedItems
}
func UploadUsers(owner string, fileId string) bool {
func UploadUsers(owner string, fileId string) (bool, error) {
table := xlsx.ReadXlsxFile(fileId)
oldUserMap := getUserMap(owner)
oldUserMap, err := getUserMap(owner)
if err != nil {
return false, err
}
newUsers := []*User{}
for index, line := range table {
if index == 0 || parseLineItem(&line, 0) == "" {
@ -135,7 +142,7 @@ func UploadUsers(owner string, fileId string) bool {
}
if len(newUsers) == 0 {
return false
return false, nil
}
return AddUsersInBatch(newUsers)
}

View File

@ -24,62 +24,70 @@ import (
"github.com/xorm-io/core"
)
func GetUserByField(organizationName string, field string, value string) *User {
func GetUserByField(organizationName string, field string, value string) (*User, error) {
if field == "" || value == "" {
return nil
return nil, nil
}
user := User{Owner: organizationName}
existed, err := adapter.Engine.Where(fmt.Sprintf("%s=?", strings.ToLower(field)), value).Get(&user)
if err != nil {
panic(err)
return nil, err
}
if existed {
return &user
return &user, nil
} else {
return nil
return nil, nil
}
}
func HasUserByField(organizationName string, field string, value string) bool {
return GetUserByField(organizationName, field, value) != nil
user, err := GetUserByField(organizationName, field, value)
if err != nil {
panic(err)
}
return user != nil
}
func GetUserByFields(organization string, field string) *User {
func GetUserByFields(organization string, field string) (*User, error) {
// check username
user := GetUserByField(organization, "name", field)
if user != nil {
return user
user, err := GetUserByField(organization, "name", field)
if err != nil || user != nil {
return user, err
}
// check email
if strings.Contains(field, "@") {
user = GetUserByField(organization, "email", field)
if user != nil {
return user
user, err = GetUserByField(organization, "email", field)
if user != nil || err != nil {
return user, err
}
}
// check phone
user = GetUserByField(organization, "phone", field)
if user != nil {
return user
user, err = GetUserByField(organization, "phone", field)
if user != nil || err != nil {
return user, err
}
// check ID card
user = GetUserByField(organization, "id_card", field)
if user != nil {
return user
user, err = GetUserByField(organization, "id_card", field)
if user != nil || err != nil {
return user, err
}
return nil
return nil, nil
}
func SetUserField(user *User, field string, value string) bool {
func SetUserField(user *User, field string, value string) (bool, error) {
bean := make(map[string]interface{})
if field == "password" {
organization := GetOrganizationByUser(user)
organization, err := GetOrganizationByUser(user)
if err != nil {
return false, err
}
user.UpdateUserPassword(organization)
bean[strings.ToLower(field)] = user.Password
bean["password_type"] = user.PasswordType
@ -89,17 +97,25 @@ func SetUserField(user *User, field string, value string) bool {
affected, err := adapter.Engine.Table(user).ID(core.PK{user.Owner, user.Name}).Update(bean)
if err != nil {
panic(err)
return false, err
}
user, err = getUser(user.Owner, user.Name)
if err != nil {
return false, err
}
err = user.UpdateUserHash()
if err != nil {
return false, err
}
user = getUser(user.Owner, user.Name)
user.UpdateUserHash()
_, err = adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("hash").Update(user)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func GetUserField(user *User, field string) string {
@ -121,7 +137,7 @@ func setUserProperty(user *User, field string, value string) {
}
}
func SetUserOAuthProperties(organization *Organization, user *User, providerType string, userInfo *idp.UserInfo) bool {
func SetUserOAuthProperties(organization *Organization, user *User, providerType string, userInfo *idp.UserInfo) (bool, error) {
if userInfo.Id != "" {
propertyName := fmt.Sprintf("oauth_%s_id", providerType)
setUserProperty(user, propertyName, userInfo.Id)
@ -164,11 +180,10 @@ func SetUserOAuthProperties(organization *Organization, user *User, providerType
}
}
affected := UpdateUserForAllFields(user.GetId(), user)
return affected
return UpdateUserForAllFields(user.GetId(), user)
}
func ClearUserOAuthProperties(user *User, providerType string) bool {
func ClearUserOAuthProperties(user *User, providerType string) (bool, error) {
for k := range user.Properties {
prefix := fmt.Sprintf("oauth_%s_", providerType)
if strings.HasPrefix(k, prefix) {
@ -178,14 +193,18 @@ func ClearUserOAuthProperties(user *User, providerType string) bool {
affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("properties").Update(user)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func CheckPermissionForUpdateUser(oldUser, newUser *User, isAdmin bool, lang string) (bool, string) {
organization := GetOrganizationByUser(oldUser)
organization, err := GetOrganizationByUser(oldUser)
if err != nil {
return false, err.Error()
}
var itemsChanged []*AccountItem
if oldUser.Owner != newUser.Owner {
@ -310,7 +329,7 @@ func (user *User) GetCountryCode(countryCode string) string {
return user.CountryCode
}
if org := GetOrganizationByUser(user); org != nil && len(org.CountryCodes) > 0 {
if org, _ := GetOrganizationByUser(user); org != nil && len(org.CountryCodes) > 0 {
return org.CountryCodes[0]
}
return ""

View File

@ -16,6 +16,7 @@ package object
import (
"encoding/base64"
"fmt"
"net/url"
"strings"
@ -24,14 +25,14 @@ import (
"github.com/go-webauthn/webauthn/webauthn"
)
func GetWebAuthnObject(host string) *webauthn.WebAuthn {
func GetWebAuthnObject(host string) (*webauthn.WebAuthn, error) {
var err error
_, originBackend := getOriginFromHost(host)
localUrl, err := url.Parse(originBackend)
if err != nil {
panic("error when parsing origin:" + err.Error())
return nil, fmt.Errorf("error when parsing origin:" + err.Error())
}
webAuthn, err := webauthn.New(&webauthn.Config{
@ -41,10 +42,10 @@ func GetWebAuthnObject(host string) *webauthn.WebAuthn {
// RPIcon: "https://duo.com/logo.png", // Optional icon URL for your site
})
if err != nil {
panic(err)
return nil, err
}
return webAuthn
return webAuthn, nil
}
// WebAuthnID
@ -84,17 +85,17 @@ func (user *User) CredentialExcludeList() []protocol.CredentialDescriptor {
return credentialExcludeList
}
func (user *User) AddCredentials(credential webauthn.Credential, isGlobalAdmin bool) bool {
func (user *User) AddCredentials(credential webauthn.Credential, isGlobalAdmin bool) (bool, error) {
user.WebauthnCredentials = append(user.WebauthnCredentials, credential)
return UpdateUser(user.GetId(), user, []string{"webauthnCredentials"}, isGlobalAdmin)
}
func (user *User) DeleteCredentials(credentialIdBase64 string) bool {
func (user *User) DeleteCredentials(credentialIdBase64 string) (bool, error) {
for i, credential := range user.WebauthnCredentials {
if base64.StdEncoding.EncodeToString(credential.ID) == credentialIdBase64 {
user.WebauthnCredentials = append(user.WebauthnCredentials[0:i], user.WebauthnCredentials[i+1:]...)
return UpdateUserForAllFields(user.GetId(), user)
}
}
return false
return false, nil
}

View File

@ -151,21 +151,24 @@ func AddToVerificationRecord(user *User, provider *Provider, remoteAddr, recordT
return nil
}
func getVerificationRecord(dest string) *VerificationRecord {
func getVerificationRecord(dest string) (*VerificationRecord, error) {
var record VerificationRecord
record.Receiver = dest
has, err := adapter.Engine.Desc("time").Where("is_used = false").Get(&record)
if err != nil {
panic(err)
return nil, err
}
if !has {
return nil
return nil, nil
}
return &record
return &record, nil
}
func CheckVerificationCode(dest, code, lang string) *VerifyResult {
record := getVerificationRecord(dest)
record, err := getVerificationRecord(dest)
if err != nil {
panic(err)
}
if record == nil {
return &VerifyResult{noRecordError, i18n.Translate(lang, "verification:Code has not been sent yet!")}
@ -188,17 +191,15 @@ func CheckVerificationCode(dest, code, lang string) *VerifyResult {
return &VerifyResult{VerificationSuccess, ""}
}
func DisableVerificationCode(dest string) {
record := getVerificationRecord(dest)
if record == nil {
func DisableVerificationCode(dest string) (err error) {
record, err := getVerificationRecord(dest)
if record == nil || err != nil {
return
}
record.IsUsed = true
_, err := adapter.Engine.ID(core.PK{record.Owner, record.Name}).AllCols().Update(record)
if err != nil {
panic(err)
}
_, err = adapter.Engine.ID(core.PK{record.Owner, record.Name}).AllCols().Update(record)
return
}
func CheckSigninCode(user *User, dest, code, lang string) string {

View File

@ -42,100 +42,97 @@ type Webhook struct {
IsEnabled bool `json:"isEnabled"`
}
func GetWebhookCount(owner, organization, field, value string) int {
func GetWebhookCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Webhook{Organization: organization})
if err != nil {
panic(err)
return session.Count(&Webhook{Organization: organization})
}
return int(count)
}
func GetWebhooks(owner string, organization string) []*Webhook {
func GetWebhooks(owner string, organization string) ([]*Webhook, error) {
webhooks := []*Webhook{}
err := adapter.Engine.Desc("created_time").Find(&webhooks, &Webhook{Owner: owner, Organization: organization})
if err != nil {
panic(err)
return webhooks, err
}
return webhooks
return webhooks, nil
}
func GetPaginationWebhooks(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Webhook {
func GetPaginationWebhooks(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Webhook, error) {
webhooks := []*Webhook{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&webhooks, &Webhook{Organization: organization})
if err != nil {
panic(err)
return nil, err
}
return webhooks
return webhooks, nil
}
func getWebhooksByOrganization(organization string) []*Webhook {
func getWebhooksByOrganization(organization string) ([]*Webhook, error) {
webhooks := []*Webhook{}
err := adapter.Engine.Desc("created_time").Find(&webhooks, &Webhook{Organization: organization})
if err != nil {
panic(err)
return webhooks, err
}
return webhooks
return webhooks, nil
}
func getWebhook(owner string, name string) *Webhook {
func getWebhook(owner string, name string) (*Webhook, error) {
if owner == "" || name == "" {
return nil
return nil, nil
}
webhook := Webhook{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&webhook)
if err != nil {
panic(err)
return &webhook, err
}
if existed {
return &webhook
return &webhook, nil
} else {
return nil
return nil, nil
}
}
func GetWebhook(id string) *Webhook {
func GetWebhook(id string) (*Webhook, error) {
owner, name := util.GetOwnerAndNameFromId(id)
return getWebhook(owner, name)
}
func UpdateWebhook(id string, webhook *Webhook) bool {
func UpdateWebhook(id string, webhook *Webhook) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id)
if getWebhook(owner, name) == nil {
return false
if w, err := getWebhook(owner, name); err != nil {
return false, err
} else if w == nil {
return false, nil
}
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(webhook)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func AddWebhook(webhook *Webhook) bool {
func AddWebhook(webhook *Webhook) (bool, error) {
affected, err := adapter.Engine.Insert(webhook)
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func DeleteWebhook(webhook *Webhook) bool {
func DeleteWebhook(webhook *Webhook) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{webhook.Owner, webhook.Name}).Delete(&Webhook{})
if err != nil {
panic(err)
return false, err
}
return affected != 0
return affected != 0, nil
}
func (p *Webhook) GetId() string {

Some files were not shown because too many files have changed in this diff Show More