feat: refactor util.GetClientIpFromRequest()

This commit is contained in:
Yang Luo 2024-10-15 12:22:38 +08:00
parent e3d135bc6e
commit a3f8ded10c
3 changed files with 19 additions and 19 deletions

View File

@ -132,7 +132,8 @@ func (c *ApiController) SendVerificationCode() {
c.ResponseError(err.Error()) c.ResponseError(err.Error())
return return
} }
remoteAddr := util.GetIPFromRequest(c.Ctx.Request)
clientIp := util.GetClientIpFromRequest(c.Ctx.Request)
if msg := vform.CheckParameter(form.SendVerifyCode, c.GetAcceptLanguage()); msg != "" { if msg := vform.CheckParameter(form.SendVerifyCode, c.GetAcceptLanguage()); msg != "" {
c.ResponseError(msg) c.ResponseError(msg)
@ -259,7 +260,7 @@ func (c *ApiController) SendVerificationCode() {
return return
} }
sendResp = object.SendVerificationCodeToEmail(organization, user, provider, remoteAddr, vform.Dest) sendResp = object.SendVerificationCodeToEmail(organization, user, provider, clientIp, vform.Dest)
case object.VerifyTypePhone: case object.VerifyTypePhone:
if vform.Method == LoginVerification || vform.Method == ForgetVerification { if vform.Method == LoginVerification || vform.Method == ForgetVerification {
if user != nil && util.GetMaskedPhone(user.Phone) == vform.Dest { if user != nil && util.GetMaskedPhone(user.Phone) == vform.Dest {
@ -309,7 +310,7 @@ func (c *ApiController) SendVerificationCode() {
c.ResponseError(fmt.Sprintf(c.T("verification:Phone number is invalid in your region %s"), vform.CountryCode)) c.ResponseError(fmt.Sprintf(c.T("verification:Phone number is invalid in your region %s"), vform.CountryCode))
return return
} else { } else {
sendResp = object.SendVerificationCodeToPhone(organization, user, provider, remoteAddr, phone) sendResp = object.SendVerificationCodeToPhone(organization, user, provider, clientIp, phone)
} }
} }

View File

@ -50,7 +50,7 @@ func maskPassword(recordString string) string {
} }
func NewRecord(ctx *context.Context) (*casvisorsdk.Record, error) { func NewRecord(ctx *context.Context) (*casvisorsdk.Record, error) {
ip := strings.Replace(util.GetIPFromRequest(ctx.Request), ": ", "", -1) clientIp := strings.Replace(util.GetClientIpFromRequest(ctx.Request), ": ", "", -1)
action := strings.Replace(ctx.Request.URL.Path, "/api/", "", -1) action := strings.Replace(ctx.Request.URL.Path, "/api/", "", -1)
requestUri := util.FilterQuery(ctx.Request.RequestURI, []string{"accessToken"}) requestUri := util.FilterQuery(ctx.Request.RequestURI, []string{"accessToken"})
if len(requestUri) > 1000 { if len(requestUri) > 1000 {
@ -83,7 +83,7 @@ func NewRecord(ctx *context.Context) (*casvisorsdk.Record, error) {
record := casvisorsdk.Record{ record := casvisorsdk.Record{
Name: util.GenerateId(), Name: util.GenerateId(),
CreatedTime: util.GetCurrentTime(), CreatedTime: util.GetCurrentTime(),
ClientIp: ip, ClientIp: clientIp,
User: "", User: "",
Method: ctx.Request.Method, Method: ctx.Request.Method,
RequestUri: requestUri, RequestUri: requestUri,

View File

@ -23,16 +23,15 @@ import (
"github.com/beego/beego/logs" "github.com/beego/beego/logs"
) )
func GetIPInfo(clientIP string) string { func getIpInfo(clientIp string) string {
if clientIP == "" { if clientIp == "" {
return "" return ""
} }
ips := strings.Split(clientIP, ",") ips := strings.Split(clientIp, ",")
res := "" res := ""
for i := range ips { for i := range ips {
ip := strings.TrimSpace(ips[i]) ip := strings.TrimSpace(ips[i])
// desc := GetDescFromIP(ip)
ipstr := fmt.Sprintf("%s: %s", ip, "") ipstr := fmt.Sprintf("%s: %s", ip, "")
if i != len(ips)-1 { if i != len(ips)-1 {
res += ipstr + " -> " res += ipstr + " -> "
@ -44,29 +43,29 @@ func GetIPInfo(clientIP string) string {
return res return res
} }
func GetIPFromRequest(req *http.Request) string { func GetClientIpFromRequest(req *http.Request) string {
clientIP := req.Header.Get("x-forwarded-for") clientIp := req.Header.Get("x-forwarded-for")
if clientIP == "" { if clientIp == "" {
ipPort := strings.Split(req.RemoteAddr, ":") ipPort := strings.Split(req.RemoteAddr, ":")
if len(ipPort) >= 1 && len(ipPort) <= 2 { if len(ipPort) >= 1 && len(ipPort) <= 2 {
clientIP = ipPort[0] clientIp = ipPort[0]
} else if len(ipPort) > 2 { } else if len(ipPort) > 2 {
idx := strings.LastIndex(req.RemoteAddr, ":") idx := strings.LastIndex(req.RemoteAddr, ":")
clientIP = req.RemoteAddr[0:idx] clientIp = req.RemoteAddr[0:idx]
clientIP = strings.TrimLeft(clientIP, "[") clientIp = strings.TrimLeft(clientIp, "[")
clientIP = strings.TrimRight(clientIP, "]") clientIp = strings.TrimRight(clientIp, "]")
} }
} }
return GetIPInfo(clientIP) return getIpInfo(clientIp)
} }
func LogInfo(ctx *context.Context, f string, v ...interface{}) { func LogInfo(ctx *context.Context, f string, v ...interface{}) {
ipString := fmt.Sprintf("(%s) ", GetIPFromRequest(ctx.Request)) ipString := fmt.Sprintf("(%s) ", GetClientIpFromRequest(ctx.Request))
logs.Info(ipString+f, v...) logs.Info(ipString+f, v...)
} }
func LogWarning(ctx *context.Context, f string, v ...interface{}) { func LogWarning(ctx *context.Context, f string, v ...interface{}) {
ipString := fmt.Sprintf("(%s) ", GetIPFromRequest(ctx.Request)) ipString := fmt.Sprintf("(%s) ", GetClientIpFromRequest(ctx.Request))
logs.Warning(ipString+f, v...) logs.Warning(ipString+f, v...)
} }