Improve CorsFilter code

This commit is contained in:
Yang Luo 2023-09-26 14:51:38 +08:00
parent a8e541159b
commit 03a281cb5d

View File

@ -16,6 +16,7 @@ package routers
import ( import (
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/beego/beego/context" "github.com/beego/beego/context"
@ -36,9 +37,25 @@ func setCorsHeaders(ctx *context.Context, origin string) {
ctx.Output.Header(headerAllowHeaders, "Content-Type, Authorization") ctx.Output.Header(headerAllowHeaders, "Content-Type, Authorization")
} }
func getHostname(s string) string {
if s == "" {
return ""
}
l, err := url.Parse(s)
if err != nil {
panic(err)
}
res := l.Hostname()
return res
}
func CorsFilter(ctx *context.Context) { func CorsFilter(ctx *context.Context) {
origin := ctx.Input.Header(headerOrigin) origin := ctx.Input.Header(headerOrigin)
originConf := conf.GetConfigString("origin") originConf := conf.GetConfigString("origin")
originHostname := getHostname(origin)
host := ctx.Request.Host
if strings.HasPrefix(origin, "http://localhost") { if strings.HasPrefix(origin, "http://localhost") {
setCorsHeaders(ctx, origin) setCorsHeaders(ctx, origin)
@ -55,22 +72,28 @@ func CorsFilter(ctx *context.Context) {
return return
} }
if origin != "" && originConf != "" && origin != originConf { if origin != "" {
ok, err := object.IsOriginAllowed(origin) if origin == originConf {
if err != nil { setCorsHeaders(ctx, origin)
panic(err) } else if originHostname == host {
}
if ok {
setCorsHeaders(ctx, origin) setCorsHeaders(ctx, origin)
} else { } else {
ctx.ResponseWriter.WriteHeader(http.StatusForbidden) ok, err := object.IsOriginAllowed(origin)
return if err != nil {
} panic(err)
}
if ctx.Input.Method() == "OPTIONS" { if ok {
ctx.ResponseWriter.WriteHeader(http.StatusOK) setCorsHeaders(ctx, origin)
return } else {
ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
return
}
if ctx.Input.Method() == "OPTIONS" {
ctx.ResponseWriter.WriteHeader(http.StatusOK)
return
}
} }
} }