diff --git a/routers/base.go b/routers/base.go index 682e85bf..9cbaad3f 100644 --- a/routers/base.go +++ b/routers/base.go @@ -16,7 +16,9 @@ package routers import ( "fmt" + "net" "net/http" + "net/url" "strings" "github.com/beego/beego/context" @@ -154,3 +156,31 @@ func parseBearerToken(ctx *context.Context) string { return tokens[1] } + +func getHostname(s string) string { + if s == "" { + return "" + } + + l, err := url.Parse(s) + if err != nil { + panic(err) + } + + res := l.Hostname() + return res +} + +func isHostIntranet(s string) bool { + ipStr, _, err := net.SplitHostPort(s) + if err != nil { + ipStr = s + } + + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + + return ip.IsPrivate() +} diff --git a/routers/cors_filter.go b/routers/cors_filter.go index 09e897af..78c57685 100644 --- a/routers/cors_filter.go +++ b/routers/cors_filter.go @@ -16,7 +16,6 @@ package routers import ( "net/http" - "net/url" "strings" "github.com/beego/beego/context" @@ -41,20 +40,6 @@ func setCorsHeaders(ctx *context.Context, origin string) { } } -func getHostname(s string) string { - if s == "" { - return "" - } - - l, err := url.Parse(s) - if err != nil { - panic(err) - } - - res := l.Hostname() - return res -} - func CorsFilter(ctx *context.Context) { origin := ctx.Input.Header(headerOrigin) originConf := conf.GetConfigString("origin") @@ -81,6 +66,8 @@ func CorsFilter(ctx *context.Context) { setCorsHeaders(ctx, origin) } else if originHostname == host { setCorsHeaders(ctx, origin) + } else if isHostIntranet(host) { + setCorsHeaders(ctx, origin) } else { ok, err := object.IsOriginAllowed(origin) if err != nil {