From 630b84f5343eebf686fd26e5ee1a37d023043ddf Mon Sep 17 00:00:00 2001 From: Steve0x2a Date: Fri, 21 Jan 2022 09:29:19 +0800 Subject: [PATCH] feat: add PKCE support (#434) * feat: add PKCE support Signed-off-by: Steve0x2a * fix: error output when challenge is empty Signed-off-by: Steve0x2a --- controllers/auth.go | 9 +++++- controllers/token.go | 13 ++++++-- object/token.go | 63 +++++++++++++++++++++++++------------ web/src/auth/AuthBackend.js | 2 +- web/src/auth/Util.js | 6 +++- 5 files changed, 68 insertions(+), 25 deletions(-) diff --git a/controllers/auth.go b/controllers/auth.go index 1efe89ba..2930dd29 100644 --- a/controllers/auth.go +++ b/controllers/auth.go @@ -52,7 +52,14 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob scope := c.Input().Get("scope") state := c.Input().Get("state") nonce := c.Input().Get("nonce") - code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce) + challengeMethod := c.Input().Get("code_challenge_method") + codeChallenge := c.Input().Get("code_challenge") + + if challengeMethod != "S256" && challengeMethod != "null" { + c.ResponseError("Challenge method should be S256") + return + } + code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge) resp = codeToResponse(code) if application.EnableSigninSession || application.HasPromptPage() { diff --git a/controllers/token.go b/controllers/token.go index c16d10a4..9e990c5b 100644 --- a/controllers/token.go +++ b/controllers/token.go @@ -142,7 +142,15 @@ func (c *ApiController) GetOAuthCode() { state := c.Input().Get("state") nonce := c.Input().Get("nonce") - c.Data["json"] = object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce) + challengeMethod := c.Input().Get("code_challenge_method") + codeChallenge := c.Input().Get("code_challenge") + + if challengeMethod != "S256" && challengeMethod != "null" { + c.ResponseError("Challenge method should be S256") + return + } + + c.Data["json"] = object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge) c.ServeJSON() } @@ -161,12 +169,13 @@ func (c *ApiController) GetOAuthToken() { clientId := c.Input().Get("client_id") clientSecret := c.Input().Get("client_secret") code := c.Input().Get("code") + verifier := c.Input().Get("code_verifier") if clientId == "" && clientSecret == "" { clientId, clientSecret, _ = c.Ctx.Request.BasicAuth() } - c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code) + c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier) c.ServeJSON() } diff --git a/object/token.go b/object/token.go index c1d73e60..7770f80c 100644 --- a/object/token.go +++ b/object/token.go @@ -15,6 +15,8 @@ package object import ( + "crypto/sha256" + "encoding/base64" "fmt" "strings" @@ -36,12 +38,13 @@ type Token struct { Organization string `xorm:"varchar(100)" json:"organization"` User string `xorm:"varchar(100)" json:"user"` - Code string `xorm:"varchar(100)" json:"code"` - AccessToken string `xorm:"mediumtext" json:"accessToken"` - RefreshToken string `xorm:"mediumtext" json:"refreshToken"` - ExpiresIn int `json:"expiresIn"` - Scope string `xorm:"varchar(100)" json:"scope"` - TokenType string `xorm:"varchar(100)" json:"tokenType"` + Code string `xorm:"varchar(100)" json:"code"` + AccessToken string `xorm:"mediumtext" json:"accessToken"` + RefreshToken string `xorm:"mediumtext" json:"refreshToken"` + ExpiresIn int `json:"expiresIn"` + Scope string `xorm:"varchar(100)" json:"scope"` + TokenType string `xorm:"varchar(100)" json:"tokenType"` + CodeChallenge string `xorm:"varchar(100)" json:"codeChallenge"` } type TokenWrapper struct { @@ -182,7 +185,7 @@ func CheckOAuthLogin(clientId string, responseType string, redirectUri string, s return "", application } -func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string) *Code { +func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string) *Code { user := GetUser(userId) if user == nil { return &Code{ @@ -210,19 +213,24 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU panic(err) } + if challenge == "null"{ + challenge = "" + } + token := &Token{ - Owner: application.Owner, - Name: util.GenerateId(), - CreatedTime: util.GetCurrentTime(), - Application: application.Name, - Organization: user.Owner, - User: user.Name, - Code: util.GenerateClientId(), - AccessToken: accessToken, - RefreshToken: refreshToken, - ExpiresIn: application.ExpireInHours * 60, - Scope: scope, - TokenType: "Bearer", + Owner: application.Owner, + Name: util.GenerateId(), + CreatedTime: util.GetCurrentTime(), + Application: application.Name, + Organization: user.Owner, + User: user.Name, + Code: util.GenerateClientId(), + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresIn: application.ExpireInHours * 60, + Scope: scope, + TokenType: "Bearer", + CodeChallenge: challenge, } AddToken(token) @@ -232,7 +240,7 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU } } -func GetOAuthToken(grantType string, clientId string, clientSecret string, code string) *TokenWrapper { +func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string) *TokenWrapper { application := GetApplicationByClientId(clientId) if application == nil { return &TokenWrapper{ @@ -288,6 +296,14 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code Scope: "", } } + if token.CodeChallenge != "" && pkceChallenge(verifier) != token.CodeChallenge { + return &TokenWrapper{ + AccessToken: "error: incorrect code_verifier", + TokenType: "", + ExpiresIn: 0, + Scope: "", + } + } tokenWrapper := &TokenWrapper{ AccessToken: token.AccessToken, @@ -392,3 +408,10 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId return tokenWrapper } + +// PkceChallenge: base64-URL-encoded SHA256 hash of verifier, per rfc 7636 +func pkceChallenge(verifier string) string { + sum := sha256.Sum256([]byte(verifier)) + challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(sum[:]) + return challenge +} diff --git a/web/src/auth/AuthBackend.js b/web/src/auth/AuthBackend.js index 708de631..5b5c4f67 100644 --- a/web/src/auth/AuthBackend.js +++ b/web/src/auth/AuthBackend.js @@ -44,7 +44,7 @@ function oAuthParamsToQuery(oAuthParams) { } // code - return `?clientId=${oAuthParams.clientId}&responseType=${oAuthParams.responseType}&redirectUri=${oAuthParams.redirectUri}&scope=${oAuthParams.scope}&state=${oAuthParams.state}&nonce=${oAuthParams.nonce}`; + return `?clientId=${oAuthParams.clientId}&responseType=${oAuthParams.responseType}&redirectUri=${oAuthParams.redirectUri}&scope=${oAuthParams.scope}&state=${oAuthParams.state}&nonce=${oAuthParams.nonce}&code_challenge_method=${oAuthParams.challengeMethod}&code_challenge=${oAuthParams.codeChallenge}`; } export function getApplicationLogin(oAuthParams) { diff --git a/web/src/auth/Util.js b/web/src/auth/Util.js index 44b7c169..1dd112a6 100644 --- a/web/src/auth/Util.js +++ b/web/src/auth/Util.js @@ -83,7 +83,9 @@ export function getOAuthGetParameters(params) { const scope = queries.get("scope"); const state = queries.get("state"); const nonce = queries.get("nonce") - + const challengeMethod = queries.get("code_challenge_method") + const codeChallenge = queries.get("code_challenge") + if (clientId === undefined || clientId === null) { // login return null; @@ -96,6 +98,8 @@ export function getOAuthGetParameters(params) { scope: scope, state: state, nonce: nonce, + challengeMethod: challengeMethod, + codeChallenge: codeChallenge, }; } }