1 package jwt
2
3 import (
4 "errors"
5 "testing"
6 "time"
7 )
8
9 var ErrFooBar = errors.New("must be foobar")
10
11 type MyCustomClaims struct {
12 Foo string `json:"foo"`
13 RegisteredClaims
14 }
15
16 func (m MyCustomClaims) Validate() error {
17 if m.Foo != "bar" {
18 return ErrFooBar
19 }
20 return nil
21 }
22
23 func Test_Validator_Validate(t *testing.T) {
24 type fields struct {
25 leeway time.Duration
26 timeFunc func() time.Time
27 verifyIat bool
28 expectedAud string
29 expectedIss string
30 expectedSub string
31 }
32 type args struct {
33 claims Claims
34 }
35 tests := []struct {
36 name string
37 fields fields
38 args args
39 wantErr error
40 }{
41 {
42 name: "expected iss mismatch",
43 fields: fields{expectedIss: "me"},
44 args: args{RegisteredClaims{Issuer: "not_me"}},
45 wantErr: ErrTokenInvalidIssuer,
46 },
47 {
48 name: "expected iss is missing",
49 fields: fields{expectedIss: "me"},
50 args: args{RegisteredClaims{}},
51 wantErr: ErrTokenRequiredClaimMissing,
52 },
53 {
54 name: "expected sub mismatch",
55 fields: fields{expectedSub: "me"},
56 args: args{RegisteredClaims{Subject: "not-me"}},
57 wantErr: ErrTokenInvalidSubject,
58 },
59 {
60 name: "expected sub is missing",
61 fields: fields{expectedSub: "me"},
62 args: args{RegisteredClaims{}},
63 wantErr: ErrTokenRequiredClaimMissing,
64 },
65 {
66 name: "custom validator",
67 fields: fields{},
68 args: args{MyCustomClaims{Foo: "not-bar"}},
69 wantErr: ErrFooBar,
70 },
71 }
72 for _, tt := range tests {
73 t.Run(tt.name, func(t *testing.T) {
74 v := &Validator{
75 leeway: tt.fields.leeway,
76 timeFunc: tt.fields.timeFunc,
77 verifyIat: tt.fields.verifyIat,
78 expectedAud: tt.fields.expectedAud,
79 expectedIss: tt.fields.expectedIss,
80 expectedSub: tt.fields.expectedSub,
81 }
82 if err := v.Validate(tt.args.claims); (err != nil) && !errors.Is(err, tt.wantErr) {
83 t.Errorf("validator.Validate() error = %v, wantErr %v", err, tt.wantErr)
84 }
85 })
86 }
87 }
88
89 func Test_Validator_verifyExpiresAt(t *testing.T) {
90 type fields struct {
91 leeway time.Duration
92 timeFunc func() time.Time
93 }
94 type args struct {
95 claims Claims
96 cmp time.Time
97 required bool
98 }
99 tests := []struct {
100 name string
101 fields fields
102 args args
103 wantErr error
104 }{
105 {
106 name: "good claim",
107 fields: fields{timeFunc: time.Now},
108 args: args{claims: RegisteredClaims{ExpiresAt: NewNumericDate(time.Now().Add(10 * time.Minute))}},
109 wantErr: nil,
110 },
111 {
112 name: "claims with invalid type",
113 fields: fields{},
114 args: args{claims: MapClaims{"exp": "string"}},
115 wantErr: ErrInvalidType,
116 },
117 }
118 for _, tt := range tests {
119 t.Run(tt.name, func(t *testing.T) {
120 v := &Validator{
121 leeway: tt.fields.leeway,
122 timeFunc: tt.fields.timeFunc,
123 }
124
125 err := v.verifyExpiresAt(tt.args.claims, tt.args.cmp, tt.args.required)
126 if (err != nil) && !errors.Is(err, tt.wantErr) {
127 t.Errorf("validator.verifyExpiresAt() error = %v, wantErr %v", err, tt.wantErr)
128 }
129 })
130 }
131 }
132
133 func Test_Validator_verifyIssuer(t *testing.T) {
134 type fields struct {
135 expectedIss string
136 }
137 type args struct {
138 claims Claims
139 cmp string
140 required bool
141 }
142 tests := []struct {
143 name string
144 fields fields
145 args args
146 wantErr error
147 }{
148 {
149 name: "good claim",
150 fields: fields{expectedIss: "me"},
151 args: args{claims: MapClaims{"iss": "me"}, cmp: "me"},
152 wantErr: nil,
153 },
154 {
155 name: "claims with invalid type",
156 fields: fields{expectedIss: "me"},
157 args: args{claims: MapClaims{"iss": 1}, cmp: "me"},
158 wantErr: ErrInvalidType,
159 },
160 }
161 for _, tt := range tests {
162 t.Run(tt.name, func(t *testing.T) {
163 v := &Validator{
164 expectedIss: tt.fields.expectedIss,
165 }
166 err := v.verifyIssuer(tt.args.claims, tt.args.cmp, tt.args.required)
167 if (err != nil) && !errors.Is(err, tt.wantErr) {
168 t.Errorf("validator.verifyIssuer() error = %v, wantErr %v", err, tt.wantErr)
169 }
170 })
171 }
172 }
173
174 func Test_Validator_verifySubject(t *testing.T) {
175 type fields struct {
176 expectedSub string
177 }
178 type args struct {
179 claims Claims
180 cmp string
181 required bool
182 }
183 tests := []struct {
184 name string
185 fields fields
186 args args
187 wantErr error
188 }{
189 {
190 name: "good claim",
191 fields: fields{expectedSub: "me"},
192 args: args{claims: MapClaims{"sub": "me"}, cmp: "me"},
193 wantErr: nil,
194 },
195 {
196 name: "claims with invalid type",
197 fields: fields{expectedSub: "me"},
198 args: args{claims: MapClaims{"sub": 1}, cmp: "me"},
199 wantErr: ErrInvalidType,
200 },
201 }
202 for _, tt := range tests {
203 t.Run(tt.name, func(t *testing.T) {
204 v := &Validator{
205 expectedSub: tt.fields.expectedSub,
206 }
207 err := v.verifySubject(tt.args.claims, tt.args.cmp, tt.args.required)
208 if (err != nil) && !errors.Is(err, tt.wantErr) {
209 t.Errorf("validator.verifySubject() error = %v, wantErr %v", err, tt.wantErr)
210 }
211 })
212 }
213 }
214
215 func Test_Validator_verifyIssuedAt(t *testing.T) {
216 type fields struct {
217 leeway time.Duration
218 timeFunc func() time.Time
219 verifyIat bool
220 }
221 type args struct {
222 claims Claims
223 cmp time.Time
224 required bool
225 }
226 tests := []struct {
227 name string
228 fields fields
229 args args
230 wantErr error
231 }{
232 {
233 name: "good claim without iat",
234 fields: fields{verifyIat: true},
235 args: args{claims: MapClaims{}, required: false},
236 wantErr: nil,
237 },
238 {
239 name: "good claim with iat",
240 fields: fields{verifyIat: true},
241 args: args{
242 claims: RegisteredClaims{IssuedAt: NewNumericDate(time.Now())},
243 cmp: time.Now().Add(10 * time.Minute),
244 required: false,
245 },
246 wantErr: nil,
247 },
248 }
249 for _, tt := range tests {
250 t.Run(tt.name, func(t *testing.T) {
251 v := &Validator{
252 leeway: tt.fields.leeway,
253 timeFunc: tt.fields.timeFunc,
254 verifyIat: tt.fields.verifyIat,
255 }
256 if err := v.verifyIssuedAt(tt.args.claims, tt.args.cmp, tt.args.required); (err != nil) && !errors.Is(err, tt.wantErr) {
257 t.Errorf("validator.verifyIssuedAt() error = %v, wantErr %v", err, tt.wantErr)
258 }
259 })
260 }
261 }
262
View as plain text