diff --git a/controllers/mfa.go b/controllers/mfa.go index 8e090e1f..23ef64f7 100644 --- a/controllers/mfa.go +++ b/controllers/mfa.go @@ -42,6 +42,7 @@ func (c *ApiController) MfaSetupInitiate() { owner := c.Ctx.Request.Form.Get("owner") name := c.Ctx.Request.Form.Get("name") mfaType := c.Ctx.Request.Form.Get("mfaType") + stateLess := c.Ctx.Request.Form.Get("stateLess") userId := util.GetId(owner, name) if len(userId) == 0 { @@ -72,15 +73,27 @@ func (c *ApiController) MfaSetupInitiate() { } recoveryCode := uuid.NewString() - c.SetSession(MfaRecoveryCodesSession, recoveryCode) - if mfaType == object.TotpType { - c.SetSession(MfaTotpSecretSession, mfaProps.Secret) + mfaCacheKey := "" + if stateLess == "true" { + mfaCacheKey = c.Ctx.Input.CruSession.SessionID() + util.GenerateSimpleTimeId() + mfaCacheVal := make(map[string]string) + mfaCacheVal[MfaRecoveryCodesSession] = recoveryCode + if mfaType == object.TotpType { + mfaCacheVal[MfaTotpSecretSession] = mfaProps.Secret + } + mfaCacheVal["verifiedUserId"] = userId + object.MfaCache.Store(mfaCacheKey, mfaCacheVal) + } else { + c.SetSession(MfaRecoveryCodesSession, recoveryCode) + if mfaType == object.TotpType { + c.SetSession(MfaTotpSecretSession, mfaProps.Secret) + } } mfaProps.RecoveryCodes = []string{recoveryCode} resp := mfaProps - c.ResponseOk(resp) + c.ResponseOk(resp, mfaCacheKey) } // MfaSetupVerify @@ -94,6 +107,7 @@ func (c *ApiController) MfaSetupInitiate() { func (c *ApiController) MfaSetupVerify() { mfaType := c.Ctx.Request.Form.Get("mfaType") passcode := c.Ctx.Request.Form.Get("passcode") + mfaCacheKey := c.Ctx.Request.Form.Get("mfaCacheKey") if mfaType == "" || passcode == "" { c.ResponseError("missing auth type or passcode") @@ -104,32 +118,32 @@ func (c *ApiController) MfaSetupVerify() { MfaType: mfaType, } if mfaType == object.TotpType { - secret := c.GetSession(MfaTotpSecretSession) - if secret == nil { + secret := object.GetPropsFromContext(MfaTotpSecretSession, c.Ctx.Input.CruSession, mfaCacheKey) + if secret == "" { c.ResponseError("totp secret is missing") return } - config.Secret = secret.(string) + config.Secret = secret } else if mfaType == object.SmsType { - dest := c.GetSession(MfaDestSession) - if dest == nil { + dest := object.GetPropsFromContext(MfaDestSession, c.Ctx.Input.CruSession, mfaCacheKey) + if dest == "" { c.ResponseError("destination is missing") return } - config.Secret = dest.(string) - countryCode := c.GetSession(MfaCountryCodeSession) - if countryCode == nil { + config.Secret = dest + countryCode := object.GetPropsFromContext(MfaCountryCodeSession, c.Ctx.Input.CruSession, mfaCacheKey) + if countryCode == "" { c.ResponseError("country code is missing") return } - config.CountryCode = countryCode.(string) + config.CountryCode = countryCode } else if mfaType == object.EmailType { - dest := c.GetSession(MfaDestSession) - if dest == nil { + dest := object.GetPropsFromContext(MfaDestSession, c.Ctx.Input.CruSession, mfaCacheKey) + if dest == "" { c.ResponseError("destination is missing") return } - config.Secret = dest.(string) + config.Secret = dest } mfaUtil := object.GetMfaUtil(mfaType, config) @@ -159,6 +173,7 @@ func (c *ApiController) MfaSetupEnable() { owner := c.Ctx.Request.Form.Get("owner") name := c.Ctx.Request.Form.Get("name") mfaType := c.Ctx.Request.Form.Get("mfaType") + mfaCacheKey := c.Ctx.Request.Form.Get("mfaCacheKey") user, err := object.GetUser(util.GetId(owner, name)) if err != nil { @@ -176,43 +191,43 @@ func (c *ApiController) MfaSetupEnable() { } if mfaType == object.TotpType { - secret := c.GetSession(MfaTotpSecretSession) - if secret == nil { + secret := object.GetPropsFromContext(MfaTotpSecretSession, c.Ctx.Input.CruSession, mfaCacheKey) + if secret == "" { c.ResponseError("totp secret is missing") return } - config.Secret = secret.(string) + config.Secret = secret } else if mfaType == object.EmailType { if user.Email == "" { - dest := c.GetSession(MfaDestSession) - if dest == nil { + dest := object.GetPropsFromContext(MfaDestSession, c.Ctx.Input.CruSession, mfaCacheKey) + if dest == "" { c.ResponseError("destination is missing") return } - user.Email = dest.(string) + user.Email = dest } } else if mfaType == object.SmsType { if user.Phone == "" { - dest := c.GetSession(MfaDestSession) - if dest == nil { + dest := object.GetPropsFromContext(MfaDestSession, c.Ctx.Input.CruSession, mfaCacheKey) + if dest == "" { c.ResponseError("destination is missing") return } - user.Phone = dest.(string) - countryCode := c.GetSession(MfaCountryCodeSession) - if countryCode == nil { + user.Phone = dest + countryCode := object.GetPropsFromContext(MfaCountryCodeSession, c.Ctx.Input.CruSession, mfaCacheKey) + if countryCode == "" { c.ResponseError("country code is missing") return } - user.CountryCode = countryCode.(string) + user.CountryCode = countryCode } } - recoveryCodes := c.GetSession(MfaRecoveryCodesSession) - if recoveryCodes == nil { + recoveryCodes := object.GetPropsFromContext(MfaRecoveryCodesSession, c.Ctx.Input.CruSession, mfaCacheKey) + if recoveryCodes == "" { c.ResponseError("recovery codes is missing") return } - config.RecoveryCodes = []string{recoveryCodes.(string)} + config.RecoveryCodes = []string{recoveryCodes} mfaUtil := object.GetMfaUtil(mfaType, config) if mfaUtil == nil { @@ -226,12 +241,16 @@ func (c *ApiController) MfaSetupEnable() { return } - c.DelSession(MfaRecoveryCodesSession) - if mfaType == object.TotpType { - c.DelSession(MfaTotpSecretSession) + if mfaCacheKey == "" { + c.DelSession(MfaRecoveryCodesSession) + if mfaType == object.TotpType { + c.DelSession(MfaTotpSecretSession) + } else { + c.DelSession(MfaCountryCodeSession) + c.DelSession(MfaDestSession) + } } else { - c.DelSession(MfaCountryCodeSession) - c.DelSession(MfaDestSession) + object.MfaCache.Delete(mfaCacheKey) } c.ResponseOk(http.StatusText(http.StatusOK)) diff --git a/controllers/verification.go b/controllers/verification.go index dd24a978..72a53cbf 100644 --- a/controllers/verification.go +++ b/controllers/verification.go @@ -247,7 +247,7 @@ func (c *ApiController) SendVerificationCode() { vform.Dest = mfaProps.Secret } } else if vform.Method == MfaSetupVerification { - c.SetSession(MfaDestSession, vform.Dest) + object.SetPropsFromContext(MfaDestSession, vform.Dest, c.Ctx.Input.CruSession, vform.MfaCacheKey) } provider, err = application.GetEmailProvider(vform.Method) @@ -284,8 +284,8 @@ func (c *ApiController) SendVerificationCode() { } if vform.Method == MfaSetupVerification { - c.SetSession(MfaCountryCodeSession, vform.CountryCode) - c.SetSession(MfaDestSession, vform.Dest) + object.SetPropsFromContext(MfaCountryCodeSession, vform.CountryCode, c.Ctx.Input.CruSession, vform.MfaCacheKey) + object.SetPropsFromContext(MfaDestSession, vform.Dest, c.Ctx.Input.CruSession, vform.MfaCacheKey) } } else if vform.Method == MfaAuthVerification { mfaProps := user.GetPreferredMfaProps(false) diff --git a/form/verification.go b/form/verification.go index e2a3b530..4ed05771 100644 --- a/form/verification.go +++ b/form/verification.go @@ -31,6 +31,8 @@ type VerificationForm struct { CaptchaType string `form:"captchaType"` ClientSecret string `form:"clientSecret"` CaptchaToken string `form:"captchaToken"` + + MfaCacheKey string `form:"mfaCacheKey"` } const ( diff --git a/object/mfa.go b/object/mfa.go index fb721172..cc47026d 100644 --- a/object/mfa.go +++ b/object/mfa.go @@ -16,7 +16,9 @@ package object import ( "fmt" + "sync" + "github.com/beego/beego/session" "github.com/casdoor/casdoor/util" ) @@ -49,6 +51,8 @@ const ( RequiredMfa = "RequiredMfa" ) +var MfaCache = sync.Map{} + func GetMfaUtil(mfaType string, config *MfaProps) MfaInterface { switch mfaType { case SmsType: @@ -183,3 +187,41 @@ func SetPreferredMultiFactorAuth(user *User, mfaType string) error { } return nil } + +func GetPropsFromContext(key string, curSession session.Store, mfaCacheKey string) string { + if mfaCacheKey != "" { + propMap, exist := MfaCache.Load(mfaCacheKey) + if !exist { + return "" + } + if propMap == nil { + return "" + } + return propMap.(map[string]string)[key] + } + + val := curSession.Get(key) + if val != nil { + return val.(string) + } + return "" +} + +func SetPropsFromContext(key string, value string, curSession session.Store, mfaCacheKey string) { + if mfaCacheKey != "" { + propMap, exist := MfaCache.Load(mfaCacheKey) + if !exist { + return + } + if propMap == nil { + return + } + propMap.(map[string]string)[key] = value + MfaCache.Store(mfaCacheKey, propMap) + } + + err := curSession.Set(key, value) + if err != nil { + return + } +}