fix: fix the SQL injection vulnerability in field filter (#442)

Signed-off-by: Yixiang Zhao <seriouszyx@foxmail.com>
This commit is contained in:
Yixiang Zhao 2022-01-26 19:36:36 +08:00 committed by GitHub
parent 051752340d
commit 5ec0c7a890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 31 additions and 59 deletions

View File

@ -190,12 +190,17 @@ func (a *Adapter) createTable() {
} }
func GetSession(owner string, offset, limit int, field, value, sortField, sortOrder string) *xorm.Session { func GetSession(owner string, offset, limit int, field, value, sortField, sortOrder string) *xorm.Session {
session := adapter.Engine.Limit(limit, offset).Where("1=1") session := adapter.Engine.Prepare()
if offset != -1 && limit != -1 {
session.Limit(limit, offset)
}
if owner != "" { if owner != "" {
session = session.And("owner=?", owner) session = session.And("owner=?", owner)
} }
if field != "" && value != "" { if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value)) if filterField(field) {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
} }
if sortField == "" || sortOrder == "" { if sortField == "" || sortOrder == "" {
sortField = "created_time" sortField = "created_time"
@ -206,4 +211,4 @@ func GetSession(owner string, offset, limit int, field, value, sortField, sortOr
session = session.Desc(util.SnakeString(sortField)) session = session.Desc(util.SnakeString(sortField))
} }
return session return session
} }

View File

@ -56,10 +56,7 @@ type Application struct {
} }
func GetApplicationCount(owner, field, value string) int { func GetApplicationCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Application{}) count, err := session.Count(&Application{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -53,10 +53,7 @@ func GetMaskedCerts(certs []*Cert) []*Cert {
} }
func GetCertCount(owner, field, value string) int { func GetCertCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Cert{}) count, err := session.Count(&Cert{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -23,10 +23,14 @@ import (
goldap "github.com/go-ldap/ldap/v3" goldap "github.com/go-ldap/ldap/v3"
) )
var reWhiteSpace *regexp.Regexp var (
reWhiteSpace *regexp.Regexp
reFieldWhiteList *regexp.Regexp
)
func init() { func init() {
reWhiteSpace, _ = regexp.Compile(`\s`) reWhiteSpace, _ = regexp.Compile(`\s`)
reFieldWhiteList, _ = regexp.Compile(`^[A-Za-z0-9]+$`)
} }
func CheckUserSignup(application *Application, organization *Organization, username string, password string, displayName string, email string, phone string, affiliation string) string { func CheckUserSignup(application *Application, organization *Organization, username string, password string, displayName string, email string, phone string, affiliation string) string {
@ -179,3 +183,7 @@ func CheckUserPassword(organization string, username string, password string) (*
return user, "" return user, ""
} }
func filterField(field string) bool {
return reFieldWhiteList.MatchString(field)
}

View File

@ -15,8 +15,6 @@
package object package object
import ( import (
"fmt"
"github.com/casdoor/casdoor/cred" "github.com/casdoor/casdoor/cred"
"github.com/casdoor/casdoor/util" "github.com/casdoor/casdoor/util"
"xorm.io/core" "xorm.io/core"
@ -39,10 +37,7 @@ type Organization struct {
} }
func GetOrganizationCount(owner, field, value string) int { func GetOrganizationCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Organization{}) count, err := session.Count(&Organization{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -39,10 +39,7 @@ type Permission struct {
} }
func GetPermissionCount(owner, field, value string) int { func GetPermissionCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Permission{}) count, err := session.Count(&Permission{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -81,10 +81,7 @@ func GetMaskedProviders(providers []*Provider) []*Provider {
} }
func GetProviderCount(owner, field, value string) int { func GetProviderCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Provider{}) count, err := session.Count(&Provider{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -102,10 +102,7 @@ func AddRecord(record *Record) bool {
} }
func GetRecordCount(field, value string) int { func GetRecordCount(field, value string) int {
session := adapter.Engine.Where("1=1") session := GetSession("", -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Record{}) count, err := session.Count(&Record{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -40,11 +40,8 @@ type Resource struct {
} }
func GetResourceCount(owner, user, field, value string) int { func GetResourceCount(owner, user, field, value string) int {
session := adapter.Engine.Where("owner=? and user=?", owner, user) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" { count, err := session.Count(&Resource{User: user})
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Resource{})
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -33,10 +33,7 @@ type Role struct {
} }
func GetRoleCount(owner, field, value string) int { func GetRoleCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Role{}) count, err := session.Count(&Role{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -56,10 +56,7 @@ type Syncer struct {
} }
func GetSyncerCount(owner, field, value string) int { func GetSyncerCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Syncer{}) count, err := session.Count(&Syncer{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -57,10 +57,7 @@ type TokenWrapper struct {
} }
func GetTokenCount(owner, field, value string) int { func GetTokenCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Token{}) count, err := session.Count(&Token{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -89,10 +89,7 @@ type User struct {
} }
func GetGlobalUserCount(field, value string) int { func GetGlobalUserCount(field, value string) int {
session := adapter.Engine.Where("1=1") session := GetSession("", -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&User{}) count, err := session.Count(&User{})
if err != nil { if err != nil {
panic(err) panic(err)
@ -123,10 +120,7 @@ func GetPaginationGlobalUsers(offset, limit int, field, value, sortField, sortOr
} }
func GetUserCount(owner, field, value string) int { func GetUserCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&User{}) count, err := session.Count(&User{})
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -43,10 +43,7 @@ type Webhook struct {
} }
func GetWebhookCount(owner, field, value string) int { func GetWebhookCount(owner, field, value string) int {
session := adapter.Engine.Where("owner=?", owner) session := GetSession(owner, -1, -1, field, value, "", "")
if field != "" && value != "" {
session = session.And(fmt.Sprintf("%s like ?", util.SnakeString(field)), fmt.Sprintf("%%%s%%", value))
}
count, err := session.Count(&Webhook{}) count, err := session.Count(&Webhook{})
if err != nil { if err != nil {
panic(err) panic(err)