1
16
17 package oidc
18
19 import (
20 "encoding/base64"
21 "encoding/json"
22 "fmt"
23 "testing"
24 "time"
25 )
26
27 func TestJSONTime(t *testing.T) {
28 data := `{
29 "t1": 1493851263,
30 "t2": 1.493851263e9
31 }`
32
33 var v struct {
34 T1 jsonTime `json:"t1"`
35 T2 jsonTime `json:"t2"`
36 }
37 if err := json.Unmarshal([]byte(data), &v); err != nil {
38 t.Fatal(err)
39 }
40 wantT1 := time.Unix(1493851263, 0)
41 wantT2 := time.Unix(1493851263, 0)
42 gotT1 := time.Time(v.T1)
43 gotT2 := time.Time(v.T2)
44
45 if !wantT1.Equal(gotT1) {
46 t.Errorf("t1 value: wanted %s got %s", wantT1, gotT1)
47 }
48 if !wantT2.Equal(gotT2) {
49 t.Errorf("t2 value: wanted %s got %s", wantT2, gotT2)
50 }
51 }
52
53 func encodeJWT(header, payload, sig string) string {
54 e := func(s string) string {
55 return base64.RawURLEncoding.EncodeToString([]byte(s))
56 }
57 return e(header) + "." + e(payload) + "." + e(sig)
58 }
59
60 func TestExpired(t *testing.T) {
61 now := time.Now()
62
63 nowFunc := func() time.Time { return now }
64
65 tests := []struct {
66 name string
67 idToken string
68 wantErr bool
69 wantExpired bool
70 }{
71 {
72 name: "valid",
73 idToken: encodeJWT(
74 "{}",
75 fmt.Sprintf(`{"exp":%d}`, now.Add(time.Hour).Unix()),
76 "blah",
77 ),
78 },
79 {
80 name: "expired",
81 idToken: encodeJWT(
82 "{}",
83 fmt.Sprintf(`{"exp":%d}`, now.Add(-time.Hour).Unix()),
84 "blah",
85 ),
86 wantExpired: true,
87 },
88 {
89 name: "bad exp claim",
90 idToken: encodeJWT(
91 "{}",
92 `{"exp":"foobar"}`,
93 "blah",
94 ),
95 wantErr: true,
96 },
97 {
98 name: "not an id token",
99 idToken: "notanidtoken",
100 wantErr: true,
101 },
102 }
103 for _, test := range tests {
104 t.Run(test.name, func(t *testing.T) {
105 valid, err := idTokenExpired(nowFunc, test.idToken)
106 if err != nil {
107 if !test.wantErr {
108 t.Errorf("parse error: %v", err)
109 }
110 return
111 }
112 if test.wantExpired == valid {
113 t.Errorf("wanted expired %t, got %t", test.wantExpired, !valid)
114 }
115 })
116 }
117 }
118
119 func TestClientCache(t *testing.T) {
120 cache := newClientCache()
121
122 if _, ok := cache.getClient("cluster1", "issuer1", "id1"); ok {
123 t.Fatalf("got client before putting one in the cache")
124 }
125 assertCacheLen(t, cache, 0)
126
127 cli1 := new(oidcAuthProvider)
128 cli2 := new(oidcAuthProvider)
129 cli3 := new(oidcAuthProvider)
130
131 gotcli := cache.setClient("cluster1", "issuer1", "id1", cli1)
132 if cli1 != gotcli {
133 t.Fatalf("set first client and got a different one")
134 }
135 assertCacheLen(t, cache, 1)
136
137 gotcli = cache.setClient("cluster1", "issuer1", "id1", cli2)
138 if cli1 != gotcli {
139 t.Fatalf("set a second client and didn't get the first")
140 }
141 assertCacheLen(t, cache, 1)
142
143 gotcli = cache.setClient("cluster2", "issuer1", "id1", cli3)
144 if cli1 == gotcli {
145 t.Fatalf("set a third client and got the first")
146 }
147 if cli3 != gotcli {
148 t.Fatalf("set third client and got a different one")
149 }
150 assertCacheLen(t, cache, 2)
151 }
152
153 func assertCacheLen(t *testing.T, cache *clientCache, length int) {
154 t.Helper()
155 if len(cache.cache) != length {
156 t.Errorf("expected cache length %d got %d", length, len(cache.cache))
157 }
158 }
159
View as plain text