Improve CorsFilter code

This commit is contained in:
Yang Luo 2023-09-26 14:51:38 +08:00
parent a8e541159b
commit 03a281cb5d

View File

@ -16,6 +16,7 @@ package routers
import ( import (
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/beego/beego/context" "github.com/beego/beego/context"
@ -36,9 +37,25 @@ func setCorsHeaders(ctx *context.Context, origin string) {
ctx.Output.Header(headerAllowHeaders, "Content-Type, Authorization") ctx.Output.Header(headerAllowHeaders, "Content-Type, Authorization")
} }
func getHostname(s string) string {
if s == "" {
return ""
}
l, err := url.Parse(s)
if err != nil {
panic(err)
}
res := l.Hostname()
return res
}
func CorsFilter(ctx *context.Context) { func CorsFilter(ctx *context.Context) {
origin := ctx.Input.Header(headerOrigin) origin := ctx.Input.Header(headerOrigin)
originConf := conf.GetConfigString("origin") originConf := conf.GetConfigString("origin")
originHostname := getHostname(origin)
host := ctx.Request.Host
if strings.HasPrefix(origin, "http://localhost") { if strings.HasPrefix(origin, "http://localhost") {
setCorsHeaders(ctx, origin) setCorsHeaders(ctx, origin)
@ -55,7 +72,12 @@ func CorsFilter(ctx *context.Context) {
return return
} }
if origin != "" && originConf != "" && origin != originConf { if origin != "" {
if origin == originConf {
setCorsHeaders(ctx, origin)
} else if originHostname == host {
setCorsHeaders(ctx, origin)
} else {
ok, err := object.IsOriginAllowed(origin) ok, err := object.IsOriginAllowed(origin)
if err != nil { if err != nil {
panic(err) panic(err)
@ -73,6 +95,7 @@ func CorsFilter(ctx *context.Context) {
return return
} }
} }
}
if ctx.Input.Method() == "OPTIONS" { if ctx.Input.Method() == "OPTIONS" {
ctx.Output.Header(headerAllowOrigin, "*") ctx.Output.Header(headerAllowOrigin, "*")