1
2
3
4
5 package jsonrpc2
6
7 import (
8 "context"
9 "encoding/json"
10 "errors"
11 "fmt"
12 "io"
13 "sync"
14 "sync/atomic"
15 "time"
16
17 "golang.org/x/tools/internal/event"
18 "golang.org/x/tools/internal/event/keys"
19 "golang.org/x/tools/internal/event/label"
20 "golang.org/x/tools/internal/jsonrpc2"
21 )
22
23
24
25
26
27 type Binder interface {
28
29
30
31
32
33 Bind(context.Context, *Connection) ConnectionOptions
34 }
35
36
37 type BinderFunc func(context.Context, *Connection) ConnectionOptions
38
39 func (f BinderFunc) Bind(ctx context.Context, c *Connection) ConnectionOptions {
40 return f(ctx, c)
41 }
42
43 var _ Binder = BinderFunc(nil)
44
45
46 type ConnectionOptions struct {
47
48
49 Framer Framer
50
51
52 Preempter Preempter
53
54
55 Handler Handler
56
57
58
59 OnInternalError func(error)
60 }
61
62
63
64
65
66 type Connection struct {
67 seq int64
68
69 stateMu sync.Mutex
70 state inFlightState
71 done chan struct{}
72
73 writer chan Writer
74
75 handler Handler
76
77 onInternalError func(error)
78 onDone func()
79 }
80
81
82
83 type inFlightState struct {
84 connClosing bool
85 reading bool
86 readErr error
87 writeErr error
88
89
90
91
92
93
94
95
96 closer io.Closer
97 closeErr error
98
99 outgoingCalls map[ID]*AsyncCall
100 outgoingNotifications int
101
102
103
104 incoming int
105
106 incomingByID map[ID]*incomingRequest
107
108
109
110
111 handlerQueue []*incomingRequest
112 handlerRunning bool
113 }
114
115
116
117
118 func (c *Connection) updateInFlight(f func(*inFlightState)) {
119 c.stateMu.Lock()
120 defer c.stateMu.Unlock()
121
122 s := &c.state
123
124 f(s)
125
126 select {
127 case <-c.done:
128
129
130
131
132 if !s.idle() {
133 panic("jsonrpc2_v2: updateInFlight transitioned to non-idle when already done")
134 }
135 return
136 default:
137 }
138
139 if s.idle() && s.shuttingDown(ErrUnknown) != nil {
140 if s.closer != nil {
141 s.closeErr = s.closer.Close()
142 s.closer = nil
143 }
144 if s.reading {
145
146
147
148 } else {
149
150
151 if c.onDone != nil {
152 c.onDone()
153 }
154 close(c.done)
155 }
156 }
157 }
158
159
160
161
162
163
164 func (s *inFlightState) idle() bool {
165 return len(s.outgoingCalls) == 0 && s.outgoingNotifications == 0 && s.incoming == 0 && !s.handlerRunning
166 }
167
168
169
170
171 func (s *inFlightState) shuttingDown(errClosing error) error {
172 if s.connClosing {
173
174
175
176 return errClosing
177 }
178 if s.readErr != nil {
179
180
181 return fmt.Errorf("%w: %v", errClosing, s.readErr)
182 }
183 if s.writeErr != nil {
184
185
186 return fmt.Errorf("%w: %v", errClosing, s.writeErr)
187 }
188 return nil
189 }
190
191
192 type incomingRequest struct {
193 *Request
194 ctx context.Context
195 cancel context.CancelFunc
196 endSpan func()
197 }
198
199
200 func (o ConnectionOptions) Bind(context.Context, *Connection) ConnectionOptions {
201 return o
202 }
203
204
205
206
207
208
209
210
211 func newConnection(bindCtx context.Context, rwc io.ReadWriteCloser, binder Binder, onDone func()) *Connection {
212
213
214 ctx := notDone{bindCtx}
215
216 c := &Connection{
217 state: inFlightState{closer: rwc},
218 done: make(chan struct{}),
219 writer: make(chan Writer, 1),
220 onDone: onDone,
221 }
222
223
224
225
226
227
228 options := binder.Bind(bindCtx, c)
229 framer := options.Framer
230 if framer == nil {
231 framer = HeaderFramer()
232 }
233 c.handler = options.Handler
234 if c.handler == nil {
235 c.handler = defaultHandler{}
236 }
237 c.onInternalError = options.OnInternalError
238
239 c.writer <- framer.Writer(rwc)
240 reader := framer.Reader(rwc)
241
242 c.updateInFlight(func(s *inFlightState) {
243 select {
244 case <-c.done:
245
246 return
247 default:
248 }
249
250
251
252
253
254 s.reading = true
255 go c.readIncoming(ctx, reader, options.Preempter)
256 })
257 return c
258 }
259
260
261
262
263 func (c *Connection) Notify(ctx context.Context, method string, params interface{}) (err error) {
264 ctx, done := event.Start(ctx, method,
265 jsonrpc2.Method.Of(method),
266 jsonrpc2.RPCDirection.Of(jsonrpc2.Outbound),
267 )
268 attempted := false
269
270 defer func() {
271 labelStatus(ctx, err)
272 done()
273 if attempted {
274 c.updateInFlight(func(s *inFlightState) {
275 s.outgoingNotifications--
276 })
277 }
278 }()
279
280 c.updateInFlight(func(s *inFlightState) {
281
282
283
284
285 if len(s.outgoingCalls) == 0 && len(s.incomingByID) == 0 {
286 err = s.shuttingDown(ErrClientClosing)
287 if err != nil {
288 return
289 }
290 }
291 s.outgoingNotifications++
292 attempted = true
293 })
294 if err != nil {
295 return err
296 }
297
298 notify, err := NewNotification(method, params)
299 if err != nil {
300 return fmt.Errorf("marshaling notify parameters: %v", err)
301 }
302
303 event.Metric(ctx, jsonrpc2.Started.Of(1))
304 return c.write(ctx, notify)
305 }
306
307
308
309
310
311
312 func (c *Connection) Call(ctx context.Context, method string, params interface{}) *AsyncCall {
313
314 id := Int64ID(atomic.AddInt64(&c.seq, 1))
315 ctx, endSpan := event.Start(ctx, method,
316 jsonrpc2.Method.Of(method),
317 jsonrpc2.RPCDirection.Of(jsonrpc2.Outbound),
318 jsonrpc2.RPCID.Of(fmt.Sprintf("%q", id)),
319 )
320
321 ac := &AsyncCall{
322 id: id,
323 ready: make(chan struct{}),
324 ctx: ctx,
325 endSpan: endSpan,
326 }
327
328
329
330
331 call, err := NewCall(ac.id, method, params)
332 if err != nil {
333 ac.retire(&Response{ID: id, Error: fmt.Errorf("marshaling call parameters: %w", err)})
334 return ac
335 }
336
337 c.updateInFlight(func(s *inFlightState) {
338 err = s.shuttingDown(ErrClientClosing)
339 if err != nil {
340 return
341 }
342 if s.outgoingCalls == nil {
343 s.outgoingCalls = make(map[ID]*AsyncCall)
344 }
345 s.outgoingCalls[ac.id] = ac
346 })
347 if err != nil {
348 ac.retire(&Response{ID: id, Error: err})
349 return ac
350 }
351
352 event.Metric(ctx, jsonrpc2.Started.Of(1))
353 if err := c.write(ctx, call); err != nil {
354
355
356 c.updateInFlight(func(s *inFlightState) {
357 if s.outgoingCalls[ac.id] == ac {
358 delete(s.outgoingCalls, ac.id)
359 ac.retire(&Response{ID: id, Error: err})
360 } else {
361
362
363 }
364 })
365 }
366 return ac
367 }
368
369 type AsyncCall struct {
370 id ID
371 ready chan struct{}
372 response *Response
373 ctx context.Context
374 endSpan func()
375 }
376
377
378
379 func (ac *AsyncCall) ID() ID { return ac.id }
380
381
382
383
384 func (ac *AsyncCall) IsReady() bool {
385 select {
386 case <-ac.ready:
387 return true
388 default:
389 return false
390 }
391 }
392
393
394 func (ac *AsyncCall) retire(response *Response) {
395 select {
396 case <-ac.ready:
397 panic(fmt.Sprintf("jsonrpc2: retire called twice for ID %v", ac.id))
398 default:
399 }
400
401 ac.response = response
402 labelStatus(ac.ctx, response.Error)
403 ac.endSpan()
404
405
406 ac.ctx, ac.endSpan = nil, nil
407
408 close(ac.ready)
409 }
410
411
412
413 func (ac *AsyncCall) Await(ctx context.Context, result interface{}) error {
414 select {
415 case <-ctx.Done():
416 return ctx.Err()
417 case <-ac.ready:
418 }
419 if ac.response.Error != nil {
420 return ac.response.Error
421 }
422 if result == nil {
423 return nil
424 }
425 return json.Unmarshal(ac.response.Result, result)
426 }
427
428
429
430
431
432 func (c *Connection) Respond(id ID, result interface{}, err error) error {
433 var req *incomingRequest
434 c.updateInFlight(func(s *inFlightState) {
435 req = s.incomingByID[id]
436 })
437 if req == nil {
438 return c.internalErrorf("Request not found for ID %v", id)
439 }
440
441 if err == ErrAsyncResponse {
442
443
444
445 err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method)
446 }
447 return c.processResult("Respond", req, result, err)
448 }
449
450
451
452
453
454
455
456 func (c *Connection) Cancel(id ID) {
457 var req *incomingRequest
458 c.updateInFlight(func(s *inFlightState) {
459 req = s.incomingByID[id]
460 })
461 if req != nil {
462 req.cancel()
463 }
464 }
465
466
467 func (c *Connection) Wait() error {
468 var err error
469 <-c.done
470 c.updateInFlight(func(s *inFlightState) {
471 err = s.closeErr
472 })
473 return err
474 }
475
476
477
478
479
480
481
482
483 func (c *Connection) Close() error {
484
485
486 c.updateInFlight(func(s *inFlightState) { s.connClosing = true })
487
488 return c.Wait()
489 }
490
491
492
493 func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter Preempter) {
494 var err error
495 for {
496 var (
497 msg Message
498 n int64
499 )
500 msg, n, err = reader.Read(ctx)
501 if err != nil {
502 break
503 }
504
505 switch msg := msg.(type) {
506 case *Request:
507 c.acceptRequest(ctx, msg, n, preempter)
508
509 case *Response:
510 c.updateInFlight(func(s *inFlightState) {
511 if ac, ok := s.outgoingCalls[msg.ID]; ok {
512 delete(s.outgoingCalls, msg.ID)
513 ac.retire(msg)
514 } else {
515
516 }
517 })
518
519 default:
520 c.internalErrorf("Read returned an unexpected message of type %T", msg)
521 }
522 }
523
524 c.updateInFlight(func(s *inFlightState) {
525 s.reading = false
526 s.readErr = err
527
528
529
530 for id, ac := range s.outgoingCalls {
531 ac.retire(&Response{ID: id, Error: err})
532 }
533 s.outgoingCalls = nil
534 })
535 }
536
537
538
539 func (c *Connection) acceptRequest(ctx context.Context, msg *Request, msgBytes int64, preempter Preempter) {
540
541 labels := append(make([]label.Label, 0, 3),
542 jsonrpc2.Method.Of(msg.Method),
543 jsonrpc2.RPCDirection.Of(jsonrpc2.Inbound),
544 )
545 if msg.IsCall() {
546 labels = append(labels, jsonrpc2.RPCID.Of(fmt.Sprintf("%q", msg.ID)))
547 }
548 ctx, endSpan := event.Start(ctx, msg.Method, labels...)
549 event.Metric(ctx,
550 jsonrpc2.Started.Of(1),
551 jsonrpc2.ReceivedBytes.Of(msgBytes))
552
553
554
555 ctx, cancel := context.WithCancel(ctx)
556 req := &incomingRequest{
557 Request: msg,
558 ctx: ctx,
559 cancel: cancel,
560 endSpan: endSpan,
561 }
562
563
564
565 var err error
566 c.updateInFlight(func(s *inFlightState) {
567 s.incoming++
568
569 if req.IsCall() {
570 if s.incomingByID[req.ID] != nil {
571 err = fmt.Errorf("%w: request ID %v already in use", ErrInvalidRequest, req.ID)
572 req.ID = ID{}
573 return
574 }
575
576 if s.incomingByID == nil {
577 s.incomingByID = make(map[ID]*incomingRequest)
578 }
579 s.incomingByID[req.ID] = req
580
581
582
583
584
585 err = s.shuttingDown(ErrServerClosing)
586 }
587 })
588 if err != nil {
589 c.processResult("acceptRequest", req, nil, err)
590 return
591 }
592
593 if preempter != nil {
594 result, err := preempter.Preempt(req.ctx, req.Request)
595
596 if req.IsCall() && errors.Is(err, ErrAsyncResponse) {
597
598 return
599 }
600
601 if !errors.Is(err, ErrNotHandled) {
602 c.processResult("Preempt", req, result, err)
603 return
604 }
605 }
606
607 c.updateInFlight(func(s *inFlightState) {
608
609
610
611
612 err = s.shuttingDown(ErrServerClosing)
613 if err != nil {
614 return
615 }
616
617
618
619
620
621
622
623 s.handlerQueue = append(s.handlerQueue, req)
624 if !s.handlerRunning {
625
626
627
628
629
630
631
632
633
634
635
636
637 s.handlerRunning = true
638 go c.handleAsync()
639 }
640 })
641 if err != nil {
642 c.processResult("acceptRequest", req, nil, err)
643 }
644 }
645
646
647
648 func (c *Connection) handleAsync() {
649 for {
650 var req *incomingRequest
651 c.updateInFlight(func(s *inFlightState) {
652 if len(s.handlerQueue) > 0 {
653 req, s.handlerQueue = s.handlerQueue[0], s.handlerQueue[1:]
654 } else {
655 s.handlerRunning = false
656 }
657 })
658 if req == nil {
659 return
660 }
661
662
663 if err := req.ctx.Err(); err != nil {
664 c.updateInFlight(func(s *inFlightState) {
665 if s.writeErr != nil {
666
667
668 err = fmt.Errorf("%w: %v", ErrServerClosing, s.writeErr)
669 }
670 })
671 c.processResult("handleAsync", req, nil, err)
672 continue
673 }
674
675 result, err := c.handler.Handle(req.ctx, req.Request)
676 c.processResult(c.handler, req, result, err)
677 }
678 }
679
680
681 func (c *Connection) processResult(from interface{}, req *incomingRequest, result interface{}, err error) error {
682 switch err {
683 case ErrAsyncResponse:
684 if !req.IsCall() {
685 return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method)
686 }
687 return nil
688 case ErrNotHandled, ErrMethodNotFound:
689
690 err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method)
691 }
692
693 if req.endSpan == nil {
694 return c.internalErrorf("%#v produced a duplicate %q Response", from, req.Method)
695 }
696
697 if result != nil && err != nil {
698 c.internalErrorf("%#v returned a non-nil result with a non-nil error for %s:\n%v\n%#v", from, req.Method, err, result)
699 result = nil
700 }
701
702 if req.IsCall() {
703 if result == nil && err == nil {
704 err = c.internalErrorf("%#v returned a nil result and nil error for a %q Request that requires a Response", from, req.Method)
705 }
706
707 response, respErr := NewResponse(req.ID, result, err)
708
709
710
711
712 c.updateInFlight(func(s *inFlightState) {
713 delete(s.incomingByID, req.ID)
714 })
715 if respErr == nil {
716 writeErr := c.write(notDone{req.ctx}, response)
717 if err == nil {
718 err = writeErr
719 }
720 } else {
721 err = c.internalErrorf("%#v returned a malformed result for %q: %w", from, req.Method, respErr)
722 }
723 } else {
724 if result != nil {
725 err = c.internalErrorf("%#v returned a non-nil result for a %q Request without an ID", from, req.Method)
726 } else if err != nil {
727 err = fmt.Errorf("%w: %q notification failed: %v", ErrInternal, req.Method, err)
728 }
729 if err != nil {
730
731
732 event.Label(req.ctx, keys.Err.Of(err))
733 }
734 }
735
736 labelStatus(req.ctx, err)
737
738
739 req.cancel()
740 req.endSpan()
741 req.endSpan = nil
742 c.updateInFlight(func(s *inFlightState) {
743 if s.incoming == 0 {
744 panic("jsonrpc2_v2: processResult called when incoming count is already zero")
745 }
746 s.incoming--
747 })
748 return nil
749 }
750
751
752
753 func (c *Connection) write(ctx context.Context, msg Message) error {
754 writer := <-c.writer
755 defer func() { c.writer <- writer }()
756 n, err := writer.Write(ctx, msg)
757 event.Metric(ctx, jsonrpc2.SentBytes.Of(n))
758
759 if err != nil && ctx.Err() == nil {
760
761
762
763
764
765
766
767
768 c.updateInFlight(func(s *inFlightState) {
769 if s.writeErr == nil {
770 s.writeErr = err
771 for _, r := range s.incomingByID {
772 r.cancel()
773 }
774 }
775 })
776 }
777
778 return err
779 }
780
781
782
783
784 func (c *Connection) internalErrorf(format string, args ...interface{}) error {
785 err := fmt.Errorf(format, args...)
786 if c.onInternalError == nil {
787 panic("jsonrpc2: " + err.Error())
788 }
789 c.onInternalError(err)
790
791 return fmt.Errorf("%w: %v", ErrInternal, err)
792 }
793
794
795 func labelStatus(ctx context.Context, err error) {
796 if err == nil {
797 event.Label(ctx, jsonrpc2.StatusCode.Of("OK"))
798 } else {
799 event.Label(ctx, jsonrpc2.StatusCode.Of("ERROR"))
800 }
801 }
802
803
804 type notDone struct{ ctx context.Context }
805
806 func (ic notDone) Value(key interface{}) interface{} {
807 return ic.ctx.Value(key)
808 }
809
810 func (notDone) Done() <-chan struct{} { return nil }
811 func (notDone) Err() error { return nil }
812 func (notDone) Deadline() (time.Time, bool) { return time.Time{}, false }
813
View as plain text