diff --git a/authz/authz.go b/authz/authz.go index 7110181e..ef3c13a7 100644 --- a/authz/authz.go +++ b/authz/authz.go @@ -33,11 +33,12 @@ func InitApi() { if err != nil { panic(err) } - Enforcer, err = e.InitEnforcer() + err = e.InitEnforcer() if err != nil { panic(err) } + Enforcer = e.Enforcer Enforcer.ClearPolicy() // if len(Enforcer.GetPolicy()) == 0 { diff --git a/controllers/casbin_api.go b/controllers/casbin_api.go index b723f522..99c03a74 100644 --- a/controllers/casbin_api.go +++ b/controllers/casbin_api.go @@ -35,6 +35,7 @@ func (c *ApiController) Enforce() { permissionId := c.Input().Get("permissionId") modelId := c.Input().Get("modelId") resourceId := c.Input().Get("resourceId") + enforcerId := c.Input().Get("enforcerId") var request object.CasbinRequest err := json.Unmarshal(c.Ctx.Input.RequestBody, &request) @@ -43,6 +44,29 @@ func (c *ApiController) Enforce() { return } + if enforcerId != "" { + enforcer, err := object.GetEnforcer(enforcerId) + if err != nil { + c.ResponseError(err.Error()) + return + } + + err = enforcer.InitEnforcer() + if err != nil { + c.ResponseError(err.Error()) + return + } + + res, err := enforcer.Enforce(request...) + if err != nil { + c.ResponseError(err.Error()) + return + } + + c.ResponseOk(res) + return + } + if permissionId != "" { permission, err := object.GetPermission(permissionId) if err != nil { @@ -121,6 +145,7 @@ func (c *ApiController) Enforce() { func (c *ApiController) BatchEnforce() { permissionId := c.Input().Get("permissionId") modelId := c.Input().Get("modelId") + enforcerId := c.Input().Get("enforcerId") var requests []object.CasbinRequest err := json.Unmarshal(c.Ctx.Input.RequestBody, &requests) @@ -129,6 +154,29 @@ func (c *ApiController) BatchEnforce() { return } + if enforcerId != "" { + enforcer, err := object.GetEnforcer(enforcerId) + if err != nil { + c.ResponseError(err.Error()) + return + } + + err = enforcer.InitEnforcer() + if err != nil { + c.ResponseError(err.Error()) + return + } + + res, err := enforcer.BatchEnforce(requests) + if err != nil { + c.ResponseError(err.Error()) + return + } + + c.ResponseOk(res) + return + } + if permissionId != "" { permission, err := object.GetPermission(permissionId) if err != nil { diff --git a/object/enforcer.go b/object/enforcer.go index 76230e95..07b1ee0f 100644 --- a/object/enforcer.go +++ b/object/enforcer.go @@ -120,44 +120,45 @@ func DeleteEnforcer(enforcer *Enforcer) (bool, error) { return affected != 0, nil } -func (enforcer *Enforcer) InitEnforcer() (*casbin.Enforcer, error) { - if enforcer == nil { - return nil, errors.New("enforcer is nil") - } - if enforcer.Model == "" || enforcer.Adapter == "" { - return nil, errors.New("missing model or adapter") +func (enforcer *Enforcer) InitEnforcer() error { + if enforcer.Enforcer == nil { + if enforcer == nil { + return errors.New("enforcer is nil") + } + if enforcer.Model == "" || enforcer.Adapter == "" { + return errors.New("missing model or adapter") + } + + var err error + var m *Model + var a *Adapter + + if m, err = GetModel(enforcer.Model); err != nil { + return err + } else if m == nil { + return errors.New("model not found") + } + if a, err = GetAdapter(enforcer.Adapter); err != nil { + return err + } else if a == nil { + return errors.New("adapter not found") + } + + err = m.initModel() + if err != nil { + return err + } + err = a.initAdapter() + if err != nil { + return err + } + + casbinEnforcer, err := casbin.NewEnforcer(m.Model, a.Adapter) + if err != nil { + return err + } + enforcer.Enforcer = casbinEnforcer } - var err error - var m *Model - var a *Adapter - - if m, err = GetModel(enforcer.Model); err != nil { - return nil, err - } else if m == nil { - return nil, errors.New("model not found") - } - - if a, err = GetAdapter(enforcer.Adapter); err != nil { - return nil, err - } else if a == nil { - return nil, errors.New("adapter not found") - } - - err = m.initModel() - if err != nil { - return nil, err - } - - err = a.initAdapter() - if err != nil { - return nil, err - } - - e, err := casbin.NewEnforcer(m.Model, a.Adapter) - if err != nil { - return nil, err - } - - return e, nil + return nil }