Refactor GetProviderFromContext().

This commit is contained in:
Yang Luo
2021-09-05 09:44:15 +08:00
parent 14d09cad2c
commit 1c5ce46bd5
4 changed files with 35 additions and 32 deletions

View File

@ -64,31 +64,6 @@ func (c *ApiController) AddResource() {
c.ServeJSON()
}
func (c *ApiController) GetProviderParam() (*object.Provider, *object.User, bool) {
providerName := c.Input().Get("provider")
if providerName != "" {
provider := object.GetProvider(util.GetId(providerName))
if provider == nil {
c.ResponseError(fmt.Sprintf("The provider: %s is not found", providerName))
return nil, nil, false
}
return provider, nil, true
}
userId, ok := c.RequireSignedIn()
if !ok {
return nil, nil, false
}
application, user := object.GetApplicationByUserId(userId)
provider := application.GetStorageProvider()
if provider == nil {
c.ResponseError(fmt.Sprintf("No storage provider is found for application: %s", application.Name))
return nil, nil, false
}
return provider, user, true
}
func (c *ApiController) DeleteResource() {
var resource object.Resource
err := json.Unmarshal(c.Ctx.Input.RequestBody, &resource)
@ -96,7 +71,7 @@ func (c *ApiController) DeleteResource() {
panic(err)
}
provider, _, ok := c.GetProviderParam()
provider, _, ok := c.GetProviderFromContext("Storage")
if !ok {
return
}
@ -132,7 +107,7 @@ func (c *ApiController) UploadResource() {
return
}
provider, user, ok := c.GetProviderParam()
provider, user, ok := c.GetProviderFromContext("Storage")
if !ok {
return
}

View File

@ -15,9 +15,12 @@
package controllers
import (
"fmt"
"strconv"
"github.com/astaxie/beego"
"github.com/casbin/casdoor/object"
"github.com/casbin/casdoor/util"
)
// ResponseOk ...
@ -66,3 +69,28 @@ func getInitScore() int {
return score
}
func (c *ApiController) GetProviderFromContext(category string) (*object.Provider, *object.User, bool) {
providerName := c.Input().Get("provider")
if providerName != "" {
provider := object.GetProvider(util.GetId(providerName))
if provider == nil {
c.ResponseError(fmt.Sprintf("The provider: %s is not found", providerName))
return nil, nil, false
}
return provider, nil, true
}
userId, ok := c.RequireSignedIn()
if !ok {
return nil, nil, false
}
application, user := object.GetApplicationByUserId(userId)
provider := application.GetProviderByCategory(category)
if provider == nil {
c.ResponseError(fmt.Sprintf("No provider for category: \"%s\" is found for application: %s", category, application.Name))
return nil, nil, false
}
return provider, user, true
}