1 package request
2
3 import (
4 "fmt"
5 "net/http"
6 "net/url"
7 "testing"
8 )
9
10 var extractorTestTokenA = "A"
11 var extractorTestTokenB = "B"
12
13 var extractorTestData = []struct {
14 name string
15 extractor Extractor
16 headers map[string]string
17 query url.Values
18 token string
19 err error
20 }{
21 {
22 name: "simple header",
23 extractor: HeaderExtractor{"Foo"},
24 headers: map[string]string{"Foo": extractorTestTokenA},
25 query: nil,
26 token: extractorTestTokenA,
27 err: nil,
28 },
29 {
30 name: "simple argument",
31 extractor: ArgumentExtractor{"token"},
32 headers: map[string]string{},
33 query: url.Values{"token": {extractorTestTokenA}},
34 token: extractorTestTokenA,
35 err: nil,
36 },
37 {
38 name: "multiple extractors",
39 extractor: MultiExtractor{
40 HeaderExtractor{"Foo"},
41 ArgumentExtractor{"token"},
42 },
43 headers: map[string]string{"Foo": extractorTestTokenA},
44 query: url.Values{"token": {extractorTestTokenB}},
45 token: extractorTestTokenA,
46 err: nil,
47 },
48 {
49 name: "simple miss",
50 extractor: HeaderExtractor{"This-Header-Is-Not-Set"},
51 headers: map[string]string{"Foo": extractorTestTokenA},
52 query: nil,
53 token: "",
54 err: ErrNoTokenInRequest,
55 },
56 {
57 name: "filter",
58 extractor: AuthorizationHeaderExtractor,
59 headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenA},
60 query: nil,
61 token: extractorTestTokenA,
62 err: nil,
63 },
64 }
65
66 func TestExtractor(t *testing.T) {
67
68 for _, data := range extractorTestData {
69
70 r := makeExampleRequest("GET", "/", data.headers, data.query)
71
72
73 token, err := data.extractor.ExtractToken(r)
74 if token != data.token {
75 t.Errorf("[%v] Expected token '%v'. Got '%v'", data.name, data.token, token)
76 continue
77 }
78 if err != data.err {
79 t.Errorf("[%v] Expected error '%v'. Got '%v'", data.name, data.err, err)
80 continue
81 }
82 }
83 }
84
85 func makeExampleRequest(method, path string, headers map[string]string, urlArgs url.Values) *http.Request {
86 r, _ := http.NewRequest(method, fmt.Sprintf("%v?%v", path, urlArgs.Encode()), nil)
87 for k, v := range headers {
88 r.Header.Set(k, v)
89 }
90 return r
91 }
92
View as plain text