1
18
19
20
21
22 package internal
23
24 import (
25 "io"
26 "sort"
27
28 "google.golang.org/grpc"
29 "google.golang.org/grpc/codes"
30 "google.golang.org/grpc/status"
31 "google.golang.org/protobuf/proto"
32 "google.golang.org/protobuf/reflect/protodesc"
33 "google.golang.org/protobuf/reflect/protoreflect"
34 "google.golang.org/protobuf/reflect/protoregistry"
35
36 v1reflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1"
37 v1reflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1"
38 v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
39 v1alphareflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
40 )
41
42
43
44 type ServiceInfoProvider interface {
45 GetServiceInfo() map[string]grpc.ServiceInfo
46 }
47
48
49
50 type ExtensionResolver interface {
51 protoregistry.ExtensionTypeResolver
52 RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool)
53 }
54
55
56 type ServerReflectionServer struct {
57 v1alphareflectiongrpc.UnimplementedServerReflectionServer
58 S ServiceInfoProvider
59 DescResolver protodesc.Resolver
60 ExtResolver ExtensionResolver
61 }
62
63
64
65
66 func (s *ServerReflectionServer) FileDescWithDependencies(fd protoreflect.FileDescriptor, sentFileDescriptors map[string]bool) ([][]byte, error) {
67 if fd.IsPlaceholder() {
68
69
70 return nil, protoregistry.NotFound
71 }
72 var r [][]byte
73 queue := []protoreflect.FileDescriptor{fd}
74 for len(queue) > 0 {
75 currentfd := queue[0]
76 queue = queue[1:]
77 if currentfd.IsPlaceholder() {
78
79 continue
80 }
81 if sent := sentFileDescriptors[currentfd.Path()]; len(r) == 0 || !sent {
82 sentFileDescriptors[currentfd.Path()] = true
83 fdProto := protodesc.ToFileDescriptorProto(currentfd)
84 currentfdEncoded, err := proto.Marshal(fdProto)
85 if err != nil {
86 return nil, err
87 }
88 r = append(r, currentfdEncoded)
89 }
90 for i := 0; i < currentfd.Imports().Len(); i++ {
91 queue = append(queue, currentfd.Imports().Get(i))
92 }
93 }
94 return r, nil
95 }
96
97
98
99
100
101 func (s *ServerReflectionServer) FileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
102 d, err := s.DescResolver.FindDescriptorByName(protoreflect.FullName(name))
103 if err != nil {
104 return nil, err
105 }
106 return s.FileDescWithDependencies(d.ParentFile(), sentFileDescriptors)
107 }
108
109
110
111
112 func (s *ServerReflectionServer) FileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
113 xt, err := s.ExtResolver.FindExtensionByNumber(protoreflect.FullName(typeName), protoreflect.FieldNumber(extNum))
114 if err != nil {
115 return nil, err
116 }
117 return s.FileDescWithDependencies(xt.TypeDescriptor().ParentFile(), sentFileDescriptors)
118 }
119
120
121 func (s *ServerReflectionServer) AllExtensionNumbersForTypeName(name string) ([]int32, error) {
122 var numbers []int32
123 s.ExtResolver.RangeExtensionsByMessage(protoreflect.FullName(name), func(xt protoreflect.ExtensionType) bool {
124 numbers = append(numbers, int32(xt.TypeDescriptor().Number()))
125 return true
126 })
127 sort.Slice(numbers, func(i, j int) bool {
128 return numbers[i] < numbers[j]
129 })
130 if len(numbers) == 0 {
131
132 if _, err := s.DescResolver.FindDescriptorByName(protoreflect.FullName(name)); err != nil {
133 return nil, err
134 }
135 }
136 return numbers, nil
137 }
138
139
140 func (s *ServerReflectionServer) ListServices() []*v1reflectionpb.ServiceResponse {
141 serviceInfo := s.S.GetServiceInfo()
142 resp := make([]*v1reflectionpb.ServiceResponse, 0, len(serviceInfo))
143 for svc := range serviceInfo {
144 resp = append(resp, &v1reflectionpb.ServiceResponse{Name: svc})
145 }
146 sort.Slice(resp, func(i, j int) bool {
147 return resp[i].Name < resp[j].Name
148 })
149 return resp
150 }
151
152
153 func (s *ServerReflectionServer) ServerReflectionInfo(stream v1reflectiongrpc.ServerReflection_ServerReflectionInfoServer) error {
154 sentFileDescriptors := make(map[string]bool)
155 for {
156 in, err := stream.Recv()
157 if err == io.EOF {
158 return nil
159 }
160 if err != nil {
161 return err
162 }
163
164 out := &v1reflectionpb.ServerReflectionResponse{
165 ValidHost: in.Host,
166 OriginalRequest: in,
167 }
168 switch req := in.MessageRequest.(type) {
169 case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
170 var b [][]byte
171 fd, err := s.DescResolver.FindFileByPath(req.FileByFilename)
172 if err == nil {
173 b, err = s.FileDescWithDependencies(fd, sentFileDescriptors)
174 }
175 if err != nil {
176 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
177 ErrorResponse: &v1reflectionpb.ErrorResponse{
178 ErrorCode: int32(codes.NotFound),
179 ErrorMessage: err.Error(),
180 },
181 }
182 } else {
183 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
184 FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
185 }
186 }
187 case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
188 b, err := s.FileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
189 if err != nil {
190 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
191 ErrorResponse: &v1reflectionpb.ErrorResponse{
192 ErrorCode: int32(codes.NotFound),
193 ErrorMessage: err.Error(),
194 },
195 }
196 } else {
197 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
198 FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
199 }
200 }
201 case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
202 typeName := req.FileContainingExtension.ContainingType
203 extNum := req.FileContainingExtension.ExtensionNumber
204 b, err := s.FileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
205 if err != nil {
206 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
207 ErrorResponse: &v1reflectionpb.ErrorResponse{
208 ErrorCode: int32(codes.NotFound),
209 ErrorMessage: err.Error(),
210 },
211 }
212 } else {
213 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
214 FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
215 }
216 }
217 case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
218 extNums, err := s.AllExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
219 if err != nil {
220 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
221 ErrorResponse: &v1reflectionpb.ErrorResponse{
222 ErrorCode: int32(codes.NotFound),
223 ErrorMessage: err.Error(),
224 },
225 }
226 } else {
227 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
228 AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
229 BaseTypeName: req.AllExtensionNumbersOfType,
230 ExtensionNumber: extNums,
231 },
232 }
233 }
234 case *v1reflectionpb.ServerReflectionRequest_ListServices:
235 out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
236 ListServicesResponse: &v1reflectionpb.ListServiceResponse{
237 Service: s.ListServices(),
238 },
239 }
240 default:
241 return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
242 }
243
244 if err := stream.Send(out); err != nil {
245 return err
246 }
247 }
248 }
249
250
251 func V1ToV1AlphaResponse(v1 *v1reflectionpb.ServerReflectionResponse) *v1alphareflectionpb.ServerReflectionResponse {
252 var v1alpha v1alphareflectionpb.ServerReflectionResponse
253 v1alpha.ValidHost = v1.ValidHost
254 if v1.OriginalRequest != nil {
255 v1alpha.OriginalRequest = V1ToV1AlphaRequest(v1.OriginalRequest)
256 }
257 switch mr := v1.MessageResponse.(type) {
258 case *v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse:
259 if mr != nil {
260 v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse{
261 FileDescriptorResponse: &v1alphareflectionpb.FileDescriptorResponse{
262 FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
263 },
264 }
265 }
266 case *v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
267 if mr != nil {
268 v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
269 AllExtensionNumbersResponse: &v1alphareflectionpb.ExtensionNumberResponse{
270 BaseTypeName: mr.AllExtensionNumbersResponse.GetBaseTypeName(),
271 ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
272 },
273 }
274 }
275 case *v1reflectionpb.ServerReflectionResponse_ListServicesResponse:
276 if mr != nil {
277 svcs := make([]*v1alphareflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
278 for i, svc := range mr.ListServicesResponse.GetService() {
279 svcs[i] = &v1alphareflectionpb.ServiceResponse{
280 Name: svc.GetName(),
281 }
282 }
283 v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse{
284 ListServicesResponse: &v1alphareflectionpb.ListServiceResponse{
285 Service: svcs,
286 },
287 }
288 }
289 case *v1reflectionpb.ServerReflectionResponse_ErrorResponse:
290 if mr != nil {
291 v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ErrorResponse{
292 ErrorResponse: &v1alphareflectionpb.ErrorResponse{
293 ErrorCode: mr.ErrorResponse.GetErrorCode(),
294 ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
295 },
296 }
297 }
298 default:
299
300 }
301 return &v1alpha
302 }
303
304
305 func V1AlphaToV1Request(v1alpha *v1alphareflectionpb.ServerReflectionRequest) *v1reflectionpb.ServerReflectionRequest {
306 var v1 v1reflectionpb.ServerReflectionRequest
307 v1.Host = v1alpha.Host
308 switch mr := v1alpha.MessageRequest.(type) {
309 case *v1alphareflectionpb.ServerReflectionRequest_FileByFilename:
310 v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileByFilename{
311 FileByFilename: mr.FileByFilename,
312 }
313 case *v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol:
314 v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingSymbol{
315 FileContainingSymbol: mr.FileContainingSymbol,
316 }
317 case *v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension:
318 if mr.FileContainingExtension != nil {
319 v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingExtension{
320 FileContainingExtension: &v1reflectionpb.ExtensionRequest{
321 ContainingType: mr.FileContainingExtension.GetContainingType(),
322 ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
323 },
324 }
325 }
326 case *v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
327 v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
328 AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
329 }
330 case *v1alphareflectionpb.ServerReflectionRequest_ListServices:
331 v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_ListServices{
332 ListServices: mr.ListServices,
333 }
334 default:
335
336 }
337 return &v1
338 }
339
340
341 func V1ToV1AlphaRequest(v1 *v1reflectionpb.ServerReflectionRequest) *v1alphareflectionpb.ServerReflectionRequest {
342 var v1alpha v1alphareflectionpb.ServerReflectionRequest
343 v1alpha.Host = v1.Host
344 switch mr := v1.MessageRequest.(type) {
345 case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
346 if mr != nil {
347 v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileByFilename{
348 FileByFilename: mr.FileByFilename,
349 }
350 }
351 case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
352 if mr != nil {
353 v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol{
354 FileContainingSymbol: mr.FileContainingSymbol,
355 }
356 }
357 case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
358 if mr != nil {
359 v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension{
360 FileContainingExtension: &v1alphareflectionpb.ExtensionRequest{
361 ContainingType: mr.FileContainingExtension.GetContainingType(),
362 ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
363 },
364 }
365 }
366 case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
367 if mr != nil {
368 v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
369 AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
370 }
371 }
372 case *v1reflectionpb.ServerReflectionRequest_ListServices:
373 if mr != nil {
374 v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_ListServices{
375 ListServices: mr.ListServices,
376 }
377 }
378 default:
379
380 }
381 return &v1alpha
382 }
383
384
385 func V1AlphaToV1Response(v1alpha *v1alphareflectionpb.ServerReflectionResponse) *v1reflectionpb.ServerReflectionResponse {
386 var v1 v1reflectionpb.ServerReflectionResponse
387 v1.ValidHost = v1alpha.ValidHost
388 if v1alpha.OriginalRequest != nil {
389 v1.OriginalRequest = V1AlphaToV1Request(v1alpha.OriginalRequest)
390 }
391 switch mr := v1alpha.MessageResponse.(type) {
392 case *v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse:
393 if mr != nil {
394 v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
395 FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{
396 FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
397 },
398 }
399 }
400 case *v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
401 if mr != nil {
402 v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
403 AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
404 BaseTypeName: mr.AllExtensionNumbersResponse.GetBaseTypeName(),
405 ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
406 },
407 }
408 }
409 case *v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse:
410 if mr != nil {
411 svcs := make([]*v1reflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
412 for i, svc := range mr.ListServicesResponse.GetService() {
413 svcs[i] = &v1reflectionpb.ServiceResponse{
414 Name: svc.GetName(),
415 }
416 }
417 v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
418 ListServicesResponse: &v1reflectionpb.ListServiceResponse{
419 Service: svcs,
420 },
421 }
422 }
423 case *v1alphareflectionpb.ServerReflectionResponse_ErrorResponse:
424 if mr != nil {
425 v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
426 ErrorResponse: &v1reflectionpb.ErrorResponse{
427 ErrorCode: mr.ErrorResponse.GetErrorCode(),
428 ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
429 },
430 }
431 }
432 default:
433
434 }
435 return &v1
436 }
437
View as plain text