1
2
3
4
5 package jsonrpc2_test
6
7 import (
8 "context"
9 "encoding/json"
10 "fmt"
11 "path"
12 "reflect"
13 "testing"
14
15 "golang.org/x/tools/internal/event/export/eventtest"
16 jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
17 "golang.org/x/tools/internal/stack/stacktest"
18 )
19
20 var callTests = []invoker{
21 call{"no_args", nil, true},
22 call{"one_string", "fish", "got:fish"},
23 call{"one_number", 10, "got:10"},
24 call{"join", []string{"a", "b", "c"}, "a/b/c"},
25 sequence{"notify", []invoker{
26 notify{"set", 3},
27 notify{"add", 5},
28 call{"get", nil, 8},
29 }},
30 sequence{"preempt", []invoker{
31 async{"a", "wait", "a"},
32 notify{"unblock", "a"},
33 collect{"a", true, false},
34 }},
35 sequence{"basic cancel", []invoker{
36 async{"b", "wait", "b"},
37 cancel{"b"},
38 collect{"b", nil, true},
39 }},
40 sequence{"queue", []invoker{
41 async{"a", "wait", "a"},
42 notify{"set", 1},
43 notify{"add", 2},
44 notify{"add", 3},
45 notify{"add", 4},
46 call{"peek", nil, 0},
47 notify{"unblock", "a"},
48 collect{"a", true, false},
49 call{"get", nil, 10},
50 }},
51 sequence{"fork", []invoker{
52 async{"a", "fork", "a"},
53 notify{"set", 1},
54 notify{"add", 2},
55 notify{"add", 3},
56 notify{"add", 4},
57 call{"get", nil, 10},
58 notify{"unblock", "a"},
59 collect{"a", true, false},
60 }},
61 sequence{"concurrent", []invoker{
62 async{"a", "fork", "a"},
63 notify{"unblock", "a"},
64 async{"b", "fork", "b"},
65 notify{"unblock", "b"},
66 collect{"a", true, false},
67 collect{"b", true, false},
68 }},
69 }
70
71 type binder struct {
72 framer jsonrpc2.Framer
73 runTest func(*handler)
74 }
75
76 type handler struct {
77 conn *jsonrpc2.Connection
78 accumulator int
79 waiters chan map[string]chan struct{}
80 calls map[string]*jsonrpc2.AsyncCall
81 }
82
83 type invoker interface {
84 Name() string
85 Invoke(t *testing.T, ctx context.Context, h *handler)
86 }
87
88 type notify struct {
89 method string
90 params interface{}
91 }
92
93 type call struct {
94 method string
95 params interface{}
96 expect interface{}
97 }
98
99 type async struct {
100 name string
101 method string
102 params interface{}
103 }
104
105 type collect struct {
106 name string
107 expect interface{}
108 fails bool
109 }
110
111 type cancel struct {
112 name string
113 }
114
115 type sequence struct {
116 name string
117 tests []invoker
118 }
119
120 type echo call
121
122 type cancelParams struct{ ID int64 }
123
124 func TestConnectionRaw(t *testing.T) {
125 testConnection(t, jsonrpc2.RawFramer())
126 }
127
128 func TestConnectionHeader(t *testing.T) {
129 testConnection(t, jsonrpc2.HeaderFramer())
130 }
131
132 func testConnection(t *testing.T, framer jsonrpc2.Framer) {
133 stacktest.NoLeak(t)
134 ctx := eventtest.NewContext(context.Background(), t)
135 listener, err := jsonrpc2.NetPipeListener(ctx)
136 if err != nil {
137 t.Fatal(err)
138 }
139 server := jsonrpc2.NewServer(ctx, listener, binder{framer, nil})
140 defer func() {
141 listener.Close()
142 server.Wait()
143 }()
144
145 for _, test := range callTests {
146 t.Run(test.Name(), func(t *testing.T) {
147 client, err := jsonrpc2.Dial(ctx,
148 listener.Dialer(), binder{framer, func(h *handler) {
149 defer h.conn.Close()
150 ctx := eventtest.NewContext(ctx, t)
151 test.Invoke(t, ctx, h)
152 if call, ok := test.(*call); ok {
153
154 (*echo)(call).Invoke(t, ctx, h)
155 }
156 }})
157 if err != nil {
158 t.Fatal(err)
159 }
160 client.Wait()
161 })
162 }
163 }
164
165 func (test notify) Name() string { return test.method }
166 func (test notify) Invoke(t *testing.T, ctx context.Context, h *handler) {
167 if err := h.conn.Notify(ctx, test.method, test.params); err != nil {
168 t.Fatalf("%v:Notify failed: %v", test.method, err)
169 }
170 }
171
172 func (test call) Name() string { return test.method }
173 func (test call) Invoke(t *testing.T, ctx context.Context, h *handler) {
174 results := newResults(test.expect)
175 if err := h.conn.Call(ctx, test.method, test.params).Await(ctx, results); err != nil {
176 t.Fatalf("%v:Call failed: %v", test.method, err)
177 }
178 verifyResults(t, test.method, results, test.expect)
179 }
180
181 func (test echo) Invoke(t *testing.T, ctx context.Context, h *handler) {
182 results := newResults(test.expect)
183 if err := h.conn.Call(ctx, "echo", []interface{}{test.method, test.params}).Await(ctx, results); err != nil {
184 t.Fatalf("%v:Echo failed: %v", test.method, err)
185 }
186 verifyResults(t, test.method, results, test.expect)
187 }
188
189 func (test async) Name() string { return test.name }
190 func (test async) Invoke(t *testing.T, ctx context.Context, h *handler) {
191 h.calls[test.name] = h.conn.Call(ctx, test.method, test.params)
192 }
193
194 func (test collect) Name() string { return test.name }
195 func (test collect) Invoke(t *testing.T, ctx context.Context, h *handler) {
196 o := h.calls[test.name]
197 results := newResults(test.expect)
198 err := o.Await(ctx, results)
199 switch {
200 case test.fails && err == nil:
201 t.Fatalf("%v:Collect was supposed to fail", test.name)
202 case !test.fails && err != nil:
203 t.Fatalf("%v:Collect failed: %v", test.name, err)
204 }
205 verifyResults(t, test.name, results, test.expect)
206 }
207
208 func (test cancel) Name() string { return test.name }
209 func (test cancel) Invoke(t *testing.T, ctx context.Context, h *handler) {
210 o := h.calls[test.name]
211 if err := h.conn.Notify(ctx, "cancel", &cancelParams{o.ID().Raw().(int64)}); err != nil {
212 t.Fatalf("%v:Collect failed: %v", test.name, err)
213 }
214 }
215
216 func (test sequence) Name() string { return test.name }
217 func (test sequence) Invoke(t *testing.T, ctx context.Context, h *handler) {
218 for _, child := range test.tests {
219 child.Invoke(t, ctx, h)
220 }
221 }
222
223
224 func newResults(expect interface{}) interface{} {
225 switch e := expect.(type) {
226 case []interface{}:
227 var r []interface{}
228 for _, v := range e {
229 r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
230 }
231 return r
232 case nil:
233 return nil
234 default:
235 return reflect.New(reflect.TypeOf(expect)).Interface()
236 }
237 }
238
239
240 func verifyResults(t *testing.T, method string, results interface{}, expect interface{}) {
241 if expect == nil {
242 if results != nil {
243 t.Errorf("%v:Got results %+v where none expected", method, expect)
244 }
245 return
246 }
247 val := reflect.Indirect(reflect.ValueOf(results)).Interface()
248 if !reflect.DeepEqual(val, expect) {
249 t.Errorf("%v:Results are incorrect, got %+v expect %+v", method, val, expect)
250 }
251 }
252
253 func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) jsonrpc2.ConnectionOptions {
254 h := &handler{
255 conn: conn,
256 waiters: make(chan map[string]chan struct{}, 1),
257 calls: make(map[string]*jsonrpc2.AsyncCall),
258 }
259 h.waiters <- make(map[string]chan struct{})
260 if b.runTest != nil {
261 go b.runTest(h)
262 }
263 return jsonrpc2.ConnectionOptions{
264 Framer: b.framer,
265 Preempter: h,
266 Handler: h,
267 }
268 }
269
270 func (h *handler) waiter(name string) chan struct{} {
271 waiters := <-h.waiters
272 defer func() { h.waiters <- waiters }()
273 waiter, found := waiters[name]
274 if !found {
275 waiter = make(chan struct{})
276 waiters[name] = waiter
277 }
278 return waiter
279 }
280
281 func (h *handler) Preempt(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
282 switch req.Method {
283 case "unblock":
284 var name string
285 if err := json.Unmarshal(req.Params, &name); err != nil {
286 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
287 }
288 close(h.waiter(name))
289 return nil, nil
290 case "peek":
291 if len(req.Params) > 0 {
292 return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
293 }
294 return h.accumulator, nil
295 case "cancel":
296 var params cancelParams
297 if err := json.Unmarshal(req.Params, ¶ms); err != nil {
298 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
299 }
300 h.conn.Cancel(jsonrpc2.Int64ID(params.ID))
301 return nil, nil
302 default:
303 return nil, jsonrpc2.ErrNotHandled
304 }
305 }
306
307 func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
308 switch req.Method {
309 case "no_args":
310 if len(req.Params) > 0 {
311 return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
312 }
313 return true, nil
314 case "one_string":
315 var v string
316 if err := json.Unmarshal(req.Params, &v); err != nil {
317 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
318 }
319 return "got:" + v, nil
320 case "one_number":
321 var v int
322 if err := json.Unmarshal(req.Params, &v); err != nil {
323 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
324 }
325 return fmt.Sprintf("got:%d", v), nil
326 case "set":
327 var v int
328 if err := json.Unmarshal(req.Params, &v); err != nil {
329 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
330 }
331 h.accumulator = v
332 return nil, nil
333 case "add":
334 var v int
335 if err := json.Unmarshal(req.Params, &v); err != nil {
336 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
337 }
338 h.accumulator += v
339 return nil, nil
340 case "get":
341 if len(req.Params) > 0 {
342 return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
343 }
344 return h.accumulator, nil
345 case "join":
346 var v []string
347 if err := json.Unmarshal(req.Params, &v); err != nil {
348 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
349 }
350 return path.Join(v...), nil
351 case "echo":
352 var v []interface{}
353 if err := json.Unmarshal(req.Params, &v); err != nil {
354 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
355 }
356 var result interface{}
357 err := h.conn.Call(ctx, v[0].(string), v[1]).Await(ctx, &result)
358 return result, err
359 case "wait":
360 var name string
361 if err := json.Unmarshal(req.Params, &name); err != nil {
362 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
363 }
364 select {
365 case <-h.waiter(name):
366 return true, nil
367 case <-ctx.Done():
368 return nil, ctx.Err()
369 }
370 case "fork":
371 var name string
372 if err := json.Unmarshal(req.Params, &name); err != nil {
373 return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
374 }
375 waitFor := h.waiter(name)
376 go func() {
377 select {
378 case <-waitFor:
379 h.conn.Respond(req.ID, true, nil)
380 case <-ctx.Done():
381 h.conn.Respond(req.ID, nil, ctx.Err())
382 }
383 }()
384 return nil, jsonrpc2.ErrAsyncResponse
385 default:
386 return nil, jsonrpc2.ErrNotHandled
387 }
388 }
389
View as plain text