feat: support stateless mfa setup

This commit is contained in:
dacongda
2024-11-28 15:55:16 +08:00
parent 2d1736f13a
commit bd843b2ff3
4 changed files with 102 additions and 39 deletions

View File

@ -42,6 +42,7 @@ func (c *ApiController) MfaSetupInitiate() {
owner := c.Ctx.Request.Form.Get("owner") owner := c.Ctx.Request.Form.Get("owner")
name := c.Ctx.Request.Form.Get("name") name := c.Ctx.Request.Form.Get("name")
mfaType := c.Ctx.Request.Form.Get("mfaType") mfaType := c.Ctx.Request.Form.Get("mfaType")
stateLess := c.Ctx.Request.Form.Get("stateLess")
userId := util.GetId(owner, name) userId := util.GetId(owner, name)
if len(userId) == 0 { if len(userId) == 0 {
@ -72,15 +73,27 @@ func (c *ApiController) MfaSetupInitiate() {
} }
recoveryCode := uuid.NewString() recoveryCode := uuid.NewString()
mfaCacheKey := ""
if stateLess == "true" {
mfaCacheKey = c.Ctx.Input.CruSession.SessionID() + util.GenerateSimpleTimeId()
mfaCacheVal := make(map[string]string)
mfaCacheVal[MfaRecoveryCodesSession] = recoveryCode
if mfaType == object.TotpType {
mfaCacheVal[MfaTotpSecretSession] = mfaProps.Secret
}
mfaCacheVal["verifiedUserId"] = userId
object.MfaCache.Store(mfaCacheKey, mfaCacheVal)
} else {
c.SetSession(MfaRecoveryCodesSession, recoveryCode) c.SetSession(MfaRecoveryCodesSession, recoveryCode)
if mfaType == object.TotpType { if mfaType == object.TotpType {
c.SetSession(MfaTotpSecretSession, mfaProps.Secret) c.SetSession(MfaTotpSecretSession, mfaProps.Secret)
} }
}
mfaProps.RecoveryCodes = []string{recoveryCode} mfaProps.RecoveryCodes = []string{recoveryCode}
resp := mfaProps resp := mfaProps
c.ResponseOk(resp) c.ResponseOk(resp, mfaCacheKey)
} }
// MfaSetupVerify // MfaSetupVerify
@ -94,6 +107,7 @@ func (c *ApiController) MfaSetupInitiate() {
func (c *ApiController) MfaSetupVerify() { func (c *ApiController) MfaSetupVerify() {
mfaType := c.Ctx.Request.Form.Get("mfaType") mfaType := c.Ctx.Request.Form.Get("mfaType")
passcode := c.Ctx.Request.Form.Get("passcode") passcode := c.Ctx.Request.Form.Get("passcode")
mfaCacheKey := c.Ctx.Request.Form.Get("mfaCacheKey")
if mfaType == "" || passcode == "" { if mfaType == "" || passcode == "" {
c.ResponseError("missing auth type or passcode") c.ResponseError("missing auth type or passcode")
@ -104,32 +118,32 @@ func (c *ApiController) MfaSetupVerify() {
MfaType: mfaType, MfaType: mfaType,
} }
if mfaType == object.TotpType { if mfaType == object.TotpType {
secret := c.GetSession(MfaTotpSecretSession) secret := object.GetPropsFromContext(MfaTotpSecretSession, c.Ctx.Input.CruSession, mfaCacheKey)
if secret == nil { if secret == "" {
c.ResponseError("totp secret is missing") c.ResponseError("totp secret is missing")
return return
} }
config.Secret = secret.(string) config.Secret = secret
} else if mfaType == object.SmsType { } else if mfaType == object.SmsType {
dest := c.GetSession(MfaDestSession) dest := object.GetPropsFromContext(MfaDestSession, c.Ctx.Input.CruSession, mfaCacheKey)
if dest == nil { if dest == "" {
c.ResponseError("destination is missing") c.ResponseError("destination is missing")
return return
} }
config.Secret = dest.(string) config.Secret = dest
countryCode := c.GetSession(MfaCountryCodeSession) countryCode := object.GetPropsFromContext(MfaCountryCodeSession, c.Ctx.Input.CruSession, mfaCacheKey)
if countryCode == nil { if countryCode == "" {
c.ResponseError("country code is missing") c.ResponseError("country code is missing")
return return
} }
config.CountryCode = countryCode.(string) config.CountryCode = countryCode
} else if mfaType == object.EmailType { } else if mfaType == object.EmailType {
dest := c.GetSession(MfaDestSession) dest := object.GetPropsFromContext(MfaDestSession, c.Ctx.Input.CruSession, mfaCacheKey)
if dest == nil { if dest == "" {
c.ResponseError("destination is missing") c.ResponseError("destination is missing")
return return
} }
config.Secret = dest.(string) config.Secret = dest
} }
mfaUtil := object.GetMfaUtil(mfaType, config) mfaUtil := object.GetMfaUtil(mfaType, config)
@ -159,6 +173,7 @@ func (c *ApiController) MfaSetupEnable() {
owner := c.Ctx.Request.Form.Get("owner") owner := c.Ctx.Request.Form.Get("owner")
name := c.Ctx.Request.Form.Get("name") name := c.Ctx.Request.Form.Get("name")
mfaType := c.Ctx.Request.Form.Get("mfaType") mfaType := c.Ctx.Request.Form.Get("mfaType")
mfaCacheKey := c.Ctx.Request.Form.Get("mfaCacheKey")
user, err := object.GetUser(util.GetId(owner, name)) user, err := object.GetUser(util.GetId(owner, name))
if err != nil { if err != nil {
@ -176,43 +191,43 @@ func (c *ApiController) MfaSetupEnable() {
} }
if mfaType == object.TotpType { if mfaType == object.TotpType {
secret := c.GetSession(MfaTotpSecretSession) secret := object.GetPropsFromContext(MfaTotpSecretSession, c.Ctx.Input.CruSession, mfaCacheKey)
if secret == nil { if secret == "" {
c.ResponseError("totp secret is missing") c.ResponseError("totp secret is missing")
return return
} }
config.Secret = secret.(string) config.Secret = secret
} else if mfaType == object.EmailType { } else if mfaType == object.EmailType {
if user.Email == "" { if user.Email == "" {
dest := c.GetSession(MfaDestSession) dest := object.GetPropsFromContext(MfaDestSession, c.Ctx.Input.CruSession, mfaCacheKey)
if dest == nil { if dest == "" {
c.ResponseError("destination is missing") c.ResponseError("destination is missing")
return return
} }
user.Email = dest.(string) user.Email = dest
} }
} else if mfaType == object.SmsType { } else if mfaType == object.SmsType {
if user.Phone == "" { if user.Phone == "" {
dest := c.GetSession(MfaDestSession) dest := object.GetPropsFromContext(MfaDestSession, c.Ctx.Input.CruSession, mfaCacheKey)
if dest == nil { if dest == "" {
c.ResponseError("destination is missing") c.ResponseError("destination is missing")
return return
} }
user.Phone = dest.(string) user.Phone = dest
countryCode := c.GetSession(MfaCountryCodeSession) countryCode := object.GetPropsFromContext(MfaCountryCodeSession, c.Ctx.Input.CruSession, mfaCacheKey)
if countryCode == nil { if countryCode == "" {
c.ResponseError("country code is missing") c.ResponseError("country code is missing")
return return
} }
user.CountryCode = countryCode.(string) user.CountryCode = countryCode
} }
} }
recoveryCodes := c.GetSession(MfaRecoveryCodesSession) recoveryCodes := object.GetPropsFromContext(MfaRecoveryCodesSession, c.Ctx.Input.CruSession, mfaCacheKey)
if recoveryCodes == nil { if recoveryCodes == "" {
c.ResponseError("recovery codes is missing") c.ResponseError("recovery codes is missing")
return return
} }
config.RecoveryCodes = []string{recoveryCodes.(string)} config.RecoveryCodes = []string{recoveryCodes}
mfaUtil := object.GetMfaUtil(mfaType, config) mfaUtil := object.GetMfaUtil(mfaType, config)
if mfaUtil == nil { if mfaUtil == nil {
@ -226,6 +241,7 @@ func (c *ApiController) MfaSetupEnable() {
return return
} }
if mfaCacheKey == "" {
c.DelSession(MfaRecoveryCodesSession) c.DelSession(MfaRecoveryCodesSession)
if mfaType == object.TotpType { if mfaType == object.TotpType {
c.DelSession(MfaTotpSecretSession) c.DelSession(MfaTotpSecretSession)
@ -233,6 +249,9 @@ func (c *ApiController) MfaSetupEnable() {
c.DelSession(MfaCountryCodeSession) c.DelSession(MfaCountryCodeSession)
c.DelSession(MfaDestSession) c.DelSession(MfaDestSession)
} }
} else {
object.MfaCache.Delete(mfaCacheKey)
}
c.ResponseOk(http.StatusText(http.StatusOK)) c.ResponseOk(http.StatusText(http.StatusOK))
} }

View File

@ -247,7 +247,7 @@ func (c *ApiController) SendVerificationCode() {
vform.Dest = mfaProps.Secret vform.Dest = mfaProps.Secret
} }
} else if vform.Method == MfaSetupVerification { } else if vform.Method == MfaSetupVerification {
c.SetSession(MfaDestSession, vform.Dest) object.SetPropsFromContext(MfaDestSession, vform.Dest, c.Ctx.Input.CruSession, vform.MfaCacheKey)
} }
provider, err = application.GetEmailProvider(vform.Method) provider, err = application.GetEmailProvider(vform.Method)
@ -284,8 +284,8 @@ func (c *ApiController) SendVerificationCode() {
} }
if vform.Method == MfaSetupVerification { if vform.Method == MfaSetupVerification {
c.SetSession(MfaCountryCodeSession, vform.CountryCode) object.SetPropsFromContext(MfaCountryCodeSession, vform.CountryCode, c.Ctx.Input.CruSession, vform.MfaCacheKey)
c.SetSession(MfaDestSession, vform.Dest) object.SetPropsFromContext(MfaDestSession, vform.Dest, c.Ctx.Input.CruSession, vform.MfaCacheKey)
} }
} else if vform.Method == MfaAuthVerification { } else if vform.Method == MfaAuthVerification {
mfaProps := user.GetPreferredMfaProps(false) mfaProps := user.GetPreferredMfaProps(false)

View File

@ -31,6 +31,8 @@ type VerificationForm struct {
CaptchaType string `form:"captchaType"` CaptchaType string `form:"captchaType"`
ClientSecret string `form:"clientSecret"` ClientSecret string `form:"clientSecret"`
CaptchaToken string `form:"captchaToken"` CaptchaToken string `form:"captchaToken"`
MfaCacheKey string `form:"mfaCacheKey"`
} }
const ( const (

View File

@ -16,7 +16,9 @@ package object
import ( import (
"fmt" "fmt"
"sync"
"github.com/beego/beego/session"
"github.com/casdoor/casdoor/util" "github.com/casdoor/casdoor/util"
) )
@ -49,6 +51,8 @@ const (
RequiredMfa = "RequiredMfa" RequiredMfa = "RequiredMfa"
) )
var MfaCache = sync.Map{}
func GetMfaUtil(mfaType string, config *MfaProps) MfaInterface { func GetMfaUtil(mfaType string, config *MfaProps) MfaInterface {
switch mfaType { switch mfaType {
case SmsType: case SmsType:
@ -183,3 +187,41 @@ func SetPreferredMultiFactorAuth(user *User, mfaType string) error {
} }
return nil return nil
} }
func GetPropsFromContext(key string, curSession session.Store, mfaCacheKey string) string {
if mfaCacheKey != "" {
propMap, exist := MfaCache.Load(mfaCacheKey)
if !exist {
return ""
}
if propMap == nil {
return ""
}
return propMap.(map[string]string)[key]
}
val := curSession.Get(key)
if val != nil {
return val.(string)
}
return ""
}
func SetPropsFromContext(key string, value string, curSession session.Store, mfaCacheKey string) {
if mfaCacheKey != "" {
propMap, exist := MfaCache.Load(mfaCacheKey)
if !exist {
return
}
if propMap == nil {
return
}
propMap.(map[string]string)[key] = value
MfaCache.Store(mfaCacheKey, propMap)
}
err := curSession.Set(key, value)
if err != nil {
return
}
}