diff --git a/object/adapter.go b/object/adapter.go index 79cddc99..ffc982e8 100644 --- a/object/adapter.go +++ b/object/adapter.go @@ -15,8 +15,10 @@ package object import ( + "database/sql" "fmt" "runtime" + "strings" "github.com/beego/beego" "github.com/casdoor/casdoor/conf" @@ -46,6 +48,11 @@ func InitConfig() { } func InitAdapter() { + err := createDatabaseForPostgres(conf.GetConfigString("driverName"), conf.GetConfigDataSourceName(), conf.GetConfigString("dbName")) + if err != nil { + panic(err) + } + adapter = NewAdapter(conf.GetConfigString("driverName"), conf.GetConfigDataSourceName(), conf.GetConfigString("dbName")) tableNamePrefix := conf.GetConfigString("tableNamePrefix") @@ -96,7 +103,32 @@ func NewAdapter(driverName string, dataSourceName string, dbName string) *Adapte return a } +func createDatabaseForPostgres(driverName string, dataSourceName string, dbName string) error { + if driverName == "postgres" { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + return err + } + defer db.Close() + + _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)) + if err != nil { + if !strings.Contains(err.Error(), "already exists") { + return err + } + } + + return nil + } else { + return nil + } +} + func (a *Adapter) CreateDatabase() error { + if a.driverName == "postgres" { + return nil + } + engine, err := xorm.NewEngine(a.driverName, a.dataSourceName) if err != nil { return err