1
16
17 package ttrpc
18
19 import (
20 "bufio"
21 "encoding/binary"
22 "fmt"
23 "io"
24 "net"
25 "sync"
26
27 "google.golang.org/grpc/codes"
28 "google.golang.org/grpc/status"
29 )
30
31 const (
32 messageHeaderLength = 10
33 messageLengthMax = 4 << 20
34 )
35
36 type messageType uint8
37
38 const (
39 messageTypeRequest messageType = 0x1
40 messageTypeResponse messageType = 0x2
41 messageTypeData messageType = 0x3
42 )
43
44 func (mt messageType) String() string {
45 switch mt {
46 case messageTypeRequest:
47 return "request"
48 case messageTypeResponse:
49 return "response"
50 case messageTypeData:
51 return "data"
52 default:
53 return "unknown"
54 }
55 }
56
57 const (
58 flagRemoteClosed uint8 = 0x1
59 flagRemoteOpen uint8 = 0x2
60 flagNoData uint8 = 0x4
61 )
62
63
64
65 type messageHeader struct {
66 Length uint32
67 StreamID uint32
68 Type messageType
69 Flags uint8
70 }
71
72 func readMessageHeader(p []byte, r io.Reader) (messageHeader, error) {
73 _, err := io.ReadFull(r, p[:messageHeaderLength])
74 if err != nil {
75 return messageHeader{}, err
76 }
77
78 return messageHeader{
79 Length: binary.BigEndian.Uint32(p[:4]),
80 StreamID: binary.BigEndian.Uint32(p[4:8]),
81 Type: messageType(p[8]),
82 Flags: p[9],
83 }, nil
84 }
85
86 func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
87 binary.BigEndian.PutUint32(p[:4], mh.Length)
88 binary.BigEndian.PutUint32(p[4:8], mh.StreamID)
89 p[8] = byte(mh.Type)
90 p[9] = mh.Flags
91
92 _, err := w.Write(p[:])
93 return err
94 }
95
96 var buffers sync.Pool
97
98 type channel struct {
99 conn net.Conn
100 bw *bufio.Writer
101 br *bufio.Reader
102 hrbuf [messageHeaderLength]byte
103 hwbuf [messageHeaderLength]byte
104 }
105
106 func newChannel(conn net.Conn) *channel {
107 return &channel{
108 conn: conn,
109 bw: bufio.NewWriter(conn),
110 br: bufio.NewReader(conn),
111 }
112 }
113
114
115
116
117
118
119
120 func (ch *channel) recv() (messageHeader, []byte, error) {
121 mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
122 if err != nil {
123 return messageHeader{}, nil, err
124 }
125
126 if mh.Length > uint32(messageLengthMax) {
127 if _, err := ch.br.Discard(int(mh.Length)); err != nil {
128 return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err)
129 }
130
131 return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
132 }
133
134 var p []byte
135 if mh.Length > 0 {
136 p = ch.getmbuf(int(mh.Length))
137 if _, err := io.ReadFull(ch.br, p); err != nil {
138 return messageHeader{}, nil, fmt.Errorf("failed reading message: %w", err)
139 }
140 }
141
142 return mh, p, nil
143 }
144
145 func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
146
147
148
149
150 if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
151 return err
152 }
153
154 if len(p) > 0 {
155 _, err := ch.bw.Write(p)
156 if err != nil {
157 return err
158 }
159 }
160
161 return ch.bw.Flush()
162 }
163
164 func (ch *channel) getmbuf(size int) []byte {
165
166
167 b, ok := buffers.Get().(*[]byte)
168 if !ok || cap(*b) < size {
169
170
171
172 bb := make([]byte, size)
173 b = &bb
174 } else {
175 *b = (*b)[:size]
176 }
177 return *b
178 }
179
180 func (ch *channel) putmbuf(p []byte) {
181 buffers.Put(&p)
182 }
183
View as plain text