feat: add PKCE support (#434)

* feat: add PKCE support

Signed-off-by: Steve0x2a <stevesough@gmail.com>

* fix: error output when challenge is empty

Signed-off-by: Steve0x2a <stevesough@gmail.com>
This commit is contained in:
Steve0x2a
2022-01-21 09:29:19 +08:00
committed by GitHub
parent 339a85e4b0
commit 630b84f534
5 changed files with 68 additions and 25 deletions

View File

@ -52,7 +52,14 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob
scope := c.Input().Get("scope")
state := c.Input().Get("state")
nonce := c.Input().Get("nonce")
code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce)
challengeMethod := c.Input().Get("code_challenge_method")
codeChallenge := c.Input().Get("code_challenge")
if challengeMethod != "S256" && challengeMethod != "null" {
c.ResponseError("Challenge method should be S256")
return
}
code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge)
resp = codeToResponse(code)
if application.EnableSigninSession || application.HasPromptPage() {

View File

@ -142,7 +142,15 @@ func (c *ApiController) GetOAuthCode() {
state := c.Input().Get("state")
nonce := c.Input().Get("nonce")
c.Data["json"] = object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce)
challengeMethod := c.Input().Get("code_challenge_method")
codeChallenge := c.Input().Get("code_challenge")
if challengeMethod != "S256" && challengeMethod != "null" {
c.ResponseError("Challenge method should be S256")
return
}
c.Data["json"] = object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge)
c.ServeJSON()
}
@ -161,12 +169,13 @@ func (c *ApiController) GetOAuthToken() {
clientId := c.Input().Get("client_id")
clientSecret := c.Input().Get("client_secret")
code := c.Input().Get("code")
verifier := c.Input().Get("code_verifier")
if clientId == "" && clientSecret == "" {
clientId, clientSecret, _ = c.Ctx.Request.BasicAuth()
}
c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code)
c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier)
c.ServeJSON()
}

View File

@ -15,6 +15,8 @@
package object
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"strings"
@ -42,6 +44,7 @@ type Token struct {
ExpiresIn int `json:"expiresIn"`
Scope string `xorm:"varchar(100)" json:"scope"`
TokenType string `xorm:"varchar(100)" json:"tokenType"`
CodeChallenge string `xorm:"varchar(100)" json:"codeChallenge"`
}
type TokenWrapper struct {
@ -182,7 +185,7 @@ func CheckOAuthLogin(clientId string, responseType string, redirectUri string, s
return "", application
}
func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string) *Code {
func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string) *Code {
user := GetUser(userId)
if user == nil {
return &Code{
@ -210,6 +213,10 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU
panic(err)
}
if challenge == "null"{
challenge = ""
}
token := &Token{
Owner: application.Owner,
Name: util.GenerateId(),
@ -223,6 +230,7 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU
ExpiresIn: application.ExpireInHours * 60,
Scope: scope,
TokenType: "Bearer",
CodeChallenge: challenge,
}
AddToken(token)
@ -232,7 +240,7 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU
}
}
func GetOAuthToken(grantType string, clientId string, clientSecret string, code string) *TokenWrapper {
func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string) *TokenWrapper {
application := GetApplicationByClientId(clientId)
if application == nil {
return &TokenWrapper{
@ -288,6 +296,14 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code
Scope: "",
}
}
if token.CodeChallenge != "" && pkceChallenge(verifier) != token.CodeChallenge {
return &TokenWrapper{
AccessToken: "error: incorrect code_verifier",
TokenType: "",
ExpiresIn: 0,
Scope: "",
}
}
tokenWrapper := &TokenWrapper{
AccessToken: token.AccessToken,
@ -392,3 +408,10 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId
return tokenWrapper
}
// PkceChallenge: base64-URL-encoded SHA256 hash of verifier, per rfc 7636
func pkceChallenge(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(sum[:])
return challenge
}

View File

@ -44,7 +44,7 @@ function oAuthParamsToQuery(oAuthParams) {
}
// code
return `?clientId=${oAuthParams.clientId}&responseType=${oAuthParams.responseType}&redirectUri=${oAuthParams.redirectUri}&scope=${oAuthParams.scope}&state=${oAuthParams.state}&nonce=${oAuthParams.nonce}`;
return `?clientId=${oAuthParams.clientId}&responseType=${oAuthParams.responseType}&redirectUri=${oAuthParams.redirectUri}&scope=${oAuthParams.scope}&state=${oAuthParams.state}&nonce=${oAuthParams.nonce}&code_challenge_method=${oAuthParams.challengeMethod}&code_challenge=${oAuthParams.codeChallenge}`;
}
export function getApplicationLogin(oAuthParams) {

View File

@ -83,6 +83,8 @@ export function getOAuthGetParameters(params) {
const scope = queries.get("scope");
const state = queries.get("state");
const nonce = queries.get("nonce")
const challengeMethod = queries.get("code_challenge_method")
const codeChallenge = queries.get("code_challenge")
if (clientId === undefined || clientId === null) {
// login
@ -96,6 +98,8 @@ export function getOAuthGetParameters(params) {
scope: scope,
state: state,
nonce: nonce,
challengeMethod: challengeMethod,
codeChallenge: codeChallenge,
};
}
}