1 package jwt_test
3 import (
4 "context"
5 "reflect"
6 "testing"
7 "time"
9 "github.com/lestrrat-go/jwx/internal/json"
11 "github.com/lestrrat-go/jwx/jwt"
12 "github.com/stretchr/testify/assert"
13 )
15 const (
16 tokenTime = 233431200
17 )
19 var zeroval reflect.Value
20 var expectedTokenTime = time.Unix(tokenTime, 0).UTC()
22 func TestHeader(t *testing.T) {
23 t.Parallel()
24 values := map[string]interface{}{
25 jwt.AudienceKey: []string{"developers", "secops", "tac"},
26 jwt.ExpirationKey: expectedTokenTime,
27 jwt.IssuedAtKey: expectedTokenTime,
28 jwt.IssuerKey: "http://www.example.com",
29 jwt.JwtIDKey: "e9bc097a-ce51-4036-9562-d2ade882db0d",
30 jwt.NotBeforeKey: expectedTokenTime,
31 jwt.SubjectKey: "unit test",
32 }
34 t.Run("Roundtrip", func(t *testing.T) {
35 t.Parallel()
36 h := jwt.New()
37 for k, v := range values {
38 if !assert.NoError(t, h.Set(k, v), `h.Set should succeed for key %#v`, k) {
39 return
40 }
41 got, ok := h.Get(k)
42 if !assert.True(t, ok, `h.Get should succeed for key %#v`, k) {
43 return
44 }
45 if !reflect.DeepEqual(v, got) {
46 t.Fatalf("Values do not match: (%v, %v)", v, got)
47 }
48 }
49 })
51 t.Run("RoundtripError", func(t *testing.T) {
52 t.Parallel()
53 type dummyStruct struct {
54 dummy1 int
55 dummy2 float64
56 }
57 dummy := &dummyStruct{1, 3.4}
59 values := map[string]interface{}{
60 jwt.AudienceKey: dummy,
61 jwt.ExpirationKey: dummy,
62 jwt.IssuedAtKey: dummy,
63 jwt.IssuerKey: dummy,
64 jwt.JwtIDKey: dummy,
65 jwt.NotBeforeKey: dummy,
66 jwt.SubjectKey: dummy,
67 }
69 h := jwt.New()
70 for k, v := range values {
71 err := h.Set(k, v)
72 if err == nil {
73 t.Fatalf("Setting %s value should have failed", k)
74 }
75 }
76 err := h.Set("default", dummy)
77 if err != nil {
78 t.Fatalf("Setting %s value failed", "default")
79 }
80 for k := range values {
81 _, ok := h.Get(k)
82 if ok {
83 t.Fatalf("Getting %s value should have failed", k)
84 }
85 }
86 _, ok := h.Get("default")
87 if !ok {
88 t.Fatal("Failed to get default value")
89 }
90 })
92 t.Run("GetError", func(t *testing.T) {
93 t.Parallel()
94 h := jwt.New()
95 issuer := h.Issuer()
96 if issuer != "" {
97 t.Fatalf("Get Issuer should return empty string")
98 }
99 jwtID := h.JwtID()
100 if jwtID != "" {
101 t.Fatalf("Get JWT Id should return empty string")
102 }
103 })
104 }
106 func TestTokenMarshal(t *testing.T) {
107 t.Parallel()
108 t1 := jwt.New()
109 err := t1.Set(jwt.JwtIDKey, "AbCdEfG")
110 if err != nil {
111 t.Fatalf("Failed to set JWT ID: %s", err.Error())
112 }
113 err = t1.Set(jwt.SubjectKey, "foobar@example.com")
114 if err != nil {
115 t.Fatalf("Failed to set Subject: %s", err.Error())
116 }
121 now := time.Unix(time.Now().Unix(), 0)
122 err = t1.Set(jwt.IssuedAtKey, now.Unix())
123 if err != nil {
124 t.Fatalf("Failed to set IssuedAt: %s", err.Error())
125 }
126 err = t1.Set(jwt.NotBeforeKey, now.Add(5*time.Second))
127 if err != nil {
128 t.Fatalf("Failed to set NotBefore: %s", err.Error())
129 }
130 err = t1.Set(jwt.ExpirationKey, now.Add(10*time.Second).Unix())
131 if err != nil {
132 t.Fatalf("Failed to set Expiration: %s", err.Error())
133 }
134 err = t1.Set(jwt.AudienceKey, []string{"devops", "secops", "tac"})
135 if err != nil {
136 t.Fatalf("Failed to set audience: %s", err.Error())
137 }
138 err = t1.Set("custom", "MyValue")
139 if err != nil {
140 t.Fatalf(`Failed to set private claim "custom": %s`, err.Error())
141 }
142 jsonbuf1, err := json.MarshalIndent(t1, "", " ")
143 if err != nil {
144 t.Fatalf("JSON Marshal failed: %s", err.Error())
145 }
147 t2 := jwt.New()
148 if !assert.NoError(t, json.Unmarshal(jsonbuf1, t2), `json.Unmarshal should succeed`) {
149 return
150 }
152 if !assert.Equal(t, t1, t2, "tokens should match") {
153 return
154 }
156 _, err = json.MarshalIndent(t2, "", " ")
157 if err != nil {
158 t.Fatalf("JSON marshal error: %s", err.Error())
159 }
160 }
162 func TestToken(t *testing.T) {
163 tok := jwt.New()
165 def := map[string]struct {
166 Value interface{}
167 Method string
168 }{
169 jwt.AudienceKey: {
170 Method: "Audience",
171 Value: []string{"developers", "secops", "tac"},
172 },
173 jwt.ExpirationKey: {
174 Method: "Expiration",
175 Value: expectedTokenTime,
176 },
177 jwt.IssuedAtKey: {
178 Method: "IssuedAt",
179 Value: expectedTokenTime,
180 },
181 jwt.IssuerKey: {
182 Method: "Issuer",
183 Value: "http://www.example.com",
184 },
185 jwt.JwtIDKey: {
186 Method: "JwtID",
187 Value: "e9bc097a-ce51-4036-9562-d2ade882db0d",
188 },
189 jwt.NotBeforeKey: {
190 Method: "NotBefore",
191 Value: expectedTokenTime,
192 },
193 jwt.SubjectKey: {
194 Method: "Subject",
195 Value: "unit test",
196 },
197 "myClaim": {
198 Value: "hello, world",
199 },
200 }
202 t.Run("Set", func(t *testing.T) {
203 for k, kdef := range def {
204 if !assert.NoError(t, tok.Set(k, kdef.Value), `tok.Set(%s) should succeed`, k) {
205 return
206 }
207 }
208 })
209 t.Run("Get", func(t *testing.T) {
210 rv := reflect.ValueOf(tok)
211 for k, kdef := range def {
212 getval, ok := tok.Get(k)
213 if !assert.True(t, ok, `tok.Get(%s) should succeed`, k) {
214 return
215 }
217 if mname := kdef.Method; mname != "" {
218 method := rv.MethodByName(mname)
219 if !assert.NotEqual(t, zeroval, method, `method %s should not be zero value`, mname) {
220 return
221 }
223 retvals := method.Call(nil)
224 if !assert.Len(t, retvals, 1, `should have exactly one return value`) {
225 return
226 }
228 if !assert.Equal(t, getval, retvals[0].Interface(), `values should match`) {
229 return
230 }
231 }
232 }
233 })
234 t.Run("Roundtrip", func(t *testing.T) {
235 buf, err := json.Marshal(tok)
236 if !assert.NoError(t, err, `json.Marshal should succeed`) {
237 return
238 }
240 newtok, err := jwt.Parse(buf)
241 if !assert.NoError(t, err, `jwt.Parse should succeed`) {
242 return
243 }
245 m1, err := tok.AsMap(context.TODO())
246 if !assert.NoError(t, err, `tok.AsMap should succeed`) {
247 return
248 }
250 m2, err := newtok.AsMap(context.TODO())
251 if !assert.NoError(t, err, `tok.AsMap should succeed`) {
252 return
253 }
255 if !assert.Equal(t, m1, m2, `tokens should match`) {
256 return
257 }
258 })
259 t.Run("Set/Remove", func(t *testing.T) {
260 ctx := context.TODO()
262 newtok, err := tok.Clone()
263 if !assert.NoError(t, err, `tok.Clone should succeed`) {
264 return
265 }
267 for iter := tok.Iterate(ctx); iter.Next(ctx); {
268 pair := iter.Pair()
269 newtok.Remove(pair.Key.(string))
270 }
272 m, err := newtok.AsMap(ctx)
273 if !assert.NoError(t, err, `tok.AsMap should succeed`) {
274 return
275 }
277 if !assert.Len(t, m, 0, `toks should have 0 tok`) {
278 return
279 }
281 for iter := tok.Iterate(ctx); iter.Next(ctx); {
282 pair := iter.Pair()
283 if !assert.NoError(t, newtok.Set(pair.Key.(string), pair.Value), `newtok.Set should succeed`) {
284 return
285 }
286 }
287 })
288 }
View as plain text