1 package jwt
2
3 import (
4 "context"
5 "sync"
6 "testing"
7 "time"
8
9 "crypto/subtle"
10
11 "github.com/go-kit/kit/endpoint"
12 "github.com/golang-jwt/jwt/v4"
13 )
14
15 type customClaims struct {
16 MyProperty string `json:"my_property"`
17 jwt.StandardClaims
18 }
19
20 func (c customClaims) VerifyMyProperty(p string) bool {
21 return subtle.ConstantTimeCompare([]byte(c.MyProperty), []byte(p)) != 0
22 }
23
24 var (
25 kid = "kid"
26 key = []byte("test_signing_key")
27 myProperty = "some value"
28 method = jwt.SigningMethodHS256
29 invalidMethod = jwt.SigningMethodRS256
30 mapClaims = jwt.MapClaims{"user": "go-kit"}
31 standardClaims = jwt.StandardClaims{Audience: "go-kit"}
32 myCustomClaims = customClaims{MyProperty: myProperty, StandardClaims: standardClaims}
33
34 signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
35 standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY"
36 customSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJteV9wcm9wZXJ0eSI6InNvbWUgdmFsdWUiLCJhdWQiOiJnby1raXQifQ.s8F-IDrV4WPJUsqr7qfDi-3GRlcKR0SRnkTeUT_U-i0"
37 invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
38 malformedKey = "malformed.jwt.token"
39 )
40
41 func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) {
42 ctx, err := signer(context.Background(), struct{}{})
43 if err != nil {
44 t.Fatalf("Signer returned error: %s", err)
45 }
46
47 token, ok := ctx.(context.Context).Value(JWTContextKey).(string)
48 if !ok {
49 t.Fatal("Token did not exist in context")
50 }
51
52 if token != expectedKey {
53 t.Fatalf("JWTs did not match: expecting %s got %s", expectedKey, token)
54 }
55 }
56
57 func TestNewSigner(t *testing.T) {
58 e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
59
60 signer := NewSigner(kid, key, method, mapClaims)(e)
61 signingValidator(t, signer, signedKey)
62
63 signer = NewSigner(kid, key, method, standardClaims)(e)
64 signingValidator(t, signer, standardSignedKey)
65
66 signer = NewSigner(kid, key, method, myCustomClaims)(e)
67 signingValidator(t, signer, customSignedKey)
68 }
69
70 func TestJWTParser(t *testing.T) {
71 e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
72
73 keys := func(token *jwt.Token) (interface{}, error) {
74 return key, nil
75 }
76
77 parser := NewParser(keys, method, MapClaimsFactory)(e)
78
79
80 _, err := parser(context.Background(), struct{}{})
81 if err == nil {
82 t.Error("Parser should have returned an error")
83 }
84
85 if err != ErrTokenContextMissing {
86 t.Errorf("unexpected error returned, expected: %s got: %s", ErrTokenContextMissing, err)
87 }
88
89
90 ctx := context.WithValue(context.Background(), JWTContextKey, invalidKey)
91 _, err = parser(ctx, struct{}{})
92 if err == nil {
93 t.Error("Parser should have returned an error")
94 }
95
96
97 badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
98 ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
99 _, err = badParser(ctx, struct{}{})
100 if err == nil {
101 t.Error("Parser should have returned an error")
102 }
103
104 if err != ErrUnexpectedSigningMethod {
105 t.Errorf("unexpected error returned, expected: %s got: %s", ErrUnexpectedSigningMethod, err)
106 }
107
108
109 invalidKeys := func(token *jwt.Token) (interface{}, error) {
110 return []byte("bad"), nil
111 }
112
113 badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
114 ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
115 _, err = badParser(ctx, struct{}{})
116 if err == nil {
117 t.Error("Parser should have returned an error")
118 }
119
120
121 ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
122 ctx1, err := parser(ctx, struct{}{})
123 if err != nil {
124 t.Fatalf("Parser returned error: %s", err)
125 }
126
127 cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims)
128 if !ok {
129 t.Fatal("Claims were not passed into context correctly")
130 }
131
132 if cl["user"] != mapClaims["user"] {
133 t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"])
134 }
135
136
137 parser = NewParser(keys, method, StandardClaimsFactory)(e)
138 ctx = context.WithValue(context.Background(), JWTContextKey, malformedKey)
139 ctx1, err = parser(ctx, struct{}{})
140 if want, have := ErrTokenMalformed, err; want != have {
141 t.Fatalf("Expected %+v, got %+v", want, have)
142 }
143
144
145 parser = NewParser(keys, method, StandardClaimsFactory)(e)
146 expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100})
147 token, err := expired.SignedString(key)
148 if err != nil {
149 t.Fatalf("Unable to Sign Token: %+v", err)
150 }
151 ctx = context.WithValue(context.Background(), JWTContextKey, token)
152 ctx1, err = parser(ctx, struct{}{})
153 if want, have := ErrTokenExpired, err; want != have {
154 t.Fatalf("Expected %+v, got %+v", want, have)
155 }
156
157
158 parser = NewParser(keys, method, StandardClaimsFactory)(e)
159 notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100})
160 token, err = notactive.SignedString(key)
161 if err != nil {
162 t.Fatalf("Unable to Sign Token: %+v", err)
163 }
164 ctx = context.WithValue(context.Background(), JWTContextKey, token)
165 ctx1, err = parser(ctx, struct{}{})
166 if want, have := ErrTokenNotActive, err; want != have {
167 t.Fatalf("Expected %+v, got %+v", want, have)
168 }
169
170
171 parser = NewParser(keys, method, StandardClaimsFactory)(e)
172 ctx = context.WithValue(context.Background(), JWTContextKey, standardSignedKey)
173 ctx1, err = parser(ctx, struct{}{})
174 if err != nil {
175 t.Fatalf("Parser returned error: %s", err)
176 }
177 stdCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*jwt.StandardClaims)
178 if !ok {
179 t.Fatal("Claims were not passed into context correctly")
180 }
181 if !stdCl.VerifyAudience("go-kit", true) {
182 t.Fatalf("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience)
183 }
184
185
186 parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
187 ctx = context.WithValue(context.Background(), JWTContextKey, customSignedKey)
188 ctx1, err = parser(ctx, struct{}{})
189 if err != nil {
190 t.Fatalf("Parser returned error: %s", err)
191 }
192 custCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*customClaims)
193 if !ok {
194 t.Fatal("Claims were not passed into context correctly")
195 }
196 if !custCl.VerifyAudience("go-kit", true) {
197 t.Fatalf("JWT customClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, custCl.Audience)
198 }
199 if !custCl.VerifyMyProperty(myProperty) {
200 t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty)
201 }
202 }
203
204 func TestIssue562(t *testing.T) {
205 var (
206 kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
207 e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
208 key = JWTContextKey
209 val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
210 ctx = context.WithValue(context.Background(), key, val)
211 )
212 wg := sync.WaitGroup{}
213 for i := 0; i < 100; i++ {
214 wg.Add(1)
215 go func() {
216 defer wg.Done()
217 e(ctx, struct{}{})
218 }()
219 }
220 wg.Wait()
221 }
222
View as plain text