feat: return most backend API errors to frontend (#1836)

* feat: return most backend API errros to frontend

Signed-off-by: yehong <239859435@qq.com>

* refactor: reduce int type change

Signed-off-by: yehong <239859435@qq.com>

* feat: return err backend in token.go

Signed-off-by: yehong <239859435@qq.com>

---------

Signed-off-by: yehong <239859435@qq.com>
This commit is contained in:
yehong
2023-05-30 15:49:39 +08:00
committed by GitHub
parent 34151c0095
commit 02e692a300
105 changed files with 3788 additions and 1734 deletions

View File

@ -154,7 +154,11 @@ func IsAllowed(subOwner string, subName string, method string, urlPath string, o
} }
} }
user := object.GetUser(util.GetId(subOwner, subName)) user, err := object.GetUser(util.GetId(subOwner, subName))
if err != nil {
panic(err)
}
if user != nil && user.IsAdmin && (subOwner == objOwner || (objOwner == "admin")) { if user != nil && user.IsAdmin && (subOwner == objOwner || (objOwner == "admin")) {
return true return true
} }

View File

@ -16,7 +16,6 @@ package conf
import ( import (
"encoding/json" "encoding/json"
"fmt"
"os" "os"
"runtime" "runtime"
"strconv" "strconv"
@ -73,14 +72,13 @@ func GetConfigString(key string) string {
return res return res
} }
func GetConfigBool(key string) (bool, error) { func GetConfigBool(key string) bool {
value := GetConfigString(key) value := GetConfigString(key)
if value == "true" { if value == "true" {
return true, nil return true
} else if value == "false" { } else {
return false, nil return false
} }
return false, fmt.Errorf("value %s cannot be converted into bool", value)
} }
func GetConfigInt64(key string) (int64, error) { func GetConfigInt64(key string) (int64, error) {

View File

@ -87,7 +87,7 @@ func TestGetConfBool(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
for _, scenery := range scenarios { for _, scenery := range scenarios {
t.Run(scenery.description, func(t *testing.T) { t.Run(scenery.description, func(t *testing.T) {
actual, err := GetConfigBool(scenery.input) actual := GetConfigBool(scenery.input)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, scenery.expected, actual) assert.Equal(t, scenery.expected, actual)
}) })

View File

@ -78,13 +78,23 @@ func (c *ApiController) Signup() {
return return
} }
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
if !application.EnableSignUp { if !application.EnableSignUp {
c.ResponseError(c.T("account:The application does not allow to sign up new account")) c.ResponseError(c.T("account:The application does not allow to sign up new account"))
return return
} }
organization := object.GetOrganization(util.GetId("admin", authForm.Organization)) organization, err := object.GetOrganization(util.GetId("admin", authForm.Organization))
if err != nil {
c.ResponseError(c.T(err.Error()))
return
}
msg := object.CheckUserSignup(application, organization, &authForm, c.GetAcceptLanguage()) msg := object.CheckUserSignup(application, organization, &authForm, c.GetAcceptLanguage())
if msg != "" { if msg != "" {
c.ResponseError(msg) c.ResponseError(msg)
@ -111,7 +121,11 @@ func (c *ApiController) Signup() {
id := util.GenerateId() id := util.GenerateId()
if application.GetSignupItemRule("ID") == "Incremental" { if application.GetSignupItemRule("ID") == "Incremental" {
lastUser := object.GetLastUser(authForm.Organization) lastUser, err := object.GetLastUser(authForm.Organization)
if err != nil {
c.ResponseError(err.Error())
return
}
lastIdInt := -1 lastIdInt := -1
if lastUser != nil { if lastUser != nil {
@ -173,25 +187,47 @@ func (c *ApiController) Signup() {
} }
} }
affected := object.AddUser(user) affected, err := object.AddUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
if !affected { if !affected {
c.ResponseError(c.T("account:Failed to add user"), util.StructToJson(user)) c.ResponseError(c.T("account:Failed to add user"), util.StructToJson(user))
return return
} }
object.AddUserToOriginalDatabase(user) err = object.AddUserToOriginalDatabase(user)
if err != nil {
c.ResponseError(err.Error())
return
}
if application.HasPromptPage() { if application.HasPromptPage() {
// The prompt page needs the user to be signed in // The prompt page needs the user to be signed in
c.SetSessionUsername(user.GetId()) c.SetSessionUsername(user.GetId())
} }
object.DisableVerificationCode(authForm.Email) err = object.DisableVerificationCode(authForm.Email)
object.DisableVerificationCode(checkPhone) if err != nil {
c.ResponseError(err.Error())
return
}
err = object.DisableVerificationCode(checkPhone)
if err != nil {
c.ResponseError(err.Error())
return
}
isSignupFromPricing := authForm.Plan != "" && authForm.Pricing != "" isSignupFromPricing := authForm.Plan != "" && authForm.Pricing != ""
if isSignupFromPricing { if isSignupFromPricing {
object.Subscribe(organization.Name, user.Name, authForm.Plan, authForm.Pricing) _, err = object.Subscribe(organization.Name, user.Name, authForm.Plan, authForm.Pricing)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
record := object.NewRecord(c.Ctx) record := object.NewRecord(c.Ctx)
@ -231,7 +267,11 @@ func (c *ApiController) Logout() {
c.ClearUserSession() c.ClearUserSession()
owner, username := util.GetOwnerAndNameFromId(user) owner, username := util.GetOwnerAndNameFromId(user)
object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID()) _, err := object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID())
if err != nil {
c.ResponseError(err.Error())
return
}
util.LogInfo(c.Ctx, "API: [%s] logged out", user) util.LogInfo(c.Ctx, "API: [%s] logged out", user)
@ -252,7 +292,12 @@ func (c *ApiController) Logout() {
return return
} }
affected, application, token := object.ExpireTokenByAccessToken(accessToken) affected, application, token, err := object.ExpireTokenByAccessToken(accessToken)
if err != nil {
c.ResponseError(err.Error())
return
}
if !affected { if !affected {
c.ResponseError(c.T("token:Token not found, invalid accessToken")) c.ResponseError(c.T("token:Token not found, invalid accessToken"))
return return
@ -272,7 +317,12 @@ func (c *ApiController) Logout() {
// TODO https://github.com/casdoor/casdoor/pull/1494#discussion_r1095675265 // TODO https://github.com/casdoor/casdoor/pull/1494#discussion_r1095675265
owner, username := util.GetOwnerAndNameFromId(user) owner, username := util.GetOwnerAndNameFromId(user)
object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID()) _, err := object.DeleteSessionId(util.GetSessionId(owner, username, object.CasdoorApplication), c.Ctx.Input.CruSession.SessionID())
if err != nil {
c.ResponseError(err.Error())
return
}
util.LogInfo(c.Ctx, "API: [%s] logged out", user) util.LogInfo(c.Ctx, "API: [%s] logged out", user)
c.Ctx.Redirect(http.StatusFound, fmt.Sprintf("%s?state=%s", strings.TrimRight(redirectUri, "/"), state)) c.Ctx.Redirect(http.StatusFound, fmt.Sprintf("%s?state=%s", strings.TrimRight(redirectUri, "/"), state))
@ -290,6 +340,7 @@ func (c *ApiController) Logout() {
// @Success 200 {object} controllers.Response The Response object // @Success 200 {object} controllers.Response The Response object
// @router /get-account [get] // @router /get-account [get]
func (c *ApiController) GetAccount() { func (c *ApiController) GetAccount() {
var err error
user, ok := c.RequireSignedInUser() user, ok := c.RequireSignedInUser()
if !ok { if !ok {
return return
@ -297,20 +348,39 @@ func (c *ApiController) GetAccount() {
managedAccounts := c.Input().Get("managedAccounts") managedAccounts := c.Input().Get("managedAccounts")
if managedAccounts == "1" { if managedAccounts == "1" {
user = object.ExtendManagedAccountsWithUser(user) user, err = object.ExtendManagedAccountsWithUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
object.ExtendUserWithRolesAndPermissions(user) err = object.ExtendUserWithRolesAndPermissions(user)
if err != nil {
c.ResponseError(err.Error())
return
}
user.Permissions = object.GetMaskedPermissions(user.Permissions) user.Permissions = object.GetMaskedPermissions(user.Permissions)
user.Roles = object.GetMaskedRoles(user.Roles) user.Roles = object.GetMaskedRoles(user.Roles)
organization := object.GetMaskedOrganization(object.GetOrganizationByUser(user)) organization, err := object.GetMaskedOrganization(object.GetOrganizationByUser(user))
if err != nil {
c.ResponseError(err.Error())
return
}
u, err := object.GetMaskedUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
resp := Response{ resp := Response{
Status: "ok", Status: "ok",
Sub: user.Id, Sub: user.Id,
Name: user.Name, Name: user.Name,
Data: object.GetMaskedUser(user), Data: u,
Data2: organization, Data2: organization,
} }
c.Data["json"] = resp c.Data["json"] = resp
@ -391,7 +461,12 @@ func (c *ApiController) GetCaptcha() {
if captchaProvider != nil { if captchaProvider != nil {
if captchaProvider.Type == "Default" { if captchaProvider.Type == "Default" {
id, img := object.GetCaptcha() id, img, err := object.GetCaptcha()
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(Captcha{Type: captchaProvider.Type, CaptchaId: id, CaptchaImage: img}) c.ResponseOk(Captcha{Type: captchaProvider.Type, CaptchaId: id, CaptchaImage: img})
return return
} else if captchaProvider.Type != "" { } else if captchaProvider.Type != "" {

View File

@ -40,21 +40,35 @@ func (c *ApiController) GetApplications() {
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization") organization := c.Input().Get("organization")
var err error
if limit == "" || page == "" { if limit == "" || page == "" {
var applications []*object.Application var applications []*object.Application
if organization == "" { if organization == "" {
applications = object.GetApplications(owner) applications, err = object.GetApplications(owner)
} else { } else {
applications = object.GetOrganizationApplications(owner, organization) applications, err = object.GetOrganizationApplications(owner, organization)
}
if err != nil {
panic(err)
} }
c.Data["json"] = object.GetMaskedApplications(applications, userId) c.Data["json"] = object.GetMaskedApplications(applications, userId)
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetApplicationCount(owner, field, value))) count, err := object.GetApplicationCount(owner, field, value)
applications := object.GetMaskedApplications(object.GetPaginationApplications(owner, paginator.Offset(), limit, field, value, sortField, sortOrder), userId) if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
app, err := object.GetPaginationApplications(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
panic(err)
}
applications := object.GetMaskedApplications(app, userId)
c.ResponseOk(applications, paginator.Nums()) c.ResponseOk(applications, paginator.Nums())
} }
} }
@ -69,8 +83,12 @@ func (c *ApiController) GetApplications() {
func (c *ApiController) GetApplication() { func (c *ApiController) GetApplication() {
userId := c.GetSessionUsername() userId := c.GetSessionUsername()
id := c.Input().Get("id") id := c.Input().Get("id")
app, err := object.GetApplication(id)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedApplication(object.GetApplication(id), userId) c.Data["json"] = object.GetMaskedApplication(app, userId)
c.ServeJSON() c.ServeJSON()
} }
@ -84,13 +102,22 @@ func (c *ApiController) GetApplication() {
func (c *ApiController) GetUserApplication() { func (c *ApiController) GetUserApplication() {
userId := c.GetSessionUsername() userId := c.GetSessionUsername()
id := c.Input().Get("id") id := c.Input().Get("id")
user := object.GetUser(id) user, err := object.GetUser(id)
if err != nil {
panic(err)
}
if user == nil { if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), id)) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), id))
return return
} }
c.Data["json"] = object.GetMaskedApplication(object.GetApplicationByUser(user), userId) app, err := object.GetApplicationByUser(user)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedApplication(app, userId)
c.ServeJSON() c.ServeJSON()
} }
@ -118,13 +145,30 @@ func (c *ApiController) GetOrganizationApplications() {
} }
if limit == "" || page == "" { if limit == "" || page == "" {
applications := object.GetOrganizationApplications(owner, organization) applications, err := object.GetOrganizationApplications(owner, organization)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedApplications(applications, userId) c.Data["json"] = object.GetMaskedApplications(applications, userId)
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetOrganizationApplicationCount(owner, organization, field, value)))
applications := object.GetMaskedApplications(object.GetPaginationOrganizationApplications(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder), userId) count, err := object.GetOrganizationApplicationCount(owner, organization, field, value)
if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
app, err := object.GetPaginationOrganizationApplications(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
applications := object.GetMaskedApplications(app, userId)
c.ResponseOk(applications, paginator.Nums()) c.ResponseOk(applications, paginator.Nums())
} }
} }
@ -166,8 +210,13 @@ func (c *ApiController) AddApplication() {
return return
} }
count := object.GetApplicationCount("", "", "") count, err := object.GetApplicationCount("", "", "")
if err := checkQuotaForApplication(count); err != nil { if err != nil {
c.ResponseError(err.Error())
return
}
if err := checkQuotaForApplication(int(count)); err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
} }

View File

@ -93,7 +93,12 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob
c.ResponseError(c.T("auth:Challenge method should be S256")) c.ResponseError(c.T("auth:Challenge method should be S256"))
return return
} }
code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge, c.Ctx.Request.Host, c.GetAcceptLanguage()) code, err := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge, c.Ctx.Request.Host, c.GetAcceptLanguage())
if err != nil {
c.ResponseError(err.Error(), nil)
return
}
resp = codeToResponse(code) resp = codeToResponse(code)
if application.EnableSigninSession || application.HasPromptPage() { if application.EnableSigninSession || application.HasPromptPage() {
@ -142,12 +147,16 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob
} }
if resp.Status == "ok" { if resp.Status == "ok" {
object.AddSession(&object.Session{ _, err = object.AddSession(&object.Session{
Owner: user.Owner, Owner: user.Owner,
Name: user.Name, Name: user.Name,
Application: application.Name, Application: application.Name,
SessionId: []string{c.Ctx.Input.CruSession.SessionID()}, SessionId: []string{c.Ctx.Input.CruSession.SessionID()},
}) })
if err != nil {
c.ResponseError(err.Error(), nil)
return
}
} }
return resp return resp
@ -171,7 +180,12 @@ func (c *ApiController) GetApplicationLogin() {
scope := c.Input().Get("scope") scope := c.Input().Get("scope")
state := c.Input().Get("state") state := c.Input().Get("state")
msg, application := object.CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, c.GetAcceptLanguage()) msg, application, err := object.CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, c.GetAcceptLanguage())
if err != nil {
c.ResponseError(err.Error())
return
}
application = object.GetMaskedApplication(application, "") application = object.GetMaskedApplication(application, "")
if msg != "" { if msg != "" {
c.ResponseError(msg, application) c.ResponseError(msg, application)
@ -248,7 +262,10 @@ func (c *ApiController) Login() {
var msg string var msg string
if authForm.Password == "" { if authForm.Password == "" {
if user = object.GetUserByFields(authForm.Organization, authForm.Username); user == nil { if user, err = object.GetUserByFields(authForm.Organization, authForm.Username); err != nil {
c.ResponseError(err.Error(), nil)
return
} else if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(authForm.Organization, authForm.Username))) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(authForm.Organization, authForm.Username)))
return return
} }
@ -272,9 +289,18 @@ func (c *ApiController) Login() {
} }
// disable the verification code // disable the verification code
object.DisableVerificationCode(checkDest) err := object.DisableVerificationCode(checkDest)
if err != nil {
c.ResponseError(err.Error(), nil)
return
}
} else { } else {
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error(), nil)
return
}
if application == nil { if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return return
@ -284,7 +310,10 @@ func (c *ApiController) Login() {
return return
} }
var enableCaptcha bool var enableCaptcha bool
if enableCaptcha = object.CheckToEnableCaptcha(application, authForm.Organization, authForm.Username); enableCaptcha { if enableCaptcha, err = object.CheckToEnableCaptcha(application, authForm.Organization, authForm.Username); err != nil {
c.ResponseError(err.Error())
return
} else if enableCaptcha {
isHuman, err := captcha.VerifyCaptchaByCaptchaType(authForm.CaptchaType, authForm.CaptchaToken, authForm.ClientSecret) isHuman, err := captcha.VerifyCaptchaByCaptchaType(authForm.CaptchaType, authForm.CaptchaToken, authForm.ClientSecret)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
@ -304,7 +333,12 @@ func (c *ApiController) Login() {
if msg != "" { if msg != "" {
resp = &Response{Status: "error", Msg: msg} resp = &Response{Status: "error", Msg: msg}
} else { } else {
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil { if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return return
@ -312,7 +346,11 @@ func (c *ApiController) Login() {
resp = c.HandleLoggedIn(application, user, &authForm) resp = c.HandleLoggedIn(application, user, &authForm)
organization := object.GetOrganizationByUser(user) organization, err := object.GetOrganizationByUser(user)
if err != nil {
c.ResponseError(err.Error())
}
if user != nil && organization.HasRequiredMfa() && !user.IsMfaEnabled() { if user != nil && organization.HasRequiredMfa() && !user.IsMfaEnabled() {
resp.Msg = object.RequiredMfa resp.Msg = object.RequiredMfa
} }
@ -325,18 +363,34 @@ func (c *ApiController) Login() {
} else if authForm.Provider != "" { } else if authForm.Provider != "" {
var application *object.Application var application *object.Application
if authForm.ClientId != "" { if authForm.ClientId != "" {
application = object.GetApplicationByClientId(authForm.ClientId) application, err = object.GetApplicationByClientId(authForm.ClientId)
if err != nil {
c.ResponseError(err.Error())
return
}
} else { } else {
application = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) application, err = object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
} }
if application == nil { if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return return
} }
organization, err := object.GetOrganization(util.GetId("admin", application.Organization))
if err != nil {
c.ResponseError(c.T(err.Error()))
}
provider, err := object.GetProvider(util.GetId("admin", authForm.Provider))
if err != nil {
c.ResponseError(err.Error())
return
}
organization := object.GetOrganization(util.GetId("admin", application.Organization))
provider := object.GetProvider(util.GetId("admin", authForm.Provider))
providerItem := application.GetProviderItem(provider.Name) providerItem := application.GetProviderItem(provider.Name)
if !providerItem.IsProviderVisible() { if !providerItem.IsProviderVisible() {
c.ResponseError(fmt.Sprintf(c.T("auth:The provider: %s is not enabled for the application"), provider.Name)) c.ResponseError(fmt.Sprintf(c.T("auth:The provider: %s is not enabled for the application"), provider.Name))
@ -396,9 +450,17 @@ func (c *ApiController) Login() {
if authForm.Method == "signup" { if authForm.Method == "signup" {
user := &object.User{} user := &object.User{}
if provider.Category == "SAML" { if provider.Category == "SAML" {
user = object.GetUser(util.GetId(application.Organization, userInfo.Id)) user, err = object.GetUser(util.GetId(application.Organization, userInfo.Id))
if err != nil {
c.ResponseError(err.Error())
return
}
} else if provider.Category == "OAuth" { } else if provider.Category == "OAuth" {
user = object.GetUserByField(application.Organization, provider.Type, userInfo.Id) user, err = object.GetUserByField(application.Organization, provider.Type, userInfo.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
if user != nil && !user.IsDeleted { if user != nil && !user.IsDeleted {
@ -419,12 +481,20 @@ func (c *ApiController) Login() {
if application.EnableLinkWithEmail { if application.EnableLinkWithEmail {
if userInfo.Email != "" { if userInfo.Email != "" {
// Find existing user with Email // Find existing user with Email
user = object.GetUserByField(application.Organization, "email", userInfo.Email) user, err = object.GetUserByField(application.Organization, "email", userInfo.Email)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
if user == nil && userInfo.Phone != "" { if user == nil && userInfo.Phone != "" {
// Find existing user with phone number // Find existing user with phone number
user = object.GetUserByField(application.Organization, "phone", userInfo.Phone) user, err = object.GetUserByField(application.Organization, "phone", userInfo.Phone)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
} }
@ -440,7 +510,12 @@ func (c *ApiController) Login() {
} }
// Handle username conflicts // Handle username conflicts
tmpUser := object.GetUser(util.GetId(application.Organization, userInfo.Username)) tmpUser, err := object.GetUser(util.GetId(application.Organization, userInfo.Username))
if err != nil {
c.ResponseError(err.Error())
return
}
if tmpUser != nil { if tmpUser != nil {
uid, err := uuid.NewRandom() uid, err := uuid.NewRandom()
if err != nil { if err != nil {
@ -453,7 +528,13 @@ func (c *ApiController) Login() {
} }
properties := map[string]string{} properties := map[string]string{}
properties["no"] = strconv.Itoa(object.GetUserCount(application.Organization, "", "") + 2) count, err := object.GetUserCount(application.Organization, "", "")
if err != nil {
c.ResponseError(err.Error())
return
}
properties["no"] = strconv.Itoa(int(count + 2))
initScore, err := organization.GetInitScore() initScore, err := organization.GetInitScore()
if err != nil { if err != nil {
c.ResponseError(fmt.Errorf(c.T("account:Get init score failed, error: %w"), err).Error()) c.ResponseError(fmt.Errorf(c.T("account:Get init score failed, error: %w"), err).Error())
@ -482,7 +563,12 @@ func (c *ApiController) Login() {
Properties: properties, Properties: properties,
} }
affected := object.AddUser(user) affected, err := object.AddUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
if !affected { if !affected {
c.ResponseError(fmt.Sprintf(c.T("auth:Failed to create user, user information is invalid: %s"), util.StructToJson(user))) c.ResponseError(fmt.Sprintf(c.T("auth:Failed to create user, user information is invalid: %s"), util.StructToJson(user)))
return return
@ -490,8 +576,17 @@ func (c *ApiController) Login() {
} }
// sync info from 3rd-party if possible // sync info from 3rd-party if possible
object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) _, err := object.SetUserOAuthProperties(organization, user, provider.Type, userInfo)
object.LinkUserAccount(user, provider.Type, userInfo.Id) if err != nil {
c.ResponseError(err.Error())
return
}
_, err = object.LinkUserAccount(user, provider.Type, userInfo.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
resp = c.HandleLoggedIn(application, user, &authForm) resp = c.HandleLoggedIn(application, user, &authForm)
@ -516,18 +611,36 @@ func (c *ApiController) Login() {
return return
} }
oldUser := object.GetUserByField(application.Organization, provider.Type, userInfo.Id) oldUser, err := object.GetUserByField(application.Organization, provider.Type, userInfo.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
if oldUser != nil { if oldUser != nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The account for provider: %s and username: %s (%s) is already linked to another account: %s (%s)"), provider.Type, userInfo.Username, userInfo.DisplayName, oldUser.Name, oldUser.DisplayName)) c.ResponseError(fmt.Sprintf(c.T("auth:The account for provider: %s and username: %s (%s) is already linked to another account: %s (%s)"), provider.Type, userInfo.Username, userInfo.DisplayName, oldUser.Name, oldUser.DisplayName))
return return
} }
user := object.GetUser(userId) user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
// sync info from 3rd-party if possible // sync info from 3rd-party if possible
object.SetUserOAuthProperties(organization, user, provider.Type, userInfo) _, err = object.SetUserOAuthProperties(organization, user, provider.Type, userInfo)
if err != nil {
c.ResponseError(err.Error())
return
}
isLinked, err := object.LinkUserAccount(user, provider.Type, userInfo.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
isLinked := object.LinkUserAccount(user, provider.Type, userInfo.Id)
if isLinked { if isLinked {
resp = &Response{Status: "ok", Msg: "", Data: isLinked} resp = &Response{Status: "ok", Msg: "", Data: isLinked}
} else { } else {
@ -536,7 +649,11 @@ func (c *ApiController) Login() {
} }
} else if c.getMfaSessionData() != nil { } else if c.getMfaSessionData() != nil {
mfaSession := c.getMfaSessionData() mfaSession := c.getMfaSessionData()
user := object.GetUser(mfaSession.UserId) user, err := object.GetUser(mfaSession.UserId)
if err != nil {
c.ResponseError(err.Error())
return
}
if authForm.Passcode != "" { if authForm.Passcode != "" {
MfaUtil := object.GetMfaUtil(authForm.MfaType, user.GetPreferMfa(false)) MfaUtil := object.GetMfaUtil(authForm.MfaType, user.GetPreferMfa(false))
@ -554,7 +671,12 @@ func (c *ApiController) Login() {
} }
} }
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil { if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return return
@ -569,7 +691,12 @@ func (c *ApiController) Login() {
} else { } else {
if c.GetSessionUsername() != "" { if c.GetSessionUsername() != "" {
// user already signed in to Casdoor, so let the user click the avatar button to do the quick sign-in // user already signed in to Casdoor, so let the user click the avatar button to do the quick sign-in
application := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application)) application, err := object.GetApplication(fmt.Sprintf("admin/%s", authForm.Application))
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil { if application == nil {
c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application)) c.ResponseError(fmt.Sprintf(c.T("auth:The application: %s does not exist"), authForm.Application))
return return
@ -624,8 +751,9 @@ func (c *ApiController) HandleSamlLogin() {
func (c *ApiController) HandleOfficialAccountEvent() { func (c *ApiController) HandleOfficialAccountEvent() {
respBytes, err := ioutil.ReadAll(c.Ctx.Request.Body) respBytes, err := ioutil.ReadAll(c.Ctx.Request.Body)
if err != nil { if err != nil {
c.ResponseError(err.Error()) panic(err)
} }
var data struct { var data struct {
MsgType string `xml:"MsgType"` MsgType string `xml:"MsgType"`
Event string `xml:"Event"` Event string `xml:"Event"`
@ -633,8 +761,9 @@ func (c *ApiController) HandleOfficialAccountEvent() {
} }
err = xml.Unmarshal(respBytes, &data) err = xml.Unmarshal(respBytes, &data)
if err != nil { if err != nil {
c.ResponseError(err.Error()) panic(err)
} }
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
if data.EventKey != "" { if data.EventKey != "" {
@ -670,7 +799,12 @@ func (c *ApiController) GetWebhookEventType() {
func (c *ApiController) GetCaptchaStatus() { func (c *ApiController) GetCaptchaStatus() {
organization := c.Input().Get("organization") organization := c.Input().Get("organization")
userId := c.Input().Get("user_id") userId := c.Input().Get("user_id")
user := object.GetUserByFields(organization, userId) user, err := object.GetUserByFields(organization, userId)
if err != nil {
c.ResponseError(err.Error())
return
}
var captchaEnabled bool var captchaEnabled bool
if user != nil && user.SigninWrongTimes >= object.SigninWrongTimesLimit { if user != nil && user.SigninWrongTimes >= object.SigninWrongTimesLimit {
captchaEnabled = true captchaEnabled = true

View File

@ -72,11 +72,15 @@ func (c *ApiController) isGlobalAdmin() (bool, *object.User) {
func (c *ApiController) getCurrentUser() *object.User { func (c *ApiController) getCurrentUser() *object.User {
var user *object.User var user *object.User
var err error
userId := c.GetSessionUsername() userId := c.GetSessionUsername()
if userId == "" { if userId == "" {
user = nil user = nil
} else { } else {
user = object.GetUser(userId) user, err = object.GetUser(userId)
if err != nil {
panic(err)
}
} }
return user return user
} }
@ -106,7 +110,11 @@ func (c *ApiController) GetSessionApplication() *object.Application {
if clientId == nil { if clientId == nil {
return nil return nil
} }
application := object.GetApplicationByClientId(clientId.(string)) application, err := object.GetApplicationByClientId(clientId.(string))
if err != nil {
panic(err)
}
return application return application
} }
@ -192,8 +200,10 @@ func (c *ApiController) setExpireForSession() {
}) })
} }
func wrapActionResponse(affected bool) *Response { func wrapActionResponse(affected bool, e ...error) *Response {
if affected { if len(e) != 0 && e[0] != nil {
return &Response{Status: "error", Msg: e[0].Error()}
} else if affected {
return &Response{Status: "ok", Msg: "", Data: "Affected"} return &Response{Status: "ok", Msg: "", Data: "Affected"}
} else { } else {
return &Response{Status: "ok", Msg: "", Data: "Unaffected"} return &Response{Status: "ok", Msg: "", Data: "Unaffected"}

View File

@ -33,19 +33,40 @@ func (c *ApiController) GetCasbinAdapters() {
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization") organization := c.Input().Get("organization")
if limit == "" || page == "" { if limit == "" || page == "" {
adapters := object.GetCasbinAdapters(owner, organization) adapters, err := object.GetCasbinAdapters(owner, organization)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(adapters) c.ResponseOk(adapters)
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetCasbinAdapterCount(owner, organization, field, value))) count, err := object.GetCasbinAdapterCount(owner, organization, field, value)
adapters := object.GetPaginationCasbinAdapters(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
adapters, err := object.GetPaginationCasbinAdapters(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(adapters, paginator.Nums()) c.ResponseOk(adapters, paginator.Nums())
} }
} }
func (c *ApiController) GetCasbinAdapter() { func (c *ApiController) GetCasbinAdapter() {
id := c.Input().Get("id") id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id) adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(adapter) c.ResponseOk(adapter)
} }
@ -89,7 +110,11 @@ func (c *ApiController) DeleteCasbinAdapter() {
func (c *ApiController) SyncPolicies() { func (c *ApiController) SyncPolicies() {
id := c.Input().Get("id") id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id) adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
policies, err := object.SyncPolicies(adapter) policies, err := object.SyncPolicies(adapter)
if err != nil { if err != nil {
@ -102,9 +127,14 @@ func (c *ApiController) SyncPolicies() {
func (c *ApiController) UpdatePolicy() { func (c *ApiController) UpdatePolicy() {
id := c.Input().Get("id") id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id) adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
var policies []xormadapter.CasbinRule var policies []xormadapter.CasbinRule
err := json.Unmarshal(c.Ctx.Input.RequestBody, &policies) err = json.Unmarshal(c.Ctx.Input.RequestBody, &policies)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
@ -121,9 +151,14 @@ func (c *ApiController) UpdatePolicy() {
func (c *ApiController) AddPolicy() { func (c *ApiController) AddPolicy() {
id := c.Input().Get("id") id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id) adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
var policy xormadapter.CasbinRule var policy xormadapter.CasbinRule
err := json.Unmarshal(c.Ctx.Input.RequestBody, &policy) err = json.Unmarshal(c.Ctx.Input.RequestBody, &policy)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
@ -140,9 +175,14 @@ func (c *ApiController) AddPolicy() {
func (c *ApiController) RemovePolicy() { func (c *ApiController) RemovePolicy() {
id := c.Input().Get("id") id := c.Input().Get("id")
adapter := object.GetCasbinAdapter(id) adapter, err := object.GetCasbinAdapter(id)
if err != nil {
c.ResponseError(err.Error())
return
}
var policy xormadapter.CasbinRule var policy xormadapter.CasbinRule
err := json.Unmarshal(c.Ctx.Input.RequestBody, &policy) err = json.Unmarshal(c.Ctx.Input.RequestBody, &policy)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return

View File

@ -37,13 +37,28 @@ func (c *ApiController) GetCerts() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedCerts(object.GetCerts(owner)) maskedCerts, err := object.GetMaskedCerts(object.GetCerts(owner))
if err != nil {
panic(err)
}
c.Data["json"] = maskedCerts
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetCertCount(owner, field, value))) count, err := object.GetCertCount(owner, field, value)
certs := object.GetMaskedCerts(object.GetPaginationCerts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
certs, err := object.GetMaskedCerts(object.GetPaginationCerts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
if err != nil {
panic(err)
}
c.ResponseOk(certs, paginator.Nums()) c.ResponseOk(certs, paginator.Nums())
} }
} }
@ -61,13 +76,28 @@ func (c *ApiController) GetGlobleCerts() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedCerts(object.GetGlobleCerts()) maskedCerts, err := object.GetMaskedCerts(object.GetGlobleCerts())
if err != nil {
panic(err)
}
c.Data["json"] = maskedCerts
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalCertsCount(field, value))) count, err := object.GetGlobalCertsCount(field, value)
certs := object.GetMaskedCerts(object.GetPaginationGlobalCerts(paginator.Offset(), limit, field, value, sortField, sortOrder)) if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
certs, err := object.GetMaskedCerts(object.GetPaginationGlobalCerts(paginator.Offset(), limit, field, value, sortField, sortOrder))
if err != nil {
panic(err)
}
c.ResponseOk(certs, paginator.Nums()) c.ResponseOk(certs, paginator.Nums())
} }
} }
@ -81,8 +111,12 @@ func (c *ApiController) GetGlobleCerts() {
// @router /get-cert [get] // @router /get-cert [get]
func (c *ApiController) GetCert() { func (c *ApiController) GetCert() {
id := c.Input().Get("id") id := c.Input().Get("id")
cert, err := object.GetCert(id)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedCert(object.GetCert(id)) c.Data["json"] = object.GetMaskedCert(cert)
c.ServeJSON() c.ServeJSON()
} }

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetChats() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedChats(object.GetChats(owner)) maskedChats, err := object.GetMaskedChats(object.GetChats(owner))
if err != nil {
panic(err)
}
c.Data["json"] = maskedChats
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetChatCount(owner, field, value))) count, err := object.GetChatCount(owner, field, value)
chats := object.GetMaskedChats(object.GetPaginationChats(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
chats, err := object.GetMaskedChats(object.GetPaginationChats(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(chats, paginator.Nums()) c.ResponseOk(chats, paginator.Nums())
} }
} }
@ -58,7 +75,12 @@ func (c *ApiController) GetChats() {
func (c *ApiController) GetChat() { func (c *ApiController) GetChat() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetMaskedChat(object.GetChat(id)) maskedChat, err := object.GetMaskedChat(object.GetChat(id))
if err != nil {
panic(err)
}
c.Data["json"] = maskedChat
c.ServeJSON() c.ServeJSON()
} }

View File

@ -34,8 +34,7 @@ func (c *ApiController) Enforce() {
} }
if permissionId != "" { if permissionId != "" {
c.Data["json"] = object.Enforce(permissionId, &request) c.ResponseOk(object.Enforce(permissionId, &request))
c.ServeJSON()
return return
} }
@ -44,9 +43,15 @@ func (c *ApiController) Enforce() {
if modelId != "" { if modelId != "" {
owner, modelName := util.GetOwnerAndNameFromId(modelId) owner, modelName := util.GetOwnerAndNameFromId(modelId)
permissions = object.GetPermissionsByModel(owner, modelName) permissions, err = object.GetPermissionsByModel(owner, modelName)
if err != nil {
panic(err)
}
} else { } else {
permissions = object.GetPermissionsByResource(resourceId) permissions, err = object.GetPermissionsByResource(resourceId)
if err != nil {
panic(err)
}
} }
for _, permission := range permissions { for _, permission := range permissions {
@ -63,8 +68,7 @@ func (c *ApiController) BatchEnforce() {
var requests []object.CasbinRequest var requests []object.CasbinRequest
err := json.Unmarshal(c.Ctx.Input.RequestBody, &requests) err := json.Unmarshal(c.Ctx.Input.RequestBody, &requests)
if err != nil { if err != nil {
c.ResponseError(err.Error()) panic(err)
return
} }
if permissionId != "" { if permissionId != "" {
@ -72,14 +76,17 @@ func (c *ApiController) BatchEnforce() {
c.ServeJSON() c.ServeJSON()
} else { } else {
owner, modelName := util.GetOwnerAndNameFromId(modelId) owner, modelName := util.GetOwnerAndNameFromId(modelId)
permissions := object.GetPermissionsByModel(owner, modelName) permissions, err := object.GetPermissionsByModel(owner, modelName)
if err != nil {
panic(err)
}
res := [][]bool{} res := [][]bool{}
for _, permission := range permissions { for _, permission := range permissions {
res = append(res, object.BatchEnforce(permission.GetId(), &requests)) res = append(res, object.BatchEnforce(permission.GetId(), &requests))
} }
c.Data["json"] = res
c.ServeJSON() c.ResponseOk(res)
} }
} }
@ -90,8 +97,7 @@ func (c *ApiController) GetAllObjects() {
return return
} }
c.Data["json"] = object.GetAllObjects(userId) c.ResponseOk(object.GetAllObjects(userId))
c.ServeJSON()
} }
func (c *ApiController) GetAllActions() { func (c *ApiController) GetAllActions() {
@ -101,8 +107,7 @@ func (c *ApiController) GetAllActions() {
return return
} }
c.Data["json"] = object.GetAllActions(userId) c.ResponseOk(object.GetAllActions(userId))
c.ServeJSON()
} }
func (c *ApiController) GetAllRoles() { func (c *ApiController) GetAllRoles() {
@ -112,6 +117,5 @@ func (c *ApiController) GetAllRoles() {
return return
} }
c.Data["json"] = object.GetAllRoles(userId) c.ResponseOk(object.GetAllRoles(userId))
c.ServeJSON()
} }

View File

@ -45,7 +45,11 @@ func (c *ApiController) GetLdapUsers() {
id := c.Input().Get("id") id := c.Input().Get("id")
_, ldapId := util.GetOwnerAndNameFromId(id) _, ldapId := util.GetOwnerAndNameFromId(id)
ldapServer := object.GetLdap(ldapId) ldapServer, err := object.GetLdap(ldapId)
if err != nil {
c.ResponseError(err.Error())
return
}
conn, err := ldapServer.GetLdapConn() conn, err := ldapServer.GetLdapConn()
if err != nil { if err != nil {
@ -76,7 +80,11 @@ func (c *ApiController) GetLdapUsers() {
for i, user := range users { for i, user := range users {
uuids[i] = user.GetLdapUuid() uuids[i] = user.GetLdapUuid()
} }
existUuids := object.GetExistUuids(ldapServer.Owner, uuids) existUuids, err := object.GetExistUuids(ldapServer.Owner, uuids)
if err != nil {
c.ResponseError(err.Error())
return
}
resp := LdapResp{ resp := LdapResp{
Users: object.AutoAdjustLdapUser(users), Users: object.AutoAdjustLdapUser(users),
@ -128,17 +136,23 @@ func (c *ApiController) AddLdap() {
return return
} }
if object.CheckLdapExist(&ldap) { if ok, err := object.CheckLdapExist(&ldap); err != nil {
c.ResponseError(err.Error())
return
} else if ok {
c.ResponseError(c.T("ldap:Ldap server exist")) c.ResponseError(c.T("ldap:Ldap server exist"))
return return
} }
affected := object.AddLdap(&ldap) resp := wrapActionResponse(object.AddLdap(&ldap))
resp := wrapActionResponse(affected)
resp.Data2 = ldap resp.Data2 = ldap
if ldap.AutoSync != 0 { if ldap.AutoSync != 0 {
object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id) err = object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
c.Data["json"] = resp c.Data["json"] = resp
@ -157,11 +171,24 @@ func (c *ApiController) UpdateLdap() {
return return
} }
prevLdap := object.GetLdap(ldap.Id) prevLdap, err := object.GetLdap(ldap.Id)
affected := object.UpdateLdap(&ldap) if err != nil {
c.ResponseError(err.Error())
return
}
affected, err := object.UpdateLdap(&ldap)
if err != nil {
c.ResponseError(err.Error())
return
}
if ldap.AutoSync != 0 { if ldap.AutoSync != 0 {
object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id) err := object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id)
if err != nil {
c.ResponseError(err.Error())
return
}
} else if ldap.AutoSync == 0 && prevLdap.AutoSync != 0 { } else if ldap.AutoSync == 0 && prevLdap.AutoSync != 0 {
object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id) object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id)
} }
@ -182,7 +209,11 @@ func (c *ApiController) DeleteLdap() {
return return
} }
affected := object.DeleteLdap(&ldap) affected, err := object.DeleteLdap(&ldap)
if err != nil {
c.ResponseError(err.Error())
return
}
object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id) object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id)
@ -204,7 +235,11 @@ func (c *ApiController) SyncLdapUsers() {
return return
} }
object.UpdateLdapSyncTime(ldapId) err = object.UpdateLdapSyncTime(ldapId)
if err != nil {
c.ResponseError(err.Error())
return
}
exist, failed, _ := object.SyncLdapUsers(owner, users, ldapId) exist, failed, _ := object.SyncLdapUsers(owner, users, ldapId)

View File

@ -53,7 +53,12 @@ func (c *ApiController) Unlink() {
if user.Id == unlinkedUser.Id && !user.IsGlobalAdmin { if user.Id == unlinkedUser.Id && !user.IsGlobalAdmin {
// if the user is unlinking themselves, should check the provider can be unlinked, if not, we should return an error. // if the user is unlinking themselves, should check the provider can be unlinked, if not, we should return an error.
application := object.GetApplicationByUser(user) application, err := object.GetApplicationByUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil { if application == nil {
c.ResponseError(c.T("link:You can't unlink yourself, you are not a member of any application")) c.ResponseError(c.T("link:You can't unlink yourself, you are not a member of any application"))
return return
@ -88,8 +93,17 @@ func (c *ApiController) Unlink() {
return return
} }
object.ClearUserOAuthProperties(&unlinkedUser, providerType) _, err = object.ClearUserOAuthProperties(&unlinkedUser, providerType)
if err != nil {
c.ResponseError(err.Error())
return
}
_, err = object.LinkUserAccount(&unlinkedUser, providerType, "")
if err != nil {
c.ResponseError(err.Error())
return
}
object.LinkUserAccount(&unlinkedUser, providerType, "")
c.ResponseOk() c.ResponseOk()
} }

View File

@ -44,18 +44,35 @@ func (c *ApiController) GetMessages() {
organization := c.Input().Get("organization") organization := c.Input().Get("organization")
if limit == "" || page == "" { if limit == "" || page == "" {
var messages []*object.Message var messages []*object.Message
var err error
if chat == "" { if chat == "" {
messages = object.GetMessages(owner) messages, err = object.GetMessages(owner)
} else { } else {
messages = object.GetChatMessages(chat) messages, err = object.GetChatMessages(chat)
}
if err != nil {
panic(err)
} }
c.Data["json"] = object.GetMaskedMessages(messages) c.Data["json"] = object.GetMaskedMessages(messages)
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetMessageCount(owner, organization, field, value))) count, err := object.GetMessageCount(owner, organization, field, value)
messages := object.GetMaskedMessages(object.GetPaginationMessages(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
paginationMessages, err := object.GetPaginationMessages(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
messages := object.GetMaskedMessages(paginationMessages)
c.ResponseOk(messages, paginator.Nums()) c.ResponseOk(messages, paginator.Nums())
} }
} }
@ -69,8 +86,12 @@ func (c *ApiController) GetMessages() {
// @router /get-message [get] // @router /get-message [get]
func (c *ApiController) GetMessage() { func (c *ApiController) GetMessage() {
id := c.Input().Get("id") id := c.Input().Get("id")
message, err := object.GetMessage(id)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedMessage(object.GetMessage(id)) c.Data["json"] = object.GetMaskedMessage(message)
c.ServeJSON() c.ServeJSON()
} }
@ -96,7 +117,12 @@ func (c *ApiController) GetMessageAnswer() {
c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
message := object.GetMessage(id) message, err := object.GetMessage(id)
if err != nil {
c.ResponseError(err.Error())
return
}
if message == nil { if message == nil {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id))
return return
@ -108,7 +134,12 @@ func (c *ApiController) GetMessageAnswer() {
} }
chatId := util.GetId("admin", message.Chat) chatId := util.GetId("admin", message.Chat)
chat := object.GetChat(chatId) chat, err := object.GetChat(chatId)
if err != nil {
c.ResponseError(err.Error())
return
}
if chat == nil || chat.Organization != message.Organization { if chat == nil || chat.Organization != message.Organization {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId))
return return
@ -119,14 +150,19 @@ func (c *ApiController) GetMessageAnswer() {
return return
} }
questionMessage := object.GetMessage(message.ReplyTo) questionMessage, err := object.GetMessage(message.ReplyTo)
if questionMessage == nil { if questionMessage == nil {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id)) c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The message: %s is not found"), id))
return return
} }
providerId := util.GetId(chat.Owner, chat.User2) providerId := util.GetId(chat.Owner, chat.User2)
provider := object.GetProvider(providerId) provider, err := object.GetProvider(providerId)
if err != nil {
c.ResponseError(err.Error())
return
}
if provider == nil { if provider == nil {
c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId)) c.ResponseErrorStream(fmt.Sprintf(c.T("chat:The provider: %s is not found"), providerId))
return return
@ -148,7 +184,7 @@ func (c *ApiController) GetMessageAnswer() {
fmt.Printf("Question: [%s]\n", questionMessage.Text) fmt.Printf("Question: [%s]\n", questionMessage.Text)
fmt.Printf("Answer: [") fmt.Printf("Answer: [")
err := ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder) err = ai.QueryAnswerStream(authToken, question, c.Ctx.ResponseWriter, &stringBuilder)
if err != nil { if err != nil {
c.ResponseErrorStream(err.Error()) c.ResponseErrorStream(err.Error())
return return
@ -165,7 +201,10 @@ func (c *ApiController) GetMessageAnswer() {
answer := stringBuilder.String() answer := stringBuilder.String()
message.Text = answer message.Text = answer
object.UpdateMessage(message.GetId(), message) _, err = object.UpdateMessage(message.GetId(), message)
if err != nil {
panic(err)
}
} }
// UpdateMessage // UpdateMessage
@ -208,14 +247,24 @@ func (c *ApiController) AddMessage() {
var chat *object.Chat var chat *object.Chat
if message.Chat != "" { if message.Chat != "" {
chatId := util.GetId("admin", message.Chat) chatId := util.GetId("admin", message.Chat)
chat = object.GetChat(chatId) chat, err = object.GetChat(chatId)
if err != nil {
c.ResponseError(err.Error())
return
}
if chat == nil || chat.Organization != message.Organization { if chat == nil || chat.Organization != message.Organization {
c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId)) c.ResponseError(fmt.Sprintf(c.T("chat:The chat: %s is not found"), chatId))
return return
} }
} }
affected := object.AddMessage(&message) affected, err := object.AddMessage(&message)
if err != nil {
c.ResponseError(err.Error())
return
}
if affected { if affected {
if chat != nil && chat.Type == "AI" { if chat != nil && chat.Type == "AI" {
answerMessage := &object.Message{ answerMessage := &object.Message{
@ -228,7 +277,11 @@ func (c *ApiController) AddMessage() {
Author: "AI", Author: "AI",
Text: "", Text: "",
} }
object.AddMessage(answerMessage) _, err = object.AddMessage(answerMessage)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
} }

View File

@ -46,7 +46,12 @@ func (c *ApiController) MfaSetupInitiate() {
if MfaUtil == nil { if MfaUtil == nil {
c.ResponseError("Invalid auth type") c.ResponseError("Invalid auth type")
} }
user := object.GetUser(userId) user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError("User doesn't exist") c.ResponseError("User doesn't exist")
return return
@ -105,14 +110,19 @@ func (c *ApiController) MfaSetupEnable() {
name := c.Ctx.Request.Form.Get("name") name := c.Ctx.Request.Form.Get("name")
authType := c.Ctx.Request.Form.Get("type") authType := c.Ctx.Request.Form.Get("type")
user := object.GetUser(util.GetId(owner, name)) user, err := object.GetUser(util.GetId(owner, name))
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError("User doesn't exist") c.ResponseError("User doesn't exist")
return return
} }
twoFactor := object.GetMfaUtil(authType, nil) twoFactor := object.GetMfaUtil(authType, nil)
err := twoFactor.Enable(c.Ctx, user) err = twoFactor.Enable(c.Ctx, user)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
@ -136,7 +146,12 @@ func (c *ApiController) DeleteMfa() {
name := c.Ctx.Request.Form.Get("name") name := c.Ctx.Request.Form.Get("name")
userId := util.GetId(owner, name) userId := util.GetId(owner, name)
user := object.GetUser(userId) user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError("User doesn't exist") c.ResponseError("User doesn't exist")
return return
@ -151,7 +166,12 @@ func (c *ApiController) DeleteMfa() {
} }
} }
user.MultiFactorAuths = mfaProps user.MultiFactorAuths = mfaProps
object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser()) _, err = object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser())
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(user.MultiFactorAuths) c.ResponseOk(user.MultiFactorAuths)
} }
@ -170,7 +190,12 @@ func (c *ApiController) SetPreferredMfa() {
name := c.Ctx.Request.Form.Get("name") name := c.Ctx.Request.Form.Get("name")
userId := util.GetId(owner, name) userId := util.GetId(owner, name)
user := object.GetUser(userId) user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError("User doesn't exist") c.ResponseError("User doesn't exist")
return return
@ -185,7 +210,11 @@ func (c *ApiController) SetPreferredMfa() {
} }
} }
object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser()) _, err = object.UpdateUser(userId, user, []string{"multi_factor_auths"}, user.IsAdminUser())
if err != nil {
c.ResponseError(err.Error())
return
}
for i, mfaProp := range mfaProps { for i, mfaProp := range mfaProps {
mfaProps[i] = object.GetMaskedProps(mfaProp) mfaProps[i] = object.GetMaskedProps(mfaProp)

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetModels() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetModels(owner) models, err := object.GetModels(owner)
if err != nil {
panic(err)
}
c.Data["json"] = models
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetModelCount(owner, field, value))) count, err := object.GetModelCount(owner, field, value)
models := object.GetPaginationModels(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
models, err := object.GetPaginationModels(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(models, paginator.Nums()) c.ResponseOk(models, paginator.Nums())
} }
} }
@ -58,7 +75,12 @@ func (c *ApiController) GetModels() {
func (c *ApiController) GetModel() { func (c *ApiController) GetModel() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetModel(id) model, err := object.GetModel(id)
if err != nil {
panic(err)
}
c.Data["json"] = model
c.ServeJSON() c.ServeJSON()
} }

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetOrganizations() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedOrganizations(object.GetOrganizations(owner)) maskedOrganizations, err := object.GetMaskedOrganizations(object.GetOrganizations(owner))
if err != nil {
panic(err)
}
c.Data["json"] = maskedOrganizations
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetOrganizationCount(owner, field, value))) count, err := object.GetOrganizationCount(owner, field, value)
organizations := object.GetMaskedOrganizations(object.GetPaginationOrganizations(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
organizations, err := object.GetMaskedOrganizations(object.GetPaginationOrganizations(owner, paginator.Offset(), limit, field, value, sortField, sortOrder))
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(organizations, paginator.Nums()) c.ResponseOk(organizations, paginator.Nums())
} }
} }
@ -58,7 +75,12 @@ func (c *ApiController) GetOrganizations() {
func (c *ApiController) GetOrganization() { func (c *ApiController) GetOrganization() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetMaskedOrganization(object.GetOrganization(id)) maskedOrganization, err := object.GetMaskedOrganization(object.GetOrganization(id))
if err != nil {
panic(err)
}
c.Data["json"] = maskedOrganization
c.ServeJSON() c.ServeJSON()
} }
@ -99,8 +121,13 @@ func (c *ApiController) AddOrganization() {
return return
} }
count := object.GetOrganizationCount("", "", "") count, err := object.GetOrganizationCount("", "", "")
if err := checkQuotaForOrganization(count); err != nil { if err != nil {
c.ResponseError(err.Error())
return
}
if err = checkQuotaForOrganization(int(count)); err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
} }
@ -158,6 +185,11 @@ func (c *ApiController) GetDefaultApplication() {
// @router /get-organization-names [get] // @router /get-organization-names [get]
func (c *ApiController) GetOrganizationNames() { func (c *ApiController) GetOrganizationNames() {
owner := c.Input().Get("owner") owner := c.Input().Get("owner")
organizationNames := object.GetOrganizationsByFields(owner, "name") organizationNames, err := object.GetOrganizationsByFields(owner, "name")
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(organizationNames) c.ResponseOk(organizationNames)
} }

View File

@ -37,13 +37,28 @@ func (c *ApiController) GetPayments() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetPayments(owner) payments, err := object.GetPayments(owner)
if err != nil {
panic(err)
}
c.Data["json"] = payments
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPaymentCount(owner, field, value))) count, err := object.GetPaymentCount(owner, field, value)
payments := object.GetPaginationPayments(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
payments, err := object.GetPaginationPayments(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
panic(err)
}
c.ResponseOk(payments, paginator.Nums()) c.ResponseOk(payments, paginator.Nums())
} }
} }
@ -62,7 +77,12 @@ func (c *ApiController) GetUserPayments() {
organization := c.Input().Get("organization") organization := c.Input().Get("organization")
user := c.Input().Get("user") user := c.Input().Get("user")
payments := object.GetUserPayments(owner, organization, user) payments, err := object.GetUserPayments(owner, organization, user)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(payments) c.ResponseOk(payments)
} }
@ -76,7 +96,12 @@ func (c *ApiController) GetUserPayments() {
func (c *ApiController) GetPayment() { func (c *ApiController) GetPayment() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetPayment(id) payment, err := object.GetPayment(id)
if err != nil {
panic(err)
}
c.Data["json"] = payment
c.ServeJSON() c.ServeJSON()
} }
@ -177,7 +202,12 @@ func (c *ApiController) NotifyPayment() {
func (c *ApiController) InvoicePayment() { func (c *ApiController) InvoicePayment() {
id := c.Input().Get("id") id := c.Input().Get("id")
payment := object.GetPayment(id) payment, err := object.GetPayment(id)
if err != nil {
c.ResponseError(err.Error())
return
}
invoiceUrl, err := object.InvoicePayment(payment) invoiceUrl, err := object.InvoicePayment(payment)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())

View File

@ -37,13 +37,28 @@ func (c *ApiController) GetPermissions() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetPermissions(owner) permissions, err := object.GetPermissions(owner)
if err != nil {
panic(err)
}
c.Data["json"] = permissions
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPermissionCount(owner, field, value))) count, err := object.GetPermissionCount(owner, field, value)
permissions := object.GetPaginationPermissions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
panic(err)
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
permissions, err := object.GetPaginationPermissions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
panic(err)
}
c.ResponseOk(permissions, paginator.Nums()) c.ResponseOk(permissions, paginator.Nums())
} }
} }
@ -60,7 +75,12 @@ func (c *ApiController) GetPermissionsBySubmitter() {
return return
} }
permissions := object.GetPermissionsBySubmitter(user.Owner, user.Name) permissions, err := object.GetPermissionsBySubmitter(user.Owner, user.Name)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(permissions, len(permissions)) c.ResponseOk(permissions, len(permissions))
return return
} }
@ -74,7 +94,12 @@ func (c *ApiController) GetPermissionsBySubmitter() {
// @router /get-permissions-by-role [get] // @router /get-permissions-by-role [get]
func (c *ApiController) GetPermissionsByRole() { func (c *ApiController) GetPermissionsByRole() {
id := c.Input().Get("id") id := c.Input().Get("id")
permissions := object.GetPermissionsByRole(id) permissions, err := object.GetPermissionsByRole(id)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(permissions, len(permissions)) c.ResponseOk(permissions, len(permissions))
return return
} }
@ -89,7 +114,12 @@ func (c *ApiController) GetPermissionsByRole() {
func (c *ApiController) GetPermission() { func (c *ApiController) GetPermission() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetPermission(id) permission, err := object.GetPermission(id)
if err != nil {
panic(err)
}
c.Data["json"] = permission
c.ServeJSON() c.ServeJSON()
} }

View File

@ -41,7 +41,11 @@ func (c *ApiController) UploadPermissions() {
return return
} }
affected := object.UploadPermissions(owner, fileId) affected, err := object.UploadPermissions(owner, fileId)
if err != nil {
c.ResponseError(err.Error())
}
if affected { if affected {
c.ResponseOk() c.ResponseOk()
} else { } else {

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetPlans() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetPlans(owner) plans, err := object.GetPlans(owner)
if err != nil {
panic(err)
}
c.Data["json"] = plans
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPlanCount(owner, field, value))) count, err := object.GetPlanCount(owner, field, value)
plan := object.GetPaginatedPlans(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
plan, err := object.GetPaginatedPlans(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(plan, paginator.Nums()) c.ResponseOk(plan, paginator.Nums())
} }
} }
@ -60,10 +77,16 @@ func (c *ApiController) GetPlan() {
id := c.Input().Get("id") id := c.Input().Get("id")
includeOption := c.Input().Get("includeOption") == "true" includeOption := c.Input().Get("includeOption") == "true"
plan := object.GetPlan(id) plan, err := object.GetPlan(id)
if err != nil {
panic(err)
}
if includeOption { if includeOption {
options := object.GetPermissionsByRole(plan.Role) options, err := object.GetPermissionsByRole(plan.Role)
if err != nil {
panic(err)
}
for _, option := range options { for _, option := range options {
plan.Options = append(plan.Options, option.DisplayName) plan.Options = append(plan.Options, option.DisplayName)

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetPricings() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetPricings(owner) pricings, err := object.GetPricings(owner)
if err != nil {
panic(err)
}
c.Data["json"] = pricings
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetPricingCount(owner, field, value))) count, err := object.GetPricingCount(owner, field, value)
pricing := object.GetPaginatedPricings(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
pricing, err := object.GetPaginatedPricings(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(pricing, paginator.Nums()) c.ResponseOk(pricing, paginator.Nums())
} }
} }
@ -58,7 +75,10 @@ func (c *ApiController) GetPricings() {
func (c *ApiController) GetPricing() { func (c *ApiController) GetPricing() {
id := c.Input().Get("id") id := c.Input().Get("id")
pricing := object.GetPricing(id) pricing, err := object.GetPricing(id)
if err != nil {
panic(err)
}
c.Data["json"] = pricing c.Data["json"] = pricing
c.ServeJSON() c.ServeJSON()

View File

@ -38,13 +38,30 @@ func (c *ApiController) GetProducts() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetProducts(owner) products, err := object.GetProducts(owner)
if err != nil {
panic(err)
}
c.Data["json"] = products
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetProductCount(owner, field, value))) count, err := object.GetProductCount(owner, field, value)
products := object.GetPaginationProducts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
products, err := object.GetPaginationProducts(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(products, paginator.Nums()) c.ResponseOk(products, paginator.Nums())
} }
} }
@ -59,8 +76,15 @@ func (c *ApiController) GetProducts() {
func (c *ApiController) GetProduct() { func (c *ApiController) GetProduct() {
id := c.Input().Get("id") id := c.Input().Get("id")
product := object.GetProduct(id) product, err := object.GetProduct(id)
object.ExtendProductWithProviders(product) if err != nil {
panic(err)
}
err = object.ExtendProductWithProviders(product)
if err != nil {
panic(err)
}
c.Data["json"] = product c.Data["json"] = product
c.ServeJSON() c.ServeJSON()
@ -145,7 +169,12 @@ func (c *ApiController) BuyProduct() {
return return
} }
user := object.GetUser(userId) user, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), userId)) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), userId))
return return

View File

@ -44,12 +44,29 @@ func (c *ApiController) GetProviders() {
} }
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedProviders(object.GetProviders(owner), isMaskEnabled) providers, err := object.GetProviders(owner)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedProviders(providers, isMaskEnabled)
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetProviderCount(owner, field, value))) count, err := object.GetProviderCount(owner, field, value)
providers := object.GetMaskedProviders(object.GetPaginationProviders(owner, paginator.Offset(), limit, field, value, sortField, sortOrder), isMaskEnabled) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
paginationProviders, err := object.GetPaginationProviders(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
providers := object.GetMaskedProviders(paginationProviders, isMaskEnabled)
c.ResponseOk(providers, paginator.Nums()) c.ResponseOk(providers, paginator.Nums())
} }
} }
@ -74,12 +91,29 @@ func (c *ApiController) GetGlobalProviders() {
} }
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedProviders(object.GetGlobalProviders(), isMaskEnabled) globalProviders, err := object.GetGlobalProviders()
if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedProviders(globalProviders, isMaskEnabled)
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalProviderCount(field, value))) count, err := object.GetGlobalProviderCount(field, value)
providers := object.GetMaskedProviders(object.GetPaginationGlobalProviders(paginator.Offset(), limit, field, value, sortField, sortOrder), isMaskEnabled) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
paginationGlobalProviders, err := object.GetPaginationGlobalProviders(paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
providers := object.GetMaskedProviders(paginationGlobalProviders, isMaskEnabled)
c.ResponseOk(providers, paginator.Nums()) c.ResponseOk(providers, paginator.Nums())
} }
} }
@ -98,8 +132,13 @@ func (c *ApiController) GetProvider() {
if !ok { if !ok {
return return
} }
provider, err := object.GetProvider(id)
if err != nil {
c.ResponseError(err.Error())
return
}
c.Data["json"] = object.GetMaskedProvider(object.GetProvider(id), isMaskEnabled) c.Data["json"] = object.GetMaskedProvider(provider, isMaskEnabled)
c.ServeJSON() c.ServeJSON()
} }
@ -140,8 +179,13 @@ func (c *ApiController) AddProvider() {
return return
} }
count := object.GetProviderCount("", "", "") count, err := object.GetProviderCount("", "", "")
if err := checkQuotaForProvider(count); err != nil { if err != nil {
c.ResponseError(err.Error())
return
}
if err := checkQuotaForProvider(int(count)); err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
} }

View File

@ -42,14 +42,31 @@ func (c *ApiController) GetRecords() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetRecords() records, err := object.GetRecords()
if err != nil {
panic(err)
}
c.Data["json"] = records
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
filterRecord := &object.Record{Organization: organization} filterRecord := &object.Record{Organization: organization}
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetRecordCount(field, value, filterRecord))) count, err := object.GetRecordCount(field, value, filterRecord)
records := object.GetPaginationRecords(paginator.Offset(), limit, field, value, sortField, sortOrder, filterRecord) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
records, err := object.GetPaginationRecords(paginator.Offset(), limit, field, value, sortField, sortOrder, filterRecord)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(records, paginator.Nums()) c.ResponseOk(records, paginator.Nums())
} }
} }
@ -67,11 +84,15 @@ func (c *ApiController) GetRecordsByFilter() {
record := &object.Record{} record := &object.Record{}
err := util.JsonToStruct(body, record) err := util.JsonToStruct(body, record)
if err != nil { if err != nil {
c.ResponseError(err.Error()) panic(err)
return
} }
c.Data["json"] = object.GetRecordsByField(record) records, err := object.GetRecordsByField(record)
if err != nil {
panic(err)
}
c.Data["json"] = records
c.ServeJSON() c.ServeJSON()
} }

View File

@ -51,12 +51,28 @@ func (c *ApiController) GetResources() {
} }
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetResources(owner, user) resources, err := object.GetResources(owner, user)
if err != nil {
panic(err)
}
c.Data["json"] = resources
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetResourceCount(owner, user, field, value))) count, err := object.GetResourceCount(owner, user, field, value)
resources := object.GetPaginationResources(owner, user, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
resources, err := object.GetPaginationResources(owner, user, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(resources, paginator.Nums()) c.ResponseOk(resources, paginator.Nums())
} }
} }
@ -68,7 +84,12 @@ func (c *ApiController) GetResources() {
func (c *ApiController) GetResource() { func (c *ApiController) GetResource() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetResource(id) resource, err := object.GetResource(id)
if err != nil {
panic(err)
}
c.Data["json"] = resource
c.ServeJSON() c.ServeJSON()
} }
@ -187,7 +208,10 @@ func (c *ApiController) UploadResource() {
index := len(fullFilePath) - len(ext) index := len(fullFilePath) - len(ext)
for i := 1; ; i++ { for i := 1; ; i++ {
_, objectKey := object.GetUploadFileUrl(provider, fullFilePath, true) _, objectKey := object.GetUploadFileUrl(provider, fullFilePath, true)
if object.GetResourceCount(owner, username, "name", objectKey) == 0 { if count, err := object.GetResourceCount(owner, username, "name", objectKey); err != nil {
c.ResponseError(err.Error())
return
} else if count == 0 {
break break
} }
@ -223,20 +247,39 @@ func (c *ApiController) UploadResource() {
Url: fileUrl, Url: fileUrl,
Description: description, Description: description,
} }
object.AddOrUpdateResource(resource) _, err = object.AddOrUpdateResource(resource)
if err != nil {
c.ResponseError(err.Error())
return
}
switch tag { switch tag {
case "avatar": case "avatar":
user := object.GetUserNoCheck(util.GetId(owner, username)) user, err := object.GetUserNoCheck(util.GetId(owner, username))
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError(c.T("resource:User is nil for tag: avatar")) c.ResponseError(c.T("resource:User is nil for tag: avatar"))
return return
} }
user.Avatar = fileUrl user.Avatar = fileUrl
object.UpdateUser(user.GetId(), user, []string{"avatar"}, false) _, err = object.UpdateUser(user.GetId(), user, []string{"avatar"}, false)
if err != nil {
c.ResponseError(err.Error())
return
}
case "termsOfUse": case "termsOfUse":
user := object.GetUserNoCheck(util.GetId(owner, username)) user, err := object.GetUserNoCheck(util.GetId(owner, username))
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(owner, username))) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(owner, username)))
return return
@ -248,9 +291,18 @@ func (c *ApiController) UploadResource() {
} }
_, applicationId := util.GetOwnerAndNameFromIdNoCheck(strings.TrimRight(fullFilePath, ".html")) _, applicationId := util.GetOwnerAndNameFromIdNoCheck(strings.TrimRight(fullFilePath, ".html"))
applicationObj := object.GetApplication(applicationId) applicationObj, err := object.GetApplication(applicationId)
if err != nil {
c.ResponseError(err.Error())
return
}
applicationObj.TermsOfUse = fileUrl applicationObj.TermsOfUse = fileUrl
object.UpdateApplication(applicationId, applicationObj) _, err = object.UpdateApplication(applicationId, applicationObj)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
c.ResponseOk(fileUrl, objectKey) c.ResponseOk(fileUrl, objectKey)

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetRoles() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetRoles(owner) roles, err := object.GetRoles(owner)
if err != nil {
panic(err)
}
c.Data["json"] = roles
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetRoleCount(owner, field, value))) count, err := object.GetRoleCount(owner, field, value)
roles := object.GetPaginationRoles(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
roles, err := object.GetPaginationRoles(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(roles, paginator.Nums()) c.ResponseOk(roles, paginator.Nums())
} }
} }
@ -58,7 +75,12 @@ func (c *ApiController) GetRoles() {
func (c *ApiController) GetRole() { func (c *ApiController) GetRole() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetRole(id) role, err := object.GetRole(id)
if err != nil {
panic(err)
}
c.Data["json"] = role
c.ServeJSON() c.ServeJSON()
} }

View File

@ -41,7 +41,11 @@ func (c *ApiController) UploadRoles() {
return return
} }
affected := object.UploadRoles(owner, fileId) affected, err := object.UploadRoles(owner, fileId)
if err != nil {
c.ResponseError(err.Error())
}
if affected { if affected {
c.ResponseOk() c.ResponseOk()
} else { } else {

View File

@ -23,7 +23,12 @@ import (
func (c *ApiController) GetSamlMeta() { func (c *ApiController) GetSamlMeta() {
host := c.Ctx.Request.Host host := c.Ctx.Request.Host
paramApp := c.Input().Get("application") paramApp := c.Input().Get("application")
application := object.GetApplication(paramApp) application, err := object.GetApplication(paramApp)
if err != nil {
c.ResponseError(err.Error())
return
}
if application == nil { if application == nil {
c.ResponseError(fmt.Sprintf(c.T("saml:Application %s not found"), paramApp)) c.ResponseError(fmt.Sprintf(c.T("saml:Application %s not found"), paramApp))
return return

View File

@ -61,7 +61,12 @@ func (c *ApiController) SendEmail() {
var provider *object.Provider var provider *object.Provider
if emailForm.Provider != "" { if emailForm.Provider != "" {
// called by frontend's TestEmailWidget, provider name is set by frontend // called by frontend's TestEmailWidget, provider name is set by frontend
provider = object.GetProvider(util.GetId("admin", emailForm.Provider)) provider, err = object.GetProvider(util.GetId("admin", emailForm.Provider))
if err != nil {
c.ResponseError(err.Error())
return
}
} else { } else {
// called by Casdoor SDK via Client ID & Client Secret, so the used Email provider will be the application' Email provider or the default Email provider // called by Casdoor SDK via Client ID & Client Secret, so the used Email provider will be the application' Email provider or the default Email provider
var ok bool var ok bool

View File

@ -37,13 +37,29 @@ func (c *ApiController) GetSessions() {
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
owner := c.Input().Get("owner") owner := c.Input().Get("owner")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetSessions(owner) sessions, err := object.GetSessions(owner)
if err != nil {
panic(err)
}
c.Data["json"] = sessions
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSessionCount(owner, field, value))) count, err := object.GetSessionCount(owner, field, value)
sessions := object.GetPaginationSessions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
sessions, err := object.GetPaginationSessions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(sessions, paginator.Nums()) c.ResponseOk(sessions, paginator.Nums())
} }
} }
@ -58,7 +74,12 @@ func (c *ApiController) GetSessions() {
func (c *ApiController) GetSingleSession() { func (c *ApiController) GetSingleSession() {
id := c.Input().Get("sessionPkId") id := c.Input().Get("sessionPkId")
c.Data["json"] = object.GetSingleSession(id) session, err := object.GetSingleSession(id)
if err != nil {
panic(err)
}
c.Data["json"] = session
c.ServeJSON() c.ServeJSON()
} }
@ -132,7 +153,11 @@ func (c *ApiController) IsSessionDuplicated() {
id := c.Input().Get("sessionPkId") id := c.Input().Get("sessionPkId")
sessionId := c.Input().Get("sessionId") sessionId := c.Input().Get("sessionId")
isUserSessionDuplicated := object.IsSessionDuplicated(id, sessionId) isUserSessionDuplicated, err := object.IsSessionDuplicated(id, sessionId)
if err != nil {
panic(err)
}
c.Data["json"] = &Response{Status: "ok", Msg: "", Data: isUserSessionDuplicated} c.Data["json"] = &Response{Status: "ok", Msg: "", Data: isUserSessionDuplicated}
c.ServeJSON() c.ServeJSON()

View File

@ -37,13 +37,30 @@ func (c *ApiController) GetSubscriptions() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetSubscriptions(owner) subscriptions, err := object.GetSubscriptions(owner)
if err != nil {
panic(err)
}
c.Data["json"] = subscriptions
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSubscriptionCount(owner, field, value))) count, err := object.GetSubscriptionCount(owner, field, value)
subscription := object.GetPaginationSubscriptions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
subscription, err := object.GetPaginationSubscriptions(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(subscription, paginator.Nums()) c.ResponseOk(subscription, paginator.Nums())
} }
} }
@ -58,7 +75,10 @@ func (c *ApiController) GetSubscriptions() {
func (c *ApiController) GetSubscription() { func (c *ApiController) GetSubscription() {
id := c.Input().Get("id") id := c.Input().Get("id")
subscription := object.GetSubscription(id) subscription, err := object.GetSubscription(id)
if err != nil {
panic(err)
}
c.Data["json"] = subscription c.Data["json"] = subscription
c.ServeJSON() c.ServeJSON()

View File

@ -38,13 +38,30 @@ func (c *ApiController) GetSyncers() {
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization") organization := c.Input().Get("organization")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetOrganizationSyncers(owner, organization) organizationSyncers, err := object.GetOrganizationSyncers(owner, organization)
if err != nil {
panic(err)
}
c.Data["json"] = organizationSyncers
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetSyncerCount(owner, organization, field, value))) count, err := object.GetSyncerCount(owner, organization, field, value)
syncers := object.GetPaginationSyncers(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
syncers, err := object.GetPaginationSyncers(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(syncers, paginator.Nums()) c.ResponseOk(syncers, paginator.Nums())
} }
} }
@ -59,7 +76,12 @@ func (c *ApiController) GetSyncers() {
func (c *ApiController) GetSyncer() { func (c *ApiController) GetSyncer() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetSyncer(id) syncer, err := object.GetSyncer(id)
if err != nil {
panic(err)
}
c.Data["json"] = syncer
c.ServeJSON() c.ServeJSON()
} }
@ -132,7 +154,11 @@ func (c *ApiController) DeleteSyncer() {
// @router /run-syncer [get] // @router /run-syncer [get]
func (c *ApiController) RunSyncer() { func (c *ApiController) RunSyncer() {
id := c.Input().Get("id") id := c.Input().Get("id")
syncer := object.GetSyncer(id) syncer, err := object.GetSyncer(id)
if err != nil {
c.ResponseError(err.Error())
return
}
object.RunSyncer(syncer) object.RunSyncer(syncer)

View File

@ -41,12 +41,28 @@ func (c *ApiController) GetTokens() {
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization") organization := c.Input().Get("organization")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetTokens(owner, organization) token, err := object.GetTokens(owner, organization)
if err != nil {
panic(err)
}
c.Data["json"] = token
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetTokenCount(owner, organization, field, value))) count, err := object.GetTokenCount(owner, organization, field, value)
tokens := object.GetPaginationTokens(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
tokens, err := object.GetPaginationTokens(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(tokens, paginator.Nums()) c.ResponseOk(tokens, paginator.Nums())
} }
} }
@ -60,8 +76,12 @@ func (c *ApiController) GetTokens() {
// @router /get-token [get] // @router /get-token [get]
func (c *ApiController) GetToken() { func (c *ApiController) GetToken() {
id := c.Input().Get("id") id := c.Input().Get("id")
token, err := object.GetToken(id)
if err != nil {
panic(err)
}
c.Data["json"] = object.GetToken(id) c.Data["json"] = token
c.ServeJSON() c.ServeJSON()
} }
@ -171,8 +191,12 @@ func (c *ApiController) GetOAuthToken() {
} }
} }
host := c.Ctx.Request.Host host := c.Ctx.Request.Host
oAuthtoken, err := object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier, scope, username, password, host, refreshToken, tag, avatar, c.GetAcceptLanguage())
if err != nil {
panic(err)
}
c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier, scope, username, password, host, refreshToken, tag, avatar, c.GetAcceptLanguage()) c.Data["json"] = oAuthtoken
c.SetTokenErrorHttpStatus() c.SetTokenErrorHttpStatus()
c.ServeJSON() c.ServeJSON()
} }
@ -210,7 +234,12 @@ func (c *ApiController) RefreshToken() {
} }
} }
c.Data["json"] = object.RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host) refreshToken2, err := object.RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host)
if err != nil {
panic(err)
}
c.Data["json"] = refreshToken2
c.SetTokenErrorHttpStatus() c.SetTokenErrorHttpStatus()
c.ServeJSON() c.ServeJSON()
} }
@ -245,7 +274,11 @@ func (c *ApiController) IntrospectToken() {
return return
} }
} }
application := object.GetApplicationByClientId(clientId) application, err := object.GetApplicationByClientId(clientId)
if err != nil {
panic(err)
}
if application == nil || application.ClientSecret != clientSecret { if application == nil || application.ClientSecret != clientSecret {
c.ResponseError(c.T("token:Invalid application or wrong clientSecret")) c.ResponseError(c.T("token:Invalid application or wrong clientSecret"))
c.Data["json"] = &object.TokenError{ c.Data["json"] = &object.TokenError{
@ -254,7 +287,11 @@ func (c *ApiController) IntrospectToken() {
c.SetTokenErrorHttpStatus() c.SetTokenErrorHttpStatus()
return return
} }
token := object.GetTokenByTokenAndApplication(tokenValue, application.Name) token, err := object.GetTokenByTokenAndApplication(tokenValue, application.Name)
if err != nil {
panic(err)
}
if token == nil { if token == nil {
c.Data["json"] = &object.IntrospectionResponse{Active: false} c.Data["json"] = &object.IntrospectionResponse{Active: false}
c.ServeJSON() c.ServeJSON()

View File

@ -37,14 +37,36 @@ func (c *ApiController) GetGlobalUsers() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedUsers(object.GetGlobalUsers()) maskedUsers, err := object.GetMaskedUsers(object.GetGlobalUsers())
if err != nil {
panic(err)
}
c.Data["json"] = maskedUsers
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetGlobalUserCount(field, value))) count, err := object.GetGlobalUserCount(field, value)
users := object.GetPaginationGlobalUsers(paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
users = object.GetMaskedUsers(users) c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
users, err := object.GetPaginationGlobalUsers(paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
users, err = object.GetMaskedUsers(users)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(users, paginator.Nums()) c.ResponseOk(users, paginator.Nums())
} }
} }
@ -64,14 +86,36 @@ func (c *ApiController) GetUsers() {
value := c.Input().Get("value") value := c.Input().Get("value")
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetMaskedUsers(object.GetUsers(owner)) maskedUsers, err := object.GetMaskedUsers(object.GetUsers(owner))
if err != nil {
panic(err)
}
c.Data["json"] = maskedUsers
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetUserCount(owner, field, value))) count, err := object.GetUserCount(owner, field, value)
users := object.GetPaginationUsers(owner, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
users = object.GetMaskedUsers(users) c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
users, err := object.GetPaginationUsers(owner, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
users, err = object.GetMaskedUsers(users)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(users, paginator.Nums()) c.ResponseOk(users, paginator.Nums())
} }
} }
@ -93,10 +137,14 @@ func (c *ApiController) GetUser() {
phone := c.Input().Get("phone") phone := c.Input().Get("phone")
userId := c.Input().Get("userId") userId := c.Input().Get("userId")
owner := c.Input().Get("owner") owner := c.Input().Get("owner")
var err error
var userFromUserId *object.User var userFromUserId *object.User
if userId != "" && owner != "" { if userId != "" && owner != "" {
userFromUserId = object.GetUserByUserId(owner, userId) userFromUserId, err = object.GetUserByUserId(owner, userId)
if err != nil {
panic(err)
}
id = util.GetId(userFromUserId.Owner, userFromUserId.Name) id = util.GetId(userFromUserId.Owner, userFromUserId.Name)
} }
@ -104,7 +152,11 @@ func (c *ApiController) GetUser() {
owner = util.GetOwnerFromId(id) owner = util.GetOwnerFromId(id)
} }
organization := object.GetOrganization(util.GetId("admin", owner)) organization, err := object.GetOrganization(util.GetId("admin", owner))
if err != nil {
panic(err)
}
if !organization.IsProfilePublic { if !organization.IsProfilePublic {
requestUserId := c.GetSessionUsername() requestUserId := c.GetSessionUsername()
hasPermission, err := object.CheckUserPermission(requestUserId, id, false, c.GetAcceptLanguage()) hasPermission, err := object.CheckUserPermission(requestUserId, id, false, c.GetAcceptLanguage())
@ -117,18 +169,30 @@ func (c *ApiController) GetUser() {
var user *object.User var user *object.User
switch { switch {
case email != "": case email != "":
user = object.GetUserByEmail(owner, email) user, err = object.GetUserByEmail(owner, email)
case phone != "": case phone != "":
user = object.GetUserByPhone(owner, phone) user, err = object.GetUserByPhone(owner, phone)
case userId != "": case userId != "":
user = userFromUserId user = userFromUserId
default: default:
user = object.GetUser(id) user, err = object.GetUser(id)
} }
object.ExtendUserWithRolesAndPermissions(user) if err != nil {
panic(err)
}
c.Data["json"] = object.GetMaskedUser(user) err = object.ExtendUserWithRolesAndPermissions(user)
if err != nil {
panic(err)
}
maskedUser, err := object.GetMaskedUser(user)
if err != nil {
panic(err)
}
c.Data["json"] = maskedUser
c.ServeJSON() c.ServeJSON()
} }
@ -158,7 +222,12 @@ func (c *ApiController) UpdateUser() {
return return
} }
} }
oldUser := object.GetUser(id) oldUser, err := object.GetUser(id)
if err != nil {
c.ResponseError(err.Error())
return
}
if oldUser == nil { if oldUser == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), id)) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), id))
return return
@ -185,9 +254,18 @@ func (c *ApiController) UpdateUser() {
columns = strings.Split(columnsStr, ",") columns = strings.Split(columnsStr, ",")
} }
affected := object.UpdateUser(id, &user, columns, isAdmin) affected, err := object.UpdateUser(id, &user, columns, isAdmin)
if err != nil {
c.ResponseError(err.Error())
return
}
if affected { if affected {
object.UpdateUserToOriginalDatabase(&user) err = object.UpdateUserToOriginalDatabase(&user)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
c.Data["json"] = wrapActionResponse(affected) c.Data["json"] = wrapActionResponse(affected)
@ -209,8 +287,13 @@ func (c *ApiController) AddUser() {
return return
} }
count := object.GetUserCount("", "", "") count, err := object.GetUserCount("", "", "")
if err := checkQuotaForUser(count); err != nil { if err != nil {
c.ResponseError(err.Error())
return
}
if err := checkQuotaForUser(int(count)); err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
} }
@ -261,7 +344,12 @@ func (c *ApiController) GetEmailAndPhone() {
organization := c.Ctx.Request.Form.Get("organization") organization := c.Ctx.Request.Form.Get("organization")
username := c.Ctx.Request.Form.Get("username") username := c.Ctx.Request.Form.Get("username")
user := object.GetUserByFields(organization, username) user, err := object.GetUserByFields(organization, username)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(organization, username))) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(organization, username)))
return return
@ -335,7 +423,11 @@ func (c *ApiController) SetPassword() {
c.SetSession("verifiedCode", "") c.SetSession("verifiedCode", "")
} }
targetUser := object.GetUser(userId) targetUser, err := object.GetUser(userId)
if err != nil {
c.ResponseError(err.Error())
return
}
if oldPassword != "" { if oldPassword != "" {
msg := object.CheckPassword(targetUser, oldPassword, c.GetAcceptLanguage()) msg := object.CheckPassword(targetUser, oldPassword, c.GetAcceptLanguage())
@ -346,7 +438,12 @@ func (c *ApiController) SetPassword() {
} }
targetUser.Password = newPassword targetUser.Password = newPassword
object.SetUserField(targetUser, "password", targetUser.Password) _, err = object.SetUserField(targetUser, "password", targetUser.Password)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk() c.ResponseOk()
} }
@ -384,7 +481,12 @@ func (c *ApiController) GetSortedUsers() {
sorter := c.Input().Get("sorter") sorter := c.Input().Get("sorter")
limit := util.ParseInt(c.Input().Get("limit")) limit := util.ParseInt(c.Input().Get("limit"))
c.Data["json"] = object.GetMaskedUsers(object.GetSortedUsers(owner, sorter, limit)) maskedUsers, err := object.GetMaskedUsers(object.GetSortedUsers(owner, sorter, limit))
if err != nil {
panic(err)
}
c.Data["json"] = maskedUsers
c.ServeJSON() c.ServeJSON()
} }
@ -400,11 +502,16 @@ func (c *ApiController) GetUserCount() {
owner := c.Input().Get("owner") owner := c.Input().Get("owner")
isOnline := c.Input().Get("isOnline") isOnline := c.Input().Get("isOnline")
count := 0 var count int64
var err error
if isOnline == "" { if isOnline == "" {
count = object.GetUserCount(owner, "", "") count, err = object.GetUserCount(owner, "", "")
} else { } else {
count = object.GetOnlineUserCount(owner, util.ParseInt(isOnline)) count, err = object.GetOnlineUserCount(owner, util.ParseInt(isOnline))
}
if err != nil {
c.ResponseError(err.Error())
return
} }
c.Data["json"] = count c.Data["json"] = count

View File

@ -57,7 +57,12 @@ func (c *ApiController) UploadUsers() {
return return
} }
affected := object.UploadUsers(owner, fileId) affected, err := object.UploadUsers(owner, fileId)
if err != nil {
c.ResponseError(err.Error())
return
}
if affected { if affected {
c.ResponseOk() c.ResponseOk()
} else { } else {

View File

@ -92,7 +92,11 @@ func (c *ApiController) RequireSignedInUser() (*object.User, bool) {
return nil, false return nil, false
} }
user := object.GetUser(userId) user, err := object.GetUser(userId)
if err != nil {
panic(err)
}
if user == nil { if user == nil {
c.ClearUserSession() c.ClearUserSession()
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), userId)) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), userId))
@ -138,7 +142,11 @@ func (c *ApiController) IsMaskedEnabled() (bool, bool) {
func (c *ApiController) GetProviderFromContext(category string) (*object.Provider, *object.User, bool) { func (c *ApiController) GetProviderFromContext(category string) (*object.Provider, *object.User, bool) {
providerName := c.Input().Get("provider") providerName := c.Input().Get("provider")
if providerName != "" { if providerName != "" {
provider := object.GetProvider(util.GetId("admin", providerName)) provider, err := object.GetProvider(util.GetId("admin", providerName))
if err != nil {
panic(err)
}
if provider == nil { if provider == nil {
c.ResponseError(fmt.Sprintf(c.T("util:The provider: %s is not found"), providerName)) c.ResponseError(fmt.Sprintf(c.T("util:The provider: %s is not found"), providerName))
return nil, nil, false return nil, nil, false
@ -151,13 +159,21 @@ func (c *ApiController) GetProviderFromContext(category string) (*object.Provide
return nil, nil, false return nil, nil, false
} }
application, user := object.GetApplicationByUserId(userId) application, user, err := object.GetApplicationByUserId(userId)
if err != nil {
panic(err)
}
if application == nil { if application == nil {
c.ResponseError(fmt.Sprintf(c.T("util:No application is found for userId: %s"), userId)) c.ResponseError(fmt.Sprintf(c.T("util:No application is found for userId: %s"), userId))
return nil, nil, false return nil, nil, false
} }
provider := application.GetProviderByCategory(category) provider, err := application.GetProviderByCategory(category)
if err != nil {
panic(err)
}
if provider == nil { if provider == nil {
c.ResponseError(fmt.Sprintf(c.T("util:No provider for category: %s is found for application: %s"), category, application.Name)) c.ResponseError(fmt.Sprintf(c.T("util:No provider for category: %s is found for application: %s"), category, application.Name))
return nil, nil, false return nil, nil, false

View File

@ -66,8 +66,17 @@ func (c *ApiController) SendVerificationCode() {
} }
} }
application := object.GetApplication(vform.ApplicationId) application, err := object.GetApplication(vform.ApplicationId)
organization := object.GetOrganization(util.GetId(application.Owner, application.Organization)) if err != nil {
c.ResponseError(err.Error())
return
}
organization, err := object.GetOrganization(util.GetId(application.Owner, application.Organization))
if err != nil {
c.ResponseError(c.T(err.Error()))
}
if organization == nil { if organization == nil {
c.ResponseError(c.T("check:Organization does not exist")) c.ResponseError(c.T("check:Organization does not exist"))
return return
@ -77,12 +86,20 @@ func (c *ApiController) SendVerificationCode() {
// checkUser != "", means method is ForgetVerification // checkUser != "", means method is ForgetVerification
if vform.CheckUser != "" { if vform.CheckUser != "" {
owner := application.Organization owner := application.Organization
user = object.GetUser(util.GetId(owner, vform.CheckUser)) user, err = object.GetUser(util.GetId(owner, vform.CheckUser))
if err != nil {
c.ResponseError(err.Error())
return
}
} }
// mfaSessionData != nil, means method is MfaSetupVerification // mfaSessionData != nil, means method is MfaSetupVerification
if mfaSessionData := c.getMfaSessionData(); mfaSessionData != nil { if mfaSessionData := c.getMfaSessionData(); mfaSessionData != nil {
user = object.GetUser(mfaSessionData.UserId) user, err = object.GetUser(mfaSessionData.UserId)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
sendResp := errors.New("invalid dest type") sendResp := errors.New("invalid dest type")
@ -99,7 +116,12 @@ func (c *ApiController) SendVerificationCode() {
vform.Dest = user.Email vform.Dest = user.Email
} }
user = object.GetUserByEmail(organization.Name, vform.Dest) user, err = object.GetUserByEmail(organization.Name, vform.Dest)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError(c.T("verification:the user does not exist, please sign up first")) c.ResponseError(c.T("verification:the user does not exist, please sign up first"))
return return
@ -113,7 +135,12 @@ func (c *ApiController) SendVerificationCode() {
} }
} }
provider := application.GetEmailProvider() provider, err := application.GetEmailProvider()
if err != nil {
c.ResponseError(err.Error())
return
}
sendResp = object.SendVerificationCodeToEmail(organization, user, provider, remoteAddr, vform.Dest) sendResp = object.SendVerificationCodeToEmail(organization, user, provider, remoteAddr, vform.Dest)
case object.VerifyTypePhone: case object.VerifyTypePhone:
if vform.Method == LoginVerification || vform.Method == ForgetVerification { if vform.Method == LoginVerification || vform.Method == ForgetVerification {
@ -121,7 +148,10 @@ func (c *ApiController) SendVerificationCode() {
vform.Dest = user.Phone vform.Dest = user.Phone
} }
if user = object.GetUserByPhone(organization.Name, vform.Dest); user == nil { if user, err = object.GetUserByPhone(organization.Name, vform.Dest); err != nil {
c.ResponseError(err.Error())
return
} else if user == nil {
c.ResponseError(c.T("verification:the user does not exist, please sign up first")) c.ResponseError(c.T("verification:the user does not exist, please sign up first"))
return return
} }
@ -140,7 +170,12 @@ func (c *ApiController) SendVerificationCode() {
vform.CountryCode = mfaProps.CountryCode vform.CountryCode = mfaProps.CountryCode
} }
provider := application.GetSmsProvider() provider, err := application.GetSmsProvider()
if err != nil {
c.ResponseError(err.Error())
return
}
if phone, ok := util.GetE164Number(vform.Dest, vform.CountryCode); !ok { if phone, ok := util.GetE164Number(vform.Dest, vform.CountryCode); !ok {
c.ResponseError(fmt.Sprintf(c.T("verification:Phone number is invalid in your region %s"), vform.CountryCode)) c.ResponseError(fmt.Sprintf(c.T("verification:Phone number is invalid in your region %s"), vform.CountryCode))
return return
@ -213,7 +248,12 @@ func (c *ApiController) ResetEmailOrPhone() {
} }
checkDest := dest checkDest := dest
organization := object.GetOrganizationByUser(user) organization, err := object.GetOrganizationByUser(user)
if err != nil {
c.ResponseError(c.T(err.Error()))
return
}
if destType == object.VerifyTypePhone { if destType == object.VerifyTypePhone {
if object.HasUserByField(user.Owner, "phone", dest) { if object.HasUserByField(user.Owner, "phone", dest) {
c.ResponseError(c.T("check:Phone already exists")) c.ResponseError(c.T("check:Phone already exists"))
@ -260,16 +300,25 @@ func (c *ApiController) ResetEmailOrPhone() {
switch destType { switch destType {
case object.VerifyTypeEmail: case object.VerifyTypeEmail:
user.Email = dest user.Email = dest
object.SetUserField(user, "email", user.Email) _, err = object.SetUserField(user, "email", user.Email)
case object.VerifyTypePhone: case object.VerifyTypePhone:
user.Phone = dest user.Phone = dest
object.SetUserField(user, "phone", user.Phone) _, err = object.SetUserField(user, "phone", user.Phone)
default: default:
c.ResponseError(c.T("verification:Unknown type")) c.ResponseError(c.T("verification:Unknown type"))
return return
} }
if err != nil {
c.ResponseError(err.Error())
return
}
err = object.DisableVerificationCode(checkDest)
if err != nil {
c.ResponseError(err.Error())
return
}
object.DisableVerificationCode(checkDest)
c.ResponseOk() c.ResponseOk()
} }
@ -287,7 +336,11 @@ func (c *ApiController) VerifyCode() {
var user *object.User var user *object.User
if authForm.Name != "" { if authForm.Name != "" {
user = object.GetUserByFields(authForm.Organization, authForm.Name) user, err = object.GetUserByFields(authForm.Organization, authForm.Name)
if err != nil {
c.ResponseError(err.Error())
return
}
} }
var checkDest string var checkDest string
@ -302,7 +355,10 @@ func (c *ApiController) VerifyCode() {
} }
} }
if user = object.GetUserByFields(authForm.Organization, authForm.Username); user == nil { if user, err = object.GetUserByFields(authForm.Organization, authForm.Username); err != nil {
c.ResponseError(err.Error())
return
} else if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(authForm.Organization, authForm.Username))) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(authForm.Organization, authForm.Username)))
return return
} }
@ -321,7 +377,11 @@ func (c *ApiController) VerifyCode() {
c.ResponseError(result.Msg) c.ResponseError(result.Msg)
return return
} }
object.DisableVerificationCode(checkDest) err = object.DisableVerificationCode(checkDest)
if err != nil {
c.ResponseError(err.Error())
return
}
c.SetSession("verifiedCode", authForm.Code) c.SetSession("verifiedCode", authForm.Code)
c.ResponseOk() c.ResponseOk()

View File

@ -33,7 +33,12 @@ import (
// @Success 200 {object} protocol.CredentialCreation The CredentialCreationOptions object // @Success 200 {object} protocol.CredentialCreation The CredentialCreationOptions object
// @router /webauthn/signup/begin [get] // @router /webauthn/signup/begin [get]
func (c *ApiController) WebAuthnSignupBegin() { func (c *ApiController) WebAuthnSignupBegin() {
webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host) webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}
user := c.getCurrentUser() user := c.getCurrentUser()
if user == nil { if user == nil {
c.ResponseError(c.T("general:Please login first")) c.ResponseError(c.T("general:Please login first"))
@ -64,7 +69,12 @@ func (c *ApiController) WebAuthnSignupBegin() {
// @Success 200 {object} Response "The Response object" // @Success 200 {object} Response "The Response object"
// @router /webauthn/signup/finish [post] // @router /webauthn/signup/finish [post]
func (c *ApiController) WebAuthnSignupFinish() { func (c *ApiController) WebAuthnSignupFinish() {
webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host) webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}
user := c.getCurrentUser() user := c.getCurrentUser()
if user == nil { if user == nil {
c.ResponseError(c.T("general:Please login first")) c.ResponseError(c.T("general:Please login first"))
@ -84,7 +94,12 @@ func (c *ApiController) WebAuthnSignupFinish() {
return return
} }
isGlobalAdmin := c.IsGlobalAdmin() isGlobalAdmin := c.IsGlobalAdmin()
user.AddCredentials(*credential, isGlobalAdmin) _, err = user.AddCredentials(*credential, isGlobalAdmin)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk() c.ResponseOk()
} }
@ -97,10 +112,20 @@ func (c *ApiController) WebAuthnSignupFinish() {
// @Success 200 {object} protocol.CredentialAssertion The CredentialAssertion object // @Success 200 {object} protocol.CredentialAssertion The CredentialAssertion object
// @router /webauthn/signin/begin [get] // @router /webauthn/signin/begin [get]
func (c *ApiController) WebAuthnSigninBegin() { func (c *ApiController) WebAuthnSigninBegin() {
webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host) webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}
userOwner := c.Input().Get("owner") userOwner := c.Input().Get("owner")
userName := c.Input().Get("name") userName := c.Input().Get("name")
user := object.GetUserByFields(userOwner, userName) user, err := object.GetUserByFields(userOwner, userName)
if err != nil {
c.ResponseError(err.Error())
return
}
if user == nil { if user == nil {
c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(userOwner, userName))) c.ResponseError(fmt.Sprintf(c.T("general:The user: %s doesn't exist"), util.GetId(userOwner, userName)))
return return
@ -129,7 +154,12 @@ func (c *ApiController) WebAuthnSigninBegin() {
// @router /webauthn/signin/finish [post] // @router /webauthn/signin/finish [post]
func (c *ApiController) WebAuthnSigninFinish() { func (c *ApiController) WebAuthnSigninFinish() {
responseType := c.Input().Get("responseType") responseType := c.Input().Get("responseType")
webauthnObj := object.GetWebAuthnObject(c.Ctx.Request.Host) webauthnObj, err := object.GetWebAuthnObject(c.Ctx.Request.Host)
if err != nil {
c.ResponseError(err.Error())
return
}
sessionObj := c.GetSession("authentication") sessionObj := c.GetSession("authentication")
sessionData, ok := sessionObj.(webauthn.SessionData) sessionData, ok := sessionObj.(webauthn.SessionData)
if !ok { if !ok {
@ -138,8 +168,13 @@ func (c *ApiController) WebAuthnSigninFinish() {
} }
c.Ctx.Request.Body = io.NopCloser(bytes.NewBuffer(c.Ctx.Input.RequestBody)) c.Ctx.Request.Body = io.NopCloser(bytes.NewBuffer(c.Ctx.Input.RequestBody))
userId := string(sessionData.UserID) userId := string(sessionData.UserID)
user := object.GetUser(userId) user, err := object.GetUser(userId)
_, err := webauthnObj.FinishLogin(user, sessionData, c.Ctx.Request) if err != nil {
c.ResponseError(err.Error())
return
}
_, err = webauthnObj.FinishLogin(user, sessionData, c.Ctx.Request)
if err != nil { if err != nil {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
@ -147,7 +182,12 @@ func (c *ApiController) WebAuthnSigninFinish() {
c.SetSessionUsername(userId) c.SetSessionUsername(userId)
util.LogInfo(c.Ctx, "API: [%s] signed in", userId) util.LogInfo(c.Ctx, "API: [%s] signed in", userId)
application := object.GetApplicationByUser(user) application, err := object.GetApplicationByUser(user)
if err != nil {
c.ResponseError(err.Error())
return
}
var authForm form.AuthForm var authForm form.AuthForm
authForm.Type = responseType authForm.Type = responseType
resp := c.HandleLoggedIn(application, user, &authForm) resp := c.HandleLoggedIn(application, user, &authForm)

View File

@ -38,13 +38,31 @@ func (c *ApiController) GetWebhooks() {
sortField := c.Input().Get("sortField") sortField := c.Input().Get("sortField")
sortOrder := c.Input().Get("sortOrder") sortOrder := c.Input().Get("sortOrder")
organization := c.Input().Get("organization") organization := c.Input().Get("organization")
if limit == "" || page == "" { if limit == "" || page == "" {
c.Data["json"] = object.GetWebhooks(owner, organization) webhooks, err := object.GetWebhooks(owner, organization)
if err != nil {
panic(err)
}
c.Data["json"] = webhooks
c.ServeJSON() c.ServeJSON()
} else { } else {
limit := util.ParseInt(limit) limit := util.ParseInt(limit)
paginator := pagination.SetPaginator(c.Ctx, limit, int64(object.GetWebhookCount(owner, organization, field, value))) count, err := object.GetWebhookCount(owner, organization, field, value)
webhooks := object.GetPaginationWebhooks(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder) if err != nil {
c.ResponseError(err.Error())
return
}
paginator := pagination.SetPaginator(c.Ctx, limit, count)
webhooks, err := object.GetPaginationWebhooks(owner, organization, paginator.Offset(), limit, field, value, sortField, sortOrder)
if err != nil {
c.ResponseError(err.Error())
return
}
c.ResponseOk(webhooks, paginator.Nums()) c.ResponseOk(webhooks, paginator.Nums())
} }
} }
@ -59,7 +77,12 @@ func (c *ApiController) GetWebhooks() {
func (c *ApiController) GetWebhook() { func (c *ApiController) GetWebhook() {
id := c.Input().Get("id") id := c.Input().Get("id")
c.Data["json"] = object.GetWebhook(id) webhook, err := object.GetWebhook(id)
if err != nil {
panic(err)
}
c.Data["json"] = webhook
c.ServeJSON() c.ServeJSON()
} }

View File

@ -84,6 +84,7 @@ func stringInSlice(value string, list []string) bool {
} }
func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int) { func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int) {
var err error
r := m.GetSearchRequest() r := m.GetSearchRequest()
name, org, code := getNameAndOrgFromFilter(string(r.BaseObject()), r.FilterString()) name, org, code := getNameAndOrgFromFilter(string(r.BaseObject()), r.FilterString())
@ -93,11 +94,19 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int)
if name == "*" && m.Client.IsOrgAdmin { // get all users from organization 'org' if name == "*" && m.Client.IsOrgAdmin { // get all users from organization 'org'
if m.Client.IsGlobalAdmin && org == "*" { if m.Client.IsGlobalAdmin && org == "*" {
filteredUsers = object.GetGlobalUsers()
filteredUsers, err = object.GetGlobalUsers()
if err != nil {
panic(err)
}
return filteredUsers, ldap.LDAPResultSuccess return filteredUsers, ldap.LDAPResultSuccess
} }
if m.Client.IsGlobalAdmin || org == m.Client.OrgName { if m.Client.IsGlobalAdmin || org == m.Client.OrgName {
filteredUsers = object.GetUsers(org) filteredUsers, err = object.GetUsers(org)
if err != nil {
panic(err)
}
return filteredUsers, ldap.LDAPResultSuccess return filteredUsers, ldap.LDAPResultSuccess
} else { } else {
return nil, ldap.LDAPResultInsufficientAccessRights return nil, ldap.LDAPResultInsufficientAccessRights
@ -112,13 +121,21 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int)
return nil, ldap.LDAPResultInsufficientAccessRights return nil, ldap.LDAPResultInsufficientAccessRights
} }
user := object.GetUser(userId) user, err := object.GetUser(userId)
if err != nil {
panic(err)
}
if user != nil { if user != nil {
filteredUsers = append(filteredUsers, user) filteredUsers = append(filteredUsers, user)
return filteredUsers, ldap.LDAPResultSuccess return filteredUsers, ldap.LDAPResultSuccess
} }
organization := object.GetOrganization(util.GetId("admin", org)) organization, err := object.GetOrganization(util.GetId("admin", org))
if err != nil {
panic(err)
}
if organization == nil { if organization == nil {
return nil, ldap.LDAPResultNoSuchObject return nil, ldap.LDAPResultNoSuchObject
} }
@ -127,7 +144,11 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int)
return nil, ldap.LDAPResultNoSuchObject return nil, ldap.LDAPResultNoSuchObject
} }
users := object.GetUsersByTag(org, name) users, err := object.GetUsersByTag(org, name)
if err != nil {
panic(err)
}
filteredUsers = append(filteredUsers, users...) filteredUsers = append(filteredUsers, users...)
return filteredUsers, ldap.LDAPResultSuccess return filteredUsers, ldap.LDAPResultSuccess
} }
@ -137,7 +158,11 @@ func GetFilteredUsers(m *ldap.Message) (filteredUsers []*object.User, code int)
// TODO not handle salt yet // TODO not handle salt yet
// @return {md5}5f4dcc3b5aa765d61d8327deb882cf99 // @return {md5}5f4dcc3b5aa765d61d8327deb882cf99
func getUserPasswordWithType(user *object.User) string { func getUserPasswordWithType(user *object.User) string {
org := object.GetOrganizationByUser(user) org, err := object.GetOrganizationByUser(user)
if err != nil {
panic(err)
}
if org.PasswordType == "" || org.PasswordType == "plain" { if org.PasswordType == "" || org.PasswordType == "plain" {
return user.Password return user.Password
} }

View File

@ -119,7 +119,7 @@ func (a *Adapter) close() {
} }
func (a *Adapter) createTable() { func (a *Adapter) createTable() {
showSql, _ := conf.GetConfigBool("showSql") showSql := conf.GetConfigBool("showSql")
a.Engine.ShowSQL(showSql) a.Engine.ShowSQL(showSql)
tableNamePrefix := conf.GetConfigString("tableNamePrefix") tableNamePrefix := conf.GetConfigString("tableNamePrefix")

View File

@ -79,134 +79,155 @@ type Application struct {
FormBackgroundUrl string `xorm:"varchar(200)" json:"formBackgroundUrl"` FormBackgroundUrl string `xorm:"varchar(200)" json:"formBackgroundUrl"`
} }
func GetApplicationCount(owner, field, value string) int { func GetApplicationCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Application{}) return session.Count(&Application{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetOrganizationApplicationCount(owner, Organization, field, value string) int { func GetOrganizationApplicationCount(owner, Organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Application{Organization: Organization}) return session.Count(&Application{Organization: Organization})
if err != nil {
panic(err)
}
return int(count)
} }
func GetApplications(owner string) []*Application { func GetApplications(owner string) ([]*Application, error) {
applications := []*Application{} applications := []*Application{}
err := adapter.Engine.Desc("created_time").Find(&applications, &Application{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&applications, &Application{Owner: owner})
if err != nil { if err != nil {
panic(err) return applications, err
} }
return applications return applications, nil
} }
func GetOrganizationApplications(owner string, organization string) []*Application { func GetOrganizationApplications(owner string, organization string) ([]*Application, error) {
applications := []*Application{} applications := []*Application{}
err := adapter.Engine.Desc("created_time").Find(&applications, &Application{Organization: organization}) err := adapter.Engine.Desc("created_time").Find(&applications, &Application{Organization: organization})
if err != nil { if err != nil {
panic(err) return applications, err
} }
return applications return applications, nil
} }
func GetPaginationApplications(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Application { func GetPaginationApplications(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Application, error) {
var applications []*Application var applications []*Application
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&applications) err := session.Find(&applications)
if err != nil { if err != nil {
panic(err) return applications, err
} }
return applications return applications, nil
} }
func GetPaginationOrganizationApplications(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Application { func GetPaginationOrganizationApplications(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Application, error) {
applications := []*Application{} applications := []*Application{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&applications, &Application{Organization: organization}) err := session.Find(&applications, &Application{Organization: organization})
if err != nil { if err != nil {
panic(err) return applications, err
} }
return applications return applications, nil
} }
func getProviderMap(owner string) map[string]*Provider { func getProviderMap(owner string) (m map[string]*Provider, err error) {
providers := GetProviders(owner) providers, err := GetProviders(owner)
m := map[string]*Provider{} if err != nil {
return nil, err
}
m = map[string]*Provider{}
for _, provider := range providers { for _, provider := range providers {
// Get QRCode only once // Get QRCode only once
if provider.Type == "WeChat" && provider.DisableSsl && provider.Content == "" { if provider.Type == "WeChat" && provider.DisableSsl && provider.Content == "" {
provider.Content, _ = idp.GetWechatOfficialAccountQRCode(provider.ClientId2, provider.ClientSecret2) provider.Content, err = idp.GetWechatOfficialAccountQRCode(provider.ClientId2, provider.ClientSecret2)
if err != nil {
return
}
UpdateProvider(provider.Owner+"/"+provider.Name, provider) UpdateProvider(provider.Owner+"/"+provider.Name, provider)
} }
m[provider.Name] = GetMaskedProvider(provider, true) m[provider.Name] = GetMaskedProvider(provider, true)
} }
return m
return m, err
} }
func extendApplicationWithProviders(application *Application) { func extendApplicationWithProviders(application *Application) (err error) {
m := getProviderMap(application.Organization) m, err := getProviderMap(application.Organization)
if err != nil {
return err
}
for _, providerItem := range application.Providers { for _, providerItem := range application.Providers {
if provider, ok := m[providerItem.Name]; ok { if provider, ok := m[providerItem.Name]; ok {
providerItem.Provider = provider providerItem.Provider = provider
} }
} }
return
} }
func extendApplicationWithOrg(application *Application) { func extendApplicationWithOrg(application *Application) (err error) {
organization := getOrganization(application.Owner, application.Organization) organization, err := getOrganization(application.Owner, application.Organization)
application.OrganizationObj = organization application.OrganizationObj = organization
return
} }
func getApplication(owner string, name string) *Application { func getApplication(owner string, name string) (*Application, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
application := Application{Owner: owner, Name: name} application := Application{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&application) existed, err := adapter.Engine.Get(&application)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
extendApplicationWithProviders(&application) err = extendApplicationWithProviders(&application)
extendApplicationWithOrg(&application) if err != nil {
return &application return nil, err
}
err = extendApplicationWithOrg(&application)
if err != nil {
return nil, err
}
return &application, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetApplicationByOrganizationName(organization string) *Application { func GetApplicationByOrganizationName(organization string) (*Application, error) {
application := Application{} application := Application{}
existed, err := adapter.Engine.Where("organization=?", organization).Get(&application) existed, err := adapter.Engine.Where("organization=?", organization).Get(&application)
if err != nil { if err != nil {
panic(err) return nil, nil
} }
if existed { if existed {
extendApplicationWithProviders(&application) err = extendApplicationWithProviders(&application)
extendApplicationWithOrg(&application) if err != nil {
return &application return nil, err
}
err = extendApplicationWithOrg(&application)
if err != nil {
return nil, err
}
return &application, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetApplicationByUser(user *User) *Application { func GetApplicationByUser(user *User) (*Application, error) {
if user.SignupApplication != "" { if user.SignupApplication != "" {
return getApplication("admin", user.SignupApplication) return getApplication("admin", user.SignupApplication)
} else { } else {
@ -214,38 +235,46 @@ func GetApplicationByUser(user *User) *Application {
} }
} }
func GetApplicationByUserId(userId string) (*Application, *User) { func GetApplicationByUserId(userId string) (application *Application, user *User, err error) {
var application *Application
owner, name := util.GetOwnerAndNameFromId(userId) owner, name := util.GetOwnerAndNameFromId(userId)
if owner == "app" { if owner == "app" {
application = getApplication("admin", name) application, err = getApplication("admin", name)
return application, nil return
} }
user := GetUser(userId) user, err = GetUser(userId)
application = GetApplicationByUser(user) if err != nil {
return nil, nil, err
return application, user }
application, err = GetApplicationByUser(user)
return
} }
func GetApplicationByClientId(clientId string) *Application { func GetApplicationByClientId(clientId string) (*Application, error) {
application := Application{} application := Application{}
existed, err := adapter.Engine.Where("client_id=?", clientId).Get(&application) existed, err := adapter.Engine.Where("client_id=?", clientId).Get(&application)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
extendApplicationWithProviders(&application) err = extendApplicationWithProviders(&application)
extendApplicationWithOrg(&application) if err != nil {
return &application return nil, err
}
err = extendApplicationWithOrg(&application)
if err != nil {
return nil, err
}
return &application, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetApplication(id string) *Application { func GetApplication(id string) (*Application, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getApplication(owner, name) return getApplication(owner, name)
} }
@ -288,11 +317,11 @@ func GetMaskedApplications(applications []*Application, userId string) []*Applic
return applications return applications
} }
func UpdateApplication(id string, application *Application) bool { func UpdateApplication(id string, application *Application) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
oldApplication := getApplication(owner, name) oldApplication, err := getApplication(owner, name)
if oldApplication == nil { if oldApplication == nil {
return false return false, err
} }
if name == "app-built-in" { if name == "app-built-in" {
@ -300,14 +329,19 @@ func UpdateApplication(id string, application *Application) bool {
} }
if name != application.Name { if name != application.Name {
err := applicationChangeTrigger(name, application.Name) err = applicationChangeTrigger(name, application.Name)
if err != nil { if err != nil {
return false return false, err
} }
} }
if oldApplication.ClientId != application.ClientId && GetApplicationByClientId(application.ClientId) != nil { applicationByClientId, err := GetApplicationByClientId(application.ClientId)
return false if err != nil {
return false, err
}
if oldApplication.ClientId != application.ClientId && applicationByClientId != nil {
return false, err
} }
for _, providerItem := range application.Providers { for _, providerItem := range application.Providers {
@ -320,13 +354,13 @@ func UpdateApplication(id string, application *Application) bool {
} }
affected, err := session.Update(application) affected, err := session.Update(application)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddApplication(application *Application) bool { func AddApplication(application *Application) (bool, error) {
if application.Owner == "" { if application.Owner == "" {
application.Owner = "admin" application.Owner = "admin"
} }
@ -339,32 +373,39 @@ func AddApplication(application *Application) bool {
if application.ClientSecret == "" { if application.ClientSecret == "" {
application.ClientSecret = util.GenerateClientSecret() application.ClientSecret = util.GenerateClientSecret()
} }
if GetApplicationByClientId(application.ClientId) != nil {
return false app, err := GetApplicationByClientId(application.ClientId)
if err != nil {
return false, err
} }
if app != nil {
return false, nil
}
for _, providerItem := range application.Providers { for _, providerItem := range application.Providers {
providerItem.Provider = nil providerItem.Provider = nil
} }
affected, err := adapter.Engine.Insert(application) affected, err := adapter.Engine.Insert(application)
if err != nil { if err != nil {
panic(err) return false, nil
} }
return affected != 0 return affected != 0, nil
} }
func DeleteApplication(application *Application) bool { func DeleteApplication(application *Application) (bool, error) {
if application.Name == "app-built-in" { if application.Name == "app-built-in" {
return false return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{application.Owner, application.Name}).Delete(&Application{}) affected, err := adapter.Engine.ID(core.PK{application.Owner, application.Name}).Delete(&Application{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (application *Application) GetId() string { func (application *Application) GetId() string {
@ -383,33 +424,43 @@ func (application *Application) IsRedirectUriValid(redirectUri string) bool {
return isValid return isValid
} }
func IsOriginAllowed(origin string) bool { func IsOriginAllowed(origin string) (bool, error) {
applications := GetApplications("") applications, err := GetApplications("")
if err != nil {
return false, err
}
for _, application := range applications { for _, application := range applications {
if application.IsRedirectUriValid(origin) { if application.IsRedirectUriValid(origin) {
return true return true, nil
} }
} }
return false return false, nil
} }
func getApplicationMap(organization string) map[string]*Application { func getApplicationMap(organization string) (map[string]*Application, error) {
applications := GetOrganizationApplications("admin", organization)
applicationMap := make(map[string]*Application) applicationMap := make(map[string]*Application)
applications, err := GetOrganizationApplications("admin", organization)
if err != nil {
return applicationMap, err
}
for _, application := range applications { for _, application := range applications {
applicationMap[application.Name] = application applicationMap[application.Name] = application
} }
return applicationMap return applicationMap, nil
} }
func ExtendManagedAccountsWithUser(user *User) *User { func ExtendManagedAccountsWithUser(user *User) (*User, error) {
if user.ManagedAccounts == nil || len(user.ManagedAccounts) == 0 { if user.ManagedAccounts == nil || len(user.ManagedAccounts) == 0 {
return user return user, nil
} }
applicationMap := getApplicationMap(user.Owner) applicationMap, err := getApplicationMap(user.Owner)
if err != nil {
return user, err
}
var managedAccounts []ManagedAccount var managedAccounts []ManagedAccount
for _, managedAccount := range user.ManagedAccounts { for _, managedAccount := range user.ManagedAccounts {
@ -421,7 +472,7 @@ func ExtendManagedAccountsWithUser(user *User) *User {
} }
user.ManagedAccounts = managedAccounts user.ManagedAccounts = managedAccounts
return user return user, nil
} }
func applicationChangeTrigger(oldName string, newName string) error { func applicationChangeTrigger(oldName string, newName string) error {

View File

@ -14,8 +14,12 @@
package object package object
func (application *Application) GetProviderByCategory(category string) *Provider { func (application *Application) GetProviderByCategory(category string) (*Provider, error) {
providers := GetProviders(application.Organization) providers, err := GetProviders(application.Organization)
if err != nil {
return nil, err
}
m := map[string]*Provider{} m := map[string]*Provider{}
for _, provider := range providers { for _, provider := range providers {
if provider.Category != category { if provider.Category != category {
@ -27,22 +31,22 @@ func (application *Application) GetProviderByCategory(category string) *Provider
for _, providerItem := range application.Providers { for _, providerItem := range application.Providers {
if provider, ok := m[providerItem.Name]; ok { if provider, ok := m[providerItem.Name]; ok {
return provider return provider, nil
} }
} }
return nil return nil, nil
} }
func (application *Application) GetEmailProvider() *Provider { func (application *Application) GetEmailProvider() (*Provider, error) {
return application.GetProviderByCategory("Email") return application.GetProviderByCategory("Email")
} }
func (application *Application) GetSmsProvider() *Provider { func (application *Application) GetSmsProvider() (*Provider, error) {
return application.GetProviderByCategory("SMS") return application.GetProviderByCategory("SMS")
} }
func (application *Application) GetStorageProvider() *Provider { func (application *Application) GetStorageProvider() (*Provider, error) {
return application.GetProviderByCategory("Storage") return application.GetProviderByCategory("Storage")
} }

View File

@ -28,7 +28,11 @@ var defaultStorageProvider *Provider = nil
func InitDefaultStorageProvider() { func InitDefaultStorageProvider() {
defaultStorageProviderStr := conf.GetConfigString("defaultStorageProvider") defaultStorageProviderStr := conf.GetConfigString("defaultStorageProvider")
if defaultStorageProviderStr != "" { if defaultStorageProviderStr != "" {
defaultStorageProvider = getProvider("admin", defaultStorageProviderStr) var err error
defaultStorageProvider, err = getProvider("admin", defaultStorageProviderStr)
if err != nil {
panic(err)
}
} }
} }
@ -50,40 +54,44 @@ func downloadFile(url string) (*bytes.Buffer, error) {
return fileBuffer, nil return fileBuffer, nil
} }
func getPermanentAvatarUrl(organization string, username string, url string, upload bool) string { func getPermanentAvatarUrl(organization string, username string, url string, upload bool) (string, error) {
if url == "" { if url == "" {
return "" return "", nil
} }
if defaultStorageProvider == nil { if defaultStorageProvider == nil {
return "" return "", nil
} }
fullFilePath := fmt.Sprintf("/avatar/%s/%s.png", organization, username) fullFilePath := fmt.Sprintf("/avatar/%s/%s.png", organization, username)
uploadedFileUrl, _ := GetUploadFileUrl(defaultStorageProvider, fullFilePath, false) uploadedFileUrl, _ := GetUploadFileUrl(defaultStorageProvider, fullFilePath, false)
if upload { if upload {
DownloadAndUpload(url, fullFilePath, "en") if err := DownloadAndUpload(url, fullFilePath, "en"); err != nil {
return "", err
}
} }
return uploadedFileUrl return uploadedFileUrl, nil
} }
func DownloadAndUpload(url string, fullFilePath string, lang string) { func DownloadAndUpload(url string, fullFilePath string, lang string) (err error) {
fileBuffer, err := downloadFile(url) fileBuffer, err := downloadFile(url)
if err != nil { if err != nil {
panic(err) return
} }
_, _, err = UploadFileSafe(defaultStorageProvider, fullFilePath, fileBuffer, lang) _, _, err = UploadFileSafe(defaultStorageProvider, fullFilePath, fileBuffer, lang)
if err != nil { if err != nil {
panic(err) return
} }
return
} }
func getPermanentAvatarUrlFromBuffer(organization string, username string, fileBuffer *bytes.Buffer, ext string, upload bool) string { func getPermanentAvatarUrlFromBuffer(organization string, username string, fileBuffer *bytes.Buffer, ext string, upload bool) (string, error) {
if defaultStorageProvider == nil { if defaultStorageProvider == nil {
return "" return "", nil
} }
fullFilePath := fmt.Sprintf("/avatar/%s/%s%s", organization, username, ext) fullFilePath := fmt.Sprintf("/avatar/%s/%s%s", organization, username, ext)
@ -92,9 +100,9 @@ func getPermanentAvatarUrlFromBuffer(organization string, username string, fileB
if upload { if upload {
_, _, err := UploadFileSafe(defaultStorageProvider, fullFilePath, fileBuffer, "en") _, _, err := UploadFileSafe(defaultStorageProvider, fullFilePath, fileBuffer, "en")
if err != nil { if err != nil {
panic(err) return "", err
} }
} }
return uploadedFileUrl return uploadedFileUrl, nil
} }

View File

@ -27,13 +27,21 @@ func TestSyncPermanentAvatars(t *testing.T) {
InitDefaultStorageProvider() InitDefaultStorageProvider()
proxy.InitHttpClient() proxy.InitHttpClient()
users := GetGlobalUsers() users, err := GetGlobalUsers()
if err != nil {
panic(err)
}
for i, user := range users { for i, user := range users {
if user.Avatar == "" { if user.Avatar == "" {
continue continue
} }
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
if err != nil {
panic(err)
}
updateUserColumn("permanent_avatar", user) updateUserColumn("permanent_avatar", user)
fmt.Printf("[%d/%d]: Update user: [%s]'s permanent avatar: %s\n", i, len(users), user.GetId(), user.PermanentAvatar) fmt.Printf("[%d/%d]: Update user: [%s]'s permanent avatar: %s\n", i, len(users), user.GetId(), user.PermanentAvatar)
} }
@ -44,16 +52,27 @@ func TestUpdateAvatars(t *testing.T) {
InitDefaultStorageProvider() InitDefaultStorageProvider()
proxy.InitHttpClient() proxy.InitHttpClient()
users := GetUsers("casdoor") users, err := GetUsers("casdoor")
if err != nil {
panic(err)
}
for _, user := range users { for _, user := range users {
if strings.HasPrefix(user.Avatar, "http") { if strings.HasPrefix(user.Avatar, "http") {
continue continue
} }
updated := user.refreshAvatar() updated, err := user.refreshAvatar()
if err != nil {
panic(err)
}
if updated { if updated {
user.PermanentAvatar = "*" user.PermanentAvatar = "*"
UpdateUser(user.GetId(), user, []string{"avatar"}, true) _, err = UpdateUser(user.GetId(), user, []string{"avatar"}, true)
if err != nil {
panic(err)
}
} }
} }
} }

View File

@ -20,17 +20,17 @@ import (
"github.com/dchest/captcha" "github.com/dchest/captcha"
) )
func GetCaptcha() (string, []byte) { func GetCaptcha() (string, []byte, error) {
id := captcha.NewLen(5) id := captcha.NewLen(5)
var buffer bytes.Buffer var buffer bytes.Buffer
err := captcha.WriteImage(&buffer, id, 200, 80) err := captcha.WriteImage(&buffer, id, 200, 80)
if err != nil { if err != nil {
panic(err) return "", nil, err
} }
return id, buffer.Bytes() return id, buffer.Bytes(), nil
} }
func VerifyCaptcha(id string, digits string) bool { func VerifyCaptcha(id string, digits string) bool {

View File

@ -46,64 +46,59 @@ type CasbinAdapter struct {
Adapter *xormadapter.Adapter `xorm:"-" json:"-"` Adapter *xormadapter.Adapter `xorm:"-" json:"-"`
} }
func GetCasbinAdapterCount(owner, organization, field, value string) int { func GetCasbinAdapterCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&CasbinAdapter{Organization: organization}) return session.Count(&CasbinAdapter{Organization: organization})
if err != nil {
panic(err)
}
return int(count)
} }
func GetCasbinAdapters(owner string, organization string) []*CasbinAdapter { func GetCasbinAdapters(owner string, organization string) ([]*CasbinAdapter, error) {
adapters := []*CasbinAdapter{} adapters := []*CasbinAdapter{}
err := adapter.Engine.Where("owner = ? and organization = ?", owner, organization).Find(&adapters) err := adapter.Engine.Where("owner = ? and organization = ?", owner, organization).Find(&adapters)
if err != nil { if err != nil {
panic(err) return adapters, err
} }
return adapters return adapters, nil
} }
func GetPaginationCasbinAdapters(owner, organization string, page, limit int, field, value, sort, order string) []*CasbinAdapter { func GetPaginationCasbinAdapters(owner, organization string, page, limit int, field, value, sort, order string) ([]*CasbinAdapter, error) {
session := GetSession(owner, page, limit, field, value, sort, order) session := GetSession(owner, page, limit, field, value, sort, order)
adapters := []*CasbinAdapter{} adapters := []*CasbinAdapter{}
err := session.Find(&adapters, &CasbinAdapter{Organization: organization}) err := session.Find(&adapters, &CasbinAdapter{Organization: organization})
if err != nil { if err != nil {
panic(err) return adapters, err
} }
return adapters return adapters, nil
} }
func getCasbinAdapter(owner, name string) *CasbinAdapter { func getCasbinAdapter(owner, name string) (*CasbinAdapter, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
casbinAdapter := CasbinAdapter{Owner: owner, Name: name} casbinAdapter := CasbinAdapter{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&casbinAdapter) existed, err := adapter.Engine.Get(&casbinAdapter)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &casbinAdapter return &casbinAdapter, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetCasbinAdapter(id string) *CasbinAdapter { func GetCasbinAdapter(id string) (*CasbinAdapter, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getCasbinAdapter(owner, name) return getCasbinAdapter(owner, name)
} }
func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) bool { func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getCasbinAdapter(owner, name) == nil { if casbinAdapter, err := getCasbinAdapter(owner, name); casbinAdapter == nil {
return false return false, err
} }
session := adapter.Engine.ID(core.PK{owner, name}).AllCols() session := adapter.Engine.ID(core.PK{owner, name}).AllCols()
@ -112,28 +107,28 @@ func UpdateCasbinAdapter(id string, casbinAdapter *CasbinAdapter) bool {
} }
affected, err := session.Update(casbinAdapter) affected, err := session.Update(casbinAdapter)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddCasbinAdapter(casbinAdapter *CasbinAdapter) bool { func AddCasbinAdapter(casbinAdapter *CasbinAdapter) (bool, error) {
affected, err := adapter.Engine.Insert(casbinAdapter) affected, err := adapter.Engine.Insert(casbinAdapter)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteCasbinAdapter(casbinAdapter *CasbinAdapter) bool { func DeleteCasbinAdapter(casbinAdapter *CasbinAdapter) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{casbinAdapter.Owner, casbinAdapter.Name}).Delete(&CasbinAdapter{}) affected, err := adapter.Engine.ID(core.PK{casbinAdapter.Owner, casbinAdapter.Name}).Delete(&CasbinAdapter{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (casbinAdapter *CasbinAdapter) GetId() string { func (casbinAdapter *CasbinAdapter) GetId() string {
@ -214,7 +209,11 @@ func matrixToCasbinRules(Ptype string, policies [][]string) []*xormadapter.Casbi
} }
func SyncPolicies(casbinAdapter *CasbinAdapter) ([]*xormadapter.CasbinRule, error) { func SyncPolicies(casbinAdapter *CasbinAdapter) ([]*xormadapter.CasbinRule, error) {
modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model) modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model)
if err != nil {
return nil, err
}
enforcer, err := initEnforcer(modelObj, casbinAdapter) enforcer, err := initEnforcer(modelObj, casbinAdapter)
if err != nil { if err != nil {
return nil, err return nil, err
@ -229,7 +228,11 @@ func SyncPolicies(casbinAdapter *CasbinAdapter) ([]*xormadapter.CasbinRule, erro
} }
func UpdatePolicy(oldPolicy, newPolicy []string, casbinAdapter *CasbinAdapter) (bool, error) { func UpdatePolicy(oldPolicy, newPolicy []string, casbinAdapter *CasbinAdapter) (bool, error) {
modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model) modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model)
if err != nil {
return false, err
}
enforcer, err := initEnforcer(modelObj, casbinAdapter) enforcer, err := initEnforcer(modelObj, casbinAdapter)
if err != nil { if err != nil {
return false, err return false, err
@ -243,7 +246,11 @@ func UpdatePolicy(oldPolicy, newPolicy []string, casbinAdapter *CasbinAdapter) (
} }
func AddPolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) { func AddPolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) {
modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model) modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model)
if err != nil {
return false, err
}
enforcer, err := initEnforcer(modelObj, casbinAdapter) enforcer, err := initEnforcer(modelObj, casbinAdapter)
if err != nil { if err != nil {
return false, err return false, err
@ -257,7 +264,11 @@ func AddPolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) {
} }
func RemovePolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) { func RemovePolicy(policy []string, casbinAdapter *CasbinAdapter) (bool, error) {
modelObj := getModel(casbinAdapter.Owner, casbinAdapter.Model) modelObj, err := getModel(casbinAdapter.Owner, casbinAdapter.Model)
if err != nil {
return false, err
}
enforcer, err := initEnforcer(modelObj, casbinAdapter) enforcer, err := initEnforcer(modelObj, casbinAdapter)
if err != nil { if err != nil {
return false, err return false, err

View File

@ -47,137 +47,133 @@ func GetMaskedCert(cert *Cert) *Cert {
return cert return cert
} }
func GetMaskedCerts(certs []*Cert) []*Cert { func GetMaskedCerts(certs []*Cert, err error) ([]*Cert, error) {
if err != nil {
return nil, err
}
for _, cert := range certs { for _, cert := range certs {
cert = GetMaskedCert(cert) cert = GetMaskedCert(cert)
} }
return certs return certs, nil
} }
func GetCertCount(owner, field, value string) int { func GetCertCount(owner, field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "") session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Cert{}) return session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Cert{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetCerts(owner string) []*Cert { func GetCerts(owner string) ([]*Cert, error) {
certs := []*Cert{} certs := []*Cert{}
err := adapter.Engine.Where("owner = ? or owner = ? ", "admin", owner).Desc("created_time").Find(&certs, &Cert{}) err := adapter.Engine.Where("owner = ? or owner = ? ", "admin", owner).Desc("created_time").Find(&certs, &Cert{})
if err != nil { if err != nil {
panic(err) return certs, err
} }
return certs return certs, nil
} }
func GetPaginationCerts(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Cert { func GetPaginationCerts(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Cert, error) {
certs := []*Cert{} certs := []*Cert{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder) session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Where("owner = ? or owner = ? ", "admin", owner).Find(&certs) err := session.Where("owner = ? or owner = ? ", "admin", owner).Find(&certs)
if err != nil { if err != nil {
panic(err) return certs, err
} }
return certs return certs, nil
} }
func GetGlobalCertsCount(field, value string) int { func GetGlobalCertsCount(field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "") session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Count(&Cert{}) return session.Count(&Cert{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetGlobleCerts() []*Cert { func GetGlobleCerts() ([]*Cert, error) {
certs := []*Cert{} certs := []*Cert{}
err := adapter.Engine.Desc("created_time").Find(&certs) err := adapter.Engine.Desc("created_time").Find(&certs)
if err != nil { if err != nil {
panic(err) return certs, err
} }
return certs return certs, nil
} }
func GetPaginationGlobalCerts(offset, limit int, field, value, sortField, sortOrder string) []*Cert { func GetPaginationGlobalCerts(offset, limit int, field, value, sortField, sortOrder string) ([]*Cert, error) {
certs := []*Cert{} certs := []*Cert{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder) session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Find(&certs) err := session.Find(&certs)
if err != nil { if err != nil {
panic(err) return certs, err
} }
return certs return certs, nil
} }
func getCert(owner string, name string) *Cert { func getCert(owner string, name string) (*Cert, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
cert := Cert{Owner: owner, Name: name} cert := Cert{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&cert) existed, err := adapter.Engine.Get(&cert)
if err != nil { if err != nil {
panic(err) return &cert, err
} }
if existed { if existed {
return &cert return &cert, nil
} else { } else {
return nil return nil, nil
} }
} }
func getCertByName(name string) *Cert { func getCertByName(name string) (*Cert, error) {
if name == "" { if name == "" {
return nil return nil, nil
} }
cert := Cert{Name: name} cert := Cert{Name: name}
existed, err := adapter.Engine.Get(&cert) existed, err := adapter.Engine.Get(&cert)
if err != nil { if err != nil {
panic(err) return &cert, nil
} }
if existed { if existed {
return &cert return &cert, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetCert(id string) *Cert { func GetCert(id string) (*Cert, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getCert(owner, name) return getCert(owner, name)
} }
func UpdateCert(id string, cert *Cert) bool { func UpdateCert(id string, cert *Cert) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getCert(owner, name) == nil { if c, err := getCert(owner, name); err != nil {
return false return false, err
} else if c == nil {
return false, nil
} }
if name != cert.Name { if name != cert.Name {
err := certChangeTrigger(name, cert.Name) err := certChangeTrigger(name, cert.Name)
if err != nil { if err != nil {
return false return false, nil
} }
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(cert) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(cert)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddCert(cert *Cert) bool { func AddCert(cert *Cert) (bool, error) {
if cert.Certificate == "" || cert.PrivateKey == "" { if cert.Certificate == "" || cert.PrivateKey == "" {
certificate, privateKey := generateRsaKeys(cert.BitSize, cert.ExpireInYears, cert.Name, cert.Owner) certificate, privateKey := generateRsaKeys(cert.BitSize, cert.ExpireInYears, cert.Name, cert.Owner)
cert.Certificate = certificate cert.Certificate = certificate
@ -186,26 +182,26 @@ func AddCert(cert *Cert) bool {
affected, err := adapter.Engine.Insert(cert) affected, err := adapter.Engine.Insert(cert)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteCert(cert *Cert) bool { func DeleteCert(cert *Cert) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{cert.Owner, cert.Name}).Delete(&Cert{}) affected, err := adapter.Engine.ID(core.PK{cert.Owner, cert.Name}).Delete(&Cert{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (p *Cert) GetId() string { func (p *Cert) GetId() string {
return fmt.Sprintf("%s/%s", p.Owner, p.Name) return fmt.Sprintf("%s/%s", p.Owner, p.Name)
} }
func getCertByApplication(application *Application) *Cert { func getCertByApplication(application *Application) (*Cert, error) {
if application.Cert != "" { if application.Cert != "" {
return getCertByName(application.Cert) return getCertByName(application.Cert)
} else { } else {
@ -213,7 +209,7 @@ func getCertByApplication(application *Application) *Cert {
} }
} }
func GetDefaultCert() *Cert { func GetDefaultCert() (*Cert, error) {
return getCert("admin", "cert-built-in") return getCert("admin", "cert-built-in")
} }

View File

@ -37,92 +37,104 @@ type Chat struct {
MessageCount int `json:"messageCount"` MessageCount int `json:"messageCount"`
} }
func GetMaskedChat(chat *Chat) *Chat { func GetMaskedChat(chat *Chat, err ...error) (*Chat, error) {
if len(err) > 0 && err[0] != nil {
return nil, err[0]
}
if chat == nil { if chat == nil {
return nil return nil, nil
} }
return chat return chat, nil
} }
func GetMaskedChats(chats []*Chat) []*Chat { func GetMaskedChats(chats []*Chat, errs ...error) ([]*Chat, error) {
if len(errs) > 0 && errs[0] != nil {
return nil, errs[0]
}
var err error
for _, chat := range chats { for _, chat := range chats {
chat = GetMaskedChat(chat) chat, err = GetMaskedChat(chat)
if err != nil {
return nil, err
}
} }
return chats return chats, nil
} }
func GetChatCount(owner, field, value string) int { func GetChatCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Chat{}) return session.Count(&Chat{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetChats(owner string) []*Chat { func GetChats(owner string) ([]*Chat, error) {
chats := []*Chat{} chats := []*Chat{}
err := adapter.Engine.Desc("created_time").Find(&chats, &Chat{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&chats, &Chat{Owner: owner})
if err != nil { if err != nil {
panic(err) return chats, err
} }
return chats return chats, nil
} }
func GetPaginationChats(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Chat { func GetPaginationChats(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Chat, error) {
chats := []*Chat{} chats := []*Chat{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&chats) err := session.Find(&chats)
if err != nil { if err != nil {
panic(err) return chats, err
} }
return chats return chats, nil
} }
func getChat(owner string, name string) *Chat { func getChat(owner string, name string) (*Chat, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
chat := Chat{Owner: owner, Name: name} chat := Chat{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&chat) existed, err := adapter.Engine.Get(&chat)
if err != nil { if err != nil {
panic(err) return &chat, err
} }
if existed { if existed {
return &chat return &chat, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetChat(id string) *Chat { func GetChat(id string) (*Chat, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getChat(owner, name) return getChat(owner, name)
} }
func UpdateChat(id string, chat *Chat) bool { func UpdateChat(id string, chat *Chat) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getChat(owner, name) == nil { if c, err := getChat(owner, name); err != nil {
return false return false, err
} else if c == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(chat) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(chat)
if err != nil { if err != nil {
panic(err) return false, nil
} }
return affected != 0 return affected != 0, nil
} }
func AddChat(chat *Chat) bool { func AddChat(chat *Chat) (bool, error) {
if chat.Type == "AI" && chat.User2 == "" { if chat.Type == "AI" && chat.User2 == "" {
provider := getDefaultAiProvider() provider, err := getDefaultAiProvider()
if err != nil {
return false, err
}
if provider != nil { if provider != nil {
chat.User2 = provider.Name chat.User2 = provider.Name
} }
@ -130,23 +142,23 @@ func AddChat(chat *Chat) bool {
affected, err := adapter.Engine.Insert(chat) affected, err := adapter.Engine.Insert(chat)
if err != nil { if err != nil {
panic(err) return false, nil
} }
return affected != 0 return affected != 0, nil
} }
func DeleteChat(chat *Chat) bool { func DeleteChat(chat *Chat) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{chat.Owner, chat.Name}).Delete(&Chat{}) affected, err := adapter.Engine.ID(core.PK{chat.Owner, chat.Name}).Delete(&Chat{})
if err != nil { if err != nil {
panic(err) return false, err
} }
if affected != 0 { if affected != 0 {
return DeleteChatMessages(chat.Name) return DeleteChatMessages(chat.Name)
} }
return affected != 0 return affected != 0, nil
} }
func (p *Chat) GetId() string { func (p *Chat) GetId() string {

View File

@ -170,7 +170,11 @@ func CheckPassword(user *User, password string, lang string, options ...bool) st
} }
} }
organization := GetOrganizationByUser(user) organization, err := GetOrganizationByUser(user)
if err != nil {
panic(err)
}
if organization == nil { if organization == nil {
return i18n.Translate(lang, "check:Organization does not exist") return i18n.Translate(lang, "check:Organization does not exist")
} }
@ -200,7 +204,11 @@ func CheckPassword(user *User, password string, lang string, options ...bool) st
} }
func checkLdapUserPassword(user *User, password string, lang string) string { func checkLdapUserPassword(user *User, password string, lang string) string {
ldaps := GetLdaps(user.Owner) ldaps, err := GetLdaps(user.Owner)
if err != nil {
return err.Error()
}
ldapLoginSuccess := false ldapLoginSuccess := false
hit := false hit := false
@ -247,7 +255,11 @@ func CheckUserPassword(organization string, username string, password string, la
if len(options) > 0 { if len(options) > 0 {
enableCaptcha = options[0] enableCaptcha = options[0]
} }
user := GetUserByFields(organization, username) user, err := GetUserByFields(organization, username)
if err != nil {
panic(err)
}
if user == nil || user.IsDeleted { if user == nil || user.IsDeleted {
return nil, fmt.Sprintf(i18n.Translate(lang, "general:The user: %s doesn't exist"), util.GetId(organization, username)) return nil, fmt.Sprintf(i18n.Translate(lang, "general:The user: %s doesn't exist"), util.GetId(organization, username))
} }
@ -284,7 +296,11 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string)
userOwner := util.GetOwnerFromId(userId) userOwner := util.GetOwnerFromId(userId)
if userId != "" { if userId != "" {
targetUser := GetUser(userId) targetUser, err := GetUser(userId)
if err != nil {
panic(err)
}
if targetUser == nil { if targetUser == nil {
if strings.HasPrefix(requestUserId, "built-in/") { if strings.HasPrefix(requestUserId, "built-in/") {
return true, nil return true, nil
@ -300,7 +316,11 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string)
if strings.HasPrefix(requestUserId, "app/") { if strings.HasPrefix(requestUserId, "app/") {
hasPermission = true hasPermission = true
} else { } else {
requestUser := GetUser(requestUserId) requestUser, err := GetUser(requestUserId)
if err != nil {
return false, err
}
if requestUser == nil { if requestUser == nil {
return false, fmt.Errorf(i18n.Translate(lang, "check:Session outdated, please login again")) return false, fmt.Errorf(i18n.Translate(lang, "check:Session outdated, please login again"))
} }
@ -321,13 +341,17 @@ func CheckUserPermission(requestUserId, userId string, strict bool, lang string)
} }
func CheckAccessPermission(userId string, application *Application) (bool, error) { func CheckAccessPermission(userId string, application *Application) (bool, error) {
var err error
if userId == "built-in/admin" { if userId == "built-in/admin" {
return true, nil return true, nil
} }
permissions := GetPermissions(application.Organization) permissions, err := GetPermissions(application.Organization)
if err != nil {
return false, err
}
allowed := true allowed := true
var err error
for _, permission := range permissions { for _, permission := range permissions {
if !permission.IsEnabled || len(permission.Users) == 0 { if !permission.IsEnabled || len(permission.Users) == 0 {
continue continue
@ -403,9 +427,9 @@ func CheckUpdateUser(oldUser, user *User, lang string) string {
return "" return ""
} }
func CheckToEnableCaptcha(application *Application, organization, username string) bool { func CheckToEnableCaptcha(application *Application, organization, username string) (bool, error) {
if len(application.Providers) == 0 { if len(application.Providers) == 0 {
return false return false, nil
} }
for _, providerItem := range application.Providers { for _, providerItem := range application.Providers {
@ -414,12 +438,15 @@ func CheckToEnableCaptcha(application *Application, organization, username strin
} }
if providerItem.Provider.Category == "Captcha" { if providerItem.Provider.Category == "Captcha" {
if providerItem.Rule == "Dynamic" { if providerItem.Rule == "Dynamic" {
user := GetUserByFields(organization, username) user, err := GetUserByFields(organization, username)
return user != nil && user.SigninWrongTimes >= SigninWrongTimesLimit if err != nil {
return false, err
}
return user != nil && user.SigninWrongTimes >= SigninWrongTimesLimit, nil
} }
return providerItem.Rule == "Always" return providerItem.Rule == "Always", nil
} }
} }
return false return false, nil
} }

View File

@ -74,7 +74,11 @@ func getBuiltInAccountItems() []*AccountItem {
} }
func initBuiltInOrganization() bool { func initBuiltInOrganization() bool {
organization := getOrganization("admin", "built-in") organization, err := getOrganization("admin", "built-in")
if err != nil {
panic(err)
}
if organization != nil { if organization != nil {
return true return true
} }
@ -96,12 +100,19 @@ func initBuiltInOrganization() bool {
EnableSoftDeletion: false, EnableSoftDeletion: false,
IsProfilePublic: false, IsProfilePublic: false,
} }
AddOrganization(organization) _, err = AddOrganization(organization)
if err != nil {
panic(err)
}
return false return false
} }
func initBuiltInUser() { func initBuiltInUser() {
user := getUser("built-in", "admin") user, err := getUser("built-in", "admin")
if err != nil {
panic(err)
}
if user != nil { if user != nil {
return return
} }
@ -131,11 +142,18 @@ func initBuiltInUser() {
CreatedIp: "127.0.0.1", CreatedIp: "127.0.0.1",
Properties: make(map[string]string), Properties: make(map[string]string),
} }
AddUser(user) _, err = AddUser(user)
if err != nil {
panic(err)
}
} }
func initBuiltInApplication() { func initBuiltInApplication() {
application := getApplication("admin", "app-built-in") application, err := getApplication("admin", "app-built-in")
if err != nil {
panic(err)
}
if application != nil { if application != nil {
return return
} }
@ -168,7 +186,10 @@ func initBuiltInApplication() {
ExpireInHours: 168, ExpireInHours: 168,
FormOffset: 2, FormOffset: 2,
} }
AddApplication(application) _, err = AddApplication(application)
if err != nil {
panic(err)
}
} }
func readTokenFromFile() (string, string) { func readTokenFromFile() (string, string) {
@ -187,7 +208,11 @@ func readTokenFromFile() (string, string) {
func initBuiltInCert() { func initBuiltInCert() {
tokenJwtCertificate, tokenJwtPrivateKey := readTokenFromFile() tokenJwtCertificate, tokenJwtPrivateKey := readTokenFromFile()
cert := getCert("admin", "cert-built-in") cert, err := getCert("admin", "cert-built-in")
if err != nil {
panic(err)
}
if cert != nil { if cert != nil {
return return
} }
@ -205,11 +230,18 @@ func initBuiltInCert() {
Certificate: tokenJwtCertificate, Certificate: tokenJwtCertificate,
PrivateKey: tokenJwtPrivateKey, PrivateKey: tokenJwtPrivateKey,
} }
AddCert(cert) _, err = AddCert(cert)
if err != nil {
panic(err)
}
} }
func initBuiltInLdap() { func initBuiltInLdap() {
ldap := GetLdap("ldap-built-in") ldap, err := GetLdap("ldap-built-in")
if err != nil {
panic(err)
}
if ldap != nil { if ldap != nil {
return return
} }
@ -226,11 +258,18 @@ func initBuiltInLdap() {
AutoSync: 0, AutoSync: 0,
LastSync: "", LastSync: "",
} }
AddLdap(ldap) _, err = AddLdap(ldap)
if err != nil {
panic(err)
}
} }
func initBuiltInProvider() { func initBuiltInProvider() {
provider := GetProvider(util.GetId("admin", "provider_captcha_default")) provider, err := GetProvider(util.GetId("admin", "provider_captcha_default"))
if err != nil {
panic(err)
}
if provider != nil { if provider != nil {
return return
} }
@ -243,7 +282,10 @@ func initBuiltInProvider() {
Category: "Captcha", Category: "Captcha",
Type: "Default", Type: "Default",
} }
AddProvider(provider) _, err = AddProvider(provider)
if err != nil {
panic(err)
}
} }
func initWebAuthn() { func initWebAuthn() {
@ -251,7 +293,11 @@ func initWebAuthn() {
} }
func initBuiltInModel() { func initBuiltInModel() {
model := GetModel("built-in/model-built-in") model, err := GetModel("built-in/model-built-in")
if err != nil {
panic(err)
}
if model != nil { if model != nil {
return return
} }
@ -274,11 +320,17 @@ e = some(where (p.eft == allow))
[matchers] [matchers]
m = r.sub == p.sub && r.obj == p.obj && r.act == p.act`, m = r.sub == p.sub && r.obj == p.obj && r.act == p.act`,
} }
AddModel(model) _, err = AddModel(model)
if err != nil {
panic(err)
}
} }
func initBuiltInPermission() { func initBuiltInPermission() {
permission := GetPermission("built-in/permission-built-in") permission, err := GetPermission("built-in/permission-built-in")
if err != nil {
panic(err)
}
if permission != nil { if permission != nil {
return return
} }
@ -298,5 +350,8 @@ func initBuiltInPermission() {
Effect: "Allow", Effect: "Allow",
IsEnabled: true, IsEnabled: true,
} }
AddPermission(permission) _, err = AddPermission(permission)
if err != nil {
panic(err)
}
} }

View File

@ -35,7 +35,11 @@ type InitData struct {
} }
func InitFromFile() { func InitFromFile() {
initData := readInitDataFromFile("./init_data.json") initData, err := readInitDataFromFile("./init_data.json")
if err != nil {
panic(err)
}
if initData != nil { if initData != nil {
for _, organization := range initData.Organizations { for _, organization := range initData.Organizations {
initDefinedOrganization(organization) initDefinedOrganization(organization)
@ -85,9 +89,9 @@ func InitFromFile() {
} }
} }
func readInitDataFromFile(filePath string) *InitData { func readInitDataFromFile(filePath string) (*InitData, error) {
if !util.FileExist(filePath) { if !util.FileExist(filePath) {
return nil return nil, nil
} }
s := util.ReadStringFromPath(filePath) s := util.ReadStringFromPath(filePath)
@ -111,7 +115,7 @@ func readInitDataFromFile(filePath string) *InitData {
} }
err := util.JsonToStruct(s, data) err := util.JsonToStruct(s, data)
if err != nil { if err != nil {
panic(err) return nil, err
} }
// transform nil slice to empty slice // transform nil slice to empty slice
@ -170,142 +174,246 @@ func readInitDataFromFile(filePath string) *InitData {
} }
} }
return data return data, nil
} }
func initDefinedOrganization(organization *Organization) { func initDefinedOrganization(organization *Organization) {
existed := getOrganization(organization.Owner, organization.Name) existed, err := getOrganization(organization.Owner, organization.Name)
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
organization.CreatedTime = util.GetCurrentTime() organization.CreatedTime = util.GetCurrentTime()
organization.AccountItems = getBuiltInAccountItems() organization.AccountItems = getBuiltInAccountItems()
AddOrganization(organization) _, err = AddOrganization(organization)
if err != nil {
panic(err)
}
} }
func initDefinedApplication(application *Application) { func initDefinedApplication(application *Application) {
existed := getApplication(application.Owner, application.Name) existed, err := getApplication(application.Owner, application.Name)
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
application.CreatedTime = util.GetCurrentTime() application.CreatedTime = util.GetCurrentTime()
AddApplication(application) _, err = AddApplication(application)
if err != nil {
panic(err)
}
} }
func initDefinedUser(user *User) { func initDefinedUser(user *User) {
existed := getUser(user.Owner, user.Name) existed, err := getUser(user.Owner, user.Name)
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
user.CreatedTime = util.GetCurrentTime() user.CreatedTime = util.GetCurrentTime()
user.Id = util.GenerateId() user.Id = util.GenerateId()
user.Properties = make(map[string]string) user.Properties = make(map[string]string)
AddUser(user) _, err = AddUser(user)
if err != nil {
panic(err)
}
} }
func initDefinedCert(cert *Cert) { func initDefinedCert(cert *Cert) {
existed := getCert(cert.Owner, cert.Name) existed, err := getCert(cert.Owner, cert.Name)
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
cert.CreatedTime = util.GetCurrentTime() cert.CreatedTime = util.GetCurrentTime()
AddCert(cert) _, err = AddCert(cert)
if err != nil {
panic(err)
}
} }
func initDefinedLdap(ldap *Ldap) { func initDefinedLdap(ldap *Ldap) {
existed := GetLdap(ldap.Id) existed, err := GetLdap(ldap.Id)
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
AddLdap(ldap) _, err = AddLdap(ldap)
if err != nil {
panic(err)
}
} }
func initDefinedProvider(provider *Provider) { func initDefinedProvider(provider *Provider) {
existed := GetProvider(util.GetId("admin", provider.Name)) existed, err := GetProvider(util.GetId("admin", provider.Name))
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
AddProvider(provider) _, err = AddProvider(provider)
if err != nil {
panic(err)
}
} }
func initDefinedModel(model *Model) { func initDefinedModel(model *Model) {
existed := GetModel(model.GetId()) existed, err := GetModel(model.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
model.CreatedTime = util.GetCurrentTime() model.CreatedTime = util.GetCurrentTime()
AddModel(model) _, err = AddModel(model)
if err != nil {
panic(err)
}
} }
func initDefinedPermission(permission *Permission) { func initDefinedPermission(permission *Permission) {
existed := GetPermission(permission.GetId()) existed, err := GetPermission(permission.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
permission.CreatedTime = util.GetCurrentTime() permission.CreatedTime = util.GetCurrentTime()
AddPermission(permission) _, err = AddPermission(permission)
if err != nil {
panic(err)
}
} }
func initDefinedPayment(payment *Payment) { func initDefinedPayment(payment *Payment) {
existed := GetPayment(payment.GetId()) existed, err := GetPayment(payment.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
payment.CreatedTime = util.GetCurrentTime() payment.CreatedTime = util.GetCurrentTime()
AddPayment(payment) _, err = AddPayment(payment)
if err != nil {
panic(err)
}
} }
func initDefinedProduct(product *Product) { func initDefinedProduct(product *Product) {
existed := GetProduct(product.GetId()) existed, err := GetProduct(product.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
product.CreatedTime = util.GetCurrentTime() product.CreatedTime = util.GetCurrentTime()
AddProduct(product) _, err = AddProduct(product)
if err != nil {
panic(err)
}
} }
func initDefinedResource(resource *Resource) { func initDefinedResource(resource *Resource) {
existed := GetResource(resource.GetId()) existed, err := GetResource(resource.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
resource.CreatedTime = util.GetCurrentTime() resource.CreatedTime = util.GetCurrentTime()
AddResource(resource) _, err = AddResource(resource)
if err != nil {
panic(err)
}
} }
func initDefinedRole(role *Role) { func initDefinedRole(role *Role) {
existed := GetRole(role.GetId()) existed, err := GetRole(role.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
role.CreatedTime = util.GetCurrentTime() role.CreatedTime = util.GetCurrentTime()
AddRole(role) _, err = AddRole(role)
if err != nil {
panic(err)
}
} }
func initDefinedSyncer(syncer *Syncer) { func initDefinedSyncer(syncer *Syncer) {
existed := GetSyncer(syncer.GetId()) existed, err := GetSyncer(syncer.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
syncer.CreatedTime = util.GetCurrentTime() syncer.CreatedTime = util.GetCurrentTime()
AddSyncer(syncer) _, err = AddSyncer(syncer)
if err != nil {
panic(err)
}
} }
func initDefinedToken(token *Token) { func initDefinedToken(token *Token) {
existed := GetToken(token.GetId()) existed, err := GetToken(token.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
token.CreatedTime = util.GetCurrentTime() token.CreatedTime = util.GetCurrentTime()
AddToken(token) _, err = AddToken(token)
if err != nil {
panic(err)
}
} }
func initDefinedWebhook(webhook *Webhook) { func initDefinedWebhook(webhook *Webhook) {
existed := GetWebhook(webhook.GetId()) existed, err := GetWebhook(webhook.GetId())
if err != nil {
panic(err)
}
if existed != nil { if existed != nil {
return return
} }
webhook.CreatedTime = util.GetCurrentTime() webhook.CreatedTime = util.GetCurrentTime()
AddWebhook(webhook) _, err = AddWebhook(webhook)
if err != nil {
panic(err)
}
} }

View File

@ -37,7 +37,7 @@ type Ldap struct {
LastSync string `xorm:"varchar(100)" json:"lastSync"` LastSync string `xorm:"varchar(100)" json:"lastSync"`
} }
func AddLdap(ldap *Ldap) bool { func AddLdap(ldap *Ldap) (bool, error) {
if len(ldap.Id) == 0 { if len(ldap.Id) == 0 {
ldap.Id = util.GenerateId() ldap.Id = util.GenerateId()
} }
@ -48,13 +48,13 @@ func AddLdap(ldap *Ldap) bool {
affected, err := adapter.Engine.Insert(ldap) affected, err := adapter.Engine.Insert(ldap)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func CheckLdapExist(ldap *Ldap) bool { func CheckLdapExist(ldap *Ldap) (bool, error) {
var result []*Ldap var result []*Ldap
err := adapter.Engine.Find(&result, &Ldap{ err := adapter.Engine.Find(&result, &Ldap{
Owner: ldap.Owner, Owner: ldap.Owner,
@ -65,63 +65,65 @@ func CheckLdapExist(ldap *Ldap) bool {
BaseDn: ldap.BaseDn, BaseDn: ldap.BaseDn,
}) })
if err != nil { if err != nil {
panic(err) return false, err
} }
if len(result) > 0 { if len(result) > 0 {
return true return true, nil
} }
return false return false, nil
} }
func GetLdaps(owner string) []*Ldap { func GetLdaps(owner string) ([]*Ldap, error) {
var ldaps []*Ldap var ldaps []*Ldap
err := adapter.Engine.Desc("created_time").Find(&ldaps, &Ldap{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&ldaps, &Ldap{Owner: owner})
if err != nil { if err != nil {
panic(err) return ldaps, err
} }
return ldaps return ldaps, nil
} }
func GetLdap(id string) *Ldap { func GetLdap(id string) (*Ldap, error) {
if util.IsStringsEmpty(id) { if util.IsStringsEmpty(id) {
return nil return nil, nil
} }
ldap := Ldap{Id: id} ldap := Ldap{Id: id}
existed, err := adapter.Engine.Get(&ldap) existed, err := adapter.Engine.Get(&ldap)
if err != nil { if err != nil {
panic(err) return &ldap, nil
} }
if existed { if existed {
return &ldap return &ldap, nil
} else { } else {
return nil return nil, nil
} }
} }
func UpdateLdap(ldap *Ldap) bool { func UpdateLdap(ldap *Ldap) (bool, error) {
if GetLdap(ldap.Id) == nil { if l, err := GetLdap(ldap.Id); err != nil {
return false return false, nil
} else if l == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(ldap.Id).Cols("owner", "server_name", "host", affected, err := adapter.Engine.ID(ldap.Id).Cols("owner", "server_name", "host",
"port", "enable_ssl", "username", "password", "base_dn", "filter", "filter_fields", "auto_sync").Update(ldap) "port", "enable_ssl", "username", "password", "base_dn", "filter", "filter_fields", "auto_sync").Update(ldap)
if err != nil { if err != nil {
panic(err) return false, nil
} }
return affected != 0 return affected != 0, nil
} }
func DeleteLdap(ldap *Ldap) bool { func DeleteLdap(ldap *Ldap) (bool, error) {
affected, err := adapter.Engine.ID(ldap.Id).Delete(&Ldap{}) affected, err := adapter.Engine.ID(ldap.Id).Delete(&Ldap{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }

View File

@ -18,7 +18,10 @@ var globalLdapAutoSynchronizer *LdapAutoSynchronizer
func InitLdapAutoSynchronizer() { func InitLdapAutoSynchronizer() {
globalLdapAutoSynchronizer = NewLdapAutoSynchronizer() globalLdapAutoSynchronizer = NewLdapAutoSynchronizer()
globalLdapAutoSynchronizer.LdapAutoSynchronizerStartUpAll() err := globalLdapAutoSynchronizer.LdapAutoSynchronizerStartUpAll()
if err != nil {
panic(err)
}
} }
func NewLdapAutoSynchronizer() *LdapAutoSynchronizer { func NewLdapAutoSynchronizer() *LdapAutoSynchronizer {
@ -37,7 +40,11 @@ func (l *LdapAutoSynchronizer) StartAutoSync(ldapId string) error {
l.Lock() l.Lock()
defer l.Unlock() defer l.Unlock()
ldap := GetLdap(ldapId) ldap, err := GetLdap(ldapId)
if err != nil {
return err
}
if ldap == nil { if ldap == nil {
return fmt.Errorf("ldap %s doesn't exist", ldapId) return fmt.Errorf("ldap %s doesn't exist", ldapId)
} }
@ -49,7 +56,12 @@ func (l *LdapAutoSynchronizer) StartAutoSync(ldapId string) error {
stopChan := make(chan struct{}) stopChan := make(chan struct{})
l.ldapIdToStopChan[ldapId] = stopChan l.ldapIdToStopChan[ldapId] = stopChan
logs.Info(fmt.Sprintf("autoSync started for %s", ldap.Id)) logs.Info(fmt.Sprintf("autoSync started for %s", ldap.Id))
util.SafeGoroutine(func() { l.syncRoutine(ldap, stopChan) }) util.SafeGoroutine(func() {
err := l.syncRoutine(ldap, stopChan)
if err != nil {
panic(err)
}
})
return nil return nil
} }
@ -63,18 +75,22 @@ func (l *LdapAutoSynchronizer) StopAutoSync(ldapId string) {
} }
// autosync goroutine // autosync goroutine
func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) { func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) error {
ticker := time.NewTicker(time.Duration(ldap.AutoSync) * time.Minute) ticker := time.NewTicker(time.Duration(ldap.AutoSync) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-stopChan: case <-stopChan:
logs.Info(fmt.Sprintf("autoSync goroutine for %s stopped", ldap.Id)) logs.Info(fmt.Sprintf("autoSync goroutine for %s stopped", ldap.Id))
return return nil
case <-ticker.C: case <-ticker.C:
} }
UpdateLdapSyncTime(ldap.Id) err := UpdateLdapSyncTime(ldap.Id)
if err != nil {
return err
}
// fetch all users // fetch all users
conn, err := ldap.GetLdapConn() conn, err := ldap.GetLdapConn()
if err != nil { if err != nil {
@ -100,24 +116,35 @@ func (l *LdapAutoSynchronizer) syncRoutine(ldap *Ldap, stopChan chan struct{}) {
// LdapAutoSynchronizerStartUpAll // LdapAutoSynchronizerStartUpAll
// start all autosync goroutine for existing ldap servers in each organizations // start all autosync goroutine for existing ldap servers in each organizations
func (l *LdapAutoSynchronizer) LdapAutoSynchronizerStartUpAll() { func (l *LdapAutoSynchronizer) LdapAutoSynchronizerStartUpAll() error {
organizations := []*Organization{} organizations := []*Organization{}
err := adapter.Engine.Desc("created_time").Find(&organizations) err := adapter.Engine.Desc("created_time").Find(&organizations)
if err != nil { if err != nil {
logs.Info("failed to Star up LdapAutoSynchronizer; ") logs.Info("failed to Star up LdapAutoSynchronizer; ")
} }
for _, org := range organizations { for _, org := range organizations {
for _, ldap := range GetLdaps(org.Name) { ldaps, err := GetLdaps(org.Name)
if err != nil {
return err
}
for _, ldap := range ldaps {
if ldap.AutoSync != 0 { if ldap.AutoSync != 0 {
l.StartAutoSync(ldap.Id) err = l.StartAutoSync(ldap.Id)
if err != nil {
return err
}
} }
} }
} }
return nil
} }
func UpdateLdapSyncTime(ldapId string) { func UpdateLdapSyncTime(ldapId string) error {
_, err := adapter.Engine.ID(ldapId).Update(&Ldap{LastSync: util.GetCurrentTime()}) _, err := adapter.Engine.ID(ldapId).Update(&Ldap{LastSync: util.GetCurrentTime()})
if err != nil { if err != nil {
panic(err) return err
} }
return nil
} }

View File

@ -255,8 +255,12 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
uuids = append(uuids, user.Uuid) uuids = append(uuids, user.Uuid)
} }
organization := getOrganization("admin", owner) organization, err := getOrganization("admin", owner)
ldap := GetLdap(ldapId) if err != nil {
panic(err)
}
ldap, err := GetLdap(ldapId)
var dc []string var dc []string
for _, basedn := range strings.Split(ldap.BaseDn, ",") { for _, basedn := range strings.Split(ldap.BaseDn, ",") {
@ -275,7 +279,11 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
tag := strings.Join(ou, ".") tag := strings.Join(ou, ".")
for _, syncUser := range syncUsers { for _, syncUser := range syncUsers {
existUuids := GetExistUuids(owner, uuids) existUuids, err := GetExistUuids(owner, uuids)
if err != nil {
return nil, nil, err
}
found := false found := false
if len(existUuids) > 0 { if len(existUuids) > 0 {
for _, existUuid := range existUuids { for _, existUuid := range existUuids {
@ -287,10 +295,19 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
} }
if !found { if !found {
score, _ := organization.GetInitScore() score, err := organization.GetInitScore()
if err != nil {
return nil, nil, err
}
name, err := syncUser.buildLdapUserName()
if err != nil {
return nil, nil, err
}
newUser := &User{ newUser := &User{
Owner: owner, Owner: owner,
Name: syncUser.buildLdapUserName(), Name: name,
CreatedTime: util.GetCurrentTime(), CreatedTime: util.GetCurrentTime(),
DisplayName: syncUser.buildLdapDisplayName(), DisplayName: syncUser.buildLdapDisplayName(),
Avatar: organization.DefaultAvatar, Avatar: organization.DefaultAvatar,
@ -303,7 +320,11 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
Ldap: syncUser.Uuid, Ldap: syncUser.Uuid,
} }
affected := AddUser(newUser) affected, err := AddUser(newUser)
if err != nil {
return nil, nil, err
}
if !affected { if !affected {
failedUsers = append(failedUsers, syncUser) failedUsers = append(failedUsers, syncUser)
continue continue
@ -314,38 +335,38 @@ func SyncLdapUsers(owner string, syncUsers []LdapUser, ldapId string) (existUser
return existUsers, failedUsers, err return existUsers, failedUsers, err
} }
func GetExistUuids(owner string, uuids []string) []string { func GetExistUuids(owner string, uuids []string) ([]string, error) {
var existUuids []string var existUuids []string
err := adapter.Engine.Table("user").Where("owner = ?", owner).Cols("ldap"). err := adapter.Engine.Table("user").Where("owner = ?", owner).Cols("ldap").
In("ldap", uuids).Select("DISTINCT ldap").Find(&existUuids) In("ldap", uuids).Select("DISTINCT ldap").Find(&existUuids)
if err != nil { if err != nil {
panic(err) return existUuids, err
} }
return existUuids return existUuids, nil
} }
func (ldapUser *LdapUser) buildLdapUserName() string { func (ldapUser *LdapUser) buildLdapUserName() (string, error) {
user := User{} user := User{}
uidWithNumber := fmt.Sprintf("%s_%s", ldapUser.Uid, ldapUser.UidNumber) uidWithNumber := fmt.Sprintf("%s_%s", ldapUser.Uid, ldapUser.UidNumber)
has, err := adapter.Engine.Where("name = ? or name = ?", ldapUser.Uid, uidWithNumber).Get(&user) has, err := adapter.Engine.Where("name = ? or name = ?", ldapUser.Uid, uidWithNumber).Get(&user)
if err != nil { if err != nil {
panic(err) return "", err
} }
if has { if has {
if user.Name == ldapUser.Uid { if user.Name == ldapUser.Uid {
return uidWithNumber return uidWithNumber, nil
} }
return fmt.Sprintf("%s_%s", uidWithNumber, randstr.Hex(6)) return fmt.Sprintf("%s_%s", uidWithNumber, randstr.Hex(6)), nil
} }
if ldapUser.Uid != "" { if ldapUser.Uid != "" {
return ldapUser.Uid return ldapUser.Uid, nil
} }
return ldapUser.Cn return ldapUser.Cn, nil
} }
func (ldapUser *LdapUser) buildLdapDisplayName() string { func (ldapUser *LdapUser) buildLdapDisplayName() string {

View File

@ -48,109 +48,94 @@ func GetMaskedMessages(messages []*Message) []*Message {
return messages return messages
} }
func GetMessageCount(owner, organization, field, value string) int { func GetMessageCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Message{Organization: organization}) return session.Count(&Message{Organization: organization})
if err != nil {
panic(err)
}
return int(count)
} }
func GetMessages(owner string) []*Message { func GetMessages(owner string) ([]*Message, error) {
messages := []*Message{} messages := []*Message{}
err := adapter.Engine.Desc("created_time").Find(&messages, &Message{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&messages, &Message{Owner: owner})
if err != nil { return messages, err
panic(err)
}
return messages
} }
func GetChatMessages(chat string) []*Message { func GetChatMessages(chat string) ([]*Message, error) {
messages := []*Message{} messages := []*Message{}
err := adapter.Engine.Asc("created_time").Find(&messages, &Message{Chat: chat}) err := adapter.Engine.Asc("created_time").Find(&messages, &Message{Chat: chat})
if err != nil { return messages, err
panic(err)
}
return messages
} }
func GetPaginationMessages(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Message { func GetPaginationMessages(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Message, error) {
messages := []*Message{} messages := []*Message{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&messages, &Message{Organization: organization}) err := session.Find(&messages, &Message{Organization: organization})
if err != nil { return messages, err
panic(err)
}
return messages
} }
func getMessage(owner string, name string) *Message { func getMessage(owner string, name string) (*Message, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
message := Message{Owner: owner, Name: name} message := Message{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&message) existed, err := adapter.Engine.Get(&message)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &message return &message, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetMessage(id string) *Message { func GetMessage(id string) (*Message, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getMessage(owner, name) return getMessage(owner, name)
} }
func UpdateMessage(id string, message *Message) bool { func UpdateMessage(id string, message *Message) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getMessage(owner, name) == nil { if m, err := getMessage(owner, name); err != nil {
return false return false, err
} else if m == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(message) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(message)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddMessage(message *Message) bool { func AddMessage(message *Message) (bool, error) {
affected, err := adapter.Engine.Insert(message) affected, err := adapter.Engine.Insert(message)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteMessage(message *Message) bool { func DeleteMessage(message *Message) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{message.Owner, message.Name}).Delete(&Message{}) affected, err := adapter.Engine.ID(core.PK{message.Owner, message.Name}).Delete(&Message{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteChatMessages(chat string) bool { func DeleteChatMessages(chat string) (bool, error) {
affected, err := adapter.Engine.Delete(&Message{Chat: chat}) affected, err := adapter.Engine.Delete(&Message{Chat: chat})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (p *Message) GetId() string { func (p *Message) GetId() string {

View File

@ -83,7 +83,11 @@ func RecoverTfs(user *User, recoveryCode string) error {
return fmt.Errorf("recovery code not found") return fmt.Errorf("recovery code not found")
} }
affected := UpdateUser(user.GetId(), user, []string{"two_factor_auth"}, user.IsAdminUser()) affected, err := UpdateUser(user.GetId(), user, []string{"two_factor_auth"}, user.IsAdminUser())
if err != nil {
return err
}
if !affected { if !affected {
return fmt.Errorf("") return fmt.Errorf("")
} }

View File

@ -100,7 +100,11 @@ func (mfa *SmsMfa) Enable(ctx *context.Context, user *User) error {
} }
user.MultiFactorAuths = append(user.MultiFactorAuths, mfa.Config) user.MultiFactorAuths = append(user.MultiFactorAuths, mfa.Config)
affected := UpdateUser(user.GetId(), user, []string{"multi_factor_auths"}, user.IsAdminUser()) affected, err := UpdateUser(user.GetId(), user, []string{"multi_factor_auths"}, user.IsAdminUser())
if err != nil {
return err
}
if !affected { if !affected {
return fmt.Errorf("failed to enable two factor authentication") return fmt.Errorf("failed to enable two factor authentication")
} }

View File

@ -44,5 +44,8 @@ func DoMigration() {
} }
m := migrate.New(adapter.Engine, options, migrations) m := migrate.New(adapter.Engine, options, migrations)
m.Migrate() err := m.Migrate()
if err != nil {
panic(err)
}
} }

View File

@ -32,56 +32,51 @@ type Model struct {
IsEnabled bool `json:"isEnabled"` IsEnabled bool `json:"isEnabled"`
} }
func GetModelCount(owner, field, value string) int { func GetModelCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Model{}) return session.Count(&Model{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetModels(owner string) []*Model { func GetModels(owner string) ([]*Model, error) {
models := []*Model{} models := []*Model{}
err := adapter.Engine.Desc("created_time").Find(&models, &Model{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&models, &Model{Owner: owner})
if err != nil { if err != nil {
panic(err) return models, err
} }
return models return models, nil
} }
func GetPaginationModels(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Model { func GetPaginationModels(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Model, error) {
models := []*Model{} models := []*Model{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&models) err := session.Find(&models)
if err != nil { if err != nil {
panic(err) return models, err
} }
return models return models, nil
} }
func getModel(owner string, name string) *Model { func getModel(owner string, name string) (*Model, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
m := Model{Owner: owner, Name: name} m := Model{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&m) existed, err := adapter.Engine.Get(&m)
if err != nil { if err != nil {
panic(err) return &m, err
} }
if existed { if existed {
return &m return &m, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetModel(id string) *Model { func GetModel(id string) (*Model, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getModel(owner, name) return getModel(owner, name)
} }
@ -92,48 +87,53 @@ func UpdateModelWithCheck(id string, modelObj *Model) error {
if err != nil { if err != nil {
return err return err
} }
UpdateModel(id, modelObj) _, err = UpdateModel(id, modelObj)
if err != nil {
return err
}
return nil return nil
} }
func UpdateModel(id string, modelObj *Model) bool { func UpdateModel(id string, modelObj *Model) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getModel(owner, name) == nil { if m, err := getModel(owner, name); err != nil {
return false return false, err
} else if m == nil {
return false, nil
} }
if name != modelObj.Name { if name != modelObj.Name {
err := modelChangeTrigger(name, modelObj.Name) err := modelChangeTrigger(name, modelObj.Name)
if err != nil { if err != nil {
return false return false, err
} }
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(modelObj) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(modelObj)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, err
} }
func AddModel(model *Model) bool { func AddModel(model *Model) (bool, error) {
affected, err := adapter.Engine.Insert(model) affected, err := adapter.Engine.Insert(model)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteModel(model *Model) bool { func DeleteModel(model *Model) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{model.Owner, model.Name}).Delete(&Model{}) affected, err := adapter.Engine.ID(core.PK{model.Owner, model.Name}).Delete(&Model{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (model *Model) GetId() string { func (model *Model) GetId() string {

View File

@ -110,8 +110,12 @@ func GetOidcDiscovery(host string) OidcDiscovery {
} }
func GetJsonWebKeySet() (jose.JSONWebKeySet, error) { func GetJsonWebKeySet() (jose.JSONWebKeySet, error) {
certs := GetCerts("admin")
jwks := jose.JSONWebKeySet{} jwks := jose.JSONWebKeySet{}
certs, err := GetCerts("admin")
if err != nil {
return jwks, err
}
// follows the protocol rfc 7517(draft) // follows the protocol rfc 7517(draft)
// link here: https://self-issued.info/docs/draft-ietf-jose-json-web-key.html // link here: https://self-issued.info/docs/draft-ietf-jose-json-web-key.html
// or https://datatracker.ietf.org/doc/html/draft-ietf-jose-json-web-key // or https://datatracker.ietf.org/doc/html/draft-ietf-jose-json-web-key

View File

@ -70,92 +70,102 @@ type Organization struct {
AccountItems []*AccountItem `xorm:"varchar(3000)" json:"accountItems"` AccountItems []*AccountItem `xorm:"varchar(3000)" json:"accountItems"`
} }
func GetOrganizationCount(owner, field, value string) int { func GetOrganizationCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Organization{}) return session.Count(&Organization{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetOrganizations(owner string) []*Organization { func GetOrganizations(owner string) ([]*Organization, error) {
organizations := []*Organization{} organizations := []*Organization{}
err := adapter.Engine.Desc("created_time").Find(&organizations, &Organization{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&organizations, &Organization{Owner: owner})
if err != nil { if err != nil {
panic(err) return nil, err
} }
return organizations return organizations, nil
} }
func GetOrganizationsByFields(owner string, fields ...string) []*Organization { func GetOrganizationsByFields(owner string, fields ...string) ([]*Organization, error) {
organizations := []*Organization{} organizations := []*Organization{}
err := adapter.Engine.Desc("created_time").Cols(fields...).Find(&organizations, &Organization{Owner: owner}) err := adapter.Engine.Desc("created_time").Cols(fields...).Find(&organizations, &Organization{Owner: owner})
if err != nil { if err != nil {
panic(err) return nil, err
} }
return organizations return organizations, nil
} }
func GetPaginationOrganizations(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Organization { func GetPaginationOrganizations(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Organization, error) {
organizations := []*Organization{} organizations := []*Organization{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&organizations) err := session.Find(&organizations)
if err != nil { if err != nil {
panic(err) return nil, err
} }
return organizations return organizations, nil
} }
func getOrganization(owner string, name string) *Organization { func getOrganization(owner string, name string) (*Organization, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
organization := Organization{Owner: owner, Name: name} organization := Organization{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&organization) existed, err := adapter.Engine.Get(&organization)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &organization return &organization, nil
} }
return nil return nil, nil
} }
func GetOrganization(id string) *Organization { func GetOrganization(id string) (*Organization, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getOrganization(owner, name) return getOrganization(owner, name)
} }
func GetMaskedOrganization(organization *Organization) *Organization { func GetMaskedOrganization(organization *Organization, errs ...error) (*Organization, error) {
if len(errs) > 0 && errs[0] != nil {
return nil, errs[0]
}
if organization == nil { if organization == nil {
return nil return nil, nil
} }
if organization.MasterPassword != "" { if organization.MasterPassword != "" {
organization.MasterPassword = "***" organization.MasterPassword = "***"
} }
return organization return organization, nil
} }
func GetMaskedOrganizations(organizations []*Organization) []*Organization { func GetMaskedOrganizations(organizations []*Organization, errs ...error) ([]*Organization, error) {
for _, organization := range organizations { if len(errs) > 0 && errs[0] != nil {
organization = GetMaskedOrganization(organization) return nil, errs[0]
} }
return organizations
var err error
for _, organization := range organizations {
organization, err = GetMaskedOrganization(organization)
if err != nil {
return nil, err
}
}
return organizations, nil
} }
func UpdateOrganization(id string, organization *Organization) bool { func UpdateOrganization(id string, organization *Organization) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getOrganization(owner, name) == nil { if org, err := getOrganization(owner, name); err != nil {
return false return false, err
} else if org == nil {
return false, nil
} }
if name == "built-in" { if name == "built-in" {
@ -165,7 +175,7 @@ func UpdateOrganization(id string, organization *Organization) bool {
if name != organization.Name { if name != organization.Name {
err := organizationChangeTrigger(name, organization.Name) err := organizationChangeTrigger(name, organization.Name)
if err != nil { if err != nil {
return false return false, nil
} }
} }
@ -183,35 +193,35 @@ func UpdateOrganization(id string, organization *Organization) bool {
} }
affected, err := session.Update(organization) affected, err := session.Update(organization)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddOrganization(organization *Organization) bool { func AddOrganization(organization *Organization) (bool, error) {
affected, err := adapter.Engine.Insert(organization) affected, err := adapter.Engine.Insert(organization)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteOrganization(organization *Organization) bool { func DeleteOrganization(organization *Organization) (bool, error) {
if organization.Name == "built-in" { if organization.Name == "built-in" {
return false return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{organization.Owner, organization.Name}).Delete(&Organization{}) affected, err := adapter.Engine.ID(core.PK{organization.Owner, organization.Name}).Delete(&Organization{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func GetOrganizationByUser(user *User) *Organization { func GetOrganizationByUser(user *User) (*Organization, error) {
return getOrganization("admin", user.Owner) return getOrganization("admin", user.Owner)
} }
@ -248,13 +258,21 @@ func CheckAccountItemModifyRule(accountItem *AccountItem, isAdmin bool, lang str
} }
func GetDefaultApplication(id string) (*Application, error) { func GetDefaultApplication(id string) (*Application, error) {
organization := GetOrganization(id) organization, err := GetOrganization(id)
if err != nil {
return nil, err
}
if organization == nil { if organization == nil {
return nil, fmt.Errorf("The organization: %s does not exist", id) return nil, fmt.Errorf("The organization: %s does not exist", id)
} }
if organization.DefaultApplication != "" { if organization.DefaultApplication != "" {
defaultApplication := getApplication("admin", organization.DefaultApplication) defaultApplication, err := getApplication("admin", organization.DefaultApplication)
if err != nil {
return nil, err
}
if defaultApplication == nil { if defaultApplication == nil {
return nil, fmt.Errorf("The default application: %s does not exist", organization.DefaultApplication) return nil, fmt.Errorf("The default application: %s does not exist", organization.DefaultApplication)
} else { } else {
@ -263,9 +281,9 @@ func GetDefaultApplication(id string) (*Application, error) {
} }
applications := []*Application{} applications := []*Application{}
err := adapter.Engine.Asc("created_time").Find(&applications, &Application{Organization: organization.Name}) err = adapter.Engine.Asc("created_time").Find(&applications, &Application{Organization: organization.Name})
if err != nil { if err != nil {
panic(err) return nil, err
} }
if len(applications) == 0 { if len(applications) == 0 {
@ -280,8 +298,15 @@ func GetDefaultApplication(id string) (*Application, error) {
} }
} }
extendApplicationWithProviders(defaultApplication) err = extendApplicationWithProviders(defaultApplication)
extendApplicationWithOrg(defaultApplication) if err != nil {
return nil, err
}
err = extendApplicationWithOrg(defaultApplication)
if err != nil {
return nil, err
}
return defaultApplication, nil return defaultApplication, nil
} }

View File

@ -56,74 +56,71 @@ type Payment struct {
InvoiceUrl string `xorm:"varchar(255)" json:"invoiceUrl"` InvoiceUrl string `xorm:"varchar(255)" json:"invoiceUrl"`
} }
func GetPaymentCount(owner, field, value string) int { func GetPaymentCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Payment{}) return session.Count(&Payment{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetPayments(owner string) []*Payment { func GetPayments(owner string) ([]*Payment, error) {
payments := []*Payment{} payments := []*Payment{}
err := adapter.Engine.Desc("created_time").Find(&payments, &Payment{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&payments, &Payment{Owner: owner})
if err != nil { if err != nil {
panic(err) return nil, err
} }
return payments return payments, nil
} }
func GetUserPayments(owner string, organization string, user string) []*Payment { func GetUserPayments(owner string, organization string, user string) ([]*Payment, error) {
payments := []*Payment{} payments := []*Payment{}
err := adapter.Engine.Desc("created_time").Find(&payments, &Payment{Owner: owner, Organization: organization, User: user}) err := adapter.Engine.Desc("created_time").Find(&payments, &Payment{Owner: owner, Organization: organization, User: user})
if err != nil { if err != nil {
panic(err) return nil, err
} }
return payments return payments, nil
} }
func GetPaginationPayments(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Payment { func GetPaginationPayments(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Payment, error) {
payments := []*Payment{} payments := []*Payment{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&payments) err := session.Find(&payments)
if err != nil { if err != nil {
panic(err) return nil, err
} }
return payments return payments, nil
} }
func getPayment(owner string, name string) *Payment { func getPayment(owner string, name string) (*Payment, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
payment := Payment{Owner: owner, Name: name} payment := Payment{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&payment) existed, err := adapter.Engine.Get(&payment)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &payment return &payment, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetPayment(id string) *Payment { func GetPayment(id string) (*Payment, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getPayment(owner, name) return getPayment(owner, name)
} }
func UpdatePayment(id string, payment *Payment) bool { func UpdatePayment(id string, payment *Payment) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getPayment(owner, name) == nil { if p, err := getPayment(owner, name); err != nil {
return false return false, err
} else if p == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(payment) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(payment)
@ -131,42 +128,53 @@ func UpdatePayment(id string, payment *Payment) bool {
panic(err) panic(err)
} }
return affected != 0 return affected != 0, nil
} }
func AddPayment(payment *Payment) bool { func AddPayment(payment *Payment) (bool, error) {
affected, err := adapter.Engine.Insert(payment) affected, err := adapter.Engine.Insert(payment)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeletePayment(payment *Payment) bool { func DeletePayment(payment *Payment) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{payment.Owner, payment.Name}).Delete(&Payment{}) affected, err := adapter.Engine.ID(core.PK{payment.Owner, payment.Name}).Delete(&Payment{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func notifyPayment(request *http.Request, body []byte, owner string, providerName string, productName string, paymentName string) (*Payment, error, string) { func notifyPayment(request *http.Request, body []byte, owner string, providerName string, productName string, paymentName string) (*Payment, error, string) {
provider := getProvider(owner, providerName) provider, err := getProvider(owner, providerName)
if err != nil {
panic(err)
}
pProvider, cert, err := provider.getPaymentProvider() pProvider, cert, err := provider.getPaymentProvider()
if err != nil { if err != nil {
panic(err) panic(err)
} }
payment := getPayment(owner, paymentName) payment, err := getPayment(owner, paymentName)
if err != nil {
panic(err)
}
if payment == nil { if payment == nil {
err = fmt.Errorf("the payment: %s does not exist", paymentName) err = fmt.Errorf("the payment: %s does not exist", paymentName)
return nil, err, pProvider.GetResponseError(err) return nil, err, pProvider.GetResponseError(err)
} }
product := getProduct(owner, productName) product, err := getProduct(owner, productName)
if err != nil {
panic(err)
}
if product == nil { if product == nil {
err = fmt.Errorf("the product: %s does not exist", productName) err = fmt.Errorf("the product: %s does not exist", productName)
return payment, err, pProvider.GetResponseError(err) return payment, err, pProvider.GetResponseError(err)
@ -201,14 +209,21 @@ func NotifyPayment(request *http.Request, body []byte, owner string, providerNam
payment.State = "Paid" payment.State = "Paid"
} }
UpdatePayment(payment.GetId(), payment) _, err = UpdatePayment(payment.GetId(), payment)
if err != nil {
panic(err)
}
} }
return err, errorResponse return err, errorResponse
} }
func invoicePayment(payment *Payment) (string, error) { func invoicePayment(payment *Payment) (string, error) {
provider := getProvider(payment.Owner, payment.Provider) provider, err := getProvider(payment.Owner, payment.Provider)
if err != nil {
panic(err)
}
if provider == nil { if provider == nil {
return "", fmt.Errorf("the payment provider: %s does not exist", payment.Provider) return "", fmt.Errorf("the payment provider: %s does not exist", payment.Provider)
} }
@ -237,7 +252,11 @@ func InvoicePayment(payment *Payment) (string, error) {
} }
payment.InvoiceUrl = invoiceUrl payment.InvoiceUrl = invoiceUrl
affected := UpdatePayment(payment.GetId(), payment) affected, err := UpdatePayment(payment.GetId(), payment)
if err != nil {
return "", err
}
if !affected { if !affected {
return "", fmt.Errorf("failed to update the payment: %s", payment.Name) return "", fmt.Errorf("failed to update the payment: %s", payment.Name)
} }

View File

@ -66,96 +66,97 @@ func (p *Permission) GetId() string {
return util.GetId(p.Owner, p.Name) return util.GetId(p.Owner, p.Name)
} }
func GetPermissionCount(owner, field, value string) int { func GetPermissionCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Permission{}) return session.Count(&Permission{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetPermissions(owner string) []*Permission { func GetPermissions(owner string) ([]*Permission, error) {
permissions := []*Permission{} permissions := []*Permission{}
err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner})
if err != nil { if err != nil {
panic(err) return permissions, err
} }
return permissions return permissions, nil
} }
func GetPaginationPermissions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Permission { func GetPaginationPermissions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Permission, error) {
permissions := []*Permission{} permissions := []*Permission{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&permissions) err := session.Find(&permissions)
if err != nil { if err != nil {
panic(err) return permissions, err
} }
return permissions return permissions, nil
} }
func getPermission(owner string, name string) *Permission { func getPermission(owner string, name string) (*Permission, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
permission := Permission{Owner: owner, Name: name} permission := Permission{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&permission) existed, err := adapter.Engine.Get(&permission)
if err != nil { if err != nil {
panic(err) return &permission, err
} }
if existed { if existed {
return &permission return &permission, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetPermission(id string) *Permission { func GetPermission(id string) (*Permission, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getPermission(owner, name) return getPermission(owner, name)
} }
// checkPermissionValid verifies if the permission is valid // checkPermissionValid verifies if the permission is valid
func checkPermissionValid(permission *Permission) { func checkPermissionValid(permission *Permission) error {
enforcer := getEnforcer(permission) enforcer := getEnforcer(permission)
enforcer.EnableAutoSave(false) enforcer.EnableAutoSave(false)
policies := getPolicies(permission) policies := getPolicies(permission)
_, err := enforcer.AddPolicies(policies) _, err := enforcer.AddPolicies(policies)
if err != nil { if err != nil {
panic(err) return err
} }
if !HasRoleDefinition(enforcer.GetModel()) { if !HasRoleDefinition(enforcer.GetModel()) {
permission.Roles = []string{} permission.Roles = []string{}
return return nil
} }
groupingPolicies := getGroupingPolicies(permission) groupingPolicies := getGroupingPolicies(permission)
if len(groupingPolicies) > 0 { if len(groupingPolicies) > 0 {
_, err := enforcer.AddGroupingPolicies(groupingPolicies) _, err := enforcer.AddGroupingPolicies(groupingPolicies)
if err != nil { if err != nil {
panic(err) return err
} }
} }
return nil
} }
func UpdatePermission(id string, permission *Permission) bool { func UpdatePermission(id string, permission *Permission) (bool, error) {
checkPermissionValid(permission) err := checkPermissionValid(permission)
if err != nil {
return false, err
}
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
oldPermission := getPermission(owner, name) oldPermission, err := getPermission(owner, name)
if oldPermission == nil { if oldPermission == nil {
return false return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(permission) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(permission)
if err != nil { if err != nil {
panic(err) return false, err
} }
if affected != 0 { if affected != 0 {
@ -166,7 +167,7 @@ func UpdatePermission(id string, permission *Permission) bool {
if isEmpty { if isEmpty {
err = adapter.Engine.DropTables(oldPermission.Adapter) err = adapter.Engine.DropTables(oldPermission.Adapter)
if err != nil { if err != nil {
panic(err) return false, err
} }
} }
} }
@ -174,13 +175,13 @@ func UpdatePermission(id string, permission *Permission) bool {
addPolicies(permission) addPolicies(permission)
} }
return affected != 0 return affected != 0, nil
} }
func AddPermission(permission *Permission) bool { func AddPermission(permission *Permission) (bool, error) {
affected, err := adapter.Engine.Insert(permission) affected, err := adapter.Engine.Insert(permission)
if err != nil { if err != nil {
panic(err) return false, err
} }
if affected != 0 { if affected != 0 {
@ -188,7 +189,7 @@ func AddPermission(permission *Permission) bool {
addPolicies(permission) addPolicies(permission)
} }
return affected != 0 return affected != 0, nil
} }
func AddPermissions(permissions []*Permission) bool { func AddPermissions(permissions []*Permission) bool {
@ -239,10 +240,10 @@ func AddPermissionsInBatch(permissions []*Permission) bool {
return affected return affected
} }
func DeletePermission(permission *Permission) bool { func DeletePermission(permission *Permission) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{permission.Owner, permission.Name}).Delete(&Permission{}) affected, err := adapter.Engine.ID(core.PK{permission.Owner, permission.Name}).Delete(&Permission{})
if err != nil { if err != nil {
panic(err) return false, err
} }
if affected != 0 { if affected != 0 {
@ -253,67 +254,67 @@ func DeletePermission(permission *Permission) bool {
if isEmpty { if isEmpty {
err = adapter.Engine.DropTables(permission.Adapter) err = adapter.Engine.DropTables(permission.Adapter)
if err != nil { if err != nil {
panic(err) return false, err
} }
} }
} }
} }
return affected != 0 return affected != 0, nil
} }
func GetPermissionsByUser(userId string) []*Permission { func GetPermissionsByUser(userId string) ([]*Permission, error) {
permissions := []*Permission{} permissions := []*Permission{}
err := adapter.Engine.Where("users like ?", "%"+userId+"\"%").Find(&permissions) err := adapter.Engine.Where("users like ?", "%"+userId+"\"%").Find(&permissions)
if err != nil { if err != nil {
panic(err) return permissions, err
} }
for i := range permissions { for i := range permissions {
permissions[i].Users = nil permissions[i].Users = nil
} }
return permissions return permissions, nil
} }
func GetPermissionsByRole(roleId string) []*Permission { func GetPermissionsByRole(roleId string) ([]*Permission, error) {
permissions := []*Permission{} permissions := []*Permission{}
err := adapter.Engine.Where("roles like ?", "%"+roleId+"\"%").Find(&permissions) err := adapter.Engine.Where("roles like ?", "%"+roleId+"\"%").Find(&permissions)
if err != nil { if err != nil {
panic(err) return permissions, err
} }
return permissions return permissions, nil
} }
func GetPermissionsByResource(resourceId string) []*Permission { func GetPermissionsByResource(resourceId string) ([]*Permission, error) {
permissions := []*Permission{} permissions := []*Permission{}
err := adapter.Engine.Where("resources like ?", "%"+resourceId+"\"%").Find(&permissions) err := adapter.Engine.Where("resources like ?", "%"+resourceId+"\"%").Find(&permissions)
if err != nil { if err != nil {
panic(err) return permissions, err
} }
return permissions return permissions, nil
} }
func GetPermissionsBySubmitter(owner string, submitter string) []*Permission { func GetPermissionsBySubmitter(owner string, submitter string) ([]*Permission, error) {
permissions := []*Permission{} permissions := []*Permission{}
err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner, Submitter: submitter}) err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner, Submitter: submitter})
if err != nil { if err != nil {
panic(err) return permissions, err
} }
return permissions return permissions, nil
} }
func GetPermissionsByModel(owner string, model string) []*Permission { func GetPermissionsByModel(owner string, model string) ([]*Permission, error) {
permissions := []*Permission{} permissions := []*Permission{}
err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner, Model: model}) err := adapter.Engine.Desc("created_time").Find(&permissions, &Permission{Owner: owner, Model: model})
if err != nil { if err != nil {
panic(err) return permissions, err
} }
return permissions return permissions, nil
} }
func ContainsAsterisk(userId string, users []string) bool { func ContainsAsterisk(userId string, users []string) bool {

View File

@ -29,7 +29,11 @@ import (
func getEnforcer(permission *Permission) *casbin.Enforcer { func getEnforcer(permission *Permission) *casbin.Enforcer {
tableName := "permission_rule" tableName := "permission_rule"
if len(permission.Adapter) != 0 { if len(permission.Adapter) != 0 {
adapterObj := getCasbinAdapter(permission.Owner, permission.Adapter) adapterObj, err := getCasbinAdapter(permission.Owner, permission.Adapter)
if err != nil {
panic(err)
}
if adapterObj != nil && adapterObj.Table != "" { if adapterObj != nil && adapterObj.Table != "" {
tableName = adapterObj.Table tableName = adapterObj.Table
} }
@ -42,7 +46,11 @@ func getEnforcer(permission *Permission) *casbin.Enforcer {
panic(err) panic(err)
} }
permissionModel := getModel(permission.Owner, permission.Model) permissionModel, err := getModel(permission.Owner, permission.Model)
if err != nil {
panic(err)
}
m := model.Model{} m := model.Model{}
if permissionModel != nil { if permissionModel != nil {
m, err = GetBuiltInModel(permissionModel.ModelText) m, err = GetBuiltInModel(permissionModel.ModelText)
@ -122,21 +130,30 @@ func getPolicies(permission *Permission) [][]string {
return policies return policies
} }
func getRolesInRole(roleId string, visited map[string]struct{}) []*Role { func getRolesInRole(roleId string, visited map[string]struct{}) ([]*Role, error) {
role := GetRole(roleId) role, err := GetRole(roleId)
if err != nil {
return []*Role{}, err
}
if role == nil { if role == nil {
return []*Role{} return []*Role{}, nil
} }
visited[roleId] = struct{}{} visited[roleId] = struct{}{}
roles := []*Role{role} roles := []*Role{role}
for _, subRole := range role.Roles { for _, subRole := range role.Roles {
if _, ok := visited[subRole]; !ok { if _, ok := visited[subRole]; !ok {
roles = append(roles, getRolesInRole(subRole, visited)...) r, err := getRolesInRole(subRole, visited)
if err != nil {
return []*Role{}, err
}
roles = append(roles, r...)
} }
} }
return roles return roles, nil
} }
func getGroupingPolicies(permission *Permission) [][]string { func getGroupingPolicies(permission *Permission) [][]string {
@ -147,8 +164,10 @@ func getGroupingPolicies(permission *Permission) [][]string {
for _, roleId := range permission.Roles { for _, roleId := range permission.Roles {
visited := map[string]struct{}{} visited := map[string]struct{}{}
rolesInRole := getRolesInRole(roleId, visited) rolesInRole, err := getRolesInRole(roleId, visited)
if err != nil {
panic(err)
}
for _, role := range rolesInRole { for _, role := range rolesInRole {
roleId := role.GetId() roleId := role.GetId()
for _, subUser := range role.Users { for _, subUser := range role.Users {
@ -223,7 +242,11 @@ func removePolicies(permission *Permission) {
type CasbinRequest = []interface{} type CasbinRequest = []interface{}
func Enforce(permissionId string, request *CasbinRequest) bool { func Enforce(permissionId string, request *CasbinRequest) bool {
permission := GetPermission(permissionId) permission, err := GetPermission(permissionId)
if err != nil {
panic(err)
}
enforcer := getEnforcer(permission) enforcer := getEnforcer(permission)
allow, err := enforcer.Enforce(*request...) allow, err := enforcer.Enforce(*request...)
@ -234,7 +257,11 @@ func Enforce(permissionId string, request *CasbinRequest) bool {
} }
func BatchEnforce(permissionId string, requests *[]CasbinRequest) []bool { func BatchEnforce(permissionId string, requests *[]CasbinRequest) []bool {
permission := GetPermission(permissionId) permission, err := GetPermission(permissionId)
if err != nil {
panic(err)
}
enforcer := getEnforcer(permission) enforcer := getEnforcer(permission)
allow, err := enforcer.BatchEnforce(*requests) allow, err := enforcer.BatchEnforce(*requests)
if err != nil { if err != nil {
@ -244,9 +271,18 @@ func BatchEnforce(permissionId string, requests *[]CasbinRequest) []bool {
} }
func getAllValues(userId string, fn func(enforcer *casbin.Enforcer) []string) []string { func getAllValues(userId string, fn func(enforcer *casbin.Enforcer) []string) []string {
permissions := GetPermissionsByUser(userId) permissions, err := GetPermissionsByUser(userId)
if err != nil {
panic(err)
}
for _, role := range GetAllRoles(userId) { for _, role := range GetAllRoles(userId) {
permissions = append(permissions, GetPermissionsByRole(role)...) permissionsByRole, err := GetPermissionsByRole(role)
if err != nil {
panic(err)
}
permissions = append(permissions, permissionsByRole...)
} }
var values []string var values []string
@ -270,7 +306,11 @@ func GetAllActions(userId string) []string {
} }
func GetAllRoles(userId string) []string { func GetAllRoles(userId string) []string {
roles := GetRolesByUser(userId) roles, err := GetRolesByUser(userId)
if err != nil {
panic(err)
}
var res []string var res []string
for _, role := range roles { for _, role := range roles {
res = append(res, role.Name) res = append(res, role.Name)

View File

@ -18,21 +18,29 @@ import (
"github.com/casdoor/casdoor/xlsx" "github.com/casdoor/casdoor/xlsx"
) )
func getPermissionMap(owner string) map[string]*Permission { func getPermissionMap(owner string) (map[string]*Permission, error) {
m := map[string]*Permission{} m := map[string]*Permission{}
permissions := GetPermissions(owner) permissions, err := GetPermissions(owner)
if err != nil {
return nil, err
}
for _, permission := range permissions { for _, permission := range permissions {
m[permission.GetId()] = permission m[permission.GetId()] = permission
} }
return m return m, err
} }
func UploadPermissions(owner string, fileId string) bool { func UploadPermissions(owner string, fileId string) (bool, error) {
table := xlsx.ReadXlsxFile(fileId) table := xlsx.ReadXlsxFile(fileId)
oldUserMap := getPermissionMap(owner) oldUserMap, err := getPermissionMap(owner)
if err != nil {
return false, err
}
newPermissions := []*Permission{} newPermissions := []*Permission{}
for index, line := range table { for index, line := range table {
if index == 0 || parseLineItem(&line, 0) == "" { if index == 0 || parseLineItem(&line, 0) == "" {
@ -71,7 +79,7 @@ func UploadPermissions(owner string, fileId string) bool {
} }
if len(newPermissions) == 0 { if len(newPermissions) == 0 {
return false return false, nil
} }
return AddPermissionsInBatch(newPermissions) return AddPermissionsInBatch(newPermissions), nil
} }

View File

@ -37,109 +37,115 @@ type Plan struct {
Options []string `xorm:"-" json:"options"` Options []string `xorm:"-" json:"options"`
} }
func GetPlanCount(owner, field, value string) int { func GetPlanCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Plan{}) return session.Count(&Plan{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetPlans(owner string) []*Plan { func GetPlans(owner string) ([]*Plan, error) {
plans := []*Plan{} plans := []*Plan{}
err := adapter.Engine.Desc("created_time").Find(&plans, &Plan{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&plans, &Plan{Owner: owner})
if err != nil { if err != nil {
panic(err) return plans, err
} }
return plans return plans, nil
} }
func GetPaginatedPlans(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Plan { func GetPaginatedPlans(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Plan, error) {
plans := []*Plan{} plans := []*Plan{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&plans) err := session.Find(&plans)
if err != nil { if err != nil {
panic(err) return plans, err
} }
return plans return plans, nil
} }
func getPlan(owner, name string) *Plan { func getPlan(owner, name string) (*Plan, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
plan := Plan{Owner: owner, Name: name} plan := Plan{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&plan) existed, err := adapter.Engine.Get(&plan)
if err != nil { if err != nil {
panic(err) return &plan, err
} }
if existed { if existed {
return &plan return &plan, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetPlan(id string) *Plan { func GetPlan(id string) (*Plan, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getPlan(owner, name) return getPlan(owner, name)
} }
func UpdatePlan(id string, plan *Plan) bool { func UpdatePlan(id string, plan *Plan) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getPlan(owner, name) == nil { if p, err := getPlan(owner, name); err != nil {
return false return false, err
} else if p == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(plan) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(plan)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddPlan(plan *Plan) bool { func AddPlan(plan *Plan) (bool, error) {
affected, err := adapter.Engine.Insert(plan) affected, err := adapter.Engine.Insert(plan)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeletePlan(plan *Plan) bool { func DeletePlan(plan *Plan) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{plan.Owner, plan.Name}).Delete(plan) affected, err := adapter.Engine.ID(core.PK{plan.Owner, plan.Name}).Delete(plan)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (plan *Plan) GetId() string { func (plan *Plan) GetId() string {
return fmt.Sprintf("%s/%s", plan.Owner, plan.Name) return fmt.Sprintf("%s/%s", plan.Owner, plan.Name)
} }
func Subscribe(owner string, user string, plan string, pricing string) *Subscription { func Subscribe(owner string, user string, plan string, pricing string) (*Subscription, error) {
selectedPricing := GetPricing(fmt.Sprintf("%s/%s", owner, pricing)) selectedPricing, err := GetPricing(fmt.Sprintf("%s/%s", owner, pricing))
if err != nil {
return nil, err
}
valid := selectedPricing != nil && selectedPricing.IsEnabled valid := selectedPricing != nil && selectedPricing.IsEnabled
if !valid { if !valid {
return nil return nil, nil
} }
planBelongToPricing := selectedPricing.HasPlan(owner, plan) planBelongToPricing, err := selectedPricing.HasPlan(owner, plan)
if err != nil {
return nil, err
}
if planBelongToPricing { if planBelongToPricing {
newSubscription := NewSubscription(owner, user, plan, selectedPricing.TrialDuration) newSubscription := NewSubscription(owner, user, plan, selectedPricing.TrialDuration)
affected := AddSubscription(newSubscription) affected, err := AddSubscription(newSubscription)
if err != nil {
return nil, err
}
if affected { if affected {
return newSubscription return newSubscription, nil
} }
} }
return nil return nil, nil
} }

View File

@ -42,96 +42,97 @@ type Pricing struct {
State string `xorm:"varchar(100)" json:"state"` State string `xorm:"varchar(100)" json:"state"`
} }
func GetPricingCount(owner, field, value string) int { func GetPricingCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Pricing{}) return session.Count(&Pricing{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetPricings(owner string) []*Pricing { func GetPricings(owner string) ([]*Pricing, error) {
pricings := []*Pricing{} pricings := []*Pricing{}
err := adapter.Engine.Desc("created_time").Find(&pricings, &Pricing{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&pricings, &Pricing{Owner: owner})
if err != nil { if err != nil {
panic(err) return pricings, err
} }
return pricings
return pricings, nil
} }
func GetPaginatedPricings(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Pricing { func GetPaginatedPricings(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Pricing, error) {
pricings := []*Pricing{} pricings := []*Pricing{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&pricings) err := session.Find(&pricings)
if err != nil { if err != nil {
panic(err) return pricings, err
} }
return pricings return pricings, nil
} }
func getPricing(owner, name string) *Pricing { func getPricing(owner, name string) (*Pricing, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
pricing := Pricing{Owner: owner, Name: name} pricing := Pricing{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&pricing) existed, err := adapter.Engine.Get(&pricing)
if err != nil { if err != nil {
panic(err) return &pricing, err
} }
if existed { if existed {
return &pricing return &pricing, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetPricing(id string) *Pricing { func GetPricing(id string) (*Pricing, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getPricing(owner, name) return getPricing(owner, name)
} }
func UpdatePricing(id string, pricing *Pricing) bool { func UpdatePricing(id string, pricing *Pricing) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getPricing(owner, name) == nil { if p, err := getPricing(owner, name); err != nil {
return false return false, err
} else if p == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(pricing) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(pricing)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddPricing(pricing *Pricing) bool { func AddPricing(pricing *Pricing) (bool, error) {
affected, err := adapter.Engine.Insert(pricing) affected, err := adapter.Engine.Insert(pricing)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeletePricing(pricing *Pricing) bool { func DeletePricing(pricing *Pricing) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{pricing.Owner, pricing.Name}).Delete(pricing) affected, err := adapter.Engine.ID(core.PK{pricing.Owner, pricing.Name}).Delete(pricing)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (pricing *Pricing) GetId() string { func (pricing *Pricing) GetId() string {
return fmt.Sprintf("%s/%s", pricing.Owner, pricing.Name) return fmt.Sprintf("%s/%s", pricing.Owner, pricing.Name)
} }
func (pricing *Pricing) HasPlan(owner string, plan string) bool { func (pricing *Pricing) HasPlan(owner string, plan string) (bool, error) {
selectedPlan := GetPlan(fmt.Sprintf("%s/%s", owner, plan)) selectedPlan, err := GetPlan(fmt.Sprintf("%s/%s", owner, plan))
if err != nil {
return false, err
}
if selectedPlan == nil { if selectedPlan == nil {
return false return false, nil
} }
result := false result := false
@ -143,5 +144,5 @@ func (pricing *Pricing) HasPlan(owner string, plan string) bool {
} }
} }
return result return result, nil
} }

View File

@ -43,90 +43,87 @@ type Product struct {
ProviderObjs []*Provider `xorm:"-" json:"providerObjs"` ProviderObjs []*Provider `xorm:"-" json:"providerObjs"`
} }
func GetProductCount(owner, field, value string) int { func GetProductCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Product{}) return session.Count(&Product{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetProducts(owner string) []*Product { func GetProducts(owner string) ([]*Product, error) {
products := []*Product{} products := []*Product{}
err := adapter.Engine.Desc("created_time").Find(&products, &Product{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&products, &Product{Owner: owner})
if err != nil { if err != nil {
panic(err) return products, err
} }
return products return products, nil
} }
func GetPaginationProducts(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Product { func GetPaginationProducts(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Product, error) {
products := []*Product{} products := []*Product{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&products) err := session.Find(&products)
if err != nil { if err != nil {
panic(err) return products, err
} }
return products return products, nil
} }
func getProduct(owner string, name string) *Product { func getProduct(owner string, name string) (*Product, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
product := Product{Owner: owner, Name: name} product := Product{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&product) existed, err := adapter.Engine.Get(&product)
if err != nil { if err != nil {
panic(err) return &product, nil
} }
if existed { if existed {
return &product return &product, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetProduct(id string) *Product { func GetProduct(id string) (*Product, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getProduct(owner, name) return getProduct(owner, name)
} }
func UpdateProduct(id string, product *Product) bool { func UpdateProduct(id string, product *Product) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getProduct(owner, name) == nil { if p, err := getProduct(owner, name); err != nil {
return false return false, err
} else if p == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(product) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(product)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddProduct(product *Product) bool { func AddProduct(product *Product) (bool, error) {
affected, err := adapter.Engine.Insert(product) affected, err := adapter.Engine.Insert(product)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteProduct(product *Product) bool { func DeleteProduct(product *Product) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{product.Owner, product.Name}).Delete(&Product{}) affected, err := adapter.Engine.ID(core.PK{product.Owner, product.Name}).Delete(&Product{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (product *Product) GetId() string { func (product *Product) GetId() string {
@ -143,7 +140,11 @@ func (product *Product) isValidProvider(provider *Provider) bool {
} }
func (product *Product) getProvider(providerId string) (*Provider, error) { func (product *Product) getProvider(providerId string) (*Provider, error) {
provider := getProvider(product.Owner, providerId) provider, err := getProvider(product.Owner, providerId)
if err != nil {
return nil, err
}
if provider == nil { if provider == nil {
return nil, fmt.Errorf("the payment provider: %s does not exist", providerId) return nil, fmt.Errorf("the payment provider: %s does not exist", providerId)
} }
@ -156,7 +157,11 @@ func (product *Product) getProvider(providerId string) (*Provider, error) {
} }
func BuyProduct(id string, providerName string, user *User, host string) (string, error) { func BuyProduct(id string, providerName string, user *User, host string) (string, error) {
product := GetProduct(id) product, err := GetProduct(id)
if err != nil {
return "", err
}
if product == nil { if product == nil {
return "", fmt.Errorf("the product: %s does not exist", id) return "", fmt.Errorf("the product: %s does not exist", id)
} }
@ -205,7 +210,11 @@ func BuyProduct(id string, providerName string, user *User, host string) (string
ReturnUrl: product.ReturnUrl, ReturnUrl: product.ReturnUrl,
State: "Created", State: "Created",
} }
affected := AddPayment(&payment) affected, err := AddPayment(&payment)
if err != nil {
return "", err
}
if !affected { if !affected {
return "", fmt.Errorf("failed to add payment: %s", util.StructToJson(payment)) return "", fmt.Errorf("failed to add payment: %s", util.StructToJson(payment))
} }
@ -213,17 +222,23 @@ func BuyProduct(id string, providerName string, user *User, host string) (string
return payUrl, err return payUrl, err
} }
func ExtendProductWithProviders(product *Product) { func ExtendProductWithProviders(product *Product) error {
if product == nil { if product == nil {
return return nil
} }
product.ProviderObjs = []*Provider{} product.ProviderObjs = []*Provider{}
m := getProviderMap(product.Owner) m, err := getProviderMap(product.Owner)
if err != nil {
return err
}
for _, providerItem := range product.Providers { for _, providerItem := range product.Providers {
if provider, ok := m[providerItem]; ok { if provider, ok := m[providerItem]; ok {
product.ProviderObjs = append(product.ProviderObjs, provider) product.ProviderObjs = append(product.ProviderObjs, provider)
} }
} }
return nil
} }

View File

@ -27,9 +27,9 @@ import (
func TestProduct(t *testing.T) { func TestProduct(t *testing.T) {
InitConfig() InitConfig()
product := GetProduct("admin/product_123") product, _ := GetProduct("admin/product_123")
provider := getProvider(product.Owner, "provider_pay_alipay") provider, _ := getProvider(product.Owner, "provider_pay_alipay")
cert := getCert(product.Owner, "cert-pay-alipay") cert, _ := getCert(product.Owner, "cert-pay-alipay")
pProvider, err := pp.GetPaymentProvider(provider.Type, provider.ClientId, provider.ClientSecret, provider.Host, cert.Certificate, cert.PrivateKey, cert.AuthorityPublicKey, cert.AuthorityRootPublicKey, provider.ClientId2) pProvider, err := pp.GetPaymentProvider(provider.Type, provider.ClientId, provider.ClientSecret, provider.Host, cert.Certificate, cert.PrivateKey, cert.AuthorityPublicKey, cert.AuthorityRootPublicKey, provider.ClientId2)
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -103,103 +103,93 @@ func GetMaskedProviders(providers []*Provider, isMaskEnabled bool) []*Provider {
return providers return providers
} }
func GetProviderCount(owner, field, value string) int { func GetProviderCount(owner, field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "") session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Provider{}) return session.Where("owner = ? or owner = ? ", "admin", owner).Count(&Provider{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetGlobalProviderCount(field, value string) int { func GetGlobalProviderCount(field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "") session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Count(&Provider{}) return session.Count(&Provider{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetProviders(owner string) []*Provider { func GetProviders(owner string) ([]*Provider, error) {
providers := []*Provider{} providers := []*Provider{}
err := adapter.Engine.Where("owner = ? or owner = ? ", "admin", owner).Desc("created_time").Find(&providers, &Provider{}) err := adapter.Engine.Where("owner = ? or owner = ? ", "admin", owner).Desc("created_time").Find(&providers, &Provider{})
if err != nil { if err != nil {
panic(err) return providers, err
} }
return providers return providers, nil
} }
func GetGlobalProviders() []*Provider { func GetGlobalProviders() ([]*Provider, error) {
providers := []*Provider{} providers := []*Provider{}
err := adapter.Engine.Desc("created_time").Find(&providers) err := adapter.Engine.Desc("created_time").Find(&providers)
if err != nil { if err != nil {
panic(err) return providers, err
} }
return providers return providers, nil
} }
func GetPaginationProviders(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Provider { func GetPaginationProviders(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Provider, error) {
providers := []*Provider{} providers := []*Provider{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder) session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Where("owner = ? or owner = ? ", "admin", owner).Find(&providers) err := session.Where("owner = ? or owner = ? ", "admin", owner).Find(&providers)
if err != nil { if err != nil {
panic(err) return providers, err
} }
return providers return providers, nil
} }
func GetPaginationGlobalProviders(offset, limit int, field, value, sortField, sortOrder string) []*Provider { func GetPaginationGlobalProviders(offset, limit int, field, value, sortField, sortOrder string) ([]*Provider, error) {
providers := []*Provider{} providers := []*Provider{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder) session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Find(&providers) err := session.Find(&providers)
if err != nil { if err != nil {
panic(err) return providers, err
} }
return providers return providers, nil
} }
func getProvider(owner string, name string) *Provider { func getProvider(owner string, name string) (*Provider, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
provider := Provider{Name: name} provider := Provider{Name: name}
existed, err := adapter.Engine.Get(&provider) existed, err := adapter.Engine.Get(&provider)
if err != nil { if err != nil {
panic(err) return &provider, err
} }
if existed { if existed {
return &provider return &provider, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetProvider(id string) *Provider { func GetProvider(id string) (*Provider, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getProvider(owner, name) return getProvider(owner, name)
} }
func getDefaultAiProvider() *Provider { func getDefaultAiProvider() (*Provider, error) {
provider := Provider{Owner: "admin", Category: "AI"} provider := Provider{Owner: "admin", Category: "AI"}
existed, err := adapter.Engine.Get(&provider) existed, err := adapter.Engine.Get(&provider)
if err != nil { if err != nil {
panic(err) return &provider, err
} }
if !existed { if !existed {
return nil return nil, nil
} }
return &provider return &provider, nil
} }
func GetWechatMiniProgramProvider(application *Application) *Provider { func GetWechatMiniProgramProvider(application *Application) *Provider {
@ -212,16 +202,18 @@ func GetWechatMiniProgramProvider(application *Application) *Provider {
return nil return nil
} }
func UpdateProvider(id string, provider *Provider) bool { func UpdateProvider(id string, provider *Provider) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getProvider(owner, name) == nil { if p, err := getProvider(owner, name); err != nil {
return false return false, err
} else if p == nil {
return false, nil
} }
if name != provider.Name { if name != provider.Name {
err := providerChangeTrigger(name, provider.Name) err := providerChangeTrigger(name, provider.Name)
if err != nil { if err != nil {
return false return false, nil
} }
} }
@ -238,37 +230,41 @@ func UpdateProvider(id string, provider *Provider) bool {
affected, err := session.Update(provider) affected, err := session.Update(provider)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddProvider(provider *Provider) bool { func AddProvider(provider *Provider) (bool, error) {
provider.Endpoint = util.GetEndPoint(provider.Endpoint) provider.Endpoint = util.GetEndPoint(provider.Endpoint)
provider.IntranetEndpoint = util.GetEndPoint(provider.IntranetEndpoint) provider.IntranetEndpoint = util.GetEndPoint(provider.IntranetEndpoint)
affected, err := adapter.Engine.Insert(provider) affected, err := adapter.Engine.Insert(provider)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteProvider(provider *Provider) bool { func DeleteProvider(provider *Provider) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{provider.Owner, provider.Name}).Delete(&Provider{}) affected, err := adapter.Engine.ID(core.PK{provider.Owner, provider.Name}).Delete(&Provider{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (p *Provider) getPaymentProvider() (pp.PaymentProvider, *Cert, error) { func (p *Provider) getPaymentProvider() (pp.PaymentProvider, *Cert, error) {
cert := &Cert{} cert := &Cert{}
if p.Cert != "" { if p.Cert != "" {
cert = getCert(p.Owner, p.Cert) cert, err := getCert(p.Owner, p.Cert)
if err != nil {
return nil, nil, err
}
if cert == nil { if cert == nil {
return nil, nil, fmt.Errorf("the cert: %s does not exist", p.Cert) return nil, nil, fmt.Errorf("the cert: %s does not exist", p.Cert)
} }
@ -309,7 +305,11 @@ func GetCaptchaProviderByApplication(applicationId, isCurrentProvider, lang stri
if isCurrentProvider == "true" { if isCurrentProvider == "true" {
return GetCaptchaProviderByOwnerName(applicationId, lang) return GetCaptchaProviderByOwnerName(applicationId, lang)
} }
application := GetApplication(applicationId) application, err := GetApplication(applicationId)
if err != nil {
return nil, err
}
if application == nil || len(application.Providers) == 0 { if application == nil || len(application.Providers) == 0 {
return nil, fmt.Errorf(i18n.Translate(lang, "provider:Invalid application id")) return nil, fmt.Errorf(i18n.Translate(lang, "provider:Invalid application id"))
} }

View File

@ -26,11 +26,7 @@ import (
var logPostOnly bool var logPostOnly bool
func init() { func init() {
var err error logPostOnly = conf.GetConfigBool("logPostOnly")
logPostOnly, err = conf.GetConfigBool("logPostOnly")
if err != nil {
// panic(err)
}
} }
type Record struct { type Record struct {
@ -108,49 +104,48 @@ func AddRecord(record *Record) bool {
return affected != 0 return affected != 0
} }
func GetRecordCount(field, value string, filterRecord *Record) int { func GetRecordCount(field, value string, filterRecord *Record) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "") session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Count(filterRecord) return session.Count(filterRecord)
if err != nil {
panic(err)
}
return int(count)
} }
func GetRecords() []*Record { func GetRecords() ([]*Record, error) {
records := []*Record{} records := []*Record{}
err := adapter.Engine.Desc("id").Find(&records) err := adapter.Engine.Desc("id").Find(&records)
if err != nil { if err != nil {
panic(err) return records, err
} }
return records return records, nil
} }
func GetPaginationRecords(offset, limit int, field, value, sortField, sortOrder string, filterRecord *Record) []*Record { func GetPaginationRecords(offset, limit int, field, value, sortField, sortOrder string, filterRecord *Record) ([]*Record, error) {
records := []*Record{} records := []*Record{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder) session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Find(&records, filterRecord) err := session.Find(&records, filterRecord)
if err != nil { if err != nil {
panic(err) return records, err
} }
return records return records, nil
} }
func GetRecordsByField(record *Record) []*Record { func GetRecordsByField(record *Record) ([]*Record, error) {
records := []*Record{} records := []*Record{}
err := adapter.Engine.Find(&records, record) err := adapter.Engine.Find(&records, record)
if err != nil { if err != nil {
panic(err) return records, err
} }
return records return records, nil
} }
func SendWebhooks(record *Record) error { func SendWebhooks(record *Record) error {
webhooks := getWebhooksByOrganization(record.Organization) webhooks, err := getWebhooksByOrganization(record.Organization)
if err != nil {
return err
}
for _, webhook := range webhooks { for _, webhook := range webhooks {
if !webhook.IsEnabled { if !webhook.IsEnabled {
continue continue
@ -166,7 +161,11 @@ func SendWebhooks(record *Record) error {
if matched { if matched {
if webhook.IsUserExtended { if webhook.IsUserExtended {
user := GetMaskedUser(getUser(record.Organization, record.User)) user, err := GetMaskedUser(getUser(record.Organization, record.User))
if err != nil {
return err
}
record.ExtendedUser = user record.ExtendedUser = user
} }

View File

@ -39,17 +39,12 @@ type Resource struct {
Description string `xorm:"varchar(255)" json:"description"` Description string `xorm:"varchar(255)" json:"description"`
} }
func GetResourceCount(owner, user, field, value string) int { func GetResourceCount(owner, user, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Resource{User: user}) return session.Count(&Resource{User: user})
if err != nil {
panic(err)
}
return int(count)
} }
func GetResources(owner string, user string) []*Resource { func GetResources(owner string, user string) ([]*Resource, error) {
if owner == "built-in" { if owner == "built-in" {
owner = "" owner = ""
user = "" user = ""
@ -58,13 +53,13 @@ func GetResources(owner string, user string) []*Resource {
resources := []*Resource{} resources := []*Resource{}
err := adapter.Engine.Desc("created_time").Find(&resources, &Resource{Owner: owner, User: user}) err := adapter.Engine.Desc("created_time").Find(&resources, &Resource{Owner: owner, User: user})
if err != nil { if err != nil {
panic(err) return resources, err
} }
return resources return resources, err
} }
func GetPaginationResources(owner, user string, offset, limit int, field, value, sortField, sortOrder string) []*Resource { func GetPaginationResources(owner, user string, offset, limit int, field, value, sortField, sortOrder string) ([]*Resource, error) {
if owner == "built-in" { if owner == "built-in" {
owner = "" owner = ""
user = "" user = ""
@ -74,70 +69,74 @@ func GetPaginationResources(owner, user string, offset, limit int, field, value,
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&resources, &Resource{User: user}) err := session.Find(&resources, &Resource{User: user})
if err != nil { if err != nil {
panic(err) return resources, err
} }
return resources return resources, nil
} }
func getResource(owner string, name string) *Resource { func getResource(owner string, name string) (*Resource, error) {
resource := Resource{Owner: owner, Name: name} resource := Resource{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&resource) existed, err := adapter.Engine.Get(&resource)
if err != nil { if err != nil {
panic(err) return &resource, err
} }
if existed { if existed {
return &resource return &resource, nil
} }
return nil return nil, nil
} }
func GetResource(id string) *Resource { func GetResource(id string) (*Resource, error) {
owner, name := util.GetOwnerAndNameFromIdNoCheck(id) owner, name := util.GetOwnerAndNameFromIdNoCheck(id)
return getResource(owner, name) return getResource(owner, name)
} }
func UpdateResource(id string, resource *Resource) bool { func UpdateResource(id string, resource *Resource) (bool, error) {
owner, name := util.GetOwnerAndNameFromIdNoCheck(id) owner, name := util.GetOwnerAndNameFromIdNoCheck(id)
if getResource(owner, name) == nil { if r, err := getResource(owner, name); err != nil {
return false return false, err
} else if r == nil {
return false, nil
} }
_, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(resource) _, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(resource)
if err != nil { if err != nil {
panic(err) return false, err
} }
// return affected != 0 // return affected != 0
return true return true, nil
} }
func AddResource(resource *Resource) bool { func AddResource(resource *Resource) (bool, error) {
affected, err := adapter.Engine.Insert(resource) affected, err := adapter.Engine.Insert(resource)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteResource(resource *Resource) bool { func DeleteResource(resource *Resource) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{resource.Owner, resource.Name}).Delete(&Resource{}) affected, err := adapter.Engine.ID(core.PK{resource.Owner, resource.Name}).Delete(&Resource{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (resource *Resource) GetId() string { func (resource *Resource) GetId() string {
return fmt.Sprintf("%s/%s", resource.Owner, resource.Name) return fmt.Sprintf("%s/%s", resource.Owner, resource.Name)
} }
func AddOrUpdateResource(resource *Resource) bool { func AddOrUpdateResource(resource *Resource) (bool, error) {
if getResource(resource.Owner, resource.Name) == nil { if r, err := getResource(resource.Owner, resource.Name); err != nil {
return false, err
} else if r == nil {
return AddResource(resource) return AddResource(resource)
} else { } else {
return UpdateResource(resource.GetId(), resource) return UpdateResource(resource.GetId(), resource)

View File

@ -36,79 +36,90 @@ type Role struct {
IsEnabled bool `json:"isEnabled"` IsEnabled bool `json:"isEnabled"`
} }
func GetRoleCount(owner, field, value string) int { func GetRoleCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Role{}) return session.Count(&Role{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetRoles(owner string) []*Role { func GetRoles(owner string) ([]*Role, error) {
roles := []*Role{} roles := []*Role{}
err := adapter.Engine.Desc("created_time").Find(&roles, &Role{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&roles, &Role{Owner: owner})
if err != nil { if err != nil {
panic(err) return roles, err
} }
return roles return roles, nil
} }
func GetPaginationRoles(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Role { func GetPaginationRoles(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Role, error) {
roles := []*Role{} roles := []*Role{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&roles) err := session.Find(&roles)
if err != nil { if err != nil {
panic(err) return roles, err
} }
return roles return roles, nil
} }
func getRole(owner string, name string) *Role { func getRole(owner string, name string) (*Role, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
role := Role{Owner: owner, Name: name} role := Role{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&role) existed, err := adapter.Engine.Get(&role)
if err != nil { if err != nil {
panic(err) return &role, err
} }
if existed { if existed {
return &role return &role, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetRole(id string) *Role { func GetRole(id string) (*Role, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getRole(owner, name) return getRole(owner, name)
} }
func UpdateRole(id string, role *Role) bool { func UpdateRole(id string, role *Role) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
oldRole := getRole(owner, name) oldRole, err := getRole(owner, name)
if err != nil {
return false, err
}
if oldRole == nil { if oldRole == nil {
return false return false, nil
} }
visited := map[string]struct{}{} visited := map[string]struct{}{}
permissions := GetPermissionsByRole(id) permissions, err := GetPermissionsByRole(id)
if err != nil {
return false, err
}
for _, permission := range permissions { for _, permission := range permissions {
removeGroupingPolicies(permission) removeGroupingPolicies(permission)
removePolicies(permission) removePolicies(permission)
visited[permission.GetId()] = struct{}{} visited[permission.GetId()] = struct{}{}
} }
ancestorRoles := GetAncestorRoles(id) ancestorRoles, err := GetAncestorRoles(id)
if err != nil {
return false, err
}
for _, r := range ancestorRoles { for _, r := range ancestorRoles {
permissions := GetPermissionsByRole(r.GetId()) permissions, err := GetPermissionsByRole(r.GetId())
if err != nil {
return false, err
}
for _, permission := range permissions { for _, permission := range permissions {
permissionId := permission.GetId() permissionId := permission.GetId()
if _, ok := visited[permissionId]; !ok { if _, ok := visited[permissionId]; !ok {
@ -121,27 +132,38 @@ func UpdateRole(id string, role *Role) bool {
if name != role.Name { if name != role.Name {
err := roleChangeTrigger(name, role.Name) err := roleChangeTrigger(name, role.Name)
if err != nil { if err != nil {
return false return false, nil
} }
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(role) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(role)
if err != nil { if err != nil {
panic(err) return false, err
} }
visited = map[string]struct{}{} visited = map[string]struct{}{}
newRoleID := role.GetId() newRoleID := role.GetId()
permissions = GetPermissionsByRole(newRoleID) permissions, err = GetPermissionsByRole(newRoleID)
if err != nil {
return false, err
}
for _, permission := range permissions { for _, permission := range permissions {
addGroupingPolicies(permission) addGroupingPolicies(permission)
addPolicies(permission) addPolicies(permission)
visited[permission.GetId()] = struct{}{} visited[permission.GetId()] = struct{}{}
} }
ancestorRoles = GetAncestorRoles(newRoleID) ancestorRoles, err = GetAncestorRoles(newRoleID)
if err != nil {
return false, err
}
for _, r := range ancestorRoles { for _, r := range ancestorRoles {
permissions := GetPermissionsByRole(r.GetId()) permissions, err := GetPermissionsByRole(r.GetId())
if err != nil {
return false, err
}
for _, permission := range permissions { for _, permission := range permissions {
permissionId := permission.GetId() permissionId := permission.GetId()
if _, ok := visited[permissionId]; !ok { if _, ok := visited[permissionId]; !ok {
@ -151,16 +173,16 @@ func UpdateRole(id string, role *Role) bool {
} }
} }
return affected != 0 return affected != 0, nil
} }
func AddRole(role *Role) bool { func AddRole(role *Role) (bool, error) {
affected, err := adapter.Engine.Insert(role) affected, err := adapter.Engine.Insert(role)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddRoles(roles []*Role) bool { func AddRoles(roles []*Role) bool {
@ -202,38 +224,45 @@ func AddRolesInBatch(roles []*Role) bool {
return affected return affected
} }
func DeleteRole(role *Role) bool { func DeleteRole(role *Role) (bool, error) {
roleId := role.GetId() roleId := role.GetId()
permissions := GetPermissionsByRole(roleId) permissions, err := GetPermissionsByRole(roleId)
if err != nil {
return false, err
}
for _, permission := range permissions { for _, permission := range permissions {
permission.Roles = util.DeleteVal(permission.Roles, roleId) permission.Roles = util.DeleteVal(permission.Roles, roleId)
UpdatePermission(permission.GetId(), permission) _, err := UpdatePermission(permission.GetId(), permission)
if err != nil {
return false, err
}
} }
affected, err := adapter.Engine.ID(core.PK{role.Owner, role.Name}).Delete(&Role{}) affected, err := adapter.Engine.ID(core.PK{role.Owner, role.Name}).Delete(&Role{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (role *Role) GetId() string { func (role *Role) GetId() string {
return fmt.Sprintf("%s/%s", role.Owner, role.Name) return fmt.Sprintf("%s/%s", role.Owner, role.Name)
} }
func GetRolesByUser(userId string) []*Role { func GetRolesByUser(userId string) ([]*Role, error) {
roles := []*Role{} roles := []*Role{}
err := adapter.Engine.Where("users like ?", "%"+userId+"\"%").Find(&roles) err := adapter.Engine.Where("users like ?", "%"+userId+"\"%").Find(&roles)
if err != nil { if err != nil {
panic(err) return roles, err
} }
for i := range roles { for i := range roles {
roles[i].Users = nil roles[i].Users = nil
} }
return roles return roles, nil
} }
func roleChangeTrigger(oldName string, newName string) error { func roleChangeTrigger(oldName string, newName string) error {
@ -250,6 +279,7 @@ func roleChangeTrigger(oldName string, newName string) error {
if err != nil { if err != nil {
return err return err
} }
for _, role := range roles { for _, role := range roles {
for j, u := range role.Roles { for j, u := range role.Roles {
owner, name := util.GetOwnerAndNameFromId(u) owner, name := util.GetOwnerAndNameFromId(u)
@ -268,6 +298,7 @@ func roleChangeTrigger(oldName string, newName string) error {
if err != nil { if err != nil {
return err return err
} }
for _, permission := range permissions { for _, permission := range permissions {
for j, u := range permission.Roles { for j, u := range permission.Roles {
// u = organization/username // u = organization/username
@ -293,17 +324,17 @@ func GetMaskedRoles(roles []*Role) []*Role {
return roles return roles
} }
func GetRolesByNamePrefix(owner string, prefix string) []*Role { func GetRolesByNamePrefix(owner string, prefix string) ([]*Role, error) {
roles := []*Role{} roles := []*Role{}
err := adapter.Engine.Where("owner=? and name like ?", owner, prefix+"%").Find(&roles) err := adapter.Engine.Where("owner=? and name like ?", owner, prefix+"%").Find(&roles)
if err != nil { if err != nil {
panic(err) return roles, err
} }
return roles return roles, nil
} }
func GetAncestorRoles(roleId string) []*Role { func GetAncestorRoles(roleId string) ([]*Role, error) {
var ( var (
result []*Role result []*Role
roleMap = make(map[string]*Role) roleMap = make(map[string]*Role)
@ -312,7 +343,11 @@ func GetAncestorRoles(roleId string) []*Role {
owner, _ := util.GetOwnerAndNameFromIdNoCheck(roleId) owner, _ := util.GetOwnerAndNameFromIdNoCheck(roleId)
allRoles := GetRoles(owner) allRoles, err := GetRoles(owner)
if err != nil {
return nil, err
}
for _, r := range allRoles { for _, r := range allRoles {
roleMap[r.GetId()] = r roleMap[r.GetId()] = r
} }
@ -331,7 +366,7 @@ func GetAncestorRoles(roleId string) []*Role {
} }
} }
return result return result, nil
} }
// containsRole is a helper function to check if a slice of roles contains a specific roleId // containsRole is a helper function to check if a slice of roles contains a specific roleId

View File

@ -18,21 +18,29 @@ import (
"github.com/casdoor/casdoor/xlsx" "github.com/casdoor/casdoor/xlsx"
) )
func getRoleMap(owner string) map[string]*Role { func getRoleMap(owner string) (map[string]*Role, error) {
m := map[string]*Role{} m := map[string]*Role{}
roles := GetRoles(owner) roles, err := GetRoles(owner)
if err != nil {
return nil, err
}
for _, role := range roles { for _, role := range roles {
m[role.GetId()] = role m[role.GetId()] = role
} }
return m return m, nil
} }
func UploadRoles(owner string, fileId string) bool { func UploadRoles(owner string, fileId string) (bool, error) {
table := xlsx.ReadXlsxFile(fileId) table := xlsx.ReadXlsxFile(fileId)
oldUserMap := getRoleMap(owner) oldUserMap, err := getRoleMap(owner)
if err != nil {
return false, err
}
newRoles := []*Role{} newRoles := []*Role{}
for index, line := range table { for index, line := range table {
if index == 0 || parseLineItem(&line, 0) == "" { if index == 0 || parseLineItem(&line, 0) == "" {
@ -57,7 +65,7 @@ func UploadRoles(owner string, fileId string) bool {
} }
if len(newRoles) == 0 { if len(newRoles) == 0 {
return false return false, nil
} }
return AddRolesInBatch(newRoles) return AddRolesInBatch(newRoles), nil
} }

View File

@ -105,7 +105,11 @@ func NewSamlResponse(user *User, host string, certificate string, destination st
roles := attributes.CreateElement("saml:Attribute") roles := attributes.CreateElement("saml:Attribute")
roles.CreateAttr("Name", "Roles") roles.CreateAttr("Name", "Roles")
roles.CreateAttr("NameFormat", "urn:oasis:names:tc:SAML:2.0:attrname-format:basic") roles.CreateAttr("NameFormat", "urn:oasis:names:tc:SAML:2.0:attrname-format:basic")
ExtendUserWithRolesAndPermissions(user) err := ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, err
}
for _, role := range user.Roles { for _, role := range user.Roles {
roles.CreateElement("saml:AttributeValue").CreateAttr("xsi:type", "xs:string").Element().SetText(role.Name) roles.CreateElement("saml:AttributeValue").CreateAttr("xsi:type", "xs:string").Element().SetText(role.Name)
} }
@ -186,7 +190,11 @@ type Attribute struct {
} }
func GetSamlMeta(application *Application, host string) (*IdpEntityDescriptor, error) { func GetSamlMeta(application *Application, host string) (*IdpEntityDescriptor, error) {
cert := getCertByApplication(application) cert, err := getCertByApplication(application)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(cert.Certificate)) block, _ := pem.Decode([]byte(cert.Certificate))
certificate := base64.StdEncoding.EncodeToString(block.Bytes) certificate := base64.StdEncoding.EncodeToString(block.Bytes)
@ -263,7 +271,11 @@ func GetSamlResponse(application *Application, user *User, samlRequest string, h
} }
// get certificate string // get certificate string
cert := getCertByApplication(application) cert, err := getCertByApplication(application)
if err != nil {
return "", "", "", err
}
block, _ := pem.Decode([]byte(cert.Certificate)) block, _ := pem.Decode([]byte(cert.Certificate))
certificate := base64.StdEncoding.EncodeToString(block.Bytes) certificate := base64.StdEncoding.EncodeToString(block.Bytes)

View File

@ -43,7 +43,10 @@ func ParseSamlResponse(samlResponse string, provider *Provider, host string) (st
} }
func GenerateSamlRequest(id, relayState, host, lang string) (auth string, method string, err error) { func GenerateSamlRequest(id, relayState, host, lang string) (auth string, method string, err error) {
provider := GetProvider(id) provider, err := GetProvider(id)
if err != nil {
return "", "", err
}
if provider.Category != "SAML" { if provider.Category != "SAML" {
return "", "", fmt.Errorf(i18n.Translate(lang, "saml_sp:provider %s's category is not SAML"), provider.Name) return "", "", fmt.Errorf(i18n.Translate(lang, "saml_sp:provider %s's category is not SAML"), provider.Name)
} }
@ -92,27 +95,33 @@ func buildSp(provider *Provider, samlResponse string, host string) (*saml2.SAMLS
} }
if provider.EnableSignAuthnRequest { if provider.EnableSignAuthnRequest {
sp.SignAuthnRequests = true sp.SignAuthnRequests = true
sp.SPKeyStore = buildSpKeyStore() sp.SPKeyStore, err = buildSpKeyStore()
if err != nil {
return nil, err
}
} }
return sp, nil return sp, nil
} }
func buildSpKeyStore() dsig.X509KeyStore { func buildSpKeyStore() (dsig.X509KeyStore, error) {
keyPair, err := tls.LoadX509KeyPair("object/token_jwt_key.pem", "object/token_jwt_key.key") keyPair, err := tls.LoadX509KeyPair("object/token_jwt_key.pem", "object/token_jwt_key.key")
if err != nil { if err != nil {
panic(err) return nil, err
} }
return &dsig.TLSCertKeyStore{ return &dsig.TLSCertKeyStore{
PrivateKey: keyPair.PrivateKey, PrivateKey: keyPair.PrivateKey,
Certificate: keyPair.Certificate, Certificate: keyPair.Certificate,
} }, nil
} }
func buildSpCertificateStore(provider *Provider, samlResponse string) (dsig.MemoryX509CertificateStore, error) { func buildSpCertificateStore(provider *Provider, samlResponse string) (certStore dsig.MemoryX509CertificateStore, err error) {
certEncodedData := "" certEncodedData := ""
if samlResponse != "" { if samlResponse != "" {
certEncodedData = getCertificateFromSamlResponse(samlResponse, provider.Type) certEncodedData, err = getCertificateFromSamlResponse(samlResponse, provider.Type)
if err != nil {
return
}
} else if provider.IdP != "" { } else if provider.IdP != "" {
certEncodedData = provider.IdP certEncodedData = provider.IdP
} }
@ -126,17 +135,18 @@ func buildSpCertificateStore(provider *Provider, samlResponse string) (dsig.Memo
return dsig.MemoryX509CertificateStore{}, err return dsig.MemoryX509CertificateStore{}, err
} }
certStore := dsig.MemoryX509CertificateStore{ certStore = dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{idpCert}, Roots: []*x509.Certificate{idpCert},
} }
return certStore, nil return certStore, nil
} }
func getCertificateFromSamlResponse(samlResponse string, providerType string) string { func getCertificateFromSamlResponse(samlResponse string, providerType string) (string, error) {
de, err := base64.StdEncoding.DecodeString(samlResponse) de, err := base64.StdEncoding.DecodeString(samlResponse)
if err != nil { if err != nil {
panic(err) return "", err
} }
deStr := strings.Replace(string(de), "\n", "", -1) deStr := strings.Replace(string(de), "\n", "", -1)
tagMap := map[string]string{ tagMap := map[string]string{
"Aliyun IDaaS": "ds", "Aliyun IDaaS": "ds",
@ -145,5 +155,5 @@ func getCertificateFromSamlResponse(samlResponse string, providerType string) st
tag := tagMap[providerType] tag := tagMap[providerType]
expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)</%s:X509Certificate>", tag, tag) expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)</%s:X509Certificate>", tag, tag)
res := regexp.MustCompile(expression).FindStringSubmatch(deStr) res := regexp.MustCompile(expression).FindStringSubmatch(deStr)
return res[1] return res[1], nil
} }

View File

@ -36,7 +36,7 @@ type Session struct {
SessionId []string `json:"sessionId"` SessionId []string `json:"sessionId"`
} }
func GetSessions(owner string) []*Session { func GetSessions(owner string) ([]*Session, error) {
sessions := []*Session{} sessions := []*Session{}
var err error var err error
if owner != "" { if owner != "" {
@ -45,61 +45,58 @@ func GetSessions(owner string) []*Session {
err = adapter.Engine.Desc("created_time").Find(&sessions) err = adapter.Engine.Desc("created_time").Find(&sessions)
} }
if err != nil { if err != nil {
panic(err) return sessions, err
} }
return sessions return sessions, nil
} }
func GetPaginationSessions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Session { func GetPaginationSessions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Session, error) {
sessions := []*Session{} sessions := []*Session{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&sessions) err := session.Find(&sessions)
if err != nil { if err != nil {
panic(err) return sessions, err
} }
return sessions return sessions, nil
} }
func GetSessionCount(owner, field, value string) int { func GetSessionCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Session{}) return session.Count(&Session{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetSingleSession(id string) *Session { func GetSingleSession(id string) (*Session, error) {
owner, name, application := util.GetOwnerAndNameAndOtherFromId(id) owner, name, application := util.GetOwnerAndNameAndOtherFromId(id)
session := Session{Owner: owner, Name: name, Application: application} session := Session{Owner: owner, Name: name, Application: application}
get, err := adapter.Engine.Get(&session) get, err := adapter.Engine.Get(&session)
if err != nil { if err != nil {
panic(err) return &session, err
} }
if !get { if !get {
return nil return nil, nil
} }
return &session return &session, nil
} }
func UpdateSession(id string, session *Session) bool { func UpdateSession(id string, session *Session) (bool, error) {
owner, name, application := util.GetOwnerAndNameAndOtherFromId(id) owner, name, application := util.GetOwnerAndNameAndOtherFromId(id)
if GetSingleSession(id) == nil { if ss, err := GetSingleSession(id); err != nil {
return false return false, err
} else if ss == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name, application}).Update(session) affected, err := adapter.Engine.ID(core.PK{owner, name, application}).Update(session)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func removeExtraSessionIds(session *Session) { func removeExtraSessionIds(session *Session) {
@ -108,17 +105,21 @@ func removeExtraSessionIds(session *Session) {
} }
} }
func AddSession(session *Session) bool { func AddSession(session *Session) (bool, error) {
dbSession := GetSingleSession(session.GetId()) dbSession, err := GetSingleSession(session.GetId())
if err != nil {
return false, err
}
if dbSession == nil { if dbSession == nil {
session.CreatedTime = util.GetCurrentTime() session.CreatedTime = util.GetCurrentTime()
affected, err := adapter.Engine.Insert(session) affected, err := adapter.Engine.Insert(session)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} else { } else {
m := make(map[string]struct{}) m := make(map[string]struct{})
for _, v := range dbSession.SessionId { for _, v := range dbSession.SessionId {
@ -136,10 +137,14 @@ func AddSession(session *Session) bool {
} }
} }
func DeleteSession(id string) bool { func DeleteSession(id string) (bool, error) {
owner, name, application := util.GetOwnerAndNameAndOtherFromId(id) owner, name, application := util.GetOwnerAndNameAndOtherFromId(id)
if owner == CasdoorOrganization && application == CasdoorApplication { if owner == CasdoorOrganization && application == CasdoorApplication {
session := GetSingleSession(id) session, err := GetSingleSession(id)
if err != nil {
return false, err
}
if session != nil { if session != nil {
DeleteBeegoSession(session.SessionId) DeleteBeegoSession(session.SessionId)
} }
@ -147,16 +152,19 @@ func DeleteSession(id string) bool {
affected, err := adapter.Engine.ID(core.PK{owner, name, application}).Delete(&Session{}) affected, err := adapter.Engine.ID(core.PK{owner, name, application}).Delete(&Session{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteSessionId(id string, sessionId string) bool { func DeleteSessionId(id string, sessionId string) (bool, error) {
session := GetSingleSession(id) session, err := GetSingleSession(id)
if err != nil {
return false, err
}
if session == nil { if session == nil {
return false return false, nil
} }
owner, _, application := util.GetOwnerAndNameAndOtherFromId(id) owner, _, application := util.GetOwnerAndNameAndOtherFromId(id)
@ -185,17 +193,21 @@ func (session *Session) GetId() string {
return fmt.Sprintf("%s/%s/%s", session.Owner, session.Name, session.Application) return fmt.Sprintf("%s/%s/%s", session.Owner, session.Name, session.Application)
} }
func IsSessionDuplicated(id string, sessionId string) bool { func IsSessionDuplicated(id string, sessionId string) (bool, error) {
session := GetSingleSession(id) session, err := GetSingleSession(id)
if err != nil {
return false, err
}
if session == nil { if session == nil {
return false return false, nil
} else { } else {
if len(session.SessionId) > 1 { if len(session.SessionId) > 1 {
return true return true, nil
} else if len(session.SessionId) < 1 { } else if len(session.SessionId) < 1 {
return false return false, nil
} else { } else {
return session.SessionId[0] != sessionId return session.SessionId[0] != sessionId, nil
} }
} }
} }

View File

@ -30,11 +30,7 @@ import (
var isCloudIntranet bool var isCloudIntranet bool
func init() { func init() {
var err error isCloudIntranet = conf.GetConfigBool("isCloudIntranet")
isCloudIntranet, err = conf.GetConfigBool("isCloudIntranet")
if err != nil {
// panic(err)
}
} }
func getProviderEndpoint(provider *Provider) string { func getProviderEndpoint(provider *Provider) string {

View File

@ -63,90 +63,87 @@ func NewSubscription(owner string, user string, plan string, duration int) *Subs
} }
} }
func GetSubscriptionCount(owner, field, value string) int { func GetSubscriptionCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Subscription{}) return session.Count(&Subscription{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetSubscriptions(owner string) []*Subscription { func GetSubscriptions(owner string) ([]*Subscription, error) {
subscriptions := []*Subscription{} subscriptions := []*Subscription{}
err := adapter.Engine.Desc("created_time").Find(&subscriptions, &Subscription{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&subscriptions, &Subscription{Owner: owner})
if err != nil { if err != nil {
panic(err) return subscriptions, err
} }
return subscriptions return subscriptions, nil
} }
func GetPaginationSubscriptions(owner string, offset, limit int, field, value, sortField, sortOrder string) []*Subscription { func GetPaginationSubscriptions(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Subscription, error) {
subscriptions := []*Subscription{} subscriptions := []*Subscription{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&subscriptions) err := session.Find(&subscriptions)
if err != nil { if err != nil {
panic(err) return subscriptions, err
} }
return subscriptions return subscriptions, nil
} }
func getSubscription(owner string, name string) *Subscription { func getSubscription(owner string, name string) (*Subscription, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
subscription := Subscription{Owner: owner, Name: name} subscription := Subscription{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&subscription) existed, err := adapter.Engine.Get(&subscription)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &subscription return &subscription, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetSubscription(id string) *Subscription { func GetSubscription(id string) (*Subscription, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getSubscription(owner, name) return getSubscription(owner, name)
} }
func UpdateSubscription(id string, subscription *Subscription) bool { func UpdateSubscription(id string, subscription *Subscription) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getSubscription(owner, name) == nil { if s, err := getSubscription(owner, name); err != nil {
return false return false, err
} else if s == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(subscription) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(subscription)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddSubscription(subscription *Subscription) bool { func AddSubscription(subscription *Subscription) (bool, error) {
affected, err := adapter.Engine.Insert(subscription) affected, err := adapter.Engine.Insert(subscription)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteSubscription(subscription *Subscription) bool { func DeleteSubscription(subscription *Subscription) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{subscription.Owner, subscription.Name}).Delete(&Subscription{}) affected, err := adapter.Engine.ID(core.PK{subscription.Owner, subscription.Name}).Delete(&Subscription{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (subscription *Subscription) GetId() string { func (subscription *Subscription) GetId() string {

View File

@ -55,66 +55,61 @@ type Syncer struct {
Adapter *Adapter `xorm:"-" json:"-"` Adapter *Adapter `xorm:"-" json:"-"`
} }
func GetSyncerCount(owner, organization, field, value string) int { func GetSyncerCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Syncer{Organization: organization}) return session.Count(&Syncer{Organization: organization})
if err != nil {
panic(err)
}
return int(count)
} }
func GetSyncers(owner string) []*Syncer { func GetSyncers(owner string) ([]*Syncer, error) {
syncers := []*Syncer{} syncers := []*Syncer{}
err := adapter.Engine.Desc("created_time").Find(&syncers, &Syncer{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&syncers, &Syncer{Owner: owner})
if err != nil { if err != nil {
panic(err) return syncers, err
} }
return syncers return syncers, nil
} }
func GetOrganizationSyncers(owner, organization string) []*Syncer { func GetOrganizationSyncers(owner, organization string) ([]*Syncer, error) {
syncers := []*Syncer{} syncers := []*Syncer{}
err := adapter.Engine.Desc("created_time").Find(&syncers, &Syncer{Owner: owner, Organization: organization}) err := adapter.Engine.Desc("created_time").Find(&syncers, &Syncer{Owner: owner, Organization: organization})
if err != nil { if err != nil {
panic(err) return syncers, err
} }
return syncers return syncers, nil
} }
func GetPaginationSyncers(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Syncer { func GetPaginationSyncers(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Syncer, error) {
syncers := []*Syncer{} syncers := []*Syncer{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&syncers, &Syncer{Organization: organization}) err := session.Find(&syncers, &Syncer{Organization: organization})
if err != nil { if err != nil {
panic(err) return syncers, err
} }
return syncers return syncers, nil
} }
func getSyncer(owner string, name string) *Syncer { func getSyncer(owner string, name string) (*Syncer, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
syncer := Syncer{Owner: owner, Name: name} syncer := Syncer{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&syncer) existed, err := adapter.Engine.Get(&syncer)
if err != nil { if err != nil {
panic(err) return &syncer, err
} }
if existed { if existed {
return &syncer return &syncer, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetSyncer(id string) *Syncer { func GetSyncer(id string) (*Syncer, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getSyncer(owner, name) return getSyncer(owner, name)
} }
@ -137,10 +132,12 @@ func GetMaskedSyncers(syncers []*Syncer) []*Syncer {
return syncers return syncers
} }
func UpdateSyncer(id string, syncer *Syncer) bool { func UpdateSyncer(id string, syncer *Syncer) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getSyncer(owner, name) == nil { if s, err := getSyncer(owner, name); err != nil {
return false return false, err
} else if s == nil {
return false, nil
} }
session := adapter.Engine.ID(core.PK{owner, name}).AllCols() session := adapter.Engine.ID(core.PK{owner, name}).AllCols()
@ -149,56 +146,66 @@ func UpdateSyncer(id string, syncer *Syncer) bool {
} }
affected, err := session.Update(syncer) affected, err := session.Update(syncer)
if err != nil { if err != nil {
panic(err) return false, err
} }
if affected == 1 { if affected == 1 {
addSyncerJob(syncer) err = addSyncerJob(syncer)
if err != nil {
return false, err
}
} }
return affected != 0 return affected != 0, nil
} }
func updateSyncerErrorText(syncer *Syncer, line string) bool { func updateSyncerErrorText(syncer *Syncer, line string) (bool, error) {
s := getSyncer(syncer.Owner, syncer.Name) s, err := getSyncer(syncer.Owner, syncer.Name)
if err != nil {
return false, err
}
if s == nil { if s == nil {
return false return false, nil
} }
s.ErrorText = s.ErrorText + line s.ErrorText = s.ErrorText + line
affected, err := adapter.Engine.ID(core.PK{s.Owner, s.Name}).Cols("error_text").Update(s) affected, err := adapter.Engine.ID(core.PK{s.Owner, s.Name}).Cols("error_text").Update(s)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddSyncer(syncer *Syncer) bool { func AddSyncer(syncer *Syncer) (bool, error) {
affected, err := adapter.Engine.Insert(syncer) affected, err := adapter.Engine.Insert(syncer)
if err != nil { if err != nil {
panic(err) return false, err
} }
if affected == 1 { if affected == 1 {
addSyncerJob(syncer) err = addSyncerJob(syncer)
if err != nil {
return false, err
}
} }
return affected != 0 return affected != 0, nil
} }
func DeleteSyncer(syncer *Syncer) bool { func DeleteSyncer(syncer *Syncer) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{syncer.Owner, syncer.Name}).Delete(&Syncer{}) affected, err := adapter.Engine.ID(core.PK{syncer.Owner, syncer.Name}).Delete(&Syncer{})
if err != nil { if err != nil {
panic(err) return false, err
} }
if affected == 1 { if affected == 1 {
deleteSyncerJob(syncer) deleteSyncerJob(syncer)
} }
return affected != 0 return affected != 0, nil
} }
func (syncer *Syncer) GetId() string { func (syncer *Syncer) GetId() string {

View File

@ -19,22 +19,25 @@ type Affiliation struct {
Name string `xorm:"varchar(128)" json:"name"` Name string `xorm:"varchar(128)" json:"name"`
} }
func (syncer *Syncer) getAffiliations() []*Affiliation { func (syncer *Syncer) getAffiliations() ([]*Affiliation, error) {
affiliations := []*Affiliation{} affiliations := []*Affiliation{}
err := syncer.Adapter.Engine.Table(syncer.AffiliationTable).Asc("id").Find(&affiliations) err := syncer.Adapter.Engine.Table(syncer.AffiliationTable).Asc("id").Find(&affiliations)
if err != nil { if err != nil {
panic(err) return nil, err
} }
return affiliations return affiliations, nil
} }
func (syncer *Syncer) getAffiliationMap() ([]*Affiliation, map[int]string) { func (syncer *Syncer) getAffiliationMap() ([]*Affiliation, map[int]string, error) {
affiliations := syncer.getAffiliations() affiliations, err := syncer.getAffiliations()
if err != nil {
return nil, nil, err
}
m := map[int]string{} m := map[int]string{}
for _, affiliation := range affiliations { for _, affiliation := range affiliations {
m[affiliation.Id] = affiliation.Name m[affiliation.Id] = affiliation.Name
} }
return affiliations, m return affiliations, m, nil
} }

View File

@ -43,11 +43,11 @@ func clearCron(name string) {
} }
} }
func addSyncerJob(syncer *Syncer) { func addSyncerJob(syncer *Syncer) error {
deleteSyncerJob(syncer) deleteSyncerJob(syncer)
if !syncer.IsEnabled { if !syncer.IsEnabled {
return return nil
} }
syncer.initAdapter() syncer.initAdapter()
@ -58,10 +58,11 @@ func addSyncerJob(syncer *Syncer) {
cron := getCronMap(syncer.Name) cron := getCronMap(syncer.Name)
_, err := cron.AddFunc(schedule, syncer.syncUsers) _, err := cron.AddFunc(schedule, syncer.syncUsers)
if err != nil { if err != nil {
panic(err) return err
} }
cron.Start() cron.Start()
return nil
} }
func deleteSyncerJob(syncer *Syncer) { func deleteSyncerJob(syncer *Syncer) {

View File

@ -16,46 +16,74 @@ package object
import "fmt" import "fmt"
func getDbSyncerForUser(user *User) *Syncer { func getDbSyncerForUser(user *User) (*Syncer, error) {
syncers := GetSyncers("admin") syncers, err := GetSyncers("admin")
if err != nil {
return nil, err
}
for _, syncer := range syncers { for _, syncer := range syncers {
if syncer.Organization == user.Owner && syncer.IsEnabled && syncer.Type == "Database" { if syncer.Organization == user.Owner && syncer.IsEnabled && syncer.Type == "Database" {
return syncer return syncer, nil
} }
} }
return nil return nil, nil
} }
func getEnabledSyncerForOrganization(organization string) *Syncer { func getEnabledSyncerForOrganization(organization string) (*Syncer, error) {
syncers := GetSyncers("admin") syncers, err := GetSyncers("admin")
if err != nil {
return nil, err
}
for _, syncer := range syncers { for _, syncer := range syncers {
if syncer.Organization == organization && syncer.IsEnabled { if syncer.Organization == organization && syncer.IsEnabled {
return syncer return syncer, nil
} }
} }
return nil return nil, nil
} }
func AddUserToOriginalDatabase(user *User) { func AddUserToOriginalDatabase(user *User) error {
syncer := getEnabledSyncerForOrganization(user.Owner) syncer, err := getEnabledSyncerForOrganization(user.Owner)
if err != nil {
return err
}
if syncer == nil { if syncer == nil {
return return nil
} }
updatedOUser := syncer.createOriginalUserFromUser(user) updatedOUser := syncer.createOriginalUserFromUser(user)
syncer.addUser(updatedOUser) _, err = syncer.addUser(updatedOUser)
fmt.Printf("Add from user to oUser: %v\n", updatedOUser) if err != nil {
} return err
func UpdateUserToOriginalDatabase(user *User) {
syncer := getEnabledSyncerForOrganization(user.Owner)
if syncer == nil {
return
} }
newUser := GetUser(user.GetId()) fmt.Printf("Add from user to oUser: %v\n", updatedOUser)
return nil
}
func UpdateUserToOriginalDatabase(user *User) error {
syncer, err := getEnabledSyncerForOrganization(user.Owner)
if err != nil {
return err
}
if syncer == nil {
return nil
}
newUser, err := GetUser(user.GetId())
if err != nil {
return err
}
updatedOUser := syncer.createOriginalUserFromUser(newUser) updatedOUser := syncer.createOriginalUserFromUser(newUser)
syncer.updateUser(updatedOUser) _, err = syncer.updateUser(updatedOUser)
if err != nil {
return err
}
fmt.Printf("Update from user to oUser: %v\n", updatedOUser) fmt.Printf("Update from user to oUser: %v\n", updatedOUser)
return nil
} }

View File

@ -37,7 +37,7 @@ func (syncer *Syncer) syncUsers() {
var affiliationMap map[int]string var affiliationMap map[int]string
if syncer.AffiliationTable != "" { if syncer.AffiliationTable != "" {
_, affiliationMap = syncer.getAffiliationMap() _, affiliationMap, err = syncer.getAffiliationMap()
} }
newUsers := []*User{} newUsers := []*User{}
@ -86,13 +86,19 @@ func (syncer *Syncer) syncUsers() {
} }
} }
} }
AddUsersInBatch(newUsers) _, err = AddUsersInBatch(newUsers)
if err != nil {
panic(err)
}
for _, user := range users { for _, user := range users {
id := user.Id id := user.Id
if _, ok := oUserMap[id]; !ok { if _, ok := oUserMap[id]; !ok {
newOUser := syncer.createOriginalUserFromUser(user) newOUser := syncer.createOriginalUserFromUser(user)
syncer.addUser(newOUser) _, err = syncer.addUser(newOUser)
if err != nil {
panic(err)
}
fmt.Printf("New oUser: %v\n", newOUser) fmt.Printf("New oUser: %v\n", newOUser)
} }
} }

View File

@ -122,14 +122,18 @@ func (syncer *Syncer) updateUser(user *OriginalUser) (bool, error) {
} }
func (syncer *Syncer) updateUserForOriginalFields(user *User) (bool, error) { func (syncer *Syncer) updateUserForOriginalFields(user *User) (bool, error) {
var err error
owner, name := util.GetOwnerAndNameFromId(user.GetId()) owner, name := util.GetOwnerAndNameFromId(user.GetId())
oldUser := getUserById(owner, name) oldUser, err := getUserById(owner, name)
if oldUser == nil { if oldUser == nil || err != nil {
return false, nil return false, err
} }
if user.Avatar != oldUser.Avatar && user.Avatar != "" { if user.Avatar != oldUser.Avatar && user.Avatar != "" {
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
if err != nil {
return false, err
}
} }
columns := syncer.getCasdoorColumns() columns := syncer.getCasdoorColumns()
@ -175,7 +179,11 @@ func (syncer *Syncer) initAdapter() {
} }
func RunSyncUsersJob() { func RunSyncUsersJob() {
syncers := GetSyncers("admin") syncers, err := GetSyncers("admin")
if err != nil {
panic(err)
}
for _, syncer := range syncers { for _, syncer := range syncers {
addSyncerJob(syncer) addSyncerJob(syncer)
} }

View File

@ -23,7 +23,7 @@ import (
func TestGetUsers(t *testing.T) { func TestGetUsers(t *testing.T) {
InitConfig() InitConfig()
syncers := GetSyncers("admin") syncers, _ := GetSyncers("admin")
syncer := syncers[0] syncer := syncers[0]
syncer.initAdapter() syncer.initAdapter()
users, _ := syncer.getOriginalUsers() users, _ := syncer.getOriginalUsers()

View File

@ -15,7 +15,11 @@
package object package object
func (syncer *Syncer) getUsers() []*User { func (syncer *Syncer) getUsers() []*User {
users := GetUsers(syncer.Organization) users, err := GetUsers(syncer.Organization)
if err != nil {
panic(err)
}
return users return users
} }

View File

@ -91,67 +91,54 @@ type IntrospectionResponse struct {
Jti string `json:"jti,omitempty"` Jti string `json:"jti,omitempty"`
} }
func GetTokenCount(owner, organization, field, value string) int { func GetTokenCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Token{Organization: organization}) return session.Count(&Token{Organization: organization})
if err != nil {
panic(err)
}
return int(count)
} }
func GetTokens(owner string, organization string) []*Token { func GetTokens(owner string, organization string) ([]*Token, error) {
tokens := []*Token{} tokens := []*Token{}
err := adapter.Engine.Desc("created_time").Find(&tokens, &Token{Owner: owner, Organization: organization}) err := adapter.Engine.Desc("created_time").Find(&tokens, &Token{Owner: owner, Organization: organization})
if err != nil { return tokens, err
panic(err)
}
return tokens
} }
func GetPaginationTokens(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Token { func GetPaginationTokens(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Token, error) {
tokens := []*Token{} tokens := []*Token{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&tokens, &Token{Organization: organization}) err := session.Find(&tokens, &Token{Organization: organization})
if err != nil { return tokens, err
panic(err)
}
return tokens
} }
func getToken(owner string, name string) *Token { func getToken(owner string, name string) (*Token, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
token := Token{Owner: owner, Name: name} token := Token{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&token) existed, err := adapter.Engine.Get(&token)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &token return &token, nil
} }
return nil return nil, nil
} }
func getTokenByCode(code string) *Token { func getTokenByCode(code string) (*Token, error) {
token := Token{Code: code} token := Token{Code: code}
existed, err := adapter.Engine.Get(&token) existed, err := adapter.Engine.Get(&token)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &token return &token, nil
} }
return nil return nil, nil
} }
func updateUsedByCode(token *Token) bool { func updateUsedByCode(token *Token) bool {
@ -163,7 +150,7 @@ func updateUsedByCode(token *Token) bool {
return affected != 0 return affected != 0
} }
func GetToken(id string) *Token { func GetToken(id string) (*Token, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getToken(owner, name) return getToken(owner, name)
} }
@ -172,124 +159,155 @@ func (token *Token) GetId() string {
return fmt.Sprintf("%s/%s", token.Owner, token.Name) return fmt.Sprintf("%s/%s", token.Owner, token.Name)
} }
func UpdateToken(id string, token *Token) bool { func UpdateToken(id string, token *Token) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getToken(owner, name) == nil { if t, err := getToken(owner, name); err != nil {
return false return false, err
} else if t == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(token) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(token)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddToken(token *Token) bool { func AddToken(token *Token) (bool, error) {
affected, err := adapter.Engine.Insert(token) affected, err := adapter.Engine.Insert(token)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteToken(token *Token) bool { func DeleteToken(token *Token) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{token.Owner, token.Name}).Delete(&Token{}) affected, err := adapter.Engine.ID(core.PK{token.Owner, token.Name}).Delete(&Token{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func ExpireTokenByAccessToken(accessToken string) (bool, *Application, *Token) { func ExpireTokenByAccessToken(accessToken string) (bool, *Application, *Token, error) {
token := Token{AccessToken: accessToken} token := Token{AccessToken: accessToken}
existed, err := adapter.Engine.Get(&token) existed, err := adapter.Engine.Get(&token)
if err != nil { if err != nil {
panic(err) return false, nil, nil, err
} }
if !existed { if !existed {
return false, nil, nil return false, nil, nil, nil
} }
token.ExpiresIn = 0 token.ExpiresIn = 0
affected, err := adapter.Engine.ID(core.PK{token.Owner, token.Name}).Cols("expires_in").Update(&token) affected, err := adapter.Engine.ID(core.PK{token.Owner, token.Name}).Cols("expires_in").Update(&token)
if err != nil { if err != nil {
panic(err) return false, nil, nil, err
} }
application := getApplication(token.Owner, token.Application) application, err := getApplication(token.Owner, token.Application)
return affected != 0, application, &token if err != nil {
return false, nil, nil, err
}
return affected != 0, application, &token, nil
} }
func GetTokenByAccessToken(accessToken string) *Token { func GetTokenByAccessToken(accessToken string) (*Token, error) {
// Check if the accessToken is in the database // Check if the accessToken is in the database
token := Token{AccessToken: accessToken} token := Token{AccessToken: accessToken}
existed, err := adapter.Engine.Get(&token) existed, err := adapter.Engine.Get(&token)
if err != nil || !existed { if err != nil {
return nil return nil, err
} }
return &token
if !existed {
return nil, nil
}
return &token, nil
} }
func GetTokenByTokenAndApplication(token string, application string) *Token { func GetTokenByTokenAndApplication(token string, application string) (*Token, error) {
tokenResult := Token{} tokenResult := Token{}
existed, err := adapter.Engine.Where("(refresh_token = ? or access_token = ? ) and application = ?", token, token, application).Get(&tokenResult) existed, err := adapter.Engine.Where("(refresh_token = ? or access_token = ? ) and application = ?", token, token, application).Get(&tokenResult)
if err != nil || !existed { if err != nil {
return nil return nil, err
} }
return &tokenResult
if !existed {
return nil, nil
}
return &tokenResult, nil
} }
func CheckOAuthLogin(clientId string, responseType string, redirectUri string, scope string, state string, lang string) (string, *Application) { func CheckOAuthLogin(clientId string, responseType string, redirectUri string, scope string, state string, lang string) (string, *Application, error) {
if responseType != "code" && responseType != "token" && responseType != "id_token" { if responseType != "code" && responseType != "token" && responseType != "id_token" {
return fmt.Sprintf(i18n.Translate(lang, "token:Grant_type: %s is not supported in this application"), responseType), nil return fmt.Sprintf(i18n.Translate(lang, "token:Grant_type: %s is not supported in this application"), responseType), nil, nil
}
application, err := GetApplicationByClientId(clientId)
if err != nil {
return "", nil, err
} }
application := GetApplicationByClientId(clientId)
if application == nil { if application == nil {
return i18n.Translate(lang, "token:Invalid client_id"), nil return i18n.Translate(lang, "token:Invalid client_id"), nil, nil
} }
if !application.IsRedirectUriValid(redirectUri) { if !application.IsRedirectUriValid(redirectUri) {
return fmt.Sprintf(i18n.Translate(lang, "token:Redirect URI: %s doesn't exist in the allowed Redirect URI list"), redirectUri), application return fmt.Sprintf(i18n.Translate(lang, "token:Redirect URI: %s doesn't exist in the allowed Redirect URI list"), redirectUri), application, nil
} }
// Mask application for /api/get-app-login // Mask application for /api/get-app-login
application.ClientSecret = "" application.ClientSecret = ""
return "", application return "", application, nil
} }
func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string, host string, lang string) *Code { func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string, host string, lang string) (*Code, error) {
user := GetUser(userId) user, err := GetUser(userId)
if err != nil {
return nil, err
}
if user == nil { if user == nil {
return &Code{ return &Code{
Message: fmt.Sprintf("general:The user: %s doesn't exist", userId), Message: fmt.Sprintf("general:The user: %s doesn't exist", userId),
Code: "", Code: "",
} }, nil
} }
if user.IsForbidden { if user.IsForbidden {
return &Code{ return &Code{
Message: "error: the user is forbidden to sign in, please contact the administrator", Message: "error: the user is forbidden to sign in, please contact the administrator",
Code: "", Code: "",
} }, nil
}
msg, application, err := CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, lang)
if err != nil {
return nil, err
} }
msg, application := CheckOAuthLogin(clientId, responseType, redirectUri, scope, state, lang)
if msg != "" { if msg != "" {
return &Code{ return &Code{
Message: msg, Message: msg,
Code: "", Code: "",
} }, nil
} }
ExtendUserWithRolesAndPermissions(user) err = ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, err
}
accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, nonce, scope, host) accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, nonce, scope, host)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if challenge == "null" { if challenge == "null" {
@ -313,21 +331,28 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU
CodeIsUsed: false, CodeIsUsed: false,
CodeExpireIn: time.Now().Add(time.Minute * 5).Unix(), CodeExpireIn: time.Now().Add(time.Minute * 5).Unix(),
} }
AddToken(token) _, err = AddToken(token)
if err != nil {
return nil, err
}
return &Code{ return &Code{
Message: "", Message: "",
Code: token.Code, Code: token.Code,
} }, nil
} }
func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string, scope string, username string, password string, host string, refreshToken string, tag string, avatar string, lang string) interface{} { func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string, scope string, username string, password string, host string, refreshToken string, tag string, avatar string, lang string) (interface{}, error) {
application := GetApplicationByClientId(clientId) application, err := GetApplicationByClientId(clientId)
if err != nil {
return nil, err
}
if application == nil { if application == nil {
return &TokenError{ return &TokenError{
Error: InvalidClient, Error: InvalidClient,
ErrorDescription: "client_id is invalid", ErrorDescription: "client_id is invalid",
} }, nil
} }
// Check if grantType is allowed in the current application // Check if grantType is allowed in the current application
@ -336,32 +361,44 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code
return &TokenError{ return &TokenError{
Error: UnsupportedGrantType, Error: UnsupportedGrantType,
ErrorDescription: fmt.Sprintf("grant_type: %s is not supported in this application", grantType), ErrorDescription: fmt.Sprintf("grant_type: %s is not supported in this application", grantType),
} }, nil
} }
var token *Token var token *Token
var tokenError *TokenError var tokenError *TokenError
switch grantType { switch grantType {
case "authorization_code": // Authorization Code Grant case "authorization_code": // Authorization Code Grant
token, tokenError = GetAuthorizationCodeToken(application, clientSecret, code, verifier) token, tokenError, err = GetAuthorizationCodeToken(application, clientSecret, code, verifier)
case "password": // Resource Owner Password Credentials Grant case "password": // Resource Owner Password Credentials Grant
token, tokenError = GetPasswordToken(application, username, password, scope, host) token, tokenError, err = GetPasswordToken(application, username, password, scope, host)
case "client_credentials": // Client Credentials Grant case "client_credentials": // Client Credentials Grant
token, tokenError = GetClientCredentialsToken(application, clientSecret, scope, host) token, tokenError, err = GetClientCredentialsToken(application, clientSecret, scope, host)
case "refresh_token": case "refresh_token":
return RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host) refreshToken2, err := RefreshToken(grantType, refreshToken, scope, clientId, clientSecret, host)
if err != nil {
return nil, err
}
return refreshToken2, nil
}
if err != nil {
return nil, err
} }
if tag == "wechat_miniprogram" { if tag == "wechat_miniprogram" {
// Wechat Mini Program // Wechat Mini Program
token, tokenError = GetWechatMiniProgramToken(application, code, host, username, avatar, lang) token, tokenError, err = GetWechatMiniProgramToken(application, code, host, username, avatar, lang)
if err != nil {
return nil, err
}
} }
if tokenError != nil { if tokenError != nil {
return tokenError return tokenError, nil
} }
token.CodeIsUsed = true token.CodeIsUsed = true
go updateUsedByCode(token) go updateUsedByCode(token)
tokenWrapper := &TokenWrapper{ tokenWrapper := &TokenWrapper{
@ -373,29 +410,33 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code
Scope: token.Scope, Scope: token.Scope,
} }
return tokenWrapper return tokenWrapper, nil
} }
func RefreshToken(grantType string, refreshToken string, scope string, clientId string, clientSecret string, host string) interface{} { func RefreshToken(grantType string, refreshToken string, scope string, clientId string, clientSecret string, host string) (interface{}, error) {
// check parameters // check parameters
if grantType != "refresh_token" { if grantType != "refresh_token" {
return &TokenError{ return &TokenError{
Error: UnsupportedGrantType, Error: UnsupportedGrantType,
ErrorDescription: "grant_type should be refresh_token", ErrorDescription: "grant_type should be refresh_token",
} }, nil
} }
application := GetApplicationByClientId(clientId) application, err := GetApplicationByClientId(clientId)
if err != nil {
return nil, err
}
if application == nil { if application == nil {
return &TokenError{ return &TokenError{
Error: InvalidClient, Error: InvalidClient,
ErrorDescription: "client_id is invalid", ErrorDescription: "client_id is invalid",
} }, nil
} }
if clientSecret != "" && application.ClientSecret != clientSecret { if clientSecret != "" && application.ClientSecret != clientSecret {
return &TokenError{ return &TokenError{
Error: InvalidClient, Error: InvalidClient,
ErrorDescription: "client_secret is invalid", ErrorDescription: "client_secret is invalid",
} }, nil
} }
// check whether the refresh token is valid, and has not expired. // check whether the refresh token is valid, and has not expired.
token := Token{RefreshToken: refreshToken} token := Token{RefreshToken: refreshToken}
@ -404,33 +445,44 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId
return &TokenError{ return &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "refresh token is invalid, expired or revoked", ErrorDescription: "refresh token is invalid, expired or revoked",
} }, nil
}
cert, err := getCertByApplication(application)
if err != nil {
return nil, err
} }
cert := getCertByApplication(application)
_, err = ParseJwtToken(refreshToken, cert) _, err = ParseJwtToken(refreshToken, cert)
if err != nil { if err != nil {
return &TokenError{ return &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: fmt.Sprintf("parse refresh token error: %s", err.Error()), ErrorDescription: fmt.Sprintf("parse refresh token error: %s", err.Error()),
} }, nil
} }
// generate a new token // generate a new token
user := getUser(application.Organization, token.User) user, err := getUser(application.Organization, token.User)
if err != nil {
return nil, err
}
if user.IsForbidden { if user.IsForbidden {
return &TokenError{ return &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "the user is forbidden to sign in, please contact the administrator", ErrorDescription: "the user is forbidden to sign in, please contact the administrator",
} }, nil
} }
ExtendUserWithRolesAndPermissions(user) err = ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, err
}
newAccessToken, newRefreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host) newAccessToken, newRefreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host)
if err != nil { if err != nil {
return &TokenError{ return &TokenError{
Error: EndpointError, Error: EndpointError,
ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()), ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()),
} }, nil
} }
newToken := &Token{ newToken := &Token{
@ -447,8 +499,15 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId
Scope: scope, Scope: scope,
TokenType: "Bearer", TokenType: "Bearer",
} }
AddToken(newToken) _, err = AddToken(newToken)
DeleteToken(&token) if err != nil {
return nil, err
}
_, err = DeleteToken(&token)
if err != nil {
return nil, err
}
tokenWrapper := &TokenWrapper{ tokenWrapper := &TokenWrapper{
AccessToken: newToken.AccessToken, AccessToken: newToken.AccessToken,
@ -459,7 +518,7 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId
Scope: newToken.Scope, Scope: newToken.Scope,
} }
return tokenWrapper return tokenWrapper, nil
} }
// PkceChallenge: base64-URL-encoded SHA256 hash of verifier, per rfc 7636 // PkceChallenge: base64-URL-encoded SHA256 hash of verifier, per rfc 7636
@ -486,34 +545,38 @@ func IsGrantTypeValid(method string, grantTypes []string) bool {
// GetAuthorizationCodeToken // GetAuthorizationCodeToken
// Authorization code flow // Authorization code flow
func GetAuthorizationCodeToken(application *Application, clientSecret string, code string, verifier string) (*Token, *TokenError) { func GetAuthorizationCodeToken(application *Application, clientSecret string, code string, verifier string) (*Token, *TokenError, error) {
if code == "" { if code == "" {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidRequest, Error: InvalidRequest,
ErrorDescription: "authorization code should not be empty", ErrorDescription: "authorization code should not be empty",
} }, nil
}
token, err := getTokenByCode(code)
if err != nil {
return nil, nil, err
} }
token := getTokenByCode(code)
if token == nil { if token == nil {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "authorization code is invalid", ErrorDescription: "authorization code is invalid",
} }, nil
} }
if token.CodeIsUsed { if token.CodeIsUsed {
// anti replay attacks // anti replay attacks
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "authorization code has been used", ErrorDescription: "authorization code has been used",
} }, nil
} }
if token.CodeChallenge != "" && pkceChallenge(verifier) != token.CodeChallenge { if token.CodeChallenge != "" && pkceChallenge(verifier) != token.CodeChallenge {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "verifier is invalid", ErrorDescription: "verifier is invalid",
} }, nil
} }
if application.ClientSecret != clientSecret { if application.ClientSecret != clientSecret {
@ -523,13 +586,13 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidClient, Error: InvalidClient,
ErrorDescription: "client_secret is invalid", ErrorDescription: "client_secret is invalid",
} }, nil
} else { } else {
if clientSecret != "" { if clientSecret != "" {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidClient, Error: InvalidClient,
ErrorDescription: "client_secret is invalid", ErrorDescription: "client_secret is invalid",
} }, nil
} }
} }
} }
@ -538,7 +601,7 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "the token is for wrong application (client_id)", ErrorDescription: "the token is for wrong application (client_id)",
} }, nil
} }
if time.Now().Unix() > token.CodeExpireIn { if time.Now().Unix() > token.CodeExpireIn {
@ -546,42 +609,50 @@ func GetAuthorizationCodeToken(application *Application, clientSecret string, co
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "authorization code has expired", ErrorDescription: "authorization code has expired",
} }, nil
} }
return token, nil return token, nil, nil
} }
// GetPasswordToken // GetPasswordToken
// Resource Owner Password Credentials flow // Resource Owner Password Credentials flow
func GetPasswordToken(application *Application, username string, password string, scope string, host string) (*Token, *TokenError) { func GetPasswordToken(application *Application, username string, password string, scope string, host string) (*Token, *TokenError, error) {
user := getUser(application.Organization, username) user, err := getUser(application.Organization, username)
if err != nil {
return nil, nil, err
}
if user == nil { if user == nil {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "the user does not exist", ErrorDescription: "the user does not exist",
} }, nil
} }
msg := CheckPassword(user, password, "en") msg := CheckPassword(user, password, "en")
if msg != "" { if msg != "" {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "invalid username or password", ErrorDescription: "invalid username or password",
} }, nil
} }
if user.IsForbidden { if user.IsForbidden {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "the user is forbidden to sign in, please contact the administrator", ErrorDescription: "the user is forbidden to sign in, please contact the administrator",
} }, nil
}
err = ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, nil, err
} }
ExtendUserWithRolesAndPermissions(user)
accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host) accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host)
if err != nil { if err != nil {
return nil, &TokenError{ return nil, &TokenError{
Error: EndpointError, Error: EndpointError,
ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()), ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()),
} }, nil
} }
token := &Token{ token := &Token{
Owner: application.Owner, Owner: application.Owner,
@ -598,18 +669,22 @@ func GetPasswordToken(application *Application, username string, password string
TokenType: "Bearer", TokenType: "Bearer",
CodeIsUsed: true, CodeIsUsed: true,
} }
AddToken(token) _, err = AddToken(token)
return token, nil if err != nil {
return nil, nil, err
}
return token, nil, nil
} }
// GetClientCredentialsToken // GetClientCredentialsToken
// Client Credentials flow // Client Credentials flow
func GetClientCredentialsToken(application *Application, clientSecret string, scope string, host string) (*Token, *TokenError) { func GetClientCredentialsToken(application *Application, clientSecret string, scope string, host string) (*Token, *TokenError, error) {
if application.ClientSecret != clientSecret { if application.ClientSecret != clientSecret {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidClient, Error: InvalidClient,
ErrorDescription: "client_secret is invalid", ErrorDescription: "client_secret is invalid",
} }, nil
} }
nullUser := &User{ nullUser := &User{
Owner: application.Owner, Owner: application.Owner,
@ -623,7 +698,7 @@ func GetClientCredentialsToken(application *Application, clientSecret string, sc
return nil, &TokenError{ return nil, &TokenError{
Error: EndpointError, Error: EndpointError,
ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()), ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()),
} }, nil
} }
token := &Token{ token := &Token{
Owner: application.Owner, Owner: application.Owner,
@ -639,18 +714,27 @@ func GetClientCredentialsToken(application *Application, clientSecret string, sc
TokenType: "Bearer", TokenType: "Bearer",
CodeIsUsed: true, CodeIsUsed: true,
} }
AddToken(token) _, err = AddToken(token)
return token, nil if err != nil {
return nil, nil, err
}
return token, nil, nil
} }
// GetTokenByUser // GetTokenByUser
// Implicit flow // Implicit flow
func GetTokenByUser(application *Application, user *User, scope string, host string) (*Token, error) { func GetTokenByUser(application *Application, user *User, scope string, host string) (*Token, error) {
ExtendUserWithRolesAndPermissions(user) err := ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, err
}
accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host) accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", scope, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
token := &Token{ token := &Token{
Owner: application.Owner, Owner: application.Owner,
Name: tokenName, Name: tokenName,
@ -666,43 +750,56 @@ func GetTokenByUser(application *Application, user *User, scope string, host str
TokenType: "Bearer", TokenType: "Bearer",
CodeIsUsed: true, CodeIsUsed: true,
} }
AddToken(token) _, err = AddToken(token)
if err != nil {
return nil, err
}
return token, nil return token, nil
} }
// GetWechatMiniProgramToken // GetWechatMiniProgramToken
// Wechat Mini Program flow // Wechat Mini Program flow
func GetWechatMiniProgramToken(application *Application, code string, host string, username string, avatar string, lang string) (*Token, *TokenError) { func GetWechatMiniProgramToken(application *Application, code string, host string, username string, avatar string, lang string) (*Token, *TokenError, error) {
mpProvider := GetWechatMiniProgramProvider(application) mpProvider := GetWechatMiniProgramProvider(application)
if mpProvider == nil { if mpProvider == nil {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidClient, Error: InvalidClient,
ErrorDescription: "the application does not support wechat mini program", ErrorDescription: "the application does not support wechat mini program",
} }, nil
} }
provider := GetProvider(util.GetId("admin", mpProvider.Name)) provider, err := GetProvider(util.GetId("admin", mpProvider.Name))
if err != nil {
return nil, nil, err
}
mpIdp := idp.NewWeChatMiniProgramIdProvider(provider.ClientId, provider.ClientSecret) mpIdp := idp.NewWeChatMiniProgramIdProvider(provider.ClientId, provider.ClientSecret)
session, err := mpIdp.GetSessionByCode(code) session, err := mpIdp.GetSessionByCode(code)
if err != nil { if err != nil {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: fmt.Sprintf("get wechat mini program session error: %s", err.Error()), ErrorDescription: fmt.Sprintf("get wechat mini program session error: %s", err.Error()),
} }, nil
} }
openId, unionId := session.Openid, session.Unionid openId, unionId := session.Openid, session.Unionid
if openId == "" && unionId == "" { if openId == "" && unionId == "" {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidRequest, Error: InvalidRequest,
ErrorDescription: "the wechat mini program session is invalid", ErrorDescription: "the wechat mini program session is invalid",
} }, nil
} }
user := getUserByWechatId(openId, unionId) user, err := getUserByWechatId(openId, unionId)
if err != nil {
return nil, nil, err
}
if user == nil { if user == nil {
if !application.EnableSignUp { if !application.EnableSignUp {
return nil, &TokenError{ return nil, &TokenError{
Error: InvalidGrant, Error: InvalidGrant,
ErrorDescription: "the application does not allow to sign up new account", ErrorDescription: "the application does not allow to sign up new account",
} }, nil
} }
// Add new user // Add new user
var name string var name string
@ -730,16 +827,23 @@ func GetWechatMiniProgramToken(application *Application, code string, host strin
UserPropertiesWechatUnionId: unionId, UserPropertiesWechatUnionId: unionId,
}, },
} }
AddUser(user) _, err = AddUser(user)
if err != nil {
return nil, nil, err
}
}
err = ExtendUserWithRolesAndPermissions(user)
if err != nil {
return nil, nil, err
} }
ExtendUserWithRolesAndPermissions(user)
accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", "", host) accessToken, refreshToken, tokenName, err := generateJwtToken(application, user, "", "", host)
if err != nil { if err != nil {
return nil, &TokenError{ return nil, &TokenError{
Error: EndpointError, Error: EndpointError,
ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()), ErrorDescription: fmt.Sprintf("generate jwt token error: %s", err.Error()),
} }, nil
} }
token := &Token{ token := &Token{
@ -757,6 +861,9 @@ func GetWechatMiniProgramToken(application *Application, code string, host strin
TokenType: "Bearer", TokenType: "Bearer",
CodeIsUsed: true, CodeIsUsed: true,
} }
AddToken(token) _, err = AddToken(token)
return token, nil if err != nil {
return nil, nil, err
}
return token, nil, nil
} }

View File

@ -177,7 +177,9 @@ func StoreCasTokenForProxyTicket(token *CasAuthenticationSuccess, targetService,
} }
func GenerateCasToken(userId string, service string) (string, error) { func GenerateCasToken(userId string, service string) (string, error) {
if user := GetUser(userId); user != nil { if user, err := GetUser(userId); err != nil {
return "", err
} else if user != nil {
authenticationSuccess := CasAuthenticationSuccess{ authenticationSuccess := CasAuthenticationSuccess{
User: user.Name, User: user.Name,
Attributes: &CasAttributes{ Attributes: &CasAttributes{
@ -232,18 +234,31 @@ func GetValidationBySaml(samlRequest string, host string) (string, string, error
return "", "", fmt.Errorf("ticket %s found", ticket) return "", "", fmt.Errorf("ticket %s found", ticket)
} }
user := GetUser(userId) user, err := GetUser(userId)
if err != nil {
return "", "", err
}
if user == nil { if user == nil {
return "", "", fmt.Errorf("user %s found", userId) return "", "", fmt.Errorf("user %s found", userId)
} }
application := GetApplicationByUser(user)
application, err := GetApplicationByUser(user)
if err != nil {
return "", "", err
}
if application == nil { if application == nil {
return "", "", fmt.Errorf("application for user %s found", userId) return "", "", fmt.Errorf("application for user %s found", userId)
} }
samlResponse := NewSamlResponse11(user, request.RequestID, host) samlResponse := NewSamlResponse11(user, request.RequestID, host)
cert := getCertByApplication(application) cert, err := getCertByApplication(application)
if err != nil {
return "", "", err
}
block, _ := pem.Decode([]byte(cert.Certificate)) block, _ := pem.Decode([]byte(cert.Certificate))
certificate := base64.StdEncoding.EncodeToString(block.Bytes) certificate := base64.StdEncoding.EncodeToString(block.Bytes)
randomKeyStore := &X509Key{ randomKeyStore := &X509Key{

View File

@ -273,7 +273,10 @@ func generateJwtToken(application *Application, user *User, nonce string, scope
refreshToken = jwt.NewWithClaims(jwt.SigningMethodRS256, claimsWithoutThirdIdp) refreshToken = jwt.NewWithClaims(jwt.SigningMethodRS256, claimsWithoutThirdIdp)
} }
cert := getCertByApplication(application) cert, err := getCertByApplication(application)
if err != nil {
return "", "", "", err
}
// RSA private key // RSA private key
key, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(cert.PrivateKey)) key, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(cert.PrivateKey))
@ -316,5 +319,10 @@ func ParseJwtToken(token string, cert *Cert) (*Claims, error) {
} }
func ParseJwtTokenByApplication(token string, application *Application) (*Claims, error) { func ParseJwtTokenByApplication(token string, application *Application) (*Claims, error) {
return ParseJwtToken(token, getCertByApplication(application)) cert, err := getCertByApplication(application)
if err != nil {
return nil, err
}
return ParseJwtToken(token, cert)
} }

View File

@ -191,217 +191,206 @@ type ManagedAccount struct {
SigninUrl string `xorm:"varchar(200)" json:"signinUrl"` SigninUrl string `xorm:"varchar(200)" json:"signinUrl"`
} }
func GetGlobalUserCount(field, value string) int { func GetGlobalUserCount(field, value string) (int64, error) {
session := GetSession("", -1, -1, field, value, "", "") session := GetSession("", -1, -1, field, value, "", "")
count, err := session.Count(&User{}) return session.Count(&User{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetGlobalUsers() []*User { func GetGlobalUsers() ([]*User, error) {
users := []*User{} users := []*User{}
err := adapter.Engine.Desc("created_time").Find(&users) err := adapter.Engine.Desc("created_time").Find(&users)
if err != nil { if err != nil {
panic(err) return nil, err
} }
return users return users, nil
} }
func GetPaginationGlobalUsers(offset, limit int, field, value, sortField, sortOrder string) []*User { func GetPaginationGlobalUsers(offset, limit int, field, value, sortField, sortOrder string) ([]*User, error) {
users := []*User{} users := []*User{}
session := GetSession("", offset, limit, field, value, sortField, sortOrder) session := GetSession("", offset, limit, field, value, sortField, sortOrder)
err := session.Find(&users) err := session.Find(&users)
if err != nil { if err != nil {
panic(err) return nil, err
} }
return users return users, nil
} }
func GetUserCount(owner, field, value string) int { func GetUserCount(owner, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&User{}) return session.Count(&User{})
if err != nil {
panic(err)
}
return int(count)
} }
func GetOnlineUserCount(owner string, isOnline int) int { func GetOnlineUserCount(owner string, isOnline int) (int64, error) {
count, err := adapter.Engine.Where("is_online = ?", isOnline).Count(&User{Owner: owner}) return adapter.Engine.Where("is_online = ?", isOnline).Count(&User{Owner: owner})
if err != nil {
panic(err)
}
return int(count)
} }
func GetUsers(owner string) []*User { func GetUsers(owner string) ([]*User, error) {
users := []*User{} users := []*User{}
err := adapter.Engine.Desc("created_time").Find(&users, &User{Owner: owner}) err := adapter.Engine.Desc("created_time").Find(&users, &User{Owner: owner})
if err != nil { if err != nil {
panic(err) return nil, err
} }
return users return users, nil
} }
func GetUsersByTag(owner string, tag string) []*User { func GetUsersByTag(owner string, tag string) ([]*User, error) {
users := []*User{} users := []*User{}
err := adapter.Engine.Desc("created_time").Find(&users, &User{Owner: owner, Tag: tag}) err := adapter.Engine.Desc("created_time").Find(&users, &User{Owner: owner, Tag: tag})
if err != nil { if err != nil {
panic(err) return nil, err
} }
return users return users, nil
} }
func GetSortedUsers(owner string, sorter string, limit int) []*User { func GetSortedUsers(owner string, sorter string, limit int) ([]*User, error) {
users := []*User{} users := []*User{}
err := adapter.Engine.Desc(sorter).Limit(limit, 0).Find(&users, &User{Owner: owner}) err := adapter.Engine.Desc(sorter).Limit(limit, 0).Find(&users, &User{Owner: owner})
if err != nil { if err != nil {
panic(err) return nil, err
} }
return users return users, nil
} }
func GetPaginationUsers(owner string, offset, limit int, field, value, sortField, sortOrder string) []*User { func GetPaginationUsers(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*User, error) {
users := []*User{} users := []*User{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&users) err := session.Find(&users)
if err != nil { if err != nil {
panic(err) return nil, err
} }
return users return users, nil
} }
func getUser(owner string, name string) *User { func getUser(owner string, name string) (*User, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
user := User{Owner: owner, Name: name} user := User{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&user) existed, err := adapter.Engine.Get(&user)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &user return &user, nil
} else { } else {
return nil return nil, nil
} }
} }
func getUserById(owner string, id string) *User { func getUserById(owner string, id string) (*User, error) {
if owner == "" || id == "" { if owner == "" || id == "" {
return nil return nil, nil
} }
user := User{Owner: owner, Id: id} user := User{Owner: owner, Id: id}
existed, err := adapter.Engine.Get(&user) existed, err := adapter.Engine.Get(&user)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &user return &user, nil
} else { } else {
return nil return nil, nil
} }
} }
func getUserByWechatId(wechatOpenId string, wechatUnionId string) *User { func getUserByWechatId(wechatOpenId string, wechatUnionId string) (*User, error) {
if wechatUnionId == "" { if wechatUnionId == "" {
wechatUnionId = wechatOpenId wechatUnionId = wechatOpenId
} }
user := &User{} user := &User{}
existed, err := adapter.Engine.Where("wechat = ? OR wechat = ?", wechatOpenId, wechatUnionId).Get(user) existed, err := adapter.Engine.Where("wechat = ? OR wechat = ?", wechatOpenId, wechatUnionId).Get(user)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return user return user, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetUserByEmail(owner string, email string) *User { func GetUserByEmail(owner string, email string) (*User, error) {
if owner == "" || email == "" { if owner == "" || email == "" {
return nil return nil, nil
} }
user := User{Owner: owner, Email: email} user := User{Owner: owner, Email: email}
existed, err := adapter.Engine.Get(&user) existed, err := adapter.Engine.Get(&user)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &user return &user, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetUserByPhone(owner string, phone string) *User { func GetUserByPhone(owner string, phone string) (*User, error) {
if owner == "" || phone == "" { if owner == "" || phone == "" {
return nil return nil, nil
} }
user := User{Owner: owner, Phone: phone} user := User{Owner: owner, Phone: phone}
existed, err := adapter.Engine.Get(&user) existed, err := adapter.Engine.Get(&user)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &user return &user, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetUserByUserId(owner string, userId string) *User { func GetUserByUserId(owner string, userId string) (*User, error) {
if owner == "" || userId == "" { if owner == "" || userId == "" {
return nil return nil, nil
} }
user := User{Owner: owner, Id: userId} user := User{Owner: owner, Id: userId}
existed, err := adapter.Engine.Get(&user) existed, err := adapter.Engine.Get(&user)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &user return &user, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetUser(id string) *User { func GetUser(id string) (*User, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getUser(owner, name) return getUser(owner, name)
} }
func GetUserNoCheck(id string) *User { func GetUserNoCheck(id string) (*User, error) {
owner, name := util.GetOwnerAndNameFromIdNoCheck(id) owner, name := util.GetOwnerAndNameFromIdNoCheck(id)
return getUser(owner, name) return getUser(owner, name)
} }
func GetMaskedUser(user *User) *User { func GetMaskedUser(user *User, errs ...error) (*User, error) {
if len(errs) > 0 && errs[0] != nil {
return nil, errs[0]
}
if user == nil { if user == nil {
return nil return nil, nil
} }
if user.Password != "" { if user.Password != "" {
@ -419,51 +408,69 @@ func GetMaskedUser(user *User) *User {
user.MultiFactorAuths[i] = GetMaskedProps(props) user.MultiFactorAuths[i] = GetMaskedProps(props)
} }
} }
return user return user, nil
} }
func GetMaskedUsers(users []*User) []*User { func GetMaskedUsers(users []*User, errs ...error) ([]*User, error) {
for _, user := range users { if len(errs) > 0 && errs[0] != nil {
user = GetMaskedUser(user) return nil, errs[0]
} }
return users
var err error
for _, user := range users {
user, err = GetMaskedUser(user)
if err != nil {
return nil, err
}
}
return users, nil
} }
func GetLastUser(owner string) *User { func GetLastUser(owner string) (*User, error) {
user := User{Owner: owner} user := User{Owner: owner}
existed, err := adapter.Engine.Desc("created_time", "id").Get(&user) existed, err := adapter.Engine.Desc("created_time", "id").Get(&user)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &user return &user, nil
} }
return nil return nil, nil
} }
func UpdateUser(id string, user *User, columns []string, isAdmin bool) bool { func UpdateUser(id string, user *User, columns []string, isAdmin bool) (bool, error) {
var err error
owner, name := util.GetOwnerAndNameFromIdNoCheck(id) owner, name := util.GetOwnerAndNameFromIdNoCheck(id)
oldUser := getUser(owner, name) oldUser, err := getUser(owner, name)
if err != nil {
return false, err
}
if oldUser == nil { if oldUser == nil {
return false return false, nil
} }
if name != user.Name { if name != user.Name {
err := userChangeTrigger(name, user.Name) err := userChangeTrigger(name, user.Name)
if err != nil { if err != nil {
return false return false, nil
} }
} }
if user.Password == "***" { if user.Password == "***" {
user.Password = oldUser.Password user.Password = oldUser.Password
} }
user.UpdateUserHash() err = user.UpdateUserHash()
if err != nil {
panic(err)
}
if user.Avatar != oldUser.Avatar && user.Avatar != "" && user.PermanentAvatar != "*" { if user.Avatar != oldUser.Avatar && user.Avatar != "" && user.PermanentAvatar != "*" {
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false) user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
if err != nil {
return false, err
}
} }
if len(columns) == 0 { if len(columns) == 0 {
@ -487,77 +494,105 @@ func UpdateUser(id string, user *User, columns []string, isAdmin bool) bool {
affected, err := adapter.Engine.ID(core.PK{owner, name}).Cols(columns...).Update(user) affected, err := adapter.Engine.ID(core.PK{owner, name}).Cols(columns...).Update(user)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func UpdateUserForAllFields(id string, user *User) bool { func UpdateUserForAllFields(id string, user *User) (bool, error) {
var err error
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
oldUser := getUser(owner, name) oldUser, err := getUser(owner, name)
if err != nil {
return false, err
}
if oldUser == nil { if oldUser == nil {
return false return false, nil
} }
if name != user.Name { if name != user.Name {
err := userChangeTrigger(name, user.Name) err := userChangeTrigger(name, user.Name)
if err != nil { if err != nil {
return false return false, nil
} }
} }
user.UpdateUserHash() err = user.UpdateUserHash()
if err != nil {
return false, err
}
if user.Avatar != oldUser.Avatar && user.Avatar != "" { if user.Avatar != oldUser.Avatar && user.Avatar != "" {
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false) user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
if err != nil {
return false, err
}
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(user) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(user)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddUser(user *User) bool { func AddUser(user *User) (bool, error) {
var err error
if user.Id == "" { if user.Id == "" {
user.Id = util.GenerateId() user.Id = util.GenerateId()
} }
if user.Owner == "" || user.Name == "" { if user.Owner == "" || user.Name == "" {
return false return false, nil
} }
organization := GetOrganizationByUser(user) organization, _ := GetOrganizationByUser(user)
if organization == nil { if organization == nil {
return false return false, nil
} }
user.UpdateUserPassword(organization) user.UpdateUserPassword(organization)
user.UpdateUserHash() err = user.UpdateUserHash()
user.PreHash = user.Hash if err != nil {
return false, err
updated := user.refreshAvatar()
if updated && user.PermanentAvatar != "*" {
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
} }
user.Ranking = GetUserCount(user.Owner, "", "") + 1 user.PreHash = user.Hash
updated, err := user.refreshAvatar()
if err != nil {
return false, err
}
if updated && user.PermanentAvatar != "*" {
user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, false)
if err != nil {
return false, err
}
}
count, err := GetUserCount(user.Owner, "", "")
if err != nil {
return false, err
}
user.Ranking = int(count + 1)
affected, err := adapter.Engine.Insert(user) affected, err := adapter.Engine.Insert(user)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddUsers(users []*User) bool { func AddUsers(users []*User) (bool, error) {
var err error
if len(users) == 0 { if len(users) == 0 {
return false return false, nil
} }
// organization := GetOrganizationByUser(users[0]) // organization := GetOrganizationByUser(users[0])
@ -565,27 +600,34 @@ func AddUsers(users []*User) bool {
// this function is only used for syncer or batch upload, so no need to encrypt the password // this function is only used for syncer or batch upload, so no need to encrypt the password
// user.UpdateUserPassword(organization) // user.UpdateUserPassword(organization)
user.UpdateUserHash() err = user.UpdateUserHash()
if err != nil {
return false, err
}
user.PreHash = user.Hash user.PreHash = user.Hash
user.PermanentAvatar = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true) user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
if err != nil {
return false, err
}
} }
affected, err := adapter.Engine.Insert(users) affected, err := adapter.Engine.Insert(users)
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "Duplicate entry") { if !strings.Contains(err.Error(), "Duplicate entry") {
panic(err) return false, err
} }
} }
return affected != 0 return affected != 0, nil
} }
func AddUsersInBatch(users []*User) bool { func AddUsersInBatch(users []*User) (bool, error) {
batchSize := conf.GetConfigBatchSize() batchSize := conf.GetConfigBatchSize()
if len(users) == 0 { if len(users) == 0 {
return false return false, nil
} }
affected := false affected := false
@ -599,24 +641,29 @@ func AddUsersInBatch(users []*User) bool {
tmp := users[start:end] tmp := users[start:end]
// TODO: save to log instead of standard output // TODO: save to log instead of standard output
// fmt.Printf("Add users: [%d - %d].\n", start, end) // fmt.Printf("Add users: [%d - %d].\n", start, end)
if AddUsers(tmp) { if ok, err := AddUsers(tmp); err != nil {
return false, err
} else if ok {
affected = true affected = true
} }
} }
return affected return affected, nil
} }
func DeleteUser(user *User) bool { func DeleteUser(user *User) (bool, error) {
// Forced offline the user first // Forced offline the user first
DeleteSession(util.GetSessionId(user.Owner, user.Name, CasdoorApplication)) _, err := DeleteSession(util.GetSessionId(user.Owner, user.Name, CasdoorApplication))
if err != nil {
return false, err
}
affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Delete(&User{}) affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Delete(&User{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func GetUserInfo(user *User, scope string, aud string, host string) *Userinfo { func GetUserInfo(user *User, scope string, aud string, host string) *Userinfo {
@ -644,7 +691,7 @@ func GetUserInfo(user *User, scope string, aud string, host string) *Userinfo {
return &resp return &resp
} }
func LinkUserAccount(user *User, field string, value string) bool { func LinkUserAccount(user *User, field string, value string) (bool, error) {
return SetUserField(user, field, value) return SetUserField(user, field, value)
} }
@ -656,13 +703,18 @@ func isUserIdGlobalAdmin(userId string) bool {
return strings.HasPrefix(userId, "built-in/") return strings.HasPrefix(userId, "built-in/")
} }
func ExtendUserWithRolesAndPermissions(user *User) { func ExtendUserWithRolesAndPermissions(user *User) (err error) {
if user == nil { if user == nil {
return return
} }
user.Roles = GetRolesByUser(user.GetId()) user.Roles, err = GetRolesByUser(user.GetId())
user.Permissions = GetPermissionsByUser(user.GetId()) if err != nil {
return
}
user.Permissions, err = GetPermissionsByUser(user.GetId())
return
} }
func userChangeTrigger(oldName string, newName string) error { func userChangeTrigger(oldName string, newName string) error {
@ -679,6 +731,7 @@ func userChangeTrigger(oldName string, newName string) error {
if err != nil { if err != nil {
return err return err
} }
for _, role := range roles { for _, role := range roles {
for j, u := range role.Users { for j, u := range role.Users {
// u = organization/username // u = organization/username
@ -722,7 +775,7 @@ func userChangeTrigger(oldName string, newName string) error {
return session.Commit() return session.Commit()
} }
func (user *User) refreshAvatar() bool { func (user *User) refreshAvatar() (bool, error) {
var err error var err error
var fileBuffer *bytes.Buffer var fileBuffer *bytes.Buffer
var ext string var ext string
@ -732,13 +785,13 @@ func (user *User) refreshAvatar() bool {
client := proxy.ProxyHttpClient client := proxy.ProxyHttpClient
has, err := hasGravatar(client, user.Email) has, err := hasGravatar(client, user.Email)
if err != nil { if err != nil {
panic(err) return false, err
} }
if has { if has {
fileBuffer, ext, err = getGravatarFileBuffer(client, user.Email) fileBuffer, ext, err = getGravatarFileBuffer(client, user.Email)
if err != nil { if err != nil {
panic(err) return false, err
} }
} }
} }
@ -746,17 +799,20 @@ func (user *User) refreshAvatar() bool {
if fileBuffer == nil && strings.Contains(user.Avatar, "Identicon") { if fileBuffer == nil && strings.Contains(user.Avatar, "Identicon") {
fileBuffer, ext, err = getIdenticonFileBuffer(user.Name) fileBuffer, ext, err = getIdenticonFileBuffer(user.Name)
if err != nil { if err != nil {
panic(err) return false, err
} }
} }
if fileBuffer != nil { if fileBuffer != nil {
avatarUrl := getPermanentAvatarUrlFromBuffer(user.Owner, user.Name, fileBuffer, ext, true) avatarUrl, err := getPermanentAvatarUrlFromBuffer(user.Owner, user.Name, fileBuffer, ext, true)
if err != nil {
return false, err
}
user.Avatar = avatarUrl user.Avatar = avatarUrl
return true return true, nil
} }
return false return false, nil
} }
func (user *User) IsMfaEnabled() bool { func (user *User) IsMfaEnabled() bool {

View File

@ -16,18 +16,27 @@ package object
import "github.com/casdoor/casdoor/cred" import "github.com/casdoor/casdoor/cred"
func calculateHash(user *User) string { func calculateHash(user *User) (string, error) {
syncer := getDbSyncerForUser(user) syncer, err := getDbSyncerForUser(user)
if syncer == nil { if err != nil {
return "" return "", err
} }
return syncer.calculateHash(user) if syncer == nil {
return "", nil
}
return syncer.calculateHash(user), nil
} }
func (user *User) UpdateUserHash() { func (user *User) UpdateUserHash() error {
hash := calculateHash(user) hash, err := calculateHash(user)
if err != nil {
return err
}
user.Hash = hash user.Hash = hash
return nil
} }
func (user *User) UpdateUserPassword(organization *Organization) { func (user *User) UpdateUserPassword(organization *Organization) {

View File

@ -36,7 +36,7 @@ func updateUserColumn(column string, user *User) bool {
func TestSyncAvatarsFromGitHub(t *testing.T) { func TestSyncAvatarsFromGitHub(t *testing.T) {
InitConfig() InitConfig()
users := GetGlobalUsers() users, _ := GetGlobalUsers()
for _, user := range users { for _, user := range users {
if user.GitHub == "" { if user.GitHub == "" {
continue continue
@ -50,7 +50,7 @@ func TestSyncAvatarsFromGitHub(t *testing.T) {
func TestSyncIds(t *testing.T) { func TestSyncIds(t *testing.T) {
InitConfig() InitConfig()
users := GetGlobalUsers() users, _ := GetGlobalUsers()
for _, user := range users { for _, user := range users {
if user.Id != "" { if user.Id != "" {
continue continue
@ -64,13 +64,16 @@ func TestSyncIds(t *testing.T) {
func TestSyncHashes(t *testing.T) { func TestSyncHashes(t *testing.T) {
InitConfig() InitConfig()
users := GetGlobalUsers() users, _ := GetGlobalUsers()
for _, user := range users { for _, user := range users {
if user.Hash != "" { if user.Hash != "" {
continue continue
} }
user.UpdateUserHash() err := user.UpdateUserHash()
if err != nil {
panic(err)
}
updateUserColumn("hash", user) updateUserColumn("hash", user)
} }
} }
@ -92,7 +95,7 @@ func TestGetMaskedUsers(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := GetMaskedUsers(tt.args.users); !reflect.DeepEqual(got, tt.want) { if got, _ := GetMaskedUsers(tt.args.users); !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetMaskedUsers() = %v, want %v", got, tt.want) t.Errorf("GetMaskedUsers() = %v, want %v", got, tt.want)
} }
}) })
@ -102,7 +105,7 @@ func TestGetMaskedUsers(t *testing.T) {
func TestGetUserByField(t *testing.T) { func TestGetUserByField(t *testing.T) {
InitConfig() InitConfig()
user := GetUserByField("built-in", "DingTalk", "test") user, _ := GetUserByField("built-in", "DingTalk", "test")
if user != nil { if user != nil {
t.Logf("%+v", user) t.Logf("%+v", user)
} else { } else {
@ -115,7 +118,7 @@ func TestGetEmailsForUsers(t *testing.T) {
emailMap := map[string]int{} emailMap := map[string]int{}
emails := []string{} emails := []string{}
users := GetUsers("built-in") users, _ := GetUsers("built-in")
for _, user := range users { for _, user := range users {
if user.Email == "" { if user.Email == "" {
continue continue

View File

@ -22,15 +22,18 @@ import (
"github.com/casdoor/casdoor/xlsx" "github.com/casdoor/casdoor/xlsx"
) )
func getUserMap(owner string) map[string]*User { func getUserMap(owner string) (map[string]*User, error) {
m := map[string]*User{} m := map[string]*User{}
users := GetUsers(owner) users, err := GetUsers(owner)
if err != nil {
return m, err
}
for _, user := range users { for _, user := range users {
m[user.GetId()] = user m[user.GetId()] = user
} }
return m return m, nil
} }
func parseLineItem(line *[]string, i int) string { func parseLineItem(line *[]string, i int) string {
@ -70,10 +73,14 @@ func parseListItem(lines *[]string, i int) []string {
return trimmedItems return trimmedItems
} }
func UploadUsers(owner string, fileId string) bool { func UploadUsers(owner string, fileId string) (bool, error) {
table := xlsx.ReadXlsxFile(fileId) table := xlsx.ReadXlsxFile(fileId)
oldUserMap := getUserMap(owner) oldUserMap, err := getUserMap(owner)
if err != nil {
return false, err
}
newUsers := []*User{} newUsers := []*User{}
for index, line := range table { for index, line := range table {
if index == 0 || parseLineItem(&line, 0) == "" { if index == 0 || parseLineItem(&line, 0) == "" {
@ -135,7 +142,7 @@ func UploadUsers(owner string, fileId string) bool {
} }
if len(newUsers) == 0 { if len(newUsers) == 0 {
return false return false, nil
} }
return AddUsersInBatch(newUsers) return AddUsersInBatch(newUsers)
} }

View File

@ -24,62 +24,70 @@ import (
"github.com/xorm-io/core" "github.com/xorm-io/core"
) )
func GetUserByField(organizationName string, field string, value string) *User { func GetUserByField(organizationName string, field string, value string) (*User, error) {
if field == "" || value == "" { if field == "" || value == "" {
return nil return nil, nil
} }
user := User{Owner: organizationName} user := User{Owner: organizationName}
existed, err := adapter.Engine.Where(fmt.Sprintf("%s=?", strings.ToLower(field)), value).Get(&user) existed, err := adapter.Engine.Where(fmt.Sprintf("%s=?", strings.ToLower(field)), value).Get(&user)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if existed { if existed {
return &user return &user, nil
} else { } else {
return nil return nil, nil
} }
} }
func HasUserByField(organizationName string, field string, value string) bool { func HasUserByField(organizationName string, field string, value string) bool {
return GetUserByField(organizationName, field, value) != nil user, err := GetUserByField(organizationName, field, value)
if err != nil {
panic(err)
}
return user != nil
} }
func GetUserByFields(organization string, field string) *User { func GetUserByFields(organization string, field string) (*User, error) {
// check username // check username
user := GetUserByField(organization, "name", field) user, err := GetUserByField(organization, "name", field)
if user != nil { if err != nil || user != nil {
return user return user, err
} }
// check email // check email
if strings.Contains(field, "@") { if strings.Contains(field, "@") {
user = GetUserByField(organization, "email", field) user, err = GetUserByField(organization, "email", field)
if user != nil { if user != nil || err != nil {
return user return user, err
} }
} }
// check phone // check phone
user = GetUserByField(organization, "phone", field) user, err = GetUserByField(organization, "phone", field)
if user != nil { if user != nil || err != nil {
return user return user, err
} }
// check ID card // check ID card
user = GetUserByField(organization, "id_card", field) user, err = GetUserByField(organization, "id_card", field)
if user != nil { if user != nil || err != nil {
return user return user, err
} }
return nil return nil, nil
} }
func SetUserField(user *User, field string, value string) bool { func SetUserField(user *User, field string, value string) (bool, error) {
bean := make(map[string]interface{}) bean := make(map[string]interface{})
if field == "password" { if field == "password" {
organization := GetOrganizationByUser(user) organization, err := GetOrganizationByUser(user)
if err != nil {
return false, err
}
user.UpdateUserPassword(organization) user.UpdateUserPassword(organization)
bean[strings.ToLower(field)] = user.Password bean[strings.ToLower(field)] = user.Password
bean["password_type"] = user.PasswordType bean["password_type"] = user.PasswordType
@ -89,17 +97,25 @@ func SetUserField(user *User, field string, value string) bool {
affected, err := adapter.Engine.Table(user).ID(core.PK{user.Owner, user.Name}).Update(bean) affected, err := adapter.Engine.Table(user).ID(core.PK{user.Owner, user.Name}).Update(bean)
if err != nil { if err != nil {
panic(err) return false, err
}
user, err = getUser(user.Owner, user.Name)
if err != nil {
return false, err
}
err = user.UpdateUserHash()
if err != nil {
return false, err
} }
user = getUser(user.Owner, user.Name)
user.UpdateUserHash()
_, err = adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("hash").Update(user) _, err = adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("hash").Update(user)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func GetUserField(user *User, field string) string { func GetUserField(user *User, field string) string {
@ -121,7 +137,7 @@ func setUserProperty(user *User, field string, value string) {
} }
} }
func SetUserOAuthProperties(organization *Organization, user *User, providerType string, userInfo *idp.UserInfo) bool { func SetUserOAuthProperties(organization *Organization, user *User, providerType string, userInfo *idp.UserInfo) (bool, error) {
if userInfo.Id != "" { if userInfo.Id != "" {
propertyName := fmt.Sprintf("oauth_%s_id", providerType) propertyName := fmt.Sprintf("oauth_%s_id", providerType)
setUserProperty(user, propertyName, userInfo.Id) setUserProperty(user, propertyName, userInfo.Id)
@ -164,11 +180,10 @@ func SetUserOAuthProperties(organization *Organization, user *User, providerType
} }
} }
affected := UpdateUserForAllFields(user.GetId(), user) return UpdateUserForAllFields(user.GetId(), user)
return affected
} }
func ClearUserOAuthProperties(user *User, providerType string) bool { func ClearUserOAuthProperties(user *User, providerType string) (bool, error) {
for k := range user.Properties { for k := range user.Properties {
prefix := fmt.Sprintf("oauth_%s_", providerType) prefix := fmt.Sprintf("oauth_%s_", providerType)
if strings.HasPrefix(k, prefix) { if strings.HasPrefix(k, prefix) {
@ -178,14 +193,18 @@ func ClearUserOAuthProperties(user *User, providerType string) bool {
affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("properties").Update(user) affected, err := adapter.Engine.ID(core.PK{user.Owner, user.Name}).Cols("properties").Update(user)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func CheckPermissionForUpdateUser(oldUser, newUser *User, isAdmin bool, lang string) (bool, string) { func CheckPermissionForUpdateUser(oldUser, newUser *User, isAdmin bool, lang string) (bool, string) {
organization := GetOrganizationByUser(oldUser) organization, err := GetOrganizationByUser(oldUser)
if err != nil {
return false, err.Error()
}
var itemsChanged []*AccountItem var itemsChanged []*AccountItem
if oldUser.Owner != newUser.Owner { if oldUser.Owner != newUser.Owner {
@ -310,7 +329,7 @@ func (user *User) GetCountryCode(countryCode string) string {
return user.CountryCode return user.CountryCode
} }
if org := GetOrganizationByUser(user); org != nil && len(org.CountryCodes) > 0 { if org, _ := GetOrganizationByUser(user); org != nil && len(org.CountryCodes) > 0 {
return org.CountryCodes[0] return org.CountryCodes[0]
} }
return "" return ""

View File

@ -16,6 +16,7 @@ package object
import ( import (
"encoding/base64" "encoding/base64"
"fmt"
"net/url" "net/url"
"strings" "strings"
@ -24,14 +25,14 @@ import (
"github.com/go-webauthn/webauthn/webauthn" "github.com/go-webauthn/webauthn/webauthn"
) )
func GetWebAuthnObject(host string) *webauthn.WebAuthn { func GetWebAuthnObject(host string) (*webauthn.WebAuthn, error) {
var err error var err error
_, originBackend := getOriginFromHost(host) _, originBackend := getOriginFromHost(host)
localUrl, err := url.Parse(originBackend) localUrl, err := url.Parse(originBackend)
if err != nil { if err != nil {
panic("error when parsing origin:" + err.Error()) return nil, fmt.Errorf("error when parsing origin:" + err.Error())
} }
webAuthn, err := webauthn.New(&webauthn.Config{ webAuthn, err := webauthn.New(&webauthn.Config{
@ -41,10 +42,10 @@ func GetWebAuthnObject(host string) *webauthn.WebAuthn {
// RPIcon: "https://duo.com/logo.png", // Optional icon URL for your site // RPIcon: "https://duo.com/logo.png", // Optional icon URL for your site
}) })
if err != nil { if err != nil {
panic(err) return nil, err
} }
return webAuthn return webAuthn, nil
} }
// WebAuthnID // WebAuthnID
@ -84,17 +85,17 @@ func (user *User) CredentialExcludeList() []protocol.CredentialDescriptor {
return credentialExcludeList return credentialExcludeList
} }
func (user *User) AddCredentials(credential webauthn.Credential, isGlobalAdmin bool) bool { func (user *User) AddCredentials(credential webauthn.Credential, isGlobalAdmin bool) (bool, error) {
user.WebauthnCredentials = append(user.WebauthnCredentials, credential) user.WebauthnCredentials = append(user.WebauthnCredentials, credential)
return UpdateUser(user.GetId(), user, []string{"webauthnCredentials"}, isGlobalAdmin) return UpdateUser(user.GetId(), user, []string{"webauthnCredentials"}, isGlobalAdmin)
} }
func (user *User) DeleteCredentials(credentialIdBase64 string) bool { func (user *User) DeleteCredentials(credentialIdBase64 string) (bool, error) {
for i, credential := range user.WebauthnCredentials { for i, credential := range user.WebauthnCredentials {
if base64.StdEncoding.EncodeToString(credential.ID) == credentialIdBase64 { if base64.StdEncoding.EncodeToString(credential.ID) == credentialIdBase64 {
user.WebauthnCredentials = append(user.WebauthnCredentials[0:i], user.WebauthnCredentials[i+1:]...) user.WebauthnCredentials = append(user.WebauthnCredentials[0:i], user.WebauthnCredentials[i+1:]...)
return UpdateUserForAllFields(user.GetId(), user) return UpdateUserForAllFields(user.GetId(), user)
} }
} }
return false return false, nil
} }

View File

@ -151,21 +151,24 @@ func AddToVerificationRecord(user *User, provider *Provider, remoteAddr, recordT
return nil return nil
} }
func getVerificationRecord(dest string) *VerificationRecord { func getVerificationRecord(dest string) (*VerificationRecord, error) {
var record VerificationRecord var record VerificationRecord
record.Receiver = dest record.Receiver = dest
has, err := adapter.Engine.Desc("time").Where("is_used = false").Get(&record) has, err := adapter.Engine.Desc("time").Where("is_used = false").Get(&record)
if err != nil { if err != nil {
panic(err) return nil, err
} }
if !has { if !has {
return nil return nil, nil
} }
return &record return &record, nil
} }
func CheckVerificationCode(dest, code, lang string) *VerifyResult { func CheckVerificationCode(dest, code, lang string) *VerifyResult {
record := getVerificationRecord(dest) record, err := getVerificationRecord(dest)
if err != nil {
panic(err)
}
if record == nil { if record == nil {
return &VerifyResult{noRecordError, i18n.Translate(lang, "verification:Code has not been sent yet!")} return &VerifyResult{noRecordError, i18n.Translate(lang, "verification:Code has not been sent yet!")}
@ -188,17 +191,15 @@ func CheckVerificationCode(dest, code, lang string) *VerifyResult {
return &VerifyResult{VerificationSuccess, ""} return &VerifyResult{VerificationSuccess, ""}
} }
func DisableVerificationCode(dest string) { func DisableVerificationCode(dest string) (err error) {
record := getVerificationRecord(dest) record, err := getVerificationRecord(dest)
if record == nil { if record == nil || err != nil {
return return
} }
record.IsUsed = true record.IsUsed = true
_, err := adapter.Engine.ID(core.PK{record.Owner, record.Name}).AllCols().Update(record) _, err = adapter.Engine.ID(core.PK{record.Owner, record.Name}).AllCols().Update(record)
if err != nil { return
panic(err)
}
} }
func CheckSigninCode(user *User, dest, code, lang string) string { func CheckSigninCode(user *User, dest, code, lang string) string {

View File

@ -42,100 +42,97 @@ type Webhook struct {
IsEnabled bool `json:"isEnabled"` IsEnabled bool `json:"isEnabled"`
} }
func GetWebhookCount(owner, organization, field, value string) int { func GetWebhookCount(owner, organization, field, value string) (int64, error) {
session := GetSession(owner, -1, -1, field, value, "", "") session := GetSession(owner, -1, -1, field, value, "", "")
count, err := session.Count(&Webhook{Organization: organization}) return session.Count(&Webhook{Organization: organization})
if err != nil {
panic(err)
}
return int(count)
} }
func GetWebhooks(owner string, organization string) []*Webhook { func GetWebhooks(owner string, organization string) ([]*Webhook, error) {
webhooks := []*Webhook{} webhooks := []*Webhook{}
err := adapter.Engine.Desc("created_time").Find(&webhooks, &Webhook{Owner: owner, Organization: organization}) err := adapter.Engine.Desc("created_time").Find(&webhooks, &Webhook{Owner: owner, Organization: organization})
if err != nil { if err != nil {
panic(err) return webhooks, err
} }
return webhooks return webhooks, nil
} }
func GetPaginationWebhooks(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) []*Webhook { func GetPaginationWebhooks(owner, organization string, offset, limit int, field, value, sortField, sortOrder string) ([]*Webhook, error) {
webhooks := []*Webhook{} webhooks := []*Webhook{}
session := GetSession(owner, offset, limit, field, value, sortField, sortOrder) session := GetSession(owner, offset, limit, field, value, sortField, sortOrder)
err := session.Find(&webhooks, &Webhook{Organization: organization}) err := session.Find(&webhooks, &Webhook{Organization: organization})
if err != nil { if err != nil {
panic(err) return nil, err
} }
return webhooks return webhooks, nil
} }
func getWebhooksByOrganization(organization string) []*Webhook { func getWebhooksByOrganization(organization string) ([]*Webhook, error) {
webhooks := []*Webhook{} webhooks := []*Webhook{}
err := adapter.Engine.Desc("created_time").Find(&webhooks, &Webhook{Organization: organization}) err := adapter.Engine.Desc("created_time").Find(&webhooks, &Webhook{Organization: organization})
if err != nil { if err != nil {
panic(err) return webhooks, err
} }
return webhooks return webhooks, nil
} }
func getWebhook(owner string, name string) *Webhook { func getWebhook(owner string, name string) (*Webhook, error) {
if owner == "" || name == "" { if owner == "" || name == "" {
return nil return nil, nil
} }
webhook := Webhook{Owner: owner, Name: name} webhook := Webhook{Owner: owner, Name: name}
existed, err := adapter.Engine.Get(&webhook) existed, err := adapter.Engine.Get(&webhook)
if err != nil { if err != nil {
panic(err) return &webhook, err
} }
if existed { if existed {
return &webhook return &webhook, nil
} else { } else {
return nil return nil, nil
} }
} }
func GetWebhook(id string) *Webhook { func GetWebhook(id string) (*Webhook, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
return getWebhook(owner, name) return getWebhook(owner, name)
} }
func UpdateWebhook(id string, webhook *Webhook) bool { func UpdateWebhook(id string, webhook *Webhook) (bool, error) {
owner, name := util.GetOwnerAndNameFromId(id) owner, name := util.GetOwnerAndNameFromId(id)
if getWebhook(owner, name) == nil { if w, err := getWebhook(owner, name); err != nil {
return false return false, err
} else if w == nil {
return false, nil
} }
affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(webhook) affected, err := adapter.Engine.ID(core.PK{owner, name}).AllCols().Update(webhook)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func AddWebhook(webhook *Webhook) bool { func AddWebhook(webhook *Webhook) (bool, error) {
affected, err := adapter.Engine.Insert(webhook) affected, err := adapter.Engine.Insert(webhook)
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func DeleteWebhook(webhook *Webhook) bool { func DeleteWebhook(webhook *Webhook) (bool, error) {
affected, err := adapter.Engine.ID(core.PK{webhook.Owner, webhook.Name}).Delete(&Webhook{}) affected, err := adapter.Engine.ID(core.PK{webhook.Owner, webhook.Name}).Delete(&Webhook{})
if err != nil { if err != nil {
panic(err) return false, err
} }
return affected != 0 return affected != 0, nil
} }
func (p *Webhook) GetId() string { func (p *Webhook) GetId() string {

Some files were not shown because too many files have changed in this diff Show More