1
16
17 package ttrpc
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "io"
24 "os"
25 "path"
26 "unsafe"
27
28 "google.golang.org/grpc/codes"
29 "google.golang.org/grpc/status"
30 "google.golang.org/protobuf/proto"
31 )
32
33 type Method func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error)
34
35 type StreamHandler func(context.Context, StreamServer) (interface{}, error)
36
37 type Stream struct {
38 Handler StreamHandler
39 StreamingClient bool
40 StreamingServer bool
41 }
42
43 type ServiceDesc struct {
44 Methods map[string]Method
45 Streams map[string]Stream
46 }
47
48 type serviceSet struct {
49 services map[string]*ServiceDesc
50 unaryInterceptor UnaryServerInterceptor
51 streamInterceptor StreamServerInterceptor
52 }
53
54 func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet {
55 return &serviceSet{
56 services: make(map[string]*ServiceDesc),
57 unaryInterceptor: interceptor,
58 streamInterceptor: defaultStreamServerInterceptor,
59 }
60 }
61
62 func (s *serviceSet) register(name string, desc *ServiceDesc) {
63 if _, ok := s.services[name]; ok {
64 panic(fmt.Errorf("duplicate service %v registered", name))
65 }
66
67 s.services[name] = desc
68 }
69
70 func (s *serviceSet) unaryCall(ctx context.Context, method Method, info *UnaryServerInfo, data []byte) (p []byte, st *status.Status) {
71 unmarshal := func(obj interface{}) error {
72 return protoUnmarshal(data, obj)
73 }
74
75 resp, err := s.unaryInterceptor(ctx, unmarshal, info, method)
76 if err == nil {
77 if isNil(resp) {
78 err = errors.New("ttrpc: marshal called with nil")
79 } else {
80 p, err = protoMarshal(resp)
81 }
82 }
83
84 st, ok := status.FromError(err)
85 if !ok {
86 st = status.New(convertCode(err), err.Error())
87 }
88
89 return p, st
90 }
91
92 func (s *serviceSet) streamCall(ctx context.Context, stream StreamHandler, info *StreamServerInfo, ss StreamServer) (p []byte, st *status.Status) {
93 resp, err := s.streamInterceptor(ctx, ss, info, stream)
94 if err == nil {
95 p, err = protoMarshal(resp)
96 }
97 st, ok := status.FromError(err)
98 if !ok {
99 st = status.New(convertCode(err), err.Error())
100 }
101 return
102 }
103
104 func (s *serviceSet) handle(ctx context.Context, req *Request, respond func(*status.Status, []byte, bool, bool) error) (*streamHandler, error) {
105 srv, ok := s.services[req.Service]
106 if !ok {
107 return nil, status.Errorf(codes.Unimplemented, "service %v", req.Service)
108 }
109
110 if method, ok := srv.Methods[req.Method]; ok {
111 go func() {
112 ctx, cancel := getRequestContext(ctx, req)
113 defer cancel()
114
115 info := &UnaryServerInfo{
116 FullMethod: fullPath(req.Service, req.Method),
117 }
118 p, st := s.unaryCall(ctx, method, info, req.Payload)
119
120 respond(st, p, false, true)
121 }()
122 return nil, nil
123 }
124 if stream, ok := srv.Streams[req.Method]; ok {
125 ctx, cancel := getRequestContext(ctx, req)
126 info := &StreamServerInfo{
127 FullMethod: fullPath(req.Service, req.Method),
128 StreamingClient: stream.StreamingClient,
129 StreamingServer: stream.StreamingServer,
130 }
131 sh := &streamHandler{
132 ctx: ctx,
133 respond: respond,
134 recv: make(chan Unmarshaler, 5),
135 info: info,
136 }
137 go func() {
138 defer cancel()
139 p, st := s.streamCall(ctx, stream.Handler, info, sh)
140 respond(st, p, stream.StreamingServer, true)
141 }()
142
143 if req.Payload != nil {
144 unmarshal := func(obj interface{}) error {
145 return protoUnmarshal(req.Payload, obj)
146 }
147 if err := sh.data(unmarshal); err != nil {
148 return nil, err
149 }
150 }
151
152 return sh, nil
153 }
154 return nil, status.Errorf(codes.Unimplemented, "method %v", req.Method)
155 }
156
157 type streamHandler struct {
158 ctx context.Context
159 respond func(*status.Status, []byte, bool, bool) error
160 recv chan Unmarshaler
161 info *StreamServerInfo
162
163 remoteClosed bool
164 localClosed bool
165 }
166
167 func (s *streamHandler) closeSend() {
168 if !s.remoteClosed {
169 s.remoteClosed = true
170 close(s.recv)
171 }
172 }
173
174 func (s *streamHandler) data(unmarshal Unmarshaler) error {
175 if s.remoteClosed {
176 return ErrStreamClosed
177 }
178 select {
179 case s.recv <- unmarshal:
180 return nil
181 case <-s.ctx.Done():
182 return s.ctx.Err()
183 }
184 }
185
186 func (s *streamHandler) SendMsg(m interface{}) error {
187 if s.localClosed {
188 return ErrStreamClosed
189 }
190 p, err := protoMarshal(m)
191 if err != nil {
192 return err
193 }
194 return s.respond(nil, p, true, false)
195 }
196
197 func (s *streamHandler) RecvMsg(m interface{}) error {
198 select {
199 case unmarshal, ok := <-s.recv:
200 if !ok {
201 return io.EOF
202 }
203 return unmarshal(m)
204 case <-s.ctx.Done():
205 return s.ctx.Err()
206
207 }
208 }
209
210 func protoUnmarshal(p []byte, obj interface{}) error {
211 switch v := obj.(type) {
212 case proto.Message:
213 if err := proto.Unmarshal(p, v); err != nil {
214 return status.Errorf(codes.Internal, "ttrpc: error unmarshalling payload: %v", err.Error())
215 }
216 default:
217 return status.Errorf(codes.Internal, "ttrpc: error unsupported request type: %T", v)
218 }
219 return nil
220 }
221
222 func protoMarshal(obj interface{}) ([]byte, error) {
223 if obj == nil {
224 return nil, nil
225 }
226
227 switch v := obj.(type) {
228 case proto.Message:
229 r, err := proto.Marshal(v)
230 if err != nil {
231 return nil, status.Errorf(codes.Internal, "ttrpc: error marshaling payload: %v", err.Error())
232 }
233
234 return r, nil
235 default:
236 return nil, status.Errorf(codes.Internal, "ttrpc: error unsupported response type: %T", v)
237 }
238 }
239
240
241
242
243 func convertCode(err error) codes.Code {
244 switch err {
245 case nil:
246 return codes.OK
247 case io.EOF:
248 return codes.OutOfRange
249 case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF:
250 return codes.FailedPrecondition
251 case os.ErrInvalid:
252 return codes.InvalidArgument
253 case context.Canceled:
254 return codes.Canceled
255 case context.DeadlineExceeded:
256 return codes.DeadlineExceeded
257 }
258 switch {
259 case os.IsExist(err):
260 return codes.AlreadyExists
261 case os.IsNotExist(err):
262 return codes.NotFound
263 case os.IsPermission(err):
264 return codes.PermissionDenied
265 }
266 return codes.Unknown
267 }
268
269 func fullPath(service, method string) string {
270 return "/" + path.Join(service, method)
271 }
272
273 func isNil(resp interface{}) bool {
274 return (*[2]uintptr)(unsafe.Pointer(&resp))[1] == 0
275 }
276
View as plain text