diff --git a/authz/authz.go b/authz/authz.go index b4328e95..4d7761d4 100644 --- a/authz/authz.go +++ b/authz/authz.go @@ -15,7 +15,6 @@ package authz import ( - "github.com/astaxie/beego" "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/model" xormadapter "github.com/casbin/xorm-adapter/v2" @@ -28,8 +27,8 @@ var Enforcer *casbin.Enforcer func InitAuthz() { var err error - tableNamePrefix := beego.AppConfig.String("tableNamePrefix") - a, err := xormadapter.NewAdapterWithTableName(beego.AppConfig.String("driverName"), conf.GetBeegoConfDataSourceName()+beego.AppConfig.String("dbName"), "casbin_rule", tableNamePrefix, true) + tableNamePrefix := conf.GetConfigString("tableNamePrefix") + a, err := xormadapter.NewAdapterWithTableName(conf.GetConfigString("driverName"), conf.GetBeegoConfDataSourceName()+conf.GetConfigString("dbName"), "casbin_rule", tableNamePrefix, true) if err != nil { panic(err) } diff --git a/conf/conf.go b/conf/conf.go index 1886ff34..7ef49917 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -15,14 +15,49 @@ package conf import ( + "fmt" "os" + "strconv" "strings" "github.com/astaxie/beego" ) +func GetConfigString(key string) string { + if value, ok := os.LookupEnv(key); ok { + return value + } + return beego.AppConfig.String(key) +} + +func GetConfigBool(key string) (bool, error) { + value := GetConfigString(key) + if value == "true" { + return true, nil + } else if value == "false" { + return false, nil + } + return false, fmt.Errorf("value %s cannot be converted into bool", value) +} + +func GetConfigInt64(key string) (int64, error) { + value := GetConfigString(key) + num, err := strconv.ParseInt(value, 10, 64) + return num, err +} + +func init() { + //this array contains the beego configuration items that may be modified via env + var presetConfigItems = []string{"httpport", "appname"} + for _, key := range presetConfigItems { + if value, ok := os.LookupEnv(key); ok { + beego.AppConfig.Set(key, value) + } + } +} + func GetBeegoConfDataSourceName() string { - dataSourceName := beego.AppConfig.String("dataSourceName") + dataSourceName := GetConfigString("dataSourceName") runningInDocker := os.Getenv("RUNNING_IN_DOCKER") if runningInDocker == "true" { diff --git a/conf/conf_test.go b/conf/conf_test.go new file mode 100644 index 00000000..604399dc --- /dev/null +++ b/conf/conf_test.go @@ -0,0 +1,98 @@ +// Copyright 2022 The Casdoor Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package conf + +import ( + "os" + "testing" + + "github.com/astaxie/beego" + "github.com/stretchr/testify/assert" +) + +func TestGetConfString(t *testing.T) { + scenarios := []struct { + description string + input string + expected interface{} + }{ + {"Should be return casbin", "appname", "casbin"}, + {"Should be return 8000", "httpport", "8000"}, + {"Should be return value", "key", "value"}, + } + + //do some set up job + + os.Setenv("appname", "casbin") + os.Setenv("key", "value") + + err := beego.LoadAppConfig("ini", "app.conf") + assert.Nil(t, err) + + for _, scenery := range scenarios { + t.Run(scenery.description, func(t *testing.T) { + actual := GetConfigString(scenery.input) + assert.Equal(t, scenery.expected, actual) + }) + } +} + +func TestGetConfInt(t *testing.T) { + scenarios := []struct { + description string + input string + expected interface{} + }{ + {"Should be return 8000", "httpport", 8001}, + {"Should be return 8000", "verificationCodeTimeout", 10}, + } + + //do some set up job + os.Setenv("httpport", "8001") + + err := beego.LoadAppConfig("ini", "app.conf") + assert.Nil(t, err) + + for _, scenery := range scenarios { + t.Run(scenery.description, func(t *testing.T) { + actual, err := GetConfigInt64(scenery.input) + assert.Nil(t, err) + assert.Equal(t, scenery.expected, int(actual)) + }) + } +} + +func TestGetConfBool(t *testing.T) { + scenarios := []struct { + description string + input string + expected interface{} + }{ + {"Should be return false", "SessionOn", false}, + {"Should be return false", "copyrequestbody", true}, + } + + //do some set up job + os.Setenv("SessionOn", "false") + + err := beego.LoadAppConfig("ini", "app.conf") + assert.Nil(t, err) + for _, scenery := range scenarios { + t.Run(scenery.description, func(t *testing.T) { + actual, err := GetConfigBool(scenery.input) + assert.Nil(t, err) + assert.Equal(t, scenery.expected, actual) + }) + } +} diff --git a/controllers/auth.go b/controllers/auth.go index 0481340c..6fec480e 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/idp" "github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/proxy" @@ -267,8 +267,8 @@ func (c *ApiController) Login() { setHttpClient(idProvider, provider.Type) - if form.State != beego.AppConfig.String("authState") && form.State != application.Name { - c.ResponseError(fmt.Sprintf("state expected: \"%s\", but got: \"%s\"", beego.AppConfig.String("authState"), form.State)) + if form.State != conf.GetConfigString("authState") && form.State != application.Name { + c.ResponseError(fmt.Sprintf("state expected: \"%s\", but got: \"%s\"", conf.GetConfigString("authState"), form.State)) return } diff --git a/controllers/ldap.go b/controllers/ldap.go index 83e8681d..fae04691 100644 --- a/controllers/ldap.go +++ b/controllers/ldap.go @@ -178,7 +178,7 @@ func (c *ApiController) UpdateLdap() { } if ldap.AutoSync != 0 { object.GetLdapAutoSynchronizer().StartAutoSync(ldap.Id) - } else if ldap.AutoSync == 0 && prevLdap.AutoSync != 0{ + } else if ldap.AutoSync == 0 && prevLdap.AutoSync != 0 { object.GetLdapAutoSynchronizer().StopAutoSync(ldap.Id) } diff --git a/controllers/util.go b/controllers/util.go index 4caef4cf..0b3b3c71 100644 --- a/controllers/util.go +++ b/controllers/util.go @@ -18,7 +18,7 @@ import ( "fmt" "strconv" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/util" ) @@ -62,7 +62,7 @@ func (c *ApiController) RequireSignedIn() (string, bool) { } func getInitScore() int { - score, err := strconv.Atoi(beego.AppConfig.String("initScore")) + score, err := strconv.Atoi(conf.GetConfigString("initScore")) if err != nil { panic(err) } diff --git a/main.go b/main.go index 291e706d..46979845 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,7 @@ import ( "github.com/astaxie/beego/logs" _ "github.com/astaxie/beego/session/redis" "github.com/casdoor/casdoor/authz" + "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/object" "github.com/casdoor/casdoor/proxy" "github.com/casdoor/casdoor/routers" @@ -31,6 +32,7 @@ import ( func main() { createDatabase := flag.Bool("createDatabase", false, "true if you need casdoor to create database") flag.Parse() + object.InitAdapter(*createDatabase) object.InitDb() object.InitDefaultStorageProvider() @@ -52,12 +54,12 @@ func main() { beego.InsertFilter("*", beego.BeforeRouter, routers.RecordMessage) beego.BConfig.WebConfig.Session.SessionName = "casdoor_session_id" - if beego.AppConfig.String("redisEndpoint") == "" { + if conf.GetConfigString("redisEndpoint") == "" { beego.BConfig.WebConfig.Session.SessionProvider = "file" beego.BConfig.WebConfig.Session.SessionProviderConfig = "./tmp" } else { beego.BConfig.WebConfig.Session.SessionProvider = "redis" - beego.BConfig.WebConfig.Session.SessionProviderConfig = beego.AppConfig.String("redisEndpoint") + beego.BConfig.WebConfig.Session.SessionProviderConfig = conf.GetConfigString("redisEndpoint") } beego.BConfig.WebConfig.Session.SessionCookieLifeTime = 3600 * 24 * 30 //beego.BConfig.WebConfig.Session.SessionCookieSameSite = http.SameSiteNoneMode diff --git a/object/adapter.go b/object/adapter.go index 3618c102..4308ccb4 100644 --- a/object/adapter.go +++ b/object/adapter.go @@ -41,7 +41,7 @@ func InitConfig() { func InitAdapter(createDatabase bool) { - adapter = NewAdapter(beego.AppConfig.String("driverName"), conf.GetBeegoConfDataSourceName(), beego.AppConfig.String("dbName")) + adapter = NewAdapter(conf.GetConfigString("driverName"), conf.GetBeegoConfDataSourceName(), conf.GetConfigString("dbName")) if createDatabase { adapter.CreateDatabase() } @@ -111,10 +111,10 @@ func (a *Adapter) close() { } func (a *Adapter) createTable() { - showSql, _ := beego.AppConfig.Bool("showSql") + showSql, _ := conf.GetConfigBool("showSql") a.Engine.ShowSQL(showSql) - tableNamePrefix := beego.AppConfig.String("tableNamePrefix") + tableNamePrefix := conf.GetConfigString("tableNamePrefix") tbMapper := core.NewPrefixMapper(core.SnakeMapper{}, tableNamePrefix) a.Engine.SetTableMapper(tbMapper) diff --git a/object/application.go b/object/application.go index 83eff22d..37bfd764 100644 --- a/object/application.go +++ b/object/application.go @@ -229,7 +229,7 @@ func GetMaskedApplication(application *Application, userId string) *Application application.OrganizationObj.PasswordSalt = "***" } } - return application + return application } func GetMaskedApplications(applications []*Application, userId string) []*Application { diff --git a/object/avatar.go b/object/avatar.go index d7f6fc33..14f0188b 100644 --- a/object/avatar.go +++ b/object/avatar.go @@ -19,14 +19,14 @@ import ( "fmt" "io" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/proxy" ) var defaultStorageProvider *Provider = nil func InitDefaultStorageProvider() { - defaultStorageProviderStr := beego.AppConfig.String("defaultStorageProvider") + defaultStorageProviderStr := conf.GetConfigString("defaultStorageProvider") if defaultStorageProviderStr != "" { defaultStorageProvider = getProvider("admin", defaultStorageProviderStr) } diff --git a/object/oidc_discovery.go b/object/oidc_discovery.go index 43aad078..2760aa5e 100644 --- a/object/oidc_discovery.go +++ b/object/oidc_discovery.go @@ -20,7 +20,7 @@ import ( "fmt" "strings" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "gopkg.in/square/go-jose.v2" ) @@ -58,7 +58,7 @@ func getOriginFromHost(host string) (string, string) { func GetOidcDiscovery(host string) OidcDiscovery { originFrontend, originBackend := getOriginFromHost(host) - origin := beego.AppConfig.String("origin") + origin := conf.GetConfigString("origin") if origin != "" { originFrontend = origin originBackend = origin diff --git a/object/record.go b/object/record.go index 0c1b97d4..01d454d6 100644 --- a/object/record.go +++ b/object/record.go @@ -18,8 +18,8 @@ import ( "fmt" "strings" - "github.com/astaxie/beego" "github.com/astaxie/beego/context" + "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/util" ) @@ -27,7 +27,7 @@ var logPostOnly bool func init() { var err error - logPostOnly, err = beego.AppConfig.Bool("logPostOnly") + logPostOnly, err = conf.GetConfigBool("logPostOnly") if err != nil { //panic(err) } diff --git a/object/saml.go b/object/saml.go index a48dc728..f74d0adc 100644 --- a/object/saml.go +++ b/object/saml.go @@ -23,7 +23,7 @@ import ( "regexp" "strings" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" saml2 "github.com/russellhaering/gosaml2" dsig "github.com/russellhaering/goxmldsig" ) @@ -73,7 +73,7 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide certStore := dsig.MemoryX509CertificateStore{ Roots: []*x509.Certificate{}, } - origin := beego.AppConfig.String("origin") + origin := conf.GetConfigString("origin") certEncodedData := "" if samlResponse != "" { certEncodedData = parseSamlResponse(samlResponse, provider.Type) diff --git a/object/storage.go b/object/storage.go index a5a7dda0..372ab57c 100644 --- a/object/storage.go +++ b/object/storage.go @@ -19,7 +19,7 @@ import ( "fmt" "strings" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/storage" "github.com/casdoor/casdoor/util" ) @@ -28,7 +28,7 @@ var isCloudIntranet bool func init() { var err error - isCloudIntranet, err = beego.AppConfig.Bool("isCloudIntranet") + isCloudIntranet, err = conf.GetConfigBool("isCloudIntranet") if err != nil { //panic(err) } diff --git a/object/token_jwt.go b/object/token_jwt.go index fc28dab8..fc23036a 100644 --- a/object/token_jwt.go +++ b/object/token_jwt.go @@ -19,7 +19,7 @@ import ( "fmt" "time" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "github.com/golang-jwt/jwt/v4" ) @@ -67,7 +67,7 @@ func generateJwtToken(application *Application, user *User, nonce string, scope refreshExpireTime := nowTime.Add(time.Duration(application.RefreshExpireInHours) * time.Hour) user.Password = "" - origin := beego.AppConfig.String("origin") + origin := conf.GetConfigString("origin") _, originBackend := getOriginFromHost(host) if origin != "" { originBackend = origin diff --git a/object/user.go b/object/user.go index 3a62717a..34a735a0 100644 --- a/object/user.go +++ b/object/user.go @@ -18,7 +18,7 @@ import ( "fmt" "strings" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/util" "xorm.io/core" ) @@ -429,7 +429,7 @@ func GetUserInfo(userId string, scope string, aud string, host string) (*Userinf if user == nil { return nil, fmt.Errorf("the user: %s doesn't exist", userId) } - origin := beego.AppConfig.String("origin") + origin := conf.GetConfigString("origin") _, originBackend := getOriginFromHost(host) if origin != "" { originBackend = origin diff --git a/object/verification.go b/object/verification.go index 6cac8b44..ac491ea4 100644 --- a/object/verification.go +++ b/object/verification.go @@ -20,7 +20,7 @@ import ( "math/rand" "time" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/util" "xorm.io/core" ) @@ -129,7 +129,7 @@ func CheckVerificationCode(dest, code string) string { return "Code has not been sent yet!" } - timeout, err := beego.AppConfig.Int64("verificationCodeTimeout") + timeout, err := conf.GetConfigInt64("verificationCodeTimeout") if err != nil { panic(err) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 53ccf970..e4f85df2 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -21,7 +21,7 @@ import ( "strings" "time" - "github.com/astaxie/beego" + "github.com/casdoor/casdoor/conf" "golang.org/x/net/proxy" ) @@ -54,7 +54,7 @@ func isAddressOpen(address string) bool { } func getProxyHttpClient() *http.Client { - sock5Proxy := beego.AppConfig.String("sock5Proxy") + sock5Proxy := conf.GetConfigString("sock5Proxy") if sock5Proxy == "" { return &http.Client{} } diff --git a/util/regex.go b/util/regex.go index 54898ce5..fdb6ef24 100644 --- a/util/regex.go +++ b/util/regex.go @@ -39,4 +39,4 @@ func IsPhoneCnValid(phone string) bool { func getMaskedPhone(phone string) string { return rePhone.ReplaceAllString(phone, "$1****$2") -} \ No newline at end of file +} diff --git a/util/string.go b/util/string.go index 72d20e85..6eb39282 100644 --- a/util/string.go +++ b/util/string.go @@ -220,7 +220,7 @@ func GetMaskedEmail(email string) string { username := maskString(tokens[0]) domain := tokens[1] domainTokens := strings.Split(domain, ".") - domainTokens[len(domainTokens) - 2] = maskString(domainTokens[len(domainTokens) - 2]) + domainTokens[len(domainTokens)-2] = maskString(domainTokens[len(domainTokens)-2]) return fmt.Sprintf("%s@%s", username, strings.Join(domainTokens, ".")) } @@ -228,6 +228,6 @@ func maskString(str string) string { if len(str) <= 2 { return str } else { - return fmt.Sprintf("%c%s%c", str[0], strings.Repeat("*", len(str) - 2), str[len(str) - 1]) + return fmt.Sprintf("%c%s%c", str[0], strings.Repeat("*", len(str)-2), str[len(str)-1]) } -} \ No newline at end of file +} diff --git a/util/string_test.go b/util/string_test.go index 58f83550..2af356bd 100644 --- a/util/string_test.go +++ b/util/string_test.go @@ -245,3 +245,4 @@ func TestSnakeString(t *testing.T) { }) } } +