1
21
22 package fosite_test
23
24 import (
25 "context"
26 "fmt"
27 "net/http"
28 "net/http/httptest"
29 "net/url"
30 "testing"
31
32 "github.com/golang/mock/gomock"
33 "github.com/pkg/errors"
34 "github.com/stretchr/testify/assert"
35
36 . "github.com/ory/fosite"
37 "github.com/ory/fosite/internal"
38 )
39
40 func TestNewRevocationRequest(t *testing.T) {
41 ctrl := gomock.NewController(t)
42 store := internal.NewMockStorage(ctrl)
43 handler := internal.NewMockRevocationHandler(ctrl)
44 hasher := internal.NewMockHasher(ctrl)
45 defer ctrl.Finish()
46
47 ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))
48
49 client := &DefaultClient{}
50 fosite := &Fosite{Store: store, Hasher: hasher}
51 for k, c := range []struct {
52 header http.Header
53 form url.Values
54 mock func()
55 method string
56 expectErr error
57 expect *AccessRequest
58 handlers RevocationHandlers
59 }{
60 {
61 header: http.Header{},
62 expectErr: ErrInvalidRequest,
63 method: "GET",
64 mock: func() {},
65 },
66 {
67 header: http.Header{},
68 expectErr: ErrInvalidRequest,
69 method: "POST",
70 mock: func() {},
71 },
72 {
73 header: http.Header{},
74 method: "POST",
75 form: url.Values{
76 "token": {"foo"},
77 },
78 mock: func() {},
79 expectErr: ErrInvalidRequest,
80 },
81 {
82 header: http.Header{
83 "Authorization": {basicAuth("foo", "bar")},
84 },
85 method: "POST",
86 form: url.Values{
87 "token": {"foo"},
88 },
89 expectErr: ErrInvalidClient,
90 mock: func() {
91 store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
92 },
93 },
94 {
95 header: http.Header{
96 "Authorization": {basicAuth("foo", "bar")},
97 },
98 method: "POST",
99 form: url.Values{
100 "token": {"foo"},
101 },
102 expectErr: ErrInvalidClient,
103 mock: func() {
104 store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
105 client.Secret = []byte("foo")
106 client.Public = false
107 hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
108 },
109 },
110 {
111 header: http.Header{
112 "Authorization": {basicAuth("foo", "bar")},
113 },
114 method: "POST",
115 form: url.Values{
116 "token": {"foo"},
117 },
118 expectErr: nil,
119 mock: func() {
120 store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
121 client.Secret = []byte("foo")
122 client.Public = false
123 hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
124 handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
125 },
126 handlers: RevocationHandlers{handler},
127 },
128 {
129 header: http.Header{
130 "Authorization": {basicAuth("foo", "bar")},
131 },
132 method: "POST",
133 form: url.Values{
134 "token": {"foo"},
135 "token_type_hint": {"access_token"},
136 },
137 expectErr: nil,
138 mock: func() {
139 store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
140 client.Secret = []byte("foo")
141 client.Public = false
142 hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
143 handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
144 },
145 handlers: RevocationHandlers{handler},
146 },
147 {
148 header: http.Header{
149 "Authorization": {basicAuth("foo", "")},
150 },
151 method: "POST",
152 form: url.Values{
153 "token": {"foo"},
154 "token_type_hint": {"refresh_token"},
155 },
156 expectErr: nil,
157 mock: func() {
158 store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
159 client.Public = true
160 handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
161 },
162 handlers: RevocationHandlers{handler},
163 },
164 {
165 header: http.Header{
166 "Authorization": {basicAuth("foo", "bar")},
167 },
168 method: "POST",
169 form: url.Values{
170 "token": {"foo"},
171 "token_type_hint": {"refresh_token"},
172 },
173 expectErr: nil,
174 mock: func() {
175 store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
176 client.Secret = []byte("foo")
177 client.Public = false
178 hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
179 handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
180 },
181 handlers: RevocationHandlers{handler},
182 },
183 {
184 header: http.Header{
185 "Authorization": {basicAuth("foo", "bar")},
186 },
187 method: "POST",
188 form: url.Values{
189 "token": {"foo"},
190 "token_type_hint": {"bar"},
191 },
192 expectErr: nil,
193 mock: func() {
194 store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
195 client.Secret = []byte("foo")
196 client.Public = false
197 hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
198 handler.EXPECT().RevokeToken(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
199 },
200 handlers: RevocationHandlers{handler},
201 },
202 } {
203 t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
204 r := &http.Request{
205 Header: c.header,
206 PostForm: c.form,
207 Form: c.form,
208 Method: c.method,
209 }
210 c.mock()
211 ctx := NewContext()
212 fosite.RevocationHandlers = c.handlers
213 err := fosite.NewRevocationRequest(ctx, r)
214
215 if c.expectErr != nil {
216 assert.EqualError(t, err, c.expectErr.Error())
217 } else {
218 assert.NoError(t, err)
219 }
220 })
221 }
222 }
223
224 func TestWriteRevocationResponse(t *testing.T) {
225 ctrl := gomock.NewController(t)
226 store := internal.NewMockStorage(ctrl)
227 hasher := internal.NewMockHasher(ctrl)
228 defer ctrl.Finish()
229
230 fosite := &Fosite{Store: store, Hasher: hasher}
231
232 type args struct {
233 rw *httptest.ResponseRecorder
234 err error
235 }
236 cases := []struct {
237 input args
238 expectCode int
239 }{
240 {
241 input: args{
242 rw: httptest.NewRecorder(),
243 err: ErrInvalidRequest,
244 },
245 expectCode: ErrInvalidRequest.CodeField,
246 },
247 {
248 input: args{
249 rw: httptest.NewRecorder(),
250 err: ErrInvalidClient,
251 },
252 expectCode: ErrInvalidClient.CodeField,
253 },
254 {
255 input: args{
256 rw: httptest.NewRecorder(),
257 err: nil,
258 },
259 expectCode: http.StatusOK,
260 },
261 }
262
263 for _, tc := range cases {
264 fosite.WriteRevocationResponse(tc.input.rw, tc.input.err)
265 assert.Equal(t, tc.expectCode, tc.input.rw.Code)
266 }
267 }
268
View as plain text