mirror of
https://github.com/casdoor/casdoor.git
synced 2025-05-23 02:35:49 +08:00
feat: refactor MFA code and fix no-session bug (#2676)
* refactor: refactor mfa * refactor: refactor mfa * refactor: refactor mfa * lint * chore: reduce wait time
This commit is contained in:
parent
06ef97a080
commit
a60be2b2ab
@ -19,6 +19,14 @@ import (
|
||||
|
||||
"github.com/casdoor/casdoor/object"
|
||||
"github.com/casdoor/casdoor/util"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
MfaRecoveryCodesSession = "mfa_recovery_codes"
|
||||
MfaCountryCodeSession = "mfa_country_code"
|
||||
MfaDestSession = "mfa_dest"
|
||||
MfaTotpSecretSession = "mfa_totp_secret"
|
||||
)
|
||||
|
||||
// MfaSetupInitiate
|
||||
@ -57,12 +65,20 @@ func (c *ApiController) MfaSetupInitiate() {
|
||||
return
|
||||
}
|
||||
|
||||
mfaProps, err := MfaUtil.Initiate(c.Ctx, user.GetId())
|
||||
mfaProps, err := MfaUtil.Initiate(user.GetId())
|
||||
if err != nil {
|
||||
c.ResponseError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
recoveryCode := uuid.NewString()
|
||||
c.SetSession(MfaRecoveryCodesSession, recoveryCode)
|
||||
if mfaType == object.TotpType {
|
||||
c.SetSession(MfaTotpSecretSession, mfaProps.Secret)
|
||||
}
|
||||
|
||||
mfaProps.RecoveryCodes = []string{recoveryCode}
|
||||
|
||||
resp := mfaProps
|
||||
c.ResponseOk(resp)
|
||||
}
|
||||
@ -83,13 +99,39 @@ func (c *ApiController) MfaSetupVerify() {
|
||||
c.ResponseError("missing auth type or passcode")
|
||||
return
|
||||
}
|
||||
mfaUtil := object.GetMfaUtil(mfaType, nil)
|
||||
|
||||
config := &object.MfaProps{
|
||||
MfaType: mfaType,
|
||||
}
|
||||
if mfaType == object.TotpType {
|
||||
secret := c.GetSession(MfaTotpSecretSession)
|
||||
if secret == nil {
|
||||
c.ResponseError("totp secret is missing")
|
||||
return
|
||||
}
|
||||
config.Secret = secret.(string)
|
||||
} else if mfaType == object.EmailType || mfaType == object.SmsType {
|
||||
dest := c.GetSession(MfaDestSession)
|
||||
if dest == nil {
|
||||
c.ResponseError("destination is missing")
|
||||
return
|
||||
}
|
||||
config.Secret = dest.(string)
|
||||
countryCode := c.GetSession(MfaCountryCodeSession)
|
||||
if countryCode == nil {
|
||||
c.ResponseError("country code is missing")
|
||||
return
|
||||
}
|
||||
config.CountryCode = countryCode.(string)
|
||||
}
|
||||
|
||||
mfaUtil := object.GetMfaUtil(mfaType, config)
|
||||
if mfaUtil == nil {
|
||||
c.ResponseError("Invalid multi-factor authentication type")
|
||||
return
|
||||
}
|
||||
|
||||
err := mfaUtil.SetupVerify(c.Ctx, passcode)
|
||||
err := mfaUtil.SetupVerify(passcode)
|
||||
if err != nil {
|
||||
c.ResponseError(err.Error())
|
||||
} else {
|
||||
@ -122,18 +164,58 @@ func (c *ApiController) MfaSetupEnable() {
|
||||
return
|
||||
}
|
||||
|
||||
mfaUtil := object.GetMfaUtil(mfaType, nil)
|
||||
config := &object.MfaProps{
|
||||
MfaType: mfaType,
|
||||
}
|
||||
|
||||
if mfaType == object.TotpType {
|
||||
secret := c.GetSession(MfaTotpSecretSession)
|
||||
if secret == nil {
|
||||
c.ResponseError("totp secret is missing")
|
||||
return
|
||||
}
|
||||
config.Secret = secret.(string)
|
||||
} else if mfaType == object.EmailType || mfaType == object.SmsType {
|
||||
dest := c.GetSession(MfaDestSession)
|
||||
if dest == nil {
|
||||
c.ResponseError("destination is missing")
|
||||
return
|
||||
}
|
||||
config.Secret = dest.(string)
|
||||
countryCode := c.GetSession(MfaCountryCodeSession)
|
||||
if countryCode == nil {
|
||||
c.ResponseError("country code is missing")
|
||||
return
|
||||
}
|
||||
config.CountryCode = countryCode.(string)
|
||||
}
|
||||
recoveryCodes := c.GetSession(MfaRecoveryCodesSession)
|
||||
if recoveryCodes == nil {
|
||||
c.ResponseError("recovery codes is missing")
|
||||
return
|
||||
}
|
||||
config.RecoveryCodes = []string{recoveryCodes.(string)}
|
||||
|
||||
mfaUtil := object.GetMfaUtil(mfaType, config)
|
||||
if mfaUtil == nil {
|
||||
c.ResponseError("Invalid multi-factor authentication type")
|
||||
return
|
||||
}
|
||||
|
||||
err = mfaUtil.Enable(c.Ctx, user)
|
||||
err = mfaUtil.Enable(user)
|
||||
if err != nil {
|
||||
c.ResponseError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.DelSession(MfaRecoveryCodesSession)
|
||||
if mfaType == object.TotpType {
|
||||
c.DelSession(MfaTotpSecretSession)
|
||||
} else {
|
||||
c.DelSession(MfaCountryCodeSession)
|
||||
c.DelSession(MfaDestSession)
|
||||
}
|
||||
|
||||
c.ResponseOk(http.StatusText(http.StatusOK))
|
||||
}
|
||||
|
||||
|
@ -161,7 +161,7 @@ func (c *ApiController) SendVerificationCode() {
|
||||
vform.Dest = mfaProps.Secret
|
||||
}
|
||||
} else if vform.Method == MfaSetupVerification {
|
||||
c.SetSession(object.MfaDestSession, vform.Dest)
|
||||
c.SetSession(MfaDestSession, vform.Dest)
|
||||
}
|
||||
|
||||
provider, err := application.GetEmailProvider()
|
||||
@ -198,8 +198,8 @@ func (c *ApiController) SendVerificationCode() {
|
||||
}
|
||||
|
||||
if vform.Method == MfaSetupVerification {
|
||||
c.SetSession(object.MfaCountryCodeSession, vform.CountryCode)
|
||||
c.SetSession(object.MfaDestSession, vform.Dest)
|
||||
c.SetSession(MfaCountryCodeSession, vform.CountryCode)
|
||||
c.SetSession(MfaDestSession, vform.Dest)
|
||||
}
|
||||
} else if vform.Method == MfaAuthVerification {
|
||||
mfaProps := user.GetPreferredMfaProps(false)
|
||||
|
@ -18,12 +18,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/casdoor/casdoor/util"
|
||||
|
||||
"github.com/beego/beego/context"
|
||||
)
|
||||
|
||||
const MfaRecoveryCodesSession = "mfa_recovery_codes"
|
||||
|
||||
type MfaProps struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
IsPreferred bool `json:"isPreferred"`
|
||||
@ -35,9 +31,9 @@ type MfaProps struct {
|
||||
}
|
||||
|
||||
type MfaInterface interface {
|
||||
Initiate(ctx *context.Context, userId string) (*MfaProps, error)
|
||||
SetupVerify(ctx *context.Context, passcode string) error
|
||||
Enable(ctx *context.Context, user *User) error
|
||||
Initiate(userId string) (*MfaProps, error)
|
||||
SetupVerify(passcode string) error
|
||||
Enable(user *User) error
|
||||
Verify(passcode string) error
|
||||
}
|
||||
|
||||
|
@ -16,85 +16,55 @@ package object
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/beego/beego/context"
|
||||
"github.com/casdoor/casdoor/util"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
MfaCountryCodeSession = "mfa_country_code"
|
||||
MfaDestSession = "mfa_dest"
|
||||
)
|
||||
|
||||
type SmsMfa struct {
|
||||
Config *MfaProps
|
||||
*MfaProps
|
||||
}
|
||||
|
||||
func (mfa *SmsMfa) Initiate(ctx *context.Context, userId string) (*MfaProps, error) {
|
||||
recoveryCode := uuid.NewString()
|
||||
|
||||
ctx.Output.Session(MfaRecoveryCodesSession, []string{recoveryCode})
|
||||
|
||||
func (mfa *SmsMfa) Initiate(userId string) (*MfaProps, error) {
|
||||
mfaProps := MfaProps{
|
||||
MfaType: mfa.Config.MfaType,
|
||||
RecoveryCodes: []string{recoveryCode},
|
||||
MfaType: mfa.MfaType,
|
||||
}
|
||||
return &mfaProps, nil
|
||||
}
|
||||
|
||||
func (mfa *SmsMfa) SetupVerify(ctx *context.Context, passCode string) error {
|
||||
destSession := ctx.Input.CruSession.Get(MfaDestSession)
|
||||
if destSession == nil {
|
||||
return errors.New("dest session is missing")
|
||||
}
|
||||
dest := destSession.(string)
|
||||
|
||||
if !util.IsEmailValid(dest) {
|
||||
countryCodeSession := ctx.Input.CruSession.Get(MfaCountryCodeSession)
|
||||
if countryCodeSession == nil {
|
||||
return errors.New("country code is missing")
|
||||
}
|
||||
countryCode := countryCodeSession.(string)
|
||||
|
||||
dest, _ = util.GetE164Number(dest, countryCode)
|
||||
func (mfa *SmsMfa) SetupVerify(passCode string) error {
|
||||
if !util.IsEmailValid(mfa.Secret) {
|
||||
mfa.Secret, _ = util.GetE164Number(mfa.Secret, mfa.CountryCode)
|
||||
}
|
||||
|
||||
if result := CheckVerificationCode(dest, passCode, "en"); result.Code != VerificationSuccess {
|
||||
if result := CheckVerificationCode(mfa.Secret, passCode, "en"); result.Code != VerificationSuccess {
|
||||
return errors.New(result.Msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mfa *SmsMfa) Enable(ctx *context.Context, user *User) error {
|
||||
recoveryCodes := ctx.Input.CruSession.Get(MfaRecoveryCodesSession).([]string)
|
||||
if len(recoveryCodes) == 0 {
|
||||
return fmt.Errorf("recovery codes is missing")
|
||||
}
|
||||
|
||||
func (mfa *SmsMfa) Enable(user *User) error {
|
||||
columns := []string{"recovery_codes", "preferred_mfa_type"}
|
||||
|
||||
user.RecoveryCodes = append(user.RecoveryCodes, recoveryCodes...)
|
||||
user.RecoveryCodes = append(user.RecoveryCodes, mfa.RecoveryCodes...)
|
||||
if user.PreferredMfaType == "" {
|
||||
user.PreferredMfaType = mfa.Config.MfaType
|
||||
user.PreferredMfaType = mfa.MfaType
|
||||
}
|
||||
|
||||
if mfa.Config.MfaType == SmsType {
|
||||
if mfa.MfaType == SmsType {
|
||||
user.MfaPhoneEnabled = true
|
||||
columns = append(columns, "mfa_phone_enabled")
|
||||
|
||||
if user.Phone == "" {
|
||||
user.Phone = ctx.Input.CruSession.Get(MfaDestSession).(string)
|
||||
user.CountryCode = ctx.Input.CruSession.Get(MfaCountryCodeSession).(string)
|
||||
user.Phone = mfa.Secret
|
||||
user.CountryCode = mfa.CountryCode
|
||||
columns = append(columns, "phone", "country_code")
|
||||
}
|
||||
} else if mfa.Config.MfaType == EmailType {
|
||||
} else if mfa.MfaType == EmailType {
|
||||
user.MfaEmailEnabled = true
|
||||
columns = append(columns, "mfa_email_enabled")
|
||||
|
||||
if user.Email == "" {
|
||||
user.Email = ctx.Input.CruSession.Get(MfaDestSession).(string)
|
||||
user.Email = mfa.Secret
|
||||
columns = append(columns, "email")
|
||||
}
|
||||
}
|
||||
@ -104,18 +74,14 @@ func (mfa *SmsMfa) Enable(ctx *context.Context, user *User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Input.CruSession.Delete(MfaRecoveryCodesSession)
|
||||
ctx.Input.CruSession.Delete(MfaDestSession)
|
||||
ctx.Input.CruSession.Delete(MfaCountryCodeSession)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mfa *SmsMfa) Verify(passCode string) error {
|
||||
if !util.IsEmailValid(mfa.Config.Secret) {
|
||||
mfa.Config.Secret, _ = util.GetE164Number(mfa.Config.Secret, mfa.Config.CountryCode)
|
||||
if !util.IsEmailValid(mfa.Secret) {
|
||||
mfa.Secret, _ = util.GetE164Number(mfa.Secret, mfa.CountryCode)
|
||||
}
|
||||
if result := CheckVerificationCode(mfa.Config.Secret, passCode, "en"); result.Code != VerificationSuccess {
|
||||
if result := CheckVerificationCode(mfa.Secret, passCode, "en"); result.Code != VerificationSuccess {
|
||||
return errors.New(result.Msg)
|
||||
}
|
||||
return nil
|
||||
@ -128,7 +94,7 @@ func NewSmsMfaUtil(config *MfaProps) *SmsMfa {
|
||||
}
|
||||
}
|
||||
return &SmsMfa{
|
||||
Config: config,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,6 +105,6 @@ func NewEmailMfaUtil(config *MfaProps) *SmsMfa {
|
||||
}
|
||||
}
|
||||
return &SmsMfa{
|
||||
Config: config,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
@ -16,28 +16,24 @@ package object
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/beego/beego/context"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
const (
|
||||
MfaTotpSecretSession = "mfa_totp_secret"
|
||||
MfaTotpPeriodInSeconds = 30
|
||||
)
|
||||
|
||||
type TotpMfa struct {
|
||||
Config *MfaProps
|
||||
*MfaProps
|
||||
period uint
|
||||
secretSize uint
|
||||
digits otp.Digits
|
||||
}
|
||||
|
||||
func (mfa *TotpMfa) Initiate(ctx *context.Context, userId string) (*MfaProps, error) {
|
||||
func (mfa *TotpMfa) Initiate(userId string) (*MfaProps, error) {
|
||||
//issuer := beego.AppConfig.String("appname")
|
||||
//if issuer == "" {
|
||||
// issuer = "casdoor"
|
||||
@ -55,27 +51,16 @@ func (mfa *TotpMfa) Initiate(ctx *context.Context, userId string) (*MfaProps, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Output.Session(MfaTotpSecretSession, key.Secret())
|
||||
|
||||
recoveryCode := uuid.NewString()
|
||||
ctx.Output.Session(MfaRecoveryCodesSession, []string{recoveryCode})
|
||||
|
||||
mfaProps := MfaProps{
|
||||
MfaType: mfa.Config.MfaType,
|
||||
RecoveryCodes: []string{recoveryCode},
|
||||
MfaType: mfa.MfaType,
|
||||
Secret: key.Secret(),
|
||||
URL: key.URL(),
|
||||
}
|
||||
return &mfaProps, nil
|
||||
}
|
||||
|
||||
func (mfa *TotpMfa) SetupVerify(ctx *context.Context, passcode string) error {
|
||||
secret := ctx.Input.CruSession.Get(MfaTotpSecretSession)
|
||||
if secret == nil {
|
||||
return errors.New("totp secret is missing")
|
||||
}
|
||||
|
||||
result, err := totp.ValidateCustom(passcode, secret.(string), time.Now().UTC(), totp.ValidateOpts{
|
||||
func (mfa *TotpMfa) SetupVerify(passcode string) error {
|
||||
result, err := totp.ValidateCustom(passcode, mfa.Secret, time.Now().UTC(), totp.ValidateOpts{
|
||||
Period: MfaTotpPeriodInSeconds,
|
||||
Skew: 1,
|
||||
Digits: otp.DigitsSix,
|
||||
@ -92,22 +77,13 @@ func (mfa *TotpMfa) SetupVerify(ctx *context.Context, passcode string) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (mfa *TotpMfa) Enable(ctx *context.Context, user *User) error {
|
||||
recoveryCodes := ctx.Input.CruSession.Get(MfaRecoveryCodesSession).([]string)
|
||||
if len(recoveryCodes) == 0 {
|
||||
return fmt.Errorf("recovery codes is missing")
|
||||
}
|
||||
secret := ctx.Input.CruSession.Get(MfaTotpSecretSession).(string)
|
||||
if secret == "" {
|
||||
return fmt.Errorf("totp secret is missing")
|
||||
}
|
||||
|
||||
func (mfa *TotpMfa) Enable(user *User) error {
|
||||
columns := []string{"recovery_codes", "preferred_mfa_type", "totp_secret"}
|
||||
|
||||
user.RecoveryCodes = append(user.RecoveryCodes, recoveryCodes...)
|
||||
user.TotpSecret = secret
|
||||
user.RecoveryCodes = append(user.RecoveryCodes, mfa.RecoveryCodes...)
|
||||
user.TotpSecret = mfa.Secret
|
||||
if user.PreferredMfaType == "" {
|
||||
user.PreferredMfaType = mfa.Config.MfaType
|
||||
user.PreferredMfaType = mfa.MfaType
|
||||
}
|
||||
|
||||
_, err := updateUser(user.GetId(), user, columns)
|
||||
@ -115,14 +91,11 @@ func (mfa *TotpMfa) Enable(ctx *context.Context, user *User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Input.CruSession.Delete(MfaRecoveryCodesSession)
|
||||
ctx.Input.CruSession.Delete(MfaTotpSecretSession)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mfa *TotpMfa) Verify(passcode string) error {
|
||||
result, err := totp.ValidateCustom(passcode, mfa.Config.Secret, time.Now().UTC(), totp.ValidateOpts{
|
||||
result, err := totp.ValidateCustom(passcode, mfa.Secret, time.Now().UTC(), totp.ValidateOpts{
|
||||
Period: MfaTotpPeriodInSeconds,
|
||||
Skew: 1,
|
||||
Digits: otp.DigitsSix,
|
||||
@ -147,7 +120,7 @@ func NewTotpMfaUtil(config *MfaProps) *TotpMfa {
|
||||
}
|
||||
|
||||
return &TotpMfa{
|
||||
Config: config,
|
||||
MfaProps: config,
|
||||
period: MfaTotpPeriodInSeconds,
|
||||
secretSize: 20,
|
||||
digits: otp.DigitsSix,
|
||||
|
@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
import React from "react";
|
||||
import {Button, Col, Result, Row, Steps} from "antd";
|
||||
import {Button, Col, Result, Row, Spin, Steps} from "antd";
|
||||
import {withRouter} from "react-router-dom";
|
||||
import * as ApplicationBackend from "../backend/ApplicationBackend";
|
||||
import * as Setting from "../Setting";
|
||||
@ -42,13 +42,20 @@ class MfaSetupPage extends React.Component {
|
||||
mfaProps: null,
|
||||
mfaType: params.get("mfaType") ?? SmsMfaType,
|
||||
isPromptPage: props.isPromptPage || location.state?.from !== undefined,
|
||||
loading: false,
|
||||
};
|
||||
}
|
||||
|
||||
componentDidMount() {
|
||||
this.getApplication();
|
||||
if (this.state.current === 1) {
|
||||
this.setState({
|
||||
loading: true,
|
||||
});
|
||||
|
||||
setTimeout(() => {
|
||||
this.initMfaProps();
|
||||
}, 200);
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,6 +92,7 @@ class MfaSetupPage extends React.Component {
|
||||
if (res.status === "ok") {
|
||||
this.setState({
|
||||
mfaProps: res.data,
|
||||
loading: false,
|
||||
});
|
||||
} else {
|
||||
Setting.showMessage("error", i18next.t("mfa:Failed to initiate MFA"));
|
||||
@ -231,6 +239,7 @@ class MfaSetupPage extends React.Component {
|
||||
<p style={{textAlign: "center", fontSize: "16px", marginTop: "10px"}}>{i18next.t("mfa:Each time you sign in to your Account, you'll need your password and a authentication code")}</p>
|
||||
</Col>
|
||||
</Row>
|
||||
<Spin spinning={this.state.loading}>
|
||||
<Steps current={this.state.current}
|
||||
items={[
|
||||
{title: i18next.t("mfa:Verify Password"), icon: <UserOutlined />},
|
||||
@ -240,6 +249,7 @@ class MfaSetupPage extends React.Component {
|
||||
style={{width: "90%", maxWidth: "500px", margin: "auto", marginTop: "50px",
|
||||
}} >
|
||||
</Steps>
|
||||
</Spin>
|
||||
</Col>
|
||||
<Col span={24} style={{display: "flex", justifyContent: "center"}}>
|
||||
<div style={{marginTop: "10px", textAlign: "center"}}>
|
||||
|
Loading…
x
Reference in New Issue
Block a user