diff --git a/object/adapter.go b/object/adapter.go index d02b859f..2fcf93a6 100644 --- a/object/adapter.go +++ b/object/adapter.go @@ -190,12 +190,17 @@ func (a *Adapter) createTable() { } func GetSession(owner string, offset, limit int, field, value, sortField, sortOrder string) *xorm.Session { - session := adapter.Engine.Limit(limit, offset).Where("1=1") + session := adapter.Engine.Prepare() + if offset != -1 && limit != -1 { + session.Limit(limit, offset) + } if owner != "" { session = session.And("owner=?", owner) } if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) + if filterField(field) { + session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) + } } if sortField == "" || sortOrder == "" { sortField = "created_time" @@ -206,4 +211,4 @@ func GetSession(owner string, offset, limit int, field, value, sortField, sortOr session = session.Desc(util.SnakeString(sortField)) } return session -} +} \ No newline at end of file diff --git a/object/application.go b/object/application.go index d09d1b7b..a515bf45 100644 --- a/object/application.go +++ b/object/application.go @@ -56,10 +56,7 @@ type Application struct { } func GetApplicationCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Application{}) if err != nil { panic(err) diff --git a/object/cert.go b/object/cert.go index 3f5e351b..bc955cde 100644 --- a/object/cert.go +++ b/object/cert.go @@ -53,10 +53,7 @@ func GetMaskedCerts(certs []*Cert) []*Cert { } func GetCertCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Cert{}) if err != nil { panic(err) diff --git a/object/check.go b/object/check.go index 8fa22ac6..02ba0271 100644 --- a/object/check.go +++ b/object/check.go @@ -23,10 +23,14 @@ import ( goldap "github.com/go-ldap/ldap/v3" ) -var reWhiteSpace *regexp.Regexp +var ( + reWhiteSpace *regexp.Regexp + reFieldWhiteList *regexp.Regexp +) func init() { reWhiteSpace, _ = regexp.Compile(`\s`) + reFieldWhiteList, _ = regexp.Compile(`^[A-Za-z0-9]+$`) } func CheckUserSignup(application *Application, organization *Organization, username string, password string, displayName string, email string, phone string, affiliation string) string { @@ -179,3 +183,7 @@ func CheckUserPassword(organization string, username string, password string) (* return user, "" } + +func filterField(field string) bool { + return reFieldWhiteList.MatchString(field) +} \ No newline at end of file diff --git a/object/organization.go b/object/organization.go index 1e7fda81..f3543b9d 100644 --- a/object/organization.go +++ b/object/organization.go @@ -15,8 +15,6 @@ package object import ( - "fmt" - "github.com/casdoor/casdoor/cred" "github.com/casdoor/casdoor/util" "xorm.io/core" @@ -39,10 +37,7 @@ type Organization struct { } func GetOrganizationCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Organization{}) if err != nil { panic(err) diff --git a/object/permission.go b/object/permission.go index 4d7027dd..7c8a37d8 100644 --- a/object/permission.go +++ b/object/permission.go @@ -39,10 +39,7 @@ type Permission struct { } func GetPermissionCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Permission{}) if err != nil { panic(err) diff --git a/object/provider.go b/object/provider.go index 9bf50608..5f408295 100644 --- a/object/provider.go +++ b/object/provider.go @@ -81,10 +81,7 @@ func GetMaskedProviders(providers []*Provider) []*Provider { } func GetProviderCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Provider{}) if err != nil { panic(err) diff --git a/object/record.go b/object/record.go index 99160dfd..318645cc 100644 --- a/object/record.go +++ b/object/record.go @@ -102,10 +102,7 @@ func AddRecord(record *Record) bool { } func GetRecordCount(field, value string) int { - session := adapter.Engine.Where("1=1") - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession("", -1, -1, field, value, "", "") count, err := session.Count(&Record{}) if err != nil { panic(err) diff --git a/object/resource.go b/object/resource.go index fcefe069..7bde6300 100644 --- a/object/resource.go +++ b/object/resource.go @@ -40,11 +40,8 @@ type Resource struct { } func GetResourceCount(owner, user, field, value string) int { - session := adapter.Engine.Where("owner=? and user=?", owner, user) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } - count, err := session.Count(&Resource{}) + session := GetSession(owner, -1, -1, field, value, "", "") + count, err := session.Count(&Resource{User: user}) if err != nil { panic(err) } diff --git a/object/role.go b/object/role.go index bef124f3..d2effa76 100644 --- a/object/role.go +++ b/object/role.go @@ -33,10 +33,7 @@ type Role struct { } func GetRoleCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Role{}) if err != nil { panic(err) diff --git a/object/syncer.go b/object/syncer.go index db03acdc..988f2eec 100644 --- a/object/syncer.go +++ b/object/syncer.go @@ -56,10 +56,7 @@ type Syncer struct { } func GetSyncerCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Syncer{}) if err != nil { panic(err) diff --git a/object/token.go b/object/token.go index 6d87b29b..bd1bfed4 100644 --- a/object/token.go +++ b/object/token.go @@ -57,10 +57,7 @@ type TokenWrapper struct { } func GetTokenCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Token{}) if err != nil { panic(err) diff --git a/object/user.go b/object/user.go index 5cafed64..919dd6af 100644 --- a/object/user.go +++ b/object/user.go @@ -89,10 +89,7 @@ type User struct { } func GetGlobalUserCount(field, value string) int { - session := adapter.Engine.Where("1=1") - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession("", -1, -1, field, value, "", "") count, err := session.Count(&User{}) if err != nil { panic(err) @@ -123,10 +120,7 @@ func GetPaginationGlobalUsers(offset, limit int, field, value, sortField, sortOr } func GetUserCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&User{}) if err != nil { panic(err) diff --git a/object/webhook.go b/object/webhook.go index 6b630330..3ca49163 100644 --- a/object/webhook.go +++ b/object/webhook.go @@ -43,10 +43,7 @@ type Webhook struct { } func GetWebhookCount(owner, field, value string) int { - session := adapter.Engine.Where("owner=?", owner) - if field != "" && value != "" { - session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) - } + session := GetSession(owner, -1, -1, field, value, "", "") count, err := session.Count(&Webhook{}) if err != nil { panic(err)