diff --git a/Dockerfile b/Dockerfile index 582d5bd0..a6d959e5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,7 +20,7 @@ COPY --from=FRONT /web/build /web/build CMD chmod 777 /tmp && service mariadb start&&\ if [ "${MYSQL_ROOT_PASSWORD}" = "" ] ;then MYSQL_ROOT_PASSWORD=123456 ; fi&&\ mysqladmin -u root password ${MYSQL_ROOT_PASSWORD} &&\ -./wait-for-it localhost:3306 -- ./server +./wait-for-it localhost:3306 -- ./server --createDatabase=true FROM alpine:latest diff --git a/main.go b/main.go index d4ad28b1..55b1342b 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,8 @@ package main import ( + "flag" + "github.com/astaxie/beego" "github.com/astaxie/beego/logs" "github.com/astaxie/beego/plugins/cors" @@ -23,12 +25,13 @@ import ( "github.com/casbin/casdoor/object" "github.com/casbin/casdoor/proxy" "github.com/casbin/casdoor/routers" - _ "github.com/casbin/casdoor/routers" ) func main() { - object.InitAdapter() + createDatabase := flag.Bool("createDatabase", false, "true if you need casdoor to create database") + flag.Parse() + object.InitAdapter(*createDatabase) object.InitDb() object.InitDefaultStorageProvider() object.InitLdapAutoSynchronizer() diff --git a/object/adapter.go b/object/adapter.go index b5b0c7d6..16b4165b 100644 --- a/object/adapter.go +++ b/object/adapter.go @@ -35,11 +35,15 @@ func InitConfig() { panic(err) } - InitAdapter() + InitAdapter(true) } -func InitAdapter() { +func InitAdapter(createDatabase bool) { + adapter = NewAdapter(beego.AppConfig.String("driverName"), conf.GetBeegoConfDataSourceName(), beego.AppConfig.String("dbName")) + if createDatabase { + adapter.CreateDatabase() + } adapter.createTable() } @@ -75,6 +79,17 @@ func NewAdapter(driverName string, dataSourceName string, dbName string) *Adapte return a } +func (a *Adapter) CreateDatabase() error { + engine, err := xorm.NewEngine(a.driverName, a.dataSourceName) + if err != nil { + return err + } + defer engine.Close() + + _, err = engine.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s default charset utf8 COLLATE utf8_general_ci", a.dbName)) + return err +} + func (a *Adapter) open() { dataSourceName := a.dataSourceName + a.dbName if a.driverName != "mysql" {