1 package jsonrpc_test
2
3 import (
4 "context"
5 "encoding/json"
6 "io"
7 "io/ioutil"
8 "net/http"
9 "net/http/httptest"
10 "net/url"
11 "testing"
12
13 "github.com/go-kit/kit/transport/http/jsonrpc"
14 )
15
16 type TestResponse struct {
17 Body io.ReadCloser
18 String string
19 }
20
21 type testServerResponseOptions struct {
22 Body string
23 Status int
24 }
25
26 func httptestServer(t *testing.T) *httptest.Server {
27 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28 defer r.Body.Close()
29
30 var testReq jsonrpc.Request
31 if err := json.NewDecoder(r.Body).Decode(&testReq); err != nil {
32 t.Fatal(err)
33 }
34
35 var options testServerResponseOptions
36 if err := json.Unmarshal(testReq.Params, &options); err != nil {
37 t.Fatal(err)
38 }
39
40 if options.Status == 0 {
41 options.Status = http.StatusOK
42 }
43
44 w.WriteHeader(options.Status)
45 w.Write([]byte(options.Body))
46 }))
47 }
48
49 func TestBeforeAfterFuncs(t *testing.T) {
50 t.Parallel()
51
52 var tests = []struct {
53 name string
54 status int
55 body string
56 }{
57 {
58 name: "empty body",
59 body: "",
60 },
61 {
62 name: "empty body 500",
63 body: "",
64 status: 500,
65 },
66
67 {
68 name: "empty json body",
69 body: "{}",
70 },
71 {
72 name: "error",
73 body: `{"jsonrpc":"2.0","error":{"code":32603,"message":"Bad thing happened."}}`,
74 },
75 }
76
77 server := httptestServer(t)
78 defer server.Close()
79
80 testUrl, err := url.Parse(server.URL)
81 if err != nil {
82 t.Fatal(err)
83 }
84
85 for _, tt := range tests {
86 t.Run(tt.name, func(t *testing.T) {
87 beforeCalled := false
88 afterCalled := false
89 finalizerCalled := false
90
91 sut := jsonrpc.NewClient(
92 testUrl,
93 "dummy",
94 jsonrpc.ClientBefore(func(ctx context.Context, req *http.Request) context.Context {
95 beforeCalled = true
96 return ctx
97 }),
98 jsonrpc.ClientAfter(func(ctx context.Context, resp *http.Response) context.Context {
99 afterCalled = true
100 return ctx
101 }),
102 jsonrpc.ClientFinalizer(func(ctx context.Context, err error) {
103 finalizerCalled = true
104 }),
105 )
106
107 sut.Endpoint()(context.TODO(), testServerResponseOptions{Body: tt.body, Status: tt.status})
108 if !beforeCalled {
109 t.Fatal("Expected client before func to be called. Wasn't.")
110 }
111 if !afterCalled {
112 t.Fatal("Expected client after func to be called. Wasn't.")
113 }
114 if !finalizerCalled {
115 t.Fatal("Expected client finalizer func to be called. Wasn't.")
116 }
117
118 })
119
120 }
121
122 }
123
124 type staticIDGenerator int
125
126 func (g staticIDGenerator) Generate() interface{} { return g }
127
128 func TestClientHappyPath(t *testing.T) {
129 t.Parallel()
130
131 var (
132 afterCalledKey = "AC"
133 beforeHeaderKey = "BF"
134 beforeHeaderValue = "beforeFuncWozEre"
135 testbody = `{"jsonrpc":"2.0", "result":5}`
136 requestBody []byte
137 beforeFunc = func(ctx context.Context, r *http.Request) context.Context {
138 r.Header.Add(beforeHeaderKey, beforeHeaderValue)
139 return ctx
140 }
141 encode = func(ctx context.Context, req interface{}) (json.RawMessage, error) {
142 return json.Marshal(req)
143 }
144 afterFunc = func(ctx context.Context, r *http.Response) context.Context {
145 return context.WithValue(ctx, afterCalledKey, true)
146 }
147 finalizerCalled = false
148 fin = func(ctx context.Context, err error) {
149 finalizerCalled = true
150 }
151 decode = func(ctx context.Context, res jsonrpc.Response) (interface{}, error) {
152 if ac := ctx.Value(afterCalledKey); ac == nil {
153 t.Fatal("after not called")
154 }
155 var result int
156 err := json.Unmarshal(res.Result, &result)
157 if err != nil {
158 return nil, err
159 }
160 return result, nil
161 }
162
163 wantID = 666
164 gen = staticIDGenerator(wantID)
165 )
166
167 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
168 if r.Header.Get(beforeHeaderKey) != beforeHeaderValue {
169 t.Fatal("Header not set by before func.")
170 }
171
172 b, err := ioutil.ReadAll(r.Body)
173 if err != nil && err != io.EOF {
174 t.Fatal(err)
175 }
176 requestBody = b
177
178 w.WriteHeader(http.StatusOK)
179 w.Write([]byte(testbody))
180 }))
181 defer server.Close()
182
183 sut := jsonrpc.NewClient(
184 mustParse(server.URL),
185 "add",
186 jsonrpc.ClientRequestEncoder(encode),
187 jsonrpc.ClientResponseDecoder(decode),
188 jsonrpc.ClientBefore(beforeFunc),
189 jsonrpc.ClientAfter(afterFunc),
190 jsonrpc.ClientRequestIDGenerator(gen),
191 jsonrpc.ClientFinalizer(fin),
192 jsonrpc.SetClient(http.DefaultClient),
193 jsonrpc.BufferedStream(false),
194 )
195
196 type addRequest struct {
197 A int
198 B int
199 }
200
201 in := addRequest{2, 2}
202
203 result, err := sut.Endpoint()(context.Background(), in)
204 if err != nil {
205 t.Fatal(err)
206 }
207 ri, ok := result.(int)
208 if !ok {
209 t.Fatalf("result is not int: (%T)%+v", result, result)
210 }
211 if ri != 5 {
212 t.Fatalf("want=5, got=%d", ri)
213 }
214
215 var requestAtServer jsonrpc.Request
216 err = json.Unmarshal(requestBody, &requestAtServer)
217 if err != nil {
218 t.Fatal(err)
219 }
220 if id, _ := requestAtServer.ID.Int(); id != wantID {
221 t.Fatalf("Request ID at server: want=%d, got=%d", wantID, id)
222 }
223 if requestAtServer.JSONRPC != jsonrpc.Version {
224 t.Fatalf("JSON-RPC version at server: want=%s, got=%s", jsonrpc.Version, requestAtServer.JSONRPC)
225 }
226
227 var paramsAtServer addRequest
228 err = json.Unmarshal(requestAtServer.Params, ¶msAtServer)
229 if err != nil {
230 t.Fatal(err)
231 }
232
233 if paramsAtServer != in {
234 t.Fatalf("want=%+v, got=%+v", in, paramsAtServer)
235 }
236
237 if !finalizerCalled {
238 t.Fatal("Expected finalizer to be called. Wasn't.")
239 }
240 }
241
242 func TestCanUseDefaults(t *testing.T) {
243 t.Parallel()
244
245 var (
246 testbody = `{"jsonrpc":"2.0", "result":"boogaloo"}`
247 requestBody []byte
248 )
249
250 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
251 b, err := ioutil.ReadAll(r.Body)
252 if err != nil && err != io.EOF {
253 t.Fatal(err)
254 }
255 requestBody = b
256
257 w.WriteHeader(http.StatusOK)
258 w.Write([]byte(testbody))
259 }))
260 defer server.Close()
261
262 sut := jsonrpc.NewClient(
263 mustParse(server.URL),
264 "add",
265 )
266
267 type addRequest struct {
268 A int
269 B int
270 }
271
272 in := addRequest{2, 2}
273
274 result, err := sut.Endpoint()(context.Background(), in)
275 if err != nil {
276 t.Fatal(err)
277 }
278 rs, ok := result.(string)
279 if !ok {
280 t.Fatalf("result is not string: (%T)%+v", result, result)
281 }
282 if rs != "boogaloo" {
283 t.Fatalf("want=boogaloo, got=%s", rs)
284 }
285
286 var requestAtServer jsonrpc.Request
287 err = json.Unmarshal(requestBody, &requestAtServer)
288 if err != nil {
289 t.Fatal(err)
290 }
291 var paramsAtServer addRequest
292 err = json.Unmarshal(requestAtServer.Params, ¶msAtServer)
293 if err != nil {
294 t.Fatal(err)
295 }
296
297 if paramsAtServer != in {
298 t.Fatalf("want=%+v, got=%+v", in, paramsAtServer)
299 }
300 }
301
302 func TestClientCanHandleJSONRPCError(t *testing.T) {
303 t.Parallel()
304
305 var testbody = `{
306 "jsonrpc": "2.0",
307 "error": {
308 "code": -32603,
309 "message": "Bad thing happened."
310 }
311 }`
312 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
313 w.WriteHeader(http.StatusOK)
314 w.Write([]byte(testbody))
315 }))
316 defer server.Close()
317
318 sut := jsonrpc.NewClient(mustParse(server.URL), "add")
319
320 _, err := sut.Endpoint()(context.Background(), 5)
321 if err == nil {
322 t.Fatal("Expected error, got none.")
323 }
324
325 {
326 want := "Bad thing happened."
327 got := err.Error()
328 if got != want {
329 t.Fatalf("error message: want=%s, got=%s", want, got)
330 }
331 }
332
333 type errorCoder interface {
334 ErrorCode() int
335 }
336 ec, ok := err.(errorCoder)
337 if !ok {
338 t.Fatal("Error is not errorCoder")
339 }
340
341 {
342 want := -32603
343 got := ec.ErrorCode()
344 if got != want {
345 t.Fatalf("error code: want=%d, got=%d", want, got)
346 }
347 }
348 }
349
350 func TestDefaultAutoIncrementer(t *testing.T) {
351 t.Parallel()
352
353 sut := jsonrpc.NewAutoIncrementID(0)
354 var want uint64
355 for ; want < 100; want++ {
356 got := sut.Generate()
357 if got != want {
358 t.Fatalf("want=%d, got=%d", want, got)
359 }
360 }
361 }
362
363 func mustParse(s string) *url.URL {
364 u, err := url.Parse(s)
365 if err != nil {
366 panic(err)
367 }
368 return u
369 }
370
View as plain text