Add createDatabaseForPostgres()

This commit is contained in:
Yang Luo 2023-07-22 16:19:13 +08:00
parent 58e8f9f90b
commit fc9528be43

View File

@ -15,8 +15,10 @@
package object package object
import ( import (
"database/sql"
"fmt" "fmt"
"runtime" "runtime"
"strings"
"github.com/beego/beego" "github.com/beego/beego"
"github.com/casdoor/casdoor/conf" "github.com/casdoor/casdoor/conf"
@ -46,6 +48,11 @@ func InitConfig() {
} }
func InitAdapter() { func InitAdapter() {
err := createDatabaseForPostgres(conf.GetConfigString("driverName"), conf.GetConfigDataSourceName(), conf.GetConfigString("dbName"))
if err != nil {
panic(err)
}
adapter = NewAdapter(conf.GetConfigString("driverName"), conf.GetConfigDataSourceName(), conf.GetConfigString("dbName")) adapter = NewAdapter(conf.GetConfigString("driverName"), conf.GetConfigDataSourceName(), conf.GetConfigString("dbName"))
tableNamePrefix := conf.GetConfigString("tableNamePrefix") tableNamePrefix := conf.GetConfigString("tableNamePrefix")
@ -96,7 +103,32 @@ func NewAdapter(driverName string, dataSourceName string, dbName string) *Adapte
return a return a
} }
func createDatabaseForPostgres(driverName string, dataSourceName string, dbName string) error {
if driverName == "postgres" {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
return err
}
defer db.Close()
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName))
if err != nil {
if !strings.Contains(err.Error(), "already exists") {
return err
}
}
return nil
} else {
return nil
}
}
func (a *Adapter) CreateDatabase() error { func (a *Adapter) CreateDatabase() error {
if a.driverName == "postgres" {
return nil
}
engine, err := xorm.NewEngine(a.driverName, a.dataSourceName) engine, err := xorm.NewEngine(a.driverName, a.dataSourceName)
if err != nil { if err != nil {
return err return err