1
21
22 package jwt
23
24 import (
25 "context"
26 "fmt"
27 "strings"
28 "testing"
29 "time"
30
31 "github.com/stretchr/testify/assert"
32 "github.com/stretchr/testify/require"
33 )
34
35 var header = &Headers{
36 Extra: map[string]interface{}{
37 "foo": "bar",
38 },
39 }
40
41 func TestHash(t *testing.T) {
42 for k, tc := range []struct {
43 d string
44 strategy JWTStrategy
45 }{
46 {
47 d: "RS256JWTStrategy",
48 strategy: &RS256JWTStrategy{
49 PrivateKey: MustRSAKey(),
50 },
51 },
52 {
53 d: "ES256JWTStrategy",
54 strategy: &ES256JWTStrategy{
55 PrivateKey: MustECDSAKey(),
56 },
57 },
58 } {
59 t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) {
60 in := []byte("foo")
61 out, err := tc.strategy.Hash(context.TODO(), in)
62 assert.NoError(t, err)
63 assert.NotEqual(t, in, out)
64 })
65 }
66 }
67
68 func TestAssign(t *testing.T) {
69 for k, c := range [][]map[string]interface{}{
70 {
71 {"foo": "bar"},
72 {"baz": "bar"},
73 {"foo": "bar", "baz": "bar"},
74 },
75 {
76 {"foo": "bar"},
77 {"foo": "baz"},
78 {"foo": "bar"},
79 },
80 {
81 {},
82 {"foo": "baz"},
83 {"foo": "baz"},
84 },
85 {
86 {"foo": "bar"},
87 {"foo": "baz", "bar": "baz"},
88 {"foo": "bar", "bar": "baz"},
89 },
90 } {
91 assert.EqualValues(t, c[2], assign(c[0], c[1]), "Case %d", k)
92 }
93 }
94
95 func TestGenerateJWT(t *testing.T) {
96 for k, tc := range []struct {
97 d string
98 strategy JWTStrategy
99 resetKey func(strategy JWTStrategy)
100 }{
101 {
102 d: "RS256JWTStrategy",
103 strategy: &RS256JWTStrategy{
104 PrivateKey: MustRSAKey(),
105 },
106 resetKey: func(strategy JWTStrategy) {
107 strategy.(*RS256JWTStrategy).PrivateKey = MustRSAKey()
108 },
109 },
110 {
111 d: "ES256JWTStrategy",
112 strategy: &ES256JWTStrategy{
113 PrivateKey: MustECDSAKey(),
114 },
115 resetKey: func(strategy JWTStrategy) {
116 strategy.(*ES256JWTStrategy).PrivateKey = MustECDSAKey()
117 },
118 },
119 } {
120 t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) {
121 claims := &JWTClaims{
122 ExpiresAt: time.Now().UTC().Add(time.Hour),
123 }
124
125 token, sig, err := tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header)
126 require.NoError(t, err)
127 require.NotNil(t, token)
128
129 sig, err = tc.strategy.Validate(context.TODO(), token)
130 require.NoError(t, err)
131
132 sig, err = tc.strategy.Validate(context.TODO(), token+"."+"0123456789")
133 require.Error(t, err)
134
135 partToken := strings.Split(token, ".")[2]
136
137 sig, err = tc.strategy.Validate(context.TODO(), partToken)
138 require.Error(t, err)
139
140
141 tc.resetKey(tc.strategy)
142
143
144 claims = &JWTClaims{
145 ExpiresAt: time.Now().UTC().Add(-time.Hour),
146 }
147 token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header)
148 require.NoError(t, err)
149 require.NotNil(t, token)
150
151 sig, err = tc.strategy.Validate(context.TODO(), token)
152 require.Error(t, err)
153
154
155 claims = &JWTClaims{
156 NotBefore: time.Now().UTC().Add(time.Hour),
157 }
158 token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header)
159 require.NoError(t, err)
160 require.NotNil(t, token)
161
162 sig, err = tc.strategy.Validate(context.TODO(), token)
163 require.Error(t, err)
164 require.Empty(t, sig, "%s", err)
165 })
166 }
167 }
168
169 func TestValidateSignatureRejectsJWT(t *testing.T) {
170 for k, tc := range []struct {
171 d string
172 strategy JWTStrategy
173 }{
174 {
175 d: "RS256JWTStrategy",
176 strategy: &RS256JWTStrategy{
177 PrivateKey: MustRSAKey(),
178 },
179 },
180 {
181 d: "ES256JWTStrategy",
182 strategy: &ES256JWTStrategy{
183 PrivateKey: MustECDSAKey(),
184 },
185 },
186 } {
187 t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) {
188 for k, c := range []string{
189 "",
190 " ",
191 "foo.bar",
192 "foo.",
193 ".foo",
194 } {
195 _, err := tc.strategy.Validate(context.TODO(), c)
196 assert.Error(t, err)
197 t.Logf("Passed test case %d", k)
198 }
199 })
200 }
201 }
202
View as plain text