1
16
17 package ttrpc
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "io"
24 "net"
25 "strings"
26 "sync"
27 "syscall"
28 "time"
29
30 "github.com/sirupsen/logrus"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/grpc/status"
33 "google.golang.org/protobuf/proto"
34 )
35
36
37 type Client struct {
38 codec codec
39 conn net.Conn
40 channel *channel
41
42 streamLock sync.RWMutex
43 streams map[streamID]*stream
44 nextStreamID streamID
45 sendLock sync.Mutex
46
47 ctx context.Context
48 closed func()
49
50 closeOnce sync.Once
51 userCloseFunc func()
52 userCloseWaitCh chan struct{}
53
54 interceptor UnaryClientInterceptor
55 }
56
57
58 type ClientOpts func(c *Client)
59
60
61 func WithOnClose(onClose func()) ClientOpts {
62 return func(c *Client) {
63 c.userCloseFunc = onClose
64 }
65 }
66
67
68 func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
69 return func(c *Client) {
70 c.interceptor = i
71 }
72 }
73
74
75 func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
76 ctx, cancel := context.WithCancel(context.Background())
77 channel := newChannel(conn)
78 c := &Client{
79 codec: codec{},
80 conn: conn,
81 channel: channel,
82 streams: make(map[streamID]*stream),
83 nextStreamID: 1,
84 closed: cancel,
85 ctx: ctx,
86 userCloseFunc: func() {},
87 userCloseWaitCh: make(chan struct{}),
88 interceptor: defaultClientInterceptor,
89 }
90
91 for _, o := range opts {
92 o(c)
93 }
94
95 go c.run()
96 return c
97 }
98
99 func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error {
100 c.sendLock.Lock()
101 defer c.sendLock.Unlock()
102 return c.channel.send(sid, mt, flags, b)
103 }
104
105
106 func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
107 payload, err := c.codec.Marshal(req)
108 if err != nil {
109 return err
110 }
111
112 var (
113 creq = &Request{
114 Service: service,
115 Method: method,
116 Payload: payload,
117
118 }
119
120 cresp = &Response{}
121 )
122
123 if metadata, ok := GetMetadata(ctx); ok {
124 metadata.setRequest(creq)
125 }
126
127 if dl, ok := ctx.Deadline(); ok {
128 creq.TimeoutNano = time.Until(dl).Nanoseconds()
129 }
130
131 info := &UnaryClientInfo{
132 FullMethod: fullPath(service, method),
133 }
134 if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
135 return err
136 }
137
138 if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
139 return err
140 }
141
142 if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
143 return status.ErrorProto(cresp.Status)
144 }
145 return nil
146 }
147
148
149
150 type StreamDesc struct {
151 StreamingClient bool
152 StreamingServer bool
153 }
154
155
156 type ClientStream interface {
157 CloseSend() error
158 SendMsg(m interface{}) error
159 RecvMsg(m interface{}) error
160 }
161
162 type clientStream struct {
163 ctx context.Context
164 s *stream
165 c *Client
166 desc *StreamDesc
167 localClosed bool
168 remoteClosed bool
169 }
170
171 func (cs *clientStream) CloseSend() error {
172 if !cs.desc.StreamingClient {
173 return fmt.Errorf("%w: cannot close non-streaming client", ErrProtocol)
174 }
175 if cs.localClosed {
176 return ErrStreamClosed
177 }
178 err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil)
179 if err != nil {
180 return filterCloseErr(err)
181 }
182 cs.localClosed = true
183 return nil
184 }
185
186 func (cs *clientStream) SendMsg(m interface{}) error {
187 if !cs.desc.StreamingClient {
188 return fmt.Errorf("%w: cannot send data from non-streaming client", ErrProtocol)
189 }
190 if cs.localClosed {
191 return ErrStreamClosed
192 }
193
194 var (
195 payload []byte
196 err error
197 )
198 if m != nil {
199 payload, err = cs.c.codec.Marshal(m)
200 if err != nil {
201 return err
202 }
203 }
204
205 err = cs.s.send(messageTypeData, 0, payload)
206 if err != nil {
207 return filterCloseErr(err)
208 }
209
210 return nil
211 }
212
213 func (cs *clientStream) RecvMsg(m interface{}) error {
214 if cs.remoteClosed {
215 return io.EOF
216 }
217
218 var msg *streamMessage
219 select {
220 case <-cs.ctx.Done():
221 return cs.ctx.Err()
222 case <-cs.s.recvClose:
223
224 select {
225 case msg = <-cs.s.recv:
226 default:
227 return cs.s.recvErr
228 }
229 case msg = <-cs.s.recv:
230 }
231
232 if msg.header.Type == messageTypeResponse {
233 resp := &Response{}
234 err := proto.Unmarshal(msg.payload[:msg.header.Length], resp)
235
236 cs.c.channel.putmbuf(msg.payload)
237 if err != nil {
238 return err
239 }
240
241 if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil {
242 return err
243 }
244
245 if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
246 return status.ErrorProto(resp.Status)
247 }
248
249 cs.c.deleteStream(cs.s)
250 cs.remoteClosed = true
251
252 return nil
253 } else if msg.header.Type == messageTypeData {
254 if !cs.desc.StreamingServer {
255 cs.c.deleteStream(cs.s)
256 cs.remoteClosed = true
257 return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol)
258 }
259 if msg.header.Flags&flagRemoteClosed == flagRemoteClosed {
260 cs.c.deleteStream(cs.s)
261 cs.remoteClosed = true
262
263 if msg.header.Flags&flagNoData == flagNoData {
264 return io.EOF
265 }
266 }
267
268 err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m)
269 cs.c.channel.putmbuf(msg.payload)
270 if err != nil {
271 return err
272 }
273 return nil
274 }
275
276 return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
277 }
278
279
280 func (c *Client) Close() error {
281 c.closeOnce.Do(func() {
282 c.closed()
283
284 c.conn.Close()
285 })
286 return nil
287 }
288
289
290
291 func (c *Client) UserOnCloseWait(ctx context.Context) error {
292 select {
293 case <-c.userCloseWaitCh:
294 return nil
295 case <-ctx.Done():
296 return ctx.Err()
297 }
298 }
299
300 func (c *Client) run() {
301 err := c.receiveLoop()
302 c.Close()
303 c.cleanupStreams(err)
304
305 c.userCloseFunc()
306 close(c.userCloseWaitCh)
307 }
308
309 func (c *Client) receiveLoop() error {
310 for {
311 select {
312 case <-c.ctx.Done():
313 return ErrClosed
314 default:
315 var (
316 msg = &streamMessage{}
317 err error
318 )
319
320 msg.header, msg.payload, err = c.channel.recv()
321 if err != nil {
322 _, ok := status.FromError(err)
323 if !ok {
324
325
326 return filterCloseErr(err)
327 }
328 }
329 sid := streamID(msg.header.StreamID)
330 s := c.getStream(sid)
331 if s == nil {
332 logrus.WithField("stream", sid).Errorf("ttrpc: received message on inactive stream")
333 continue
334 }
335
336 if err != nil {
337 s.closeWithError(err)
338 } else {
339 if err := s.receive(c.ctx, msg); err != nil {
340 logrus.WithError(err).WithField("stream", sid).Errorf("ttrpc: failed to handle message")
341 }
342 }
343 }
344 }
345 }
346
347
348
349 func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
350 c.streamLock.Lock()
351
352
353
354 select {
355 case <-c.ctx.Done():
356 c.streamLock.Unlock()
357 return nil, ErrClosed
358 default:
359 }
360
361
362 s := newStream(c.nextStreamID, c)
363 c.streams[s.id] = s
364 c.nextStreamID = c.nextStreamID + 2
365
366 c.sendLock.Lock()
367 defer c.sendLock.Unlock()
368 c.streamLock.Unlock()
369
370 if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil {
371 return s, filterCloseErr(err)
372 }
373
374 return s, nil
375 }
376
377 func (c *Client) deleteStream(s *stream) {
378 c.streamLock.Lock()
379 delete(c.streams, s.id)
380 c.streamLock.Unlock()
381 s.closeWithError(nil)
382 }
383
384 func (c *Client) getStream(sid streamID) *stream {
385 c.streamLock.RLock()
386 s := c.streams[sid]
387 c.streamLock.RUnlock()
388 return s
389 }
390
391 func (c *Client) cleanupStreams(err error) {
392 c.streamLock.Lock()
393 defer c.streamLock.Unlock()
394
395 for sid, s := range c.streams {
396 s.closeWithError(err)
397 delete(c.streams, sid)
398 }
399 }
400
401
402
403
404
405 func filterCloseErr(err error) error {
406 switch {
407 case err == nil:
408 return nil
409 case err == io.EOF:
410 return ErrClosed
411 case errors.Is(err, io.ErrClosedPipe):
412 return ErrClosed
413 case errors.Is(err, io.EOF):
414 return ErrClosed
415 case strings.Contains(err.Error(), "use of closed network connection"):
416 return ErrClosed
417 default:
418
419 var oerr *net.OpError
420 if errors.As(err, &oerr) {
421 if (oerr.Op == "write" && errors.Is(err, syscall.EPIPE)) ||
422 (oerr.Op == "read" && errors.Is(err, syscall.ECONNRESET)) {
423 return ErrClosed
424 }
425 }
426 }
427
428 return err
429 }
430
431
432
433
434 func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string, req interface{}) (ClientStream, error) {
435 var payload []byte
436 if req != nil {
437 var err error
438 payload, err = c.codec.Marshal(req)
439 if err != nil {
440 return nil, err
441 }
442 }
443
444 request := &Request{
445 Service: service,
446 Method: method,
447 Payload: payload,
448
449 }
450 p, err := c.codec.Marshal(request)
451 if err != nil {
452 return nil, err
453 }
454
455 var flags uint8
456 if desc.StreamingClient {
457 flags = flagRemoteOpen
458 } else {
459 flags = flagRemoteClosed
460 }
461 s, err := c.createStream(flags, p)
462 if err != nil {
463 return nil, err
464 }
465
466 return &clientStream{
467 ctx: ctx,
468 s: s,
469 c: c,
470 desc: desc,
471 }, nil
472 }
473
474 func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
475 p, err := c.codec.Marshal(req)
476 if err != nil {
477 return err
478 }
479
480 s, err := c.createStream(0, p)
481 if err != nil {
482 return err
483 }
484 defer c.deleteStream(s)
485
486 var msg *streamMessage
487 select {
488 case <-ctx.Done():
489 return ctx.Err()
490 case <-c.ctx.Done():
491 return ErrClosed
492 case <-s.recvClose:
493
494 select {
495 case msg = <-s.recv:
496 default:
497 return s.recvErr
498 }
499 case msg = <-s.recv:
500 }
501
502 if msg.header.Type == messageTypeResponse {
503 err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
504 } else {
505 err = fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
506 }
507
508
509 c.channel.putmbuf(msg.payload)
510
511 return err
512 }
513
View as plain text