1 package request
2
3 import (
4 "fmt"
5 "net/http"
6 "net/url"
7 "reflect"
8 "strings"
9 "testing"
10
11 "github.com/golang-jwt/jwt"
12 "github.com/golang-jwt/jwt/test"
13 )
14
15 var requestTestData = []struct {
16 name string
17 claims jwt.MapClaims
18 extractor Extractor
19 headers map[string]string
20 query url.Values
21 valid bool
22 }{
23 {
24 "authorization bearer token",
25 jwt.MapClaims{"foo": "bar"},
26 AuthorizationHeaderExtractor,
27 map[string]string{"Authorization": "Bearer %v"},
28 url.Values{},
29 true,
30 },
31 {
32 "oauth bearer token - header",
33 jwt.MapClaims{"foo": "bar"},
34 OAuth2Extractor,
35 map[string]string{"Authorization": "Bearer %v"},
36 url.Values{},
37 true,
38 },
39 {
40 "oauth bearer token - url",
41 jwt.MapClaims{"foo": "bar"},
42 OAuth2Extractor,
43 map[string]string{},
44 url.Values{"access_token": {"%v"}},
45 true,
46 },
47 {
48 "url token",
49 jwt.MapClaims{"foo": "bar"},
50 ArgumentExtractor{"token"},
51 map[string]string{},
52 url.Values{"token": {"%v"}},
53 true,
54 },
55 }
56
57 func TestParseRequest(t *testing.T) {
58
59 privateKey := test.LoadRSAPrivateKeyFromDisk("../test/sample_key")
60 publicKey := test.LoadRSAPublicKeyFromDisk("../test/sample_key.pub")
61 keyfunc := func(*jwt.Token) (interface{}, error) {
62 return publicKey, nil
63 }
64
65
66 for _, data := range requestTestData {
67
68 tokenString := test.MakeSampleToken(data.claims, privateKey)
69
70
71 for k, vv := range data.query {
72 for i, v := range vv {
73 if strings.Contains(v, "%v") {
74 data.query[k][i] = fmt.Sprintf(v, tokenString)
75 }
76 }
77 }
78
79
80 r, _ := http.NewRequest("GET", fmt.Sprintf("/?%v", data.query.Encode()), nil)
81 for k, v := range data.headers {
82 if strings.Contains(v, "%v") {
83 r.Header.Set(k, fmt.Sprintf(v, tokenString))
84 } else {
85 r.Header.Set(k, tokenString)
86 }
87 }
88 token, err := ParseFromRequestWithClaims(r, data.extractor, jwt.MapClaims{}, keyfunc)
89
90 if token == nil {
91 t.Errorf("[%v] Token was not found: %v", data.name, err)
92 continue
93 }
94 if !reflect.DeepEqual(data.claims, token.Claims) {
95 t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims)
96 }
97 if data.valid && err != nil {
98 t.Errorf("[%v] Error while verifying token: %v", data.name, err)
99 }
100 if !data.valid && err == nil {
101 t.Errorf("[%v] Invalid token passed validation", data.name)
102 }
103 }
104 }
105
View as plain text