...
1 package jwt
2
3 import (
4 "context"
5 "errors"
6
7 "github.com/go-kit/kit/endpoint"
8 "github.com/golang-jwt/jwt/v4"
9 )
10
11 type contextKey string
12
13 const (
14
15 JWTContextKey contextKey = "JWTToken"
16
17
18
19
20 JWTTokenContextKey = JWTContextKey
21
22
23
24 JWTClaimsContextKey contextKey = "JWTClaims"
25 )
26
27 var (
28
29
30 ErrTokenContextMissing = errors.New("token up for parsing was not passed through the context")
31
32
33 ErrTokenInvalid = errors.New("JWT was invalid")
34
35
36 ErrTokenExpired = errors.New("JWT is expired")
37
38
39 ErrTokenMalformed = errors.New("JWT is malformed")
40
41
42
43 ErrTokenNotActive = errors.New("token is not valid yet")
44
45
46
47 ErrUnexpectedSigningMethod = errors.New("unexpected signing method")
48 )
49
50
51
52
53
54 func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
55 return func(next endpoint.Endpoint) endpoint.Endpoint {
56 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
57 token := jwt.NewWithClaims(method, claims)
58 token.Header["kid"] = kid
59
60
61 tokenString, err := token.SignedString(key)
62 if err != nil {
63 return nil, err
64 }
65 ctx = context.WithValue(ctx, JWTContextKey, tokenString)
66
67 return next(ctx, request)
68 }
69 }
70 }
71
72
73
74 type ClaimsFactory func() jwt.Claims
75
76
77
78 func MapClaimsFactory() jwt.Claims {
79 return jwt.MapClaims{}
80 }
81
82
83
84 func StandardClaimsFactory() jwt.Claims {
85 return &jwt.StandardClaims{}
86 }
87
88
89
90
91
92 func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware {
93 return func(next endpoint.Endpoint) endpoint.Endpoint {
94 return func(ctx context.Context, request interface{}) (response interface{}, err error) {
95
96 tokenString, ok := ctx.Value(JWTContextKey).(string)
97 if !ok {
98 return nil, ErrTokenContextMissing
99 }
100
101
102
103
104
105
106
107 token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) {
108
109 if token.Method != method {
110 return nil, ErrUnexpectedSigningMethod
111 }
112
113 return keyFunc(token)
114 })
115 if err != nil {
116 if e, ok := err.(*jwt.ValidationError); ok {
117 switch {
118 case e.Errors&jwt.ValidationErrorMalformed != 0:
119
120 return nil, ErrTokenMalformed
121 case e.Errors&jwt.ValidationErrorExpired != 0:
122
123 return nil, ErrTokenExpired
124 case e.Errors&jwt.ValidationErrorNotValidYet != 0:
125
126 return nil, ErrTokenNotActive
127 case e.Inner != nil:
128
129 return nil, e.Inner
130 }
131
132
133 }
134 return nil, err
135 }
136
137 if !token.Valid {
138 return nil, ErrTokenInvalid
139 }
140
141 ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims)
142
143 return next(ctx, request)
144 }
145 }
146 }
147
View as plain text