feat: add PKCE support (#434)

* feat: add PKCE support

Signed-off-by: Steve0x2a <stevesough@gmail.com>

* fix: error output when challenge is empty

Signed-off-by: Steve0x2a <stevesough@gmail.com>
This commit is contained in:
Steve0x2a
2022-01-21 09:29:19 +08:00
committed by GitHub
parent 339a85e4b0
commit 630b84f534
5 changed files with 68 additions and 25 deletions

View File

@ -52,7 +52,14 @@ func (c *ApiController) HandleLoggedIn(application *object.Application, user *ob
scope := c.Input().Get("scope") scope := c.Input().Get("scope")
state := c.Input().Get("state") state := c.Input().Get("state")
nonce := c.Input().Get("nonce") nonce := c.Input().Get("nonce")
code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce) challengeMethod := c.Input().Get("code_challenge_method")
codeChallenge := c.Input().Get("code_challenge")
if challengeMethod != "S256" && challengeMethod != "null" {
c.ResponseError("Challenge method should be S256")
return
}
code := object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge)
resp = codeToResponse(code) resp = codeToResponse(code)
if application.EnableSigninSession || application.HasPromptPage() { if application.EnableSigninSession || application.HasPromptPage() {

View File

@ -142,7 +142,15 @@ func (c *ApiController) GetOAuthCode() {
state := c.Input().Get("state") state := c.Input().Get("state")
nonce := c.Input().Get("nonce") nonce := c.Input().Get("nonce")
c.Data["json"] = object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce) challengeMethod := c.Input().Get("code_challenge_method")
codeChallenge := c.Input().Get("code_challenge")
if challengeMethod != "S256" && challengeMethod != "null" {
c.ResponseError("Challenge method should be S256")
return
}
c.Data["json"] = object.GetOAuthCode(userId, clientId, responseType, redirectUri, scope, state, nonce, codeChallenge)
c.ServeJSON() c.ServeJSON()
} }
@ -161,12 +169,13 @@ func (c *ApiController) GetOAuthToken() {
clientId := c.Input().Get("client_id") clientId := c.Input().Get("client_id")
clientSecret := c.Input().Get("client_secret") clientSecret := c.Input().Get("client_secret")
code := c.Input().Get("code") code := c.Input().Get("code")
verifier := c.Input().Get("code_verifier")
if clientId == "" && clientSecret == "" { if clientId == "" && clientSecret == "" {
clientId, clientSecret, _ = c.Ctx.Request.BasicAuth() clientId, clientSecret, _ = c.Ctx.Request.BasicAuth()
} }
c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code) c.Data["json"] = object.GetOAuthToken(grantType, clientId, clientSecret, code, verifier)
c.ServeJSON() c.ServeJSON()
} }

View File

@ -15,6 +15,8 @@
package object package object
import ( import (
"crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"strings" "strings"
@ -36,12 +38,13 @@ type Token struct {
Organization string `xorm:"varchar(100)" json:"organization"` Organization string `xorm:"varchar(100)" json:"organization"`
User string `xorm:"varchar(100)" json:"user"` User string `xorm:"varchar(100)" json:"user"`
Code string `xorm:"varchar(100)" json:"code"` Code string `xorm:"varchar(100)" json:"code"`
AccessToken string `xorm:"mediumtext" json:"accessToken"` AccessToken string `xorm:"mediumtext" json:"accessToken"`
RefreshToken string `xorm:"mediumtext" json:"refreshToken"` RefreshToken string `xorm:"mediumtext" json:"refreshToken"`
ExpiresIn int `json:"expiresIn"` ExpiresIn int `json:"expiresIn"`
Scope string `xorm:"varchar(100)" json:"scope"` Scope string `xorm:"varchar(100)" json:"scope"`
TokenType string `xorm:"varchar(100)" json:"tokenType"` TokenType string `xorm:"varchar(100)" json:"tokenType"`
CodeChallenge string `xorm:"varchar(100)" json:"codeChallenge"`
} }
type TokenWrapper struct { type TokenWrapper struct {
@ -182,7 +185,7 @@ func CheckOAuthLogin(clientId string, responseType string, redirectUri string, s
return "", application return "", application
} }
func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string) *Code { func GetOAuthCode(userId string, clientId string, responseType string, redirectUri string, scope string, state string, nonce string, challenge string) *Code {
user := GetUser(userId) user := GetUser(userId)
if user == nil { if user == nil {
return &Code{ return &Code{
@ -210,19 +213,24 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU
panic(err) panic(err)
} }
if challenge == "null"{
challenge = ""
}
token := &Token{ token := &Token{
Owner: application.Owner, Owner: application.Owner,
Name: util.GenerateId(), Name: util.GenerateId(),
CreatedTime: util.GetCurrentTime(), CreatedTime: util.GetCurrentTime(),
Application: application.Name, Application: application.Name,
Organization: user.Owner, Organization: user.Owner,
User: user.Name, User: user.Name,
Code: util.GenerateClientId(), Code: util.GenerateClientId(),
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
ExpiresIn: application.ExpireInHours * 60, ExpiresIn: application.ExpireInHours * 60,
Scope: scope, Scope: scope,
TokenType: "Bearer", TokenType: "Bearer",
CodeChallenge: challenge,
} }
AddToken(token) AddToken(token)
@ -232,7 +240,7 @@ func GetOAuthCode(userId string, clientId string, responseType string, redirectU
} }
} }
func GetOAuthToken(grantType string, clientId string, clientSecret string, code string) *TokenWrapper { func GetOAuthToken(grantType string, clientId string, clientSecret string, code string, verifier string) *TokenWrapper {
application := GetApplicationByClientId(clientId) application := GetApplicationByClientId(clientId)
if application == nil { if application == nil {
return &TokenWrapper{ return &TokenWrapper{
@ -288,6 +296,14 @@ func GetOAuthToken(grantType string, clientId string, clientSecret string, code
Scope: "", Scope: "",
} }
} }
if token.CodeChallenge != "" && pkceChallenge(verifier) != token.CodeChallenge {
return &TokenWrapper{
AccessToken: "error: incorrect code_verifier",
TokenType: "",
ExpiresIn: 0,
Scope: "",
}
}
tokenWrapper := &TokenWrapper{ tokenWrapper := &TokenWrapper{
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
@ -392,3 +408,10 @@ func RefreshToken(grantType string, refreshToken string, scope string, clientId
return tokenWrapper return tokenWrapper
} }
// PkceChallenge: base64-URL-encoded SHA256 hash of verifier, per rfc 7636
func pkceChallenge(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(sum[:])
return challenge
}

View File

@ -44,7 +44,7 @@ function oAuthParamsToQuery(oAuthParams) {
} }
// code // code
return `?clientId=${oAuthParams.clientId}&responseType=${oAuthParams.responseType}&redirectUri=${oAuthParams.redirectUri}&scope=${oAuthParams.scope}&state=${oAuthParams.state}&nonce=${oAuthParams.nonce}`; return `?clientId=${oAuthParams.clientId}&responseType=${oAuthParams.responseType}&redirectUri=${oAuthParams.redirectUri}&scope=${oAuthParams.scope}&state=${oAuthParams.state}&nonce=${oAuthParams.nonce}&code_challenge_method=${oAuthParams.challengeMethod}&code_challenge=${oAuthParams.codeChallenge}`;
} }
export function getApplicationLogin(oAuthParams) { export function getApplicationLogin(oAuthParams) {

View File

@ -83,7 +83,9 @@ export function getOAuthGetParameters(params) {
const scope = queries.get("scope"); const scope = queries.get("scope");
const state = queries.get("state"); const state = queries.get("state");
const nonce = queries.get("nonce") const nonce = queries.get("nonce")
const challengeMethod = queries.get("code_challenge_method")
const codeChallenge = queries.get("code_challenge")
if (clientId === undefined || clientId === null) { if (clientId === undefined || clientId === null) {
// login // login
return null; return null;
@ -96,6 +98,8 @@ export function getOAuthGetParameters(params) {
scope: scope, scope: scope,
state: state, state: state,
nonce: nonce, nonce: nonce,
challengeMethod: challengeMethod,
codeChallenge: codeChallenge,
}; };
} }
} }