...
1 package keyfunc
2
3 import (
4 "encoding/base64"
5 "errors"
6 "fmt"
7 "strings"
8
9 "github.com/golang-jwt/jwt/v5"
10 )
11
12 var (
13
14 ErrKID = errors.New("the JWT has an invalid kid")
15 )
16
17
18 func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) {
19 kid, alg, err := kidAlg(token)
20 if err != nil {
21 return nil, err
22 }
23 return j.getKey(alg, kid)
24 }
25
26
27 func (m *MultipleJWKS) Keyfunc(token *jwt.Token) (interface{}, error) {
28 return m.keySelector(m, token)
29 }
30
31 func kidAlg(token *jwt.Token) (kid, alg string, err error) {
32 kidInter, ok := token.Header["kid"]
33 if !ok {
34 return "", "", fmt.Errorf("%w: could not find kid in JWT header", ErrKID)
35 }
36 kid, ok = kidInter.(string)
37 if !ok {
38 return "", "", fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID)
39 }
40 alg, ok = token.Header["alg"].(string)
41 if !ok {
42
43
44 return "", "", fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch)
45 }
46 return kid, alg, nil
47 }
48
49
50
51
52
53
54
55
56 func base64urlTrailingPadding(s string) ([]byte, error) {
57 s = strings.TrimRight(s, "=")
58 return base64.RawURLEncoding.DecodeString(s)
59 }
60
View as plain text