From 03a281cb5d092fdf2c9a4736507ca94d8d2e071a Mon Sep 17 00:00:00 2001 From: Yang Luo Date: Tue, 26 Sep 2023 14:51:38 +0800 Subject: [PATCH] Improve CorsFilter code --- routers/cors_filter.go | 49 +++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/routers/cors_filter.go b/routers/cors_filter.go index 656cd05f..b17ed7b7 100644 --- a/routers/cors_filter.go +++ b/routers/cors_filter.go @@ -16,6 +16,7 @@ package routers import ( "net/http" + "net/url" "strings" "github.com/beego/beego/context" @@ -36,9 +37,25 @@ func setCorsHeaders(ctx *context.Context, origin string) { ctx.Output.Header(headerAllowHeaders, "Content-Type, Authorization") } +func getHostname(s string) string { + if s == "" { + return "" + } + + l, err := url.Parse(s) + if err != nil { + panic(err) + } + + res := l.Hostname() + return res +} + func CorsFilter(ctx *context.Context) { origin := ctx.Input.Header(headerOrigin) originConf := conf.GetConfigString("origin") + originHostname := getHostname(origin) + host := ctx.Request.Host if strings.HasPrefix(origin, "http://localhost") { setCorsHeaders(ctx, origin) @@ -55,22 +72,28 @@ func CorsFilter(ctx *context.Context) { return } - if origin != "" && originConf != "" && origin != originConf { - ok, err := object.IsOriginAllowed(origin) - if err != nil { - panic(err) - } - - if ok { + if origin != "" { + if origin == originConf { + setCorsHeaders(ctx, origin) + } else if originHostname == host { setCorsHeaders(ctx, origin) } else { - ctx.ResponseWriter.WriteHeader(http.StatusForbidden) - return - } + ok, err := object.IsOriginAllowed(origin) + if err != nil { + panic(err) + } - if ctx.Input.Method() == "OPTIONS" { - ctx.ResponseWriter.WriteHeader(http.StatusOK) - return + if ok { + setCorsHeaders(ctx, origin) + } else { + ctx.ResponseWriter.WriteHeader(http.StatusForbidden) + return + } + + if ctx.Input.Method() == "OPTIONS" { + ctx.ResponseWriter.WriteHeader(http.StatusOK) + return + } } }