1
16
17 package ttrpc
18
19 import (
20 "context"
21 "errors"
22 "io"
23 "math/rand"
24 "net"
25 "sync"
26 "sync/atomic"
27 "syscall"
28 "time"
29
30 "github.com/sirupsen/logrus"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/grpc/status"
33 )
34
35 type Server struct {
36 config *serverConfig
37 services *serviceSet
38 codec codec
39
40 mu sync.Mutex
41 listeners map[net.Listener]struct{}
42 connections map[*serverConn]struct{}
43 done chan struct{}
44 }
45
46 func NewServer(opts ...ServerOpt) (*Server, error) {
47 config := &serverConfig{}
48 for _, opt := range opts {
49 if err := opt(config); err != nil {
50 return nil, err
51 }
52 }
53 if config.interceptor == nil {
54 config.interceptor = defaultServerInterceptor
55 }
56
57 return &Server{
58 config: config,
59 services: newServiceSet(config.interceptor),
60 done: make(chan struct{}),
61 listeners: make(map[net.Listener]struct{}),
62 connections: make(map[*serverConn]struct{}),
63 }, nil
64 }
65
66
67
68 func (s *Server) Register(name string, methods map[string]Method) {
69 s.services.register(name, &ServiceDesc{Methods: methods})
70 }
71
72 func (s *Server) RegisterService(name string, desc *ServiceDesc) {
73 s.services.register(name, desc)
74 }
75
76 func (s *Server) Serve(ctx context.Context, l net.Listener) error {
77 s.addListener(l)
78 defer s.closeListener(l)
79
80 var (
81 backoff time.Duration
82 handshaker = s.config.handshaker
83 )
84
85 if handshaker == nil {
86 handshaker = handshakerFunc(noopHandshake)
87 }
88
89 for {
90 conn, err := l.Accept()
91 if err != nil {
92 select {
93 case <-s.done:
94 return ErrServerClosed
95 default:
96 }
97
98 if terr, ok := err.(interface {
99 Temporary() bool
100 }); ok && terr.Temporary() {
101 if backoff == 0 {
102 backoff = time.Millisecond
103 } else {
104 backoff *= 2
105 }
106
107 if max := time.Second; backoff > max {
108 backoff = max
109 }
110
111 sleep := time.Duration(rand.Int63n(int64(backoff)))
112 logrus.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
113 time.Sleep(sleep)
114 continue
115 }
116
117 return err
118 }
119
120 backoff = 0
121
122 approved, handshake, err := handshaker.Handshake(ctx, conn)
123 if err != nil {
124 logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
125 conn.Close()
126 continue
127 }
128
129 sc, err := s.newConn(approved, handshake)
130 if err != nil {
131 logrus.WithError(err).Error("ttrpc: create connection failed")
132 conn.Close()
133 continue
134 }
135
136 go sc.run(ctx)
137 }
138 }
139
140 func (s *Server) Shutdown(ctx context.Context) error {
141 s.mu.Lock()
142 select {
143 case <-s.done:
144 default:
145
146 close(s.done)
147 }
148 lnerr := s.closeListeners()
149 s.mu.Unlock()
150
151 ticker := time.NewTicker(200 * time.Millisecond)
152 defer ticker.Stop()
153 for {
154 s.closeIdleConns()
155
156 if s.countConnection() == 0 {
157 break
158 }
159
160 select {
161 case <-ctx.Done():
162 return ctx.Err()
163 case <-ticker.C:
164 }
165 }
166
167 return lnerr
168 }
169
170
171 func (s *Server) Close() error {
172 s.mu.Lock()
173 defer s.mu.Unlock()
174
175 select {
176 case <-s.done:
177 default:
178
179 close(s.done)
180 }
181
182 err := s.closeListeners()
183 for c := range s.connections {
184 c.close()
185 delete(s.connections, c)
186 }
187
188 return err
189 }
190
191 func (s *Server) addListener(l net.Listener) {
192 s.mu.Lock()
193 defer s.mu.Unlock()
194 s.listeners[l] = struct{}{}
195 }
196
197 func (s *Server) closeListener(l net.Listener) error {
198 s.mu.Lock()
199 defer s.mu.Unlock()
200
201 return s.closeListenerLocked(l)
202 }
203
204 func (s *Server) closeListenerLocked(l net.Listener) error {
205 defer delete(s.listeners, l)
206 return l.Close()
207 }
208
209 func (s *Server) closeListeners() error {
210 var err error
211 for l := range s.listeners {
212 if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
213 err = cerr
214 }
215 }
216 return err
217 }
218
219 func (s *Server) addConnection(c *serverConn) error {
220 s.mu.Lock()
221 defer s.mu.Unlock()
222
223 select {
224 case <-s.done:
225 return ErrServerClosed
226 default:
227 }
228
229 s.connections[c] = struct{}{}
230 return nil
231 }
232
233 func (s *Server) delConnection(c *serverConn) {
234 s.mu.Lock()
235 defer s.mu.Unlock()
236
237 delete(s.connections, c)
238 }
239
240 func (s *Server) countConnection() int {
241 s.mu.Lock()
242 defer s.mu.Unlock()
243
244 return len(s.connections)
245 }
246
247 func (s *Server) closeIdleConns() {
248 s.mu.Lock()
249 defer s.mu.Unlock()
250
251 for c := range s.connections {
252 if st, ok := c.getState(); !ok || st == connStateActive {
253 continue
254 }
255 c.close()
256 delete(s.connections, c)
257 }
258 }
259
260 type connState int
261
262 const (
263 connStateActive = iota + 1
264 connStateIdle
265 connStateClosed
266 )
267
268 func (cs connState) String() string {
269 switch cs {
270 case connStateActive:
271 return "active"
272 case connStateIdle:
273 return "idle"
274 case connStateClosed:
275 return "closed"
276 default:
277 return "unknown"
278 }
279 }
280
281 func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
282 c := &serverConn{
283 server: s,
284 conn: conn,
285 handshake: handshake,
286 shutdown: make(chan struct{}),
287 }
288 c.setState(connStateIdle)
289 if err := s.addConnection(c); err != nil {
290 c.close()
291 return nil, err
292 }
293 return c, nil
294 }
295
296 type serverConn struct {
297 server *Server
298 conn net.Conn
299 handshake interface{}
300 state atomic.Value
301
302 shutdownOnce sync.Once
303 shutdown chan struct{}
304 }
305
306 func (c *serverConn) getState() (connState, bool) {
307 cs, ok := c.state.Load().(connState)
308 return cs, ok
309 }
310
311 func (c *serverConn) setState(newstate connState) {
312 c.state.Store(newstate)
313 }
314
315 func (c *serverConn) close() error {
316 c.shutdownOnce.Do(func() {
317 close(c.shutdown)
318 })
319
320 return nil
321 }
322
323 func (c *serverConn) run(sctx context.Context) {
324 type (
325 response struct {
326 id uint32
327 status *status.Status
328 data []byte
329 closeStream bool
330 streaming bool
331 }
332 )
333
334 var (
335 ch = newChannel(c.conn)
336 ctx, cancel = context.WithCancel(sctx)
337 state connState = connStateIdle
338 responses = make(chan response)
339 recvErr = make(chan error, 1)
340 done = make(chan struct{})
341 streams = sync.Map{}
342 active int32
343 lastStreamID uint32
344 )
345
346 defer c.conn.Close()
347 defer cancel()
348 defer close(done)
349 defer c.server.delConnection(c)
350
351 sendStatus := func(id uint32, st *status.Status) bool {
352 select {
353 case responses <- response{
354
355
356
357 id: id,
358 status: st,
359 closeStream: true,
360 }:
361 return true
362 case <-c.shutdown:
363 return false
364 case <-done:
365 return false
366 }
367 }
368
369 go func(recvErr chan error) {
370 defer close(recvErr)
371 for {
372 select {
373 case <-c.shutdown:
374 return
375 case <-done:
376 return
377 default:
378 }
379
380 mh, p, err := ch.recv()
381 if err != nil {
382 status, ok := status.FromError(err)
383 if !ok {
384 recvErr <- err
385 return
386 }
387
388
389
390 if !sendStatus(mh.StreamID, status) {
391 return
392 }
393
394 continue
395 }
396
397 if mh.StreamID%2 != 1 {
398
399 if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
400 return
401 }
402 continue
403 }
404
405 if mh.Type == messageTypeData {
406 i, ok := streams.Load(mh.StreamID)
407 if !ok {
408 if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID is no longer active")) {
409 return
410 }
411 }
412 sh := i.(*streamHandler)
413 if mh.Flags&flagNoData != flagNoData {
414 unmarshal := func(obj interface{}) error {
415 err := protoUnmarshal(p, obj)
416 ch.putmbuf(p)
417 return err
418 }
419
420 if err := sh.data(unmarshal); err != nil {
421 if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data handling error: %v", err)) {
422 return
423 }
424 }
425 }
426
427 if mh.Flags&flagRemoteClosed == flagRemoteClosed {
428 sh.closeSend()
429 if len(p) > 0 {
430 if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data close message cannot include data")) {
431 return
432 }
433 }
434 }
435 } else if mh.Type == messageTypeRequest {
436 if mh.StreamID <= lastStreamID {
437
438 if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID cannot be re-used and must increment")) {
439 return
440 }
441 continue
442
443 }
444 lastStreamID = mh.StreamID
445
446
447
448 var req Request
449 if err := c.server.codec.Unmarshal(p, &req); err != nil {
450 ch.putmbuf(p)
451 if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
452 return
453 }
454 continue
455 }
456 ch.putmbuf(p)
457
458 id := mh.StreamID
459 respond := func(status *status.Status, data []byte, streaming, closeStream bool) error {
460 select {
461 case responses <- response{
462 id: id,
463 status: status,
464 data: data,
465 closeStream: closeStream,
466 streaming: streaming,
467 }:
468 case <-done:
469 return ErrClosed
470 }
471 return nil
472 }
473 sh, err := c.server.services.handle(ctx, &req, respond)
474 if err != nil {
475 status, _ := status.FromError(err)
476 if !sendStatus(mh.StreamID, status) {
477 return
478 }
479 continue
480 }
481
482 streams.Store(id, sh)
483 atomic.AddInt32(&active, 1)
484 }
485
486 }
487 }(recvErr)
488
489 for {
490 var (
491 newstate connState
492 shutdown chan struct{}
493 )
494
495 activeN := atomic.LoadInt32(&active)
496 if activeN > 0 {
497 newstate = connStateActive
498 shutdown = nil
499 } else {
500 newstate = connStateIdle
501 shutdown = c.shutdown
502 }
503 if newstate != state {
504 c.setState(newstate)
505 state = newstate
506 }
507
508 select {
509 case response := <-responses:
510 if !response.streaming || response.status.Code() != codes.OK {
511 p, err := c.server.codec.Marshal(&Response{
512 Status: response.status.Proto(),
513 Payload: response.data,
514 })
515 if err != nil {
516 logrus.WithError(err).Error("failed marshaling response")
517 return
518 }
519
520 if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
521 logrus.WithError(err).Error("failed sending message on channel")
522 return
523 }
524 } else {
525 var flags uint8
526 if response.closeStream {
527 flags = flagRemoteClosed
528 }
529 if response.data == nil {
530 flags = flags | flagNoData
531 }
532 if err := ch.send(response.id, messageTypeData, flags, response.data); err != nil {
533 logrus.WithError(err).Error("failed sending message on channel")
534 return
535 }
536 }
537
538 if response.closeStream {
539
540
541
542 streams.Delete(response.id)
543 atomic.AddInt32(&active, -1)
544 }
545 case err := <-recvErr:
546
547
548
549 recvErr = nil
550 if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, syscall.ECONNRESET) {
551
552
553 return
554 }
555 logrus.WithError(err).Error("error receiving message")
556
557 case <-shutdown:
558 return
559 }
560 }
561 }
562
563 var noopFunc = func() {}
564
565 func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
566 if len(req.Metadata) > 0 {
567 md := MD{}
568 md.fromRequest(req)
569 ctx = WithMetadata(ctx, md)
570 }
571
572 cancel = noopFunc
573 if req.TimeoutNano == 0 {
574 return ctx, cancel
575 }
576
577 ctx, cancel = context.WithTimeout(ctx, time.Duration(req.TimeoutNano))
578 return ctx, cancel
579 }
580
View as plain text