Refactor GetProviderFromContext().

This commit is contained in:
Yang Luo
2021-09-05 09:44:15 +08:00
parent 14d09cad2c
commit 1c5ce46bd5
4 changed files with 35 additions and 32 deletions

View File

@ -64,31 +64,6 @@ func (c *ApiController) AddResource() {
c.ServeJSON()
}
func (c *ApiController) GetProviderParam() (*object.Provider, *object.User, bool) {
providerName := c.Input().Get("provider")
if providerName != "" {
provider := object.GetProvider(util.GetId(providerName))
if provider == nil {
c.ResponseError(fmt.Sprintf("The provider: %s is not found", providerName))
return nil, nil, false
}
return provider, nil, true
}
userId, ok := c.RequireSignedIn()
if !ok {
return nil, nil, false
}
application, user := object.GetApplicationByUserId(userId)
provider := application.GetStorageProvider()
if provider == nil {
c.ResponseError(fmt.Sprintf("No storage provider is found for application: %s", application.Name))
return nil, nil, false
}
return provider, user, true
}
func (c *ApiController) DeleteResource() {
var resource object.Resource
err := json.Unmarshal(c.Ctx.Input.RequestBody, &resource)
@ -96,7 +71,7 @@ func (c *ApiController) DeleteResource() {
panic(err)
}
provider, _, ok := c.GetProviderParam()
provider, _, ok := c.GetProviderFromContext("Storage")
if !ok {
return
}
@ -132,7 +107,7 @@ func (c *ApiController) UploadResource() {
return
}
provider, user, ok := c.GetProviderParam()
provider, user, ok := c.GetProviderFromContext("Storage")
if !ok {
return
}

View File

@ -15,9 +15,12 @@
package controllers
import (
"fmt"
"strconv"
"github.com/astaxie/beego"
"github.com/casbin/casdoor/object"
"github.com/casbin/casdoor/util"
)
// ResponseOk ...
@ -66,3 +69,28 @@ func getInitScore() int {
return score
}
func (c *ApiController) GetProviderFromContext(category string) (*object.Provider, *object.User, bool) {
providerName := c.Input().Get("provider")
if providerName != "" {
provider := object.GetProvider(util.GetId(providerName))
if provider == nil {
c.ResponseError(fmt.Sprintf("The provider: %s is not found", providerName))
return nil, nil, false
}
return provider, nil, true
}
userId, ok := c.RequireSignedIn()
if !ok {
return nil, nil, false
}
application, user := object.GetApplicationByUserId(userId)
provider := application.GetProviderByCategory(category)
if provider == nil {
c.ResponseError(fmt.Sprintf("No provider for category: \"%s\" is found for application: %s", category, application.Name))
return nil, nil, false
}
return provider, user, true
}

View File

@ -14,7 +14,7 @@
package object
func (application *Application) getProviderByCategory(category string) *Provider {
func (application *Application) GetProviderByCategory(category string) *Provider {
providers := GetProviders(application.Owner)
m := map[string]*Provider{}
for _, provider := range providers {
@ -35,15 +35,15 @@ func (application *Application) getProviderByCategory(category string) *Provider
}
func (application *Application) GetEmailProvider() *Provider {
return application.getProviderByCategory("Email")
return application.GetProviderByCategory("Email")
}
func (application *Application) GetSmsProvider() *Provider {
return application.getProviderByCategory("SMS")
return application.GetProviderByCategory("SMS")
}
func (application *Application) GetStorageProvider() *Provider {
return application.getProviderByCategory("Storage")
return application.GetProviderByCategory("Storage")
}
func (application *Application) getSignupItem(itemName string) *SignupItem {

View File

@ -18,7 +18,7 @@ package object
import "github.com/go-gomail/gomail"
func SendEmail(provider *Provider, title, content, dest, sender string) error {
func SendEmail(provider *Provider, title string, content string, dest string, sender string) error {
dialer := gomail.NewDialer(provider.Host, provider.Port, provider.ClientId, provider.ClientSecret)
message := gomail.NewMessage()