1
16
17 package jwtverifier
18
19 import (
20 "encoding/base64"
21 "encoding/json"
22 "fmt"
23 "net/http"
24 "regexp"
25 "strings"
26 "time"
27
28 "github.com/okta/okta-jwt-verifier-golang/adaptors"
29 "github.com/okta/okta-jwt-verifier-golang/adaptors/lestrratGoJwx"
30 "github.com/okta/okta-jwt-verifier-golang/discovery"
31 "github.com/okta/okta-jwt-verifier-golang/discovery/oidc"
32 "github.com/okta/okta-jwt-verifier-golang/errors"
33 "github.com/okta/okta-jwt-verifier-golang/utils"
34 )
35
36 var (
37 regx = regexp.MustCompile(`[a-zA-Z0-9-_]+\.[a-zA-Z0-9-_]+\.?([a-zA-Z0-9-_]+)[/a-zA-Z0-9-_]+?$`)
38 )
39
40 type JwtVerifier struct {
41 Issuer string
42
43 ClaimsToValidate map[string]string
44
45 Discovery discovery.Discovery
46
47 Adaptor adaptors.Adaptor
48
49
50 Cache func(func(string) (interface{}, error)) (utils.Cacher, error)
51
52 metadataCache utils.Cacher
53
54 leeway int64
55 }
56
57 type Jwt struct {
58 Claims map[string]interface{}
59 }
60
61 func fetchMetaData(url string) (interface{}, error) {
62 resp, err := http.Get(url)
63 if err != nil {
64 return nil, fmt.Errorf("request for metadata was not successful: %w", err)
65 }
66 defer resp.Body.Close()
67
68 metadata := make(map[string]interface{})
69 if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
70 return nil, err
71 }
72 return metadata, nil
73 }
74
75 func (j *JwtVerifier) New() *JwtVerifier {
76
77 if j.Discovery == nil {
78 disc := oidc.Oidc{}
79 j.Discovery = disc.New()
80 }
81
82 if j.Cache == nil {
83 j.Cache = utils.NewDefaultCache
84 }
85
86
87 if j.Adaptor == nil {
88 adaptor := &lestrratGoJwx.LestrratGoJwx{Cache: j.Cache}
89 j.Adaptor = adaptor.New()
90 }
91
92
93 j.leeway = 120
94
95 return j
96 }
97
98 func (j *JwtVerifier) SetLeeway(duration string) {
99 dur, _ := time.ParseDuration(duration)
100 j.leeway = int64(dur.Seconds())
101 }
102
103 func (j *JwtVerifier) VerifyAccessToken(jwt string) (*Jwt, error) {
104 validJwt, err := j.isValidJwt(jwt)
105 if !validJwt {
106 return nil, fmt.Errorf("token is not valid: %w", err)
107 }
108
109 resp, err := j.decodeJwt(jwt)
110 if err != nil {
111 return nil, err
112 }
113
114 token := resp.(map[string]interface{})
115
116 myJwt := Jwt{
117 Claims: token,
118 }
119
120 err = j.validateIss(token["iss"])
121 if err != nil {
122 return &myJwt, fmt.Errorf("the `Issuer` was not able to be validated. %w", err)
123 }
124
125 err = j.validateAudience(token["aud"])
126 if err != nil {
127 return &myJwt, fmt.Errorf("the `Audience` was not able to be validated. %w", err)
128 }
129
130 err = j.validateClientId(token["cid"])
131 if err != nil {
132 return &myJwt, fmt.Errorf("the `Client Id` was not able to be validated. %w", err)
133 }
134
135 err = j.validateExp(token["exp"])
136 if err != nil {
137 return &myJwt, fmt.Errorf("the `Expiration` was not able to be validated. %w", err)
138 }
139
140 err = j.validateIat(token["iat"])
141 if err != nil {
142 return &myJwt, fmt.Errorf("the `Issued At` was not able to be validated. %w", err)
143 }
144
145 return &myJwt, nil
146 }
147
148 func (j *JwtVerifier) decodeJwt(jwt string) (interface{}, error) {
149 metaData, err := j.getMetaData()
150 if err != nil {
151 return nil, err
152 }
153 jwksURI, ok := metaData["jwks_uri"].(string)
154 if !ok {
155 return nil, fmt.Errorf("failed to decode JWT: missing 'jwks_uri' from metadata")
156 }
157 resp, err := j.Adaptor.Decode(jwt, jwksURI)
158 if err != nil {
159 return nil, fmt.Errorf("could not decode token: %w", err)
160 }
161
162 return resp, nil
163 }
164
165 func (j *JwtVerifier) VerifyIdToken(jwt string) (*Jwt, error) {
166 validJwt, err := j.isValidJwt(jwt)
167 if !validJwt {
168 return nil, fmt.Errorf("token is not valid: %w", err)
169 }
170
171 resp, err := j.decodeJwt(jwt)
172 if err != nil {
173 return nil, err
174 }
175
176 token := resp.(map[string]interface{})
177
178 myJwt := Jwt{
179 Claims: token,
180 }
181
182 err = j.validateIss(token["iss"])
183 if err != nil {
184 return &myJwt, fmt.Errorf("the `Issuer` was not able to be validated. %w", err)
185 }
186
187 err = j.validateAudience(token["aud"])
188 if err != nil {
189 return &myJwt, fmt.Errorf("the `Audience` was not able to be validated. %w", err)
190 }
191
192 err = j.validateExp(token["exp"])
193 if err != nil {
194 return &myJwt, fmt.Errorf("the `Expiration` was not able to be validated. %w", err)
195 }
196
197 err = j.validateIat(token["iat"])
198 if err != nil {
199 return &myJwt, fmt.Errorf("the `Issued At` was not able to be validated. %w", err)
200 }
201
202 err = j.validateNonce(token["nonce"])
203 if err != nil {
204 return &myJwt, fmt.Errorf("the `Nonce` was not able to be validated. %w", err)
205 }
206
207 return &myJwt, nil
208 }
209
210 func (j *JwtVerifier) GetDiscovery() discovery.Discovery {
211 return j.Discovery
212 }
213
214 func (j *JwtVerifier) GetAdaptor() adaptors.Adaptor {
215 return j.Adaptor
216 }
217
218 func (j *JwtVerifier) validateNonce(nonce interface{}) error {
219 if nonce == nil {
220 nonce = ""
221 }
222
223 if nonce != j.ClaimsToValidate["nonce"] {
224 return fmt.Errorf("nonce: %s does not match %s", nonce, j.ClaimsToValidate["nonce"])
225 }
226 return nil
227 }
228
229 func (j *JwtVerifier) validateAudience(audience interface{}) error {
230 switch v := audience.(type) {
231 case string:
232 if v != j.ClaimsToValidate["aud"] {
233 return fmt.Errorf("aud: %s does not match %s", v, j.ClaimsToValidate["aud"])
234 }
235 case []string:
236 for _, element := range v {
237 if element == j.ClaimsToValidate["aud"] {
238 return nil
239 }
240 }
241 return fmt.Errorf("aud: %s does not match %s", v, j.ClaimsToValidate["aud"])
242 case []interface{}:
243 for _, e := range v {
244 element, ok := e.(string)
245 if !ok {
246 return fmt.Errorf("unknown type for audience validation")
247 }
248 if element == j.ClaimsToValidate["aud"] {
249 return nil
250 }
251 }
252 return fmt.Errorf("aud: %s does not match %s", v, j.ClaimsToValidate["aud"])
253 default:
254 return fmt.Errorf("unknown type for audience validation")
255 }
256
257 return nil
258 }
259
260 func (j *JwtVerifier) validateClientId(clientId interface{}) error {
261
262 if cid, exists := j.ClaimsToValidate["cid"]; exists && clientId != cid {
263 switch v := clientId.(type) {
264 case string:
265 if v != cid {
266 return fmt.Errorf("aud: %s does not match %s", v, cid)
267 }
268 case []string:
269 for _, element := range v {
270 if element == cid {
271 return nil
272 }
273 }
274 return fmt.Errorf("aud: %s does not match %s", v, cid)
275 default:
276 return fmt.Errorf("unknown type for clientId validation")
277 }
278 }
279 return nil
280 }
281
282 func (j *JwtVerifier) validateExp(exp interface{}) error {
283 expf, ok := exp.(float64)
284 if !ok {
285 return fmt.Errorf("exp: missing")
286 }
287 if float64(time.Now().Unix()-j.leeway) > expf {
288 return fmt.Errorf("the token is expired")
289 }
290 return nil
291 }
292
293 func (j *JwtVerifier) validateIat(iat interface{}) error {
294 iatf, ok := iat.(float64)
295 if !ok {
296 return fmt.Errorf("iat: missing")
297 }
298 if float64(time.Now().Unix()+j.leeway) < iatf {
299 return fmt.Errorf("the token was issued in the future")
300 }
301 return nil
302 }
303
304 func (j *JwtVerifier) validateIss(issuer interface{}) error {
305 if issuer != j.Issuer {
306 return fmt.Errorf("iss: %s does not match %s", issuer, j.Issuer)
307 }
308 return nil
309 }
310
311 func (j *JwtVerifier) getMetaData() (map[string]interface{}, error) {
312 metaDataUrl := j.Issuer + j.Discovery.GetWellKnownUrl()
313
314 if j.metadataCache == nil {
315 metadataCache, err := j.Cache(fetchMetaData)
316 if err != nil {
317 return nil, err
318 }
319 j.metadataCache = metadataCache
320 }
321
322 value, err := j.metadataCache.Get(metaDataUrl)
323 if err != nil {
324 return nil, err
325 }
326
327 metadata, ok := value.(map[string]interface{})
328 if !ok {
329 return nil, fmt.Errorf("unable to cast %v to metadata", value)
330 }
331 return metadata, nil
332 }
333
334 func (j *JwtVerifier) isValidJwt(jwt string) (bool, error) {
335 if jwt == "" {
336 return false, errors.JwtEmptyStringError()
337 }
338
339
340 jwtRegex := regx.MatchString
341 if !jwtRegex(jwt) {
342 return false, fmt.Errorf("token must contain at least 1 period ('.') and only characters 'a-Z 0-9 _'")
343 }
344
345 parts := strings.Split(jwt, ".")
346 header := parts[0]
347 header = padHeader(header)
348 headerDecoded, err := base64.StdEncoding.DecodeString(header)
349 if err != nil {
350 return false, fmt.Errorf("the tokens header does not appear to be a base64 encoded string")
351 }
352
353 var jsonObject map[string]interface{}
354 isHeaderJson := json.Unmarshal([]byte(headerDecoded), &jsonObject) == nil
355 if !isHeaderJson {
356 return false, fmt.Errorf("the tokens header is not a json object")
357 }
358
359 _, algExists := jsonObject["alg"]
360 _, kidExists := jsonObject["kid"]
361
362 if !algExists {
363 return false, fmt.Errorf("the tokens header must contain an 'alg'")
364 }
365
366 if !kidExists {
367 return false, fmt.Errorf("the tokens header must contain a 'kid'")
368 }
369
370 if jsonObject["alg"] != "RS256" {
371 return false, fmt.Errorf("the only supported alg is RS256")
372 }
373
374 return true, nil
375 }
376
377 func padHeader(header string) string {
378 if i := len(header) % 4; i != 0 {
379 header += strings.Repeat("=", 4-i)
380 }
381 return header
382 }
383
View as plain text