1
2
3
4
5 package jsonrpc2_test
6
7 import (
8 "context"
9 "encoding/json"
10 "flag"
11 "fmt"
12 "net"
13 "path"
14 "reflect"
15 "testing"
16
17 "golang.org/x/tools/internal/event/export/eventtest"
18 "golang.org/x/tools/internal/jsonrpc2"
19 "golang.org/x/tools/internal/stack/stacktest"
20 )
21
22 var logRPC = flag.Bool("logrpc", false, "Enable jsonrpc2 communication logging")
23
24 type callTest struct {
25 method string
26 params interface{}
27 expect interface{}
28 }
29
30 var callTests = []callTest{
31 {"no_args", nil, true},
32 {"one_string", "fish", "got:fish"},
33 {"one_number", 10, "got:10"},
34 {"join", []string{"a", "b", "c"}, "a/b/c"},
35
36 }
37
38 func (test *callTest) newResults() interface{} {
39 switch e := test.expect.(type) {
40 case []interface{}:
41 var r []interface{}
42 for _, v := range e {
43 r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
44 }
45 return r
46 case nil:
47 return nil
48 default:
49 return reflect.New(reflect.TypeOf(test.expect)).Interface()
50 }
51 }
52
53 func (test *callTest) verifyResults(t *testing.T, results interface{}) {
54 if results == nil {
55 return
56 }
57 val := reflect.Indirect(reflect.ValueOf(results)).Interface()
58 if !reflect.DeepEqual(val, test.expect) {
59 t.Errorf("%v:Results are incorrect, got %+v expect %+v", test.method, val, test.expect)
60 }
61 }
62
63 func TestCall(t *testing.T) {
64 stacktest.NoLeak(t)
65 ctx := eventtest.NewContext(context.Background(), t)
66 for _, headers := range []bool{false, true} {
67 name := "Plain"
68 if headers {
69 name = "Headers"
70 }
71 t.Run(name, func(t *testing.T) {
72 ctx := eventtest.NewContext(ctx, t)
73 a, b, done := prepare(ctx, t, headers)
74 defer done()
75 for _, test := range callTests {
76 t.Run(test.method, func(t *testing.T) {
77 ctx := eventtest.NewContext(ctx, t)
78 results := test.newResults()
79 if _, err := a.Call(ctx, test.method, test.params, results); err != nil {
80 t.Fatalf("%v:Call failed: %v", test.method, err)
81 }
82 test.verifyResults(t, results)
83 if _, err := b.Call(ctx, test.method, test.params, results); err != nil {
84 t.Fatalf("%v:Call failed: %v", test.method, err)
85 }
86 test.verifyResults(t, results)
87 })
88 }
89 })
90 }
91 }
92
93 func prepare(ctx context.Context, t *testing.T, withHeaders bool) (jsonrpc2.Conn, jsonrpc2.Conn, func()) {
94
95 aPipe, bPipe := net.Pipe()
96 a := run(ctx, withHeaders, aPipe)
97 b := run(ctx, withHeaders, bPipe)
98 return a, b, func() {
99 a.Close()
100 b.Close()
101 <-a.Done()
102 <-b.Done()
103 }
104 }
105
106 func run(ctx context.Context, withHeaders bool, nc net.Conn) jsonrpc2.Conn {
107 var stream jsonrpc2.Stream
108 if withHeaders {
109 stream = jsonrpc2.NewHeaderStream(nc)
110 } else {
111 stream = jsonrpc2.NewRawStream(nc)
112 }
113 conn := jsonrpc2.NewConn(stream)
114 conn.Go(ctx, testHandler(*logRPC))
115 return conn
116 }
117
118 func testHandler(log bool) jsonrpc2.Handler {
119 return func(ctx context.Context, reply jsonrpc2.Replier, req jsonrpc2.Request) error {
120 switch req.Method() {
121 case "no_args":
122 if len(req.Params()) > 0 {
123 return reply(ctx, nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams))
124 }
125 return reply(ctx, true, nil)
126 case "one_string":
127 var v string
128 if err := json.Unmarshal(req.Params(), &v); err != nil {
129 return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err))
130 }
131 return reply(ctx, "got:"+v, nil)
132 case "one_number":
133 var v int
134 if err := json.Unmarshal(req.Params(), &v); err != nil {
135 return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err))
136 }
137 return reply(ctx, fmt.Sprintf("got:%d", v), nil)
138 case "join":
139 var v []string
140 if err := json.Unmarshal(req.Params(), &v); err != nil {
141 return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err))
142 }
143 return reply(ctx, path.Join(v...), nil)
144 default:
145 return jsonrpc2.MethodNotFound(ctx, reply, req)
146 }
147 }
148 }
149
View as plain text