Refactor GenerateCasToken()

This commit is contained in:
Yang Luo 2023-09-06 18:35:13 +08:00
parent a9de7d3aef
commit 3c2fd574a6

View File

@ -185,9 +185,14 @@ func StoreCasTokenForProxyTicket(token *CasAuthenticationSuccess, targetService,
} }
func GenerateCasToken(userId string, service string) (string, error) { func GenerateCasToken(userId string, service string) (string, error) {
if user, err := GetUser(userId); err != nil { user, err := GetUser(userId)
if err != nil {
return "", err return "", err
} else if user != nil { }
if user == nil {
return "", fmt.Errorf("The user: %s doesn't exist", userId)
}
authenticationSuccess := CasAuthenticationSuccess{ authenticationSuccess := CasAuthenticationSuccess{
User: user.Name, User: user.Name,
Attributes: &CasAttributes{ Attributes: &CasAttributes{
@ -196,9 +201,18 @@ func GenerateCasToken(userId string, service string) (string, error) {
}, },
ProxyGrantingTicket: fmt.Sprintf("PGTIOU-%s", util.GenerateId()), ProxyGrantingTicket: fmt.Sprintf("PGTIOU-%s", util.GenerateId()),
} }
data, _ := json.Marshal(user)
data, err := json.Marshal(user)
if err != nil {
return "", err
}
tmp := map[string]string{} tmp := map[string]string{}
json.Unmarshal(data, &tmp) err = json.Unmarshal(data, &tmp)
if err != nil {
return "", err
}
for k, v := range tmp { for k, v := range tmp {
if v != "" { if v != "" {
authenticationSuccess.Attributes.UserAttributes.Attributes = append(authenticationSuccess.Attributes.UserAttributes.Attributes, &CasNamedAttribute{ authenticationSuccess.Attributes.UserAttributes.Attributes = append(authenticationSuccess.Attributes.UserAttributes.Attributes, &CasNamedAttribute{
@ -207,6 +221,7 @@ func GenerateCasToken(userId string, service string) (string, error) {
}) })
} }
} }
st := fmt.Sprintf("ST-%d", rand.Int()) st := fmt.Sprintf("ST-%d", rand.Int())
stToServiceResponse.Store(st, &CasAuthenticationSuccessWrapper{ stToServiceResponse.Store(st, &CasAuthenticationSuccessWrapper{
AuthenticationSuccess: &authenticationSuccess, AuthenticationSuccess: &authenticationSuccess,
@ -214,9 +229,6 @@ func GenerateCasToken(userId string, service string) (string, error) {
UserId: userId, UserId: userId,
}) })
return st, nil return st, nil
} else {
return "", fmt.Errorf("invalid user Id")
}
} }
// GetValidationBySaml // GetValidationBySaml