Improve CorsFilter code

This commit is contained in:
Yang Luo 2023-09-26 14:51:38 +08:00
parent a8e541159b
commit 03a281cb5d

View File

@ -16,6 +16,7 @@ package routers
import (
"net/http"
"net/url"
"strings"
"github.com/beego/beego/context"
@ -36,9 +37,25 @@ func setCorsHeaders(ctx *context.Context, origin string) {
ctx.Output.Header(headerAllowHeaders, "Content-Type, Authorization")
}
func getHostname(s string) string {
if s == "" {
return ""
}
l, err := url.Parse(s)
if err != nil {
panic(err)
}
res := l.Hostname()
return res
}
func CorsFilter(ctx *context.Context) {
origin := ctx.Input.Header(headerOrigin)
originConf := conf.GetConfigString("origin")
originHostname := getHostname(origin)
host := ctx.Request.Host
if strings.HasPrefix(origin, "http://localhost") {
setCorsHeaders(ctx, origin)
@ -55,22 +72,28 @@ func CorsFilter(ctx *context.Context) {
return
}
if origin != "" && originConf != "" && origin != originConf {
ok, err := object.IsOriginAllowed(origin)
if err != nil {
panic(err)
}
if ok {
if origin != "" {
if origin == originConf {
setCorsHeaders(ctx, origin)
} else if originHostname == host {
setCorsHeaders(ctx, origin)
} else {
ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
return
}
ok, err := object.IsOriginAllowed(origin)
if err != nil {
panic(err)
}
if ctx.Input.Method() == "OPTIONS" {
ctx.ResponseWriter.WriteHeader(http.StatusOK)
return
if ok {
setCorsHeaders(ctx, origin)
} else {
ctx.ResponseWriter.WriteHeader(http.StatusForbidden)
return
}
if ctx.Input.Method() == "OPTIONS" {
ctx.ResponseWriter.WriteHeader(http.StatusOK)
return
}
}
}