casdoor/object/syncer_user.go
DacongDA 23c2ba3a2b
feat: support ssh key/pem file in DB syncer (#2727)
* feat: support connect database with ssh tunnel in syncer

* feat: improve i18n translate

* feat: improve code format and i18n
2024-02-21 17:27:37 +08:00

213 lines
5.8 KiB
Go

// Copyright 2021 The Casdoor Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package object
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"strings"
"time"
"golang.org/x/crypto/ssh"
"github.com/casdoor/casdoor/util"
"github.com/go-sql-driver/mysql"
)
type OriginalUser = User
type Credential struct {
Value string `json:"value"`
Salt string `json:"salt"`
}
func (syncer *Syncer) getOriginalUsers() ([]*OriginalUser, error) {
var results []map[string]sql.NullString
err := syncer.Ormer.Engine.Table(syncer.getTable()).Find(&results)
if err != nil {
return nil, err
}
// Memory leak problem handling
// https://github.com/casdoor/casdoor/issues/1256
users := syncer.getOriginalUsersFromMap(results)
for _, m := range results {
for k := range m {
delete(m, k)
}
}
return users, nil
}
func (syncer *Syncer) addUser(user *OriginalUser) (bool, error) {
m := syncer.getMapFromOriginalUser(user)
affected, err := syncer.Ormer.Engine.Table(syncer.getTable()).Insert(m)
if err != nil {
return false, err
}
return affected != 0, nil
}
func (syncer *Syncer) getCasdoorColumns() []string {
res := []string{}
for _, tableColumn := range syncer.TableColumns {
if tableColumn.CasdoorName != "Id" {
v := util.CamelToSnakeCase(tableColumn.CasdoorName)
res = append(res, v)
}
}
return res
}
func (syncer *Syncer) updateUser(user *OriginalUser) (bool, error) {
key := syncer.getKey()
m := syncer.getMapFromOriginalUser(user)
pkValue := m[key]
delete(m, key)
affected, err := syncer.Ormer.Engine.Table(syncer.getTable()).Where(fmt.Sprintf("%s = ?", key), pkValue).Update(&m)
if err != nil {
return false, err
}
return affected != 0, nil
}
func (syncer *Syncer) updateUserForOriginalFields(user *User, key string) (bool, error) {
var err error
oldUser := User{}
existed, err := ormer.Engine.Where(key+" = ? and owner = ?", syncer.getUserValue(user, key), user.Owner).Get(&oldUser)
if err != nil {
return false, err
}
if !existed {
return false, nil
}
if user.Avatar != oldUser.Avatar && user.Avatar != "" {
user.PermanentAvatar, err = getPermanentAvatarUrl(user.Owner, user.Name, user.Avatar, true)
if err != nil {
return false, err
}
}
columns := syncer.getCasdoorColumns()
columns = append(columns, "affiliation", "hash", "pre_hash")
affected, err := ormer.Engine.Where(key+" = ? and owner = ?", syncer.getUserValue(&oldUser, key), oldUser.Owner).Cols(columns...).Update(user)
if err != nil {
return false, err
}
return affected != 0, nil
}
func (syncer *Syncer) calculateHash(user *OriginalUser) string {
values := []string{}
m := syncer.getMapFromOriginalUser(user)
for _, tableColumn := range syncer.TableColumns {
if tableColumn.IsHashed {
values = append(values, m[tableColumn.Name])
}
}
s := strings.Join(values, "|")
return util.GetMd5Hash(s)
}
type dsnConnector struct {
dsn string
driver driver.Driver
}
func (t dsnConnector) Connect(ctx context.Context) (driver.Conn, error) {
return t.driver.Open(t.dsn)
}
func (t dsnConnector) Driver() driver.Driver {
return t.driver
}
func (syncer *Syncer) initAdapter() error {
if syncer.Ormer != nil {
return nil
}
var dataSourceName string
if syncer.DatabaseType == "mssql" {
dataSourceName = fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", syncer.User, syncer.Password, syncer.Host, syncer.Port, syncer.Database)
} else if syncer.DatabaseType == "postgres" {
sslMode := "disable"
if syncer.SslMode != "" {
sslMode = syncer.SslMode
}
dataSourceName = fmt.Sprintf("user=%s password=%s host=%s port=%d sslmode=%s dbname=%s", syncer.User, syncer.Password, syncer.Host, syncer.Port, sslMode, syncer.Database)
} else {
dataSourceName = fmt.Sprintf("%s:%s@tcp(%s:%d)/", syncer.User, syncer.Password, syncer.Host, syncer.Port)
}
var db *sql.DB
var err error
if syncer.SshType != "" && (syncer.DatabaseType == "mysql" || syncer.DatabaseType == "postgres" || syncer.DatabaseType == "mssql") {
var dial *ssh.Client
if syncer.SshType == "password" {
dial, err = DialWithPassword(syncer.SshUser, syncer.SshPassword, syncer.SshHost, syncer.SshPort)
} else {
dial, err = DialWithCert(syncer.SshUser, syncer.Owner+"/"+syncer.Cert, syncer.SshHost, syncer.SshPort)
}
if err != nil {
return err
}
if syncer.DatabaseType == "mysql" {
dataSourceName = fmt.Sprintf("%s:%s@%s(%s:%d)/", syncer.User, syncer.Password, syncer.Owner+syncer.Name, syncer.Host, syncer.Port)
mysql.RegisterDialContext(syncer.Owner+syncer.Name, (&ViaSSHDialer{Client: dial, Context: nil}).MysqlDial)
} else if syncer.DatabaseType == "postgres" || syncer.DatabaseType == "mssql" {
db = sql.OpenDB(dsnConnector{dsn: dataSourceName, driver: &ViaSSHDialer{Client: dial, Context: nil, DatabaseType: syncer.DatabaseType}})
}
}
if !isCloudIntranet {
dataSourceName = strings.ReplaceAll(dataSourceName, "dbi.", "db.")
}
if db != nil {
syncer.Ormer, err = NewAdapterFromDb(syncer.DatabaseType, dataSourceName, syncer.Database, db)
} else {
syncer.Ormer, err = NewAdapter(syncer.DatabaseType, dataSourceName, syncer.Database)
}
return err
}
func RunSyncUsersJob() {
syncers, err := GetSyncers("admin")
if err != nil {
panic(err)
}
for _, syncer := range syncers {
err = addSyncerJob(syncer)
if err != nil {
panic(err)
}
}
time.Sleep(time.Duration(1<<63 - 1))
}