...

Source file src/google.golang.org/grpc/reflection/internal/internal.go

Documentation: google.golang.org/grpc/reflection/internal

     1  /*
     2   *
     3   * Copyright 2024 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package internal contains code that is shared by both reflection package and
    20  // the test package. The packages are split in this way inorder to avoid
    21  // depenedency to deprecated package github.com/golang/protobuf.
    22  package internal
    23  
    24  import (
    25  	"io"
    26  	"sort"
    27  
    28  	"google.golang.org/grpc"
    29  	"google.golang.org/grpc/codes"
    30  	"google.golang.org/grpc/status"
    31  	"google.golang.org/protobuf/proto"
    32  	"google.golang.org/protobuf/reflect/protodesc"
    33  	"google.golang.org/protobuf/reflect/protoreflect"
    34  	"google.golang.org/protobuf/reflect/protoregistry"
    35  
    36  	v1reflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1"
    37  	v1reflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1"
    38  	v1alphareflectiongrpc "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
    39  	v1alphareflectionpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
    40  )
    41  
    42  // ServiceInfoProvider is an interface used to retrieve metadata about the
    43  // services to expose.
    44  type ServiceInfoProvider interface {
    45  	GetServiceInfo() map[string]grpc.ServiceInfo
    46  }
    47  
    48  // ExtensionResolver is the interface used to query details about extensions.
    49  // This interface is satisfied by protoregistry.GlobalTypes.
    50  type ExtensionResolver interface {
    51  	protoregistry.ExtensionTypeResolver
    52  	RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool)
    53  }
    54  
    55  // ServerReflectionServer is the server API for ServerReflection service.
    56  type ServerReflectionServer struct {
    57  	v1alphareflectiongrpc.UnimplementedServerReflectionServer
    58  	S            ServiceInfoProvider
    59  	DescResolver protodesc.Resolver
    60  	ExtResolver  ExtensionResolver
    61  }
    62  
    63  // FileDescWithDependencies returns a slice of serialized fileDescriptors in
    64  // wire format ([]byte). The fileDescriptors will include fd and all the
    65  // transitive dependencies of fd with names not in sentFileDescriptors.
    66  func (s *ServerReflectionServer) FileDescWithDependencies(fd protoreflect.FileDescriptor, sentFileDescriptors map[string]bool) ([][]byte, error) {
    67  	if fd.IsPlaceholder() {
    68  		// If the given root file is a placeholder, treat it
    69  		// as missing instead of serializing it.
    70  		return nil, protoregistry.NotFound
    71  	}
    72  	var r [][]byte
    73  	queue := []protoreflect.FileDescriptor{fd}
    74  	for len(queue) > 0 {
    75  		currentfd := queue[0]
    76  		queue = queue[1:]
    77  		if currentfd.IsPlaceholder() {
    78  			// Skip any missing files in the dependency graph.
    79  			continue
    80  		}
    81  		if sent := sentFileDescriptors[currentfd.Path()]; len(r) == 0 || !sent {
    82  			sentFileDescriptors[currentfd.Path()] = true
    83  			fdProto := protodesc.ToFileDescriptorProto(currentfd)
    84  			currentfdEncoded, err := proto.Marshal(fdProto)
    85  			if err != nil {
    86  				return nil, err
    87  			}
    88  			r = append(r, currentfdEncoded)
    89  		}
    90  		for i := 0; i < currentfd.Imports().Len(); i++ {
    91  			queue = append(queue, currentfd.Imports().Get(i))
    92  		}
    93  	}
    94  	return r, nil
    95  }
    96  
    97  // FileDescEncodingContainingSymbol finds the file descriptor containing the
    98  // given symbol, finds all of its previously unsent transitive dependencies,
    99  // does marshalling on them, and returns the marshalled result. The given symbol
   100  // can be a type, a service or a method.
   101  func (s *ServerReflectionServer) FileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
   102  	d, err := s.DescResolver.FindDescriptorByName(protoreflect.FullName(name))
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	return s.FileDescWithDependencies(d.ParentFile(), sentFileDescriptors)
   107  }
   108  
   109  // FileDescEncodingContainingExtension finds the file descriptor containing
   110  // given extension, finds all of its previously unsent transitive dependencies,
   111  // does marshalling on them, and returns the marshalled result.
   112  func (s *ServerReflectionServer) FileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
   113  	xt, err := s.ExtResolver.FindExtensionByNumber(protoreflect.FullName(typeName), protoreflect.FieldNumber(extNum))
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	return s.FileDescWithDependencies(xt.TypeDescriptor().ParentFile(), sentFileDescriptors)
   118  }
   119  
   120  // AllExtensionNumbersForTypeName returns all extension numbers for the given type.
   121  func (s *ServerReflectionServer) AllExtensionNumbersForTypeName(name string) ([]int32, error) {
   122  	var numbers []int32
   123  	s.ExtResolver.RangeExtensionsByMessage(protoreflect.FullName(name), func(xt protoreflect.ExtensionType) bool {
   124  		numbers = append(numbers, int32(xt.TypeDescriptor().Number()))
   125  		return true
   126  	})
   127  	sort.Slice(numbers, func(i, j int) bool {
   128  		return numbers[i] < numbers[j]
   129  	})
   130  	if len(numbers) == 0 {
   131  		// maybe return an error if given type name is not known
   132  		if _, err := s.DescResolver.FindDescriptorByName(protoreflect.FullName(name)); err != nil {
   133  			return nil, err
   134  		}
   135  	}
   136  	return numbers, nil
   137  }
   138  
   139  // ListServices returns the names of services this server exposes.
   140  func (s *ServerReflectionServer) ListServices() []*v1reflectionpb.ServiceResponse {
   141  	serviceInfo := s.S.GetServiceInfo()
   142  	resp := make([]*v1reflectionpb.ServiceResponse, 0, len(serviceInfo))
   143  	for svc := range serviceInfo {
   144  		resp = append(resp, &v1reflectionpb.ServiceResponse{Name: svc})
   145  	}
   146  	sort.Slice(resp, func(i, j int) bool {
   147  		return resp[i].Name < resp[j].Name
   148  	})
   149  	return resp
   150  }
   151  
   152  // ServerReflectionInfo is the reflection service handler.
   153  func (s *ServerReflectionServer) ServerReflectionInfo(stream v1reflectiongrpc.ServerReflection_ServerReflectionInfoServer) error {
   154  	sentFileDescriptors := make(map[string]bool)
   155  	for {
   156  		in, err := stream.Recv()
   157  		if err == io.EOF {
   158  			return nil
   159  		}
   160  		if err != nil {
   161  			return err
   162  		}
   163  
   164  		out := &v1reflectionpb.ServerReflectionResponse{
   165  			ValidHost:       in.Host,
   166  			OriginalRequest: in,
   167  		}
   168  		switch req := in.MessageRequest.(type) {
   169  		case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
   170  			var b [][]byte
   171  			fd, err := s.DescResolver.FindFileByPath(req.FileByFilename)
   172  			if err == nil {
   173  				b, err = s.FileDescWithDependencies(fd, sentFileDescriptors)
   174  			}
   175  			if err != nil {
   176  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
   177  					ErrorResponse: &v1reflectionpb.ErrorResponse{
   178  						ErrorCode:    int32(codes.NotFound),
   179  						ErrorMessage: err.Error(),
   180  					},
   181  				}
   182  			} else {
   183  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
   184  					FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
   185  				}
   186  			}
   187  		case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
   188  			b, err := s.FileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
   189  			if err != nil {
   190  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
   191  					ErrorResponse: &v1reflectionpb.ErrorResponse{
   192  						ErrorCode:    int32(codes.NotFound),
   193  						ErrorMessage: err.Error(),
   194  					},
   195  				}
   196  			} else {
   197  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
   198  					FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
   199  				}
   200  			}
   201  		case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
   202  			typeName := req.FileContainingExtension.ContainingType
   203  			extNum := req.FileContainingExtension.ExtensionNumber
   204  			b, err := s.FileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
   205  			if err != nil {
   206  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
   207  					ErrorResponse: &v1reflectionpb.ErrorResponse{
   208  						ErrorCode:    int32(codes.NotFound),
   209  						ErrorMessage: err.Error(),
   210  					},
   211  				}
   212  			} else {
   213  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
   214  					FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{FileDescriptorProto: b},
   215  				}
   216  			}
   217  		case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
   218  			extNums, err := s.AllExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
   219  			if err != nil {
   220  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
   221  					ErrorResponse: &v1reflectionpb.ErrorResponse{
   222  						ErrorCode:    int32(codes.NotFound),
   223  						ErrorMessage: err.Error(),
   224  					},
   225  				}
   226  			} else {
   227  				out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
   228  					AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
   229  						BaseTypeName:    req.AllExtensionNumbersOfType,
   230  						ExtensionNumber: extNums,
   231  					},
   232  				}
   233  			}
   234  		case *v1reflectionpb.ServerReflectionRequest_ListServices:
   235  			out.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
   236  				ListServicesResponse: &v1reflectionpb.ListServiceResponse{
   237  					Service: s.ListServices(),
   238  				},
   239  			}
   240  		default:
   241  			return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
   242  		}
   243  
   244  		if err := stream.Send(out); err != nil {
   245  			return err
   246  		}
   247  	}
   248  }
   249  
   250  // V1ToV1AlphaResponse converts a v1 ServerReflectionResponse to a v1alpha.
   251  func V1ToV1AlphaResponse(v1 *v1reflectionpb.ServerReflectionResponse) *v1alphareflectionpb.ServerReflectionResponse {
   252  	var v1alpha v1alphareflectionpb.ServerReflectionResponse
   253  	v1alpha.ValidHost = v1.ValidHost
   254  	if v1.OriginalRequest != nil {
   255  		v1alpha.OriginalRequest = V1ToV1AlphaRequest(v1.OriginalRequest)
   256  	}
   257  	switch mr := v1.MessageResponse.(type) {
   258  	case *v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse:
   259  		if mr != nil {
   260  			v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse{
   261  				FileDescriptorResponse: &v1alphareflectionpb.FileDescriptorResponse{
   262  					FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
   263  				},
   264  			}
   265  		}
   266  	case *v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
   267  		if mr != nil {
   268  			v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
   269  				AllExtensionNumbersResponse: &v1alphareflectionpb.ExtensionNumberResponse{
   270  					BaseTypeName:    mr.AllExtensionNumbersResponse.GetBaseTypeName(),
   271  					ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
   272  				},
   273  			}
   274  		}
   275  	case *v1reflectionpb.ServerReflectionResponse_ListServicesResponse:
   276  		if mr != nil {
   277  			svcs := make([]*v1alphareflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
   278  			for i, svc := range mr.ListServicesResponse.GetService() {
   279  				svcs[i] = &v1alphareflectionpb.ServiceResponse{
   280  					Name: svc.GetName(),
   281  				}
   282  			}
   283  			v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse{
   284  				ListServicesResponse: &v1alphareflectionpb.ListServiceResponse{
   285  					Service: svcs,
   286  				},
   287  			}
   288  		}
   289  	case *v1reflectionpb.ServerReflectionResponse_ErrorResponse:
   290  		if mr != nil {
   291  			v1alpha.MessageResponse = &v1alphareflectionpb.ServerReflectionResponse_ErrorResponse{
   292  				ErrorResponse: &v1alphareflectionpb.ErrorResponse{
   293  					ErrorCode:    mr.ErrorResponse.GetErrorCode(),
   294  					ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
   295  				},
   296  			}
   297  		}
   298  	default:
   299  		// no value set
   300  	}
   301  	return &v1alpha
   302  }
   303  
   304  // V1AlphaToV1Request converts a v1alpha ServerReflectionRequest to a v1.
   305  func V1AlphaToV1Request(v1alpha *v1alphareflectionpb.ServerReflectionRequest) *v1reflectionpb.ServerReflectionRequest {
   306  	var v1 v1reflectionpb.ServerReflectionRequest
   307  	v1.Host = v1alpha.Host
   308  	switch mr := v1alpha.MessageRequest.(type) {
   309  	case *v1alphareflectionpb.ServerReflectionRequest_FileByFilename:
   310  		v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileByFilename{
   311  			FileByFilename: mr.FileByFilename,
   312  		}
   313  	case *v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol:
   314  		v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingSymbol{
   315  			FileContainingSymbol: mr.FileContainingSymbol,
   316  		}
   317  	case *v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension:
   318  		if mr.FileContainingExtension != nil {
   319  			v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_FileContainingExtension{
   320  				FileContainingExtension: &v1reflectionpb.ExtensionRequest{
   321  					ContainingType:  mr.FileContainingExtension.GetContainingType(),
   322  					ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
   323  				},
   324  			}
   325  		}
   326  	case *v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
   327  		v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
   328  			AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
   329  		}
   330  	case *v1alphareflectionpb.ServerReflectionRequest_ListServices:
   331  		v1.MessageRequest = &v1reflectionpb.ServerReflectionRequest_ListServices{
   332  			ListServices: mr.ListServices,
   333  		}
   334  	default:
   335  		// no value set
   336  	}
   337  	return &v1
   338  }
   339  
   340  // V1ToV1AlphaRequest converts a v1 ServerReflectionRequest to a v1alpha.
   341  func V1ToV1AlphaRequest(v1 *v1reflectionpb.ServerReflectionRequest) *v1alphareflectionpb.ServerReflectionRequest {
   342  	var v1alpha v1alphareflectionpb.ServerReflectionRequest
   343  	v1alpha.Host = v1.Host
   344  	switch mr := v1.MessageRequest.(type) {
   345  	case *v1reflectionpb.ServerReflectionRequest_FileByFilename:
   346  		if mr != nil {
   347  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileByFilename{
   348  				FileByFilename: mr.FileByFilename,
   349  			}
   350  		}
   351  	case *v1reflectionpb.ServerReflectionRequest_FileContainingSymbol:
   352  		if mr != nil {
   353  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingSymbol{
   354  				FileContainingSymbol: mr.FileContainingSymbol,
   355  			}
   356  		}
   357  	case *v1reflectionpb.ServerReflectionRequest_FileContainingExtension:
   358  		if mr != nil {
   359  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_FileContainingExtension{
   360  				FileContainingExtension: &v1alphareflectionpb.ExtensionRequest{
   361  					ContainingType:  mr.FileContainingExtension.GetContainingType(),
   362  					ExtensionNumber: mr.FileContainingExtension.GetExtensionNumber(),
   363  				},
   364  			}
   365  		}
   366  	case *v1reflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType:
   367  		if mr != nil {
   368  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_AllExtensionNumbersOfType{
   369  				AllExtensionNumbersOfType: mr.AllExtensionNumbersOfType,
   370  			}
   371  		}
   372  	case *v1reflectionpb.ServerReflectionRequest_ListServices:
   373  		if mr != nil {
   374  			v1alpha.MessageRequest = &v1alphareflectionpb.ServerReflectionRequest_ListServices{
   375  				ListServices: mr.ListServices,
   376  			}
   377  		}
   378  	default:
   379  		// no value set
   380  	}
   381  	return &v1alpha
   382  }
   383  
   384  // V1AlphaToV1Response converts a v1alpha ServerReflectionResponse to a v1.
   385  func V1AlphaToV1Response(v1alpha *v1alphareflectionpb.ServerReflectionResponse) *v1reflectionpb.ServerReflectionResponse {
   386  	var v1 v1reflectionpb.ServerReflectionResponse
   387  	v1.ValidHost = v1alpha.ValidHost
   388  	if v1alpha.OriginalRequest != nil {
   389  		v1.OriginalRequest = V1AlphaToV1Request(v1alpha.OriginalRequest)
   390  	}
   391  	switch mr := v1alpha.MessageResponse.(type) {
   392  	case *v1alphareflectionpb.ServerReflectionResponse_FileDescriptorResponse:
   393  		if mr != nil {
   394  			v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_FileDescriptorResponse{
   395  				FileDescriptorResponse: &v1reflectionpb.FileDescriptorResponse{
   396  					FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(),
   397  				},
   398  			}
   399  		}
   400  	case *v1alphareflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse:
   401  		if mr != nil {
   402  			v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_AllExtensionNumbersResponse{
   403  				AllExtensionNumbersResponse: &v1reflectionpb.ExtensionNumberResponse{
   404  					BaseTypeName:    mr.AllExtensionNumbersResponse.GetBaseTypeName(),
   405  					ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(),
   406  				},
   407  			}
   408  		}
   409  	case *v1alphareflectionpb.ServerReflectionResponse_ListServicesResponse:
   410  		if mr != nil {
   411  			svcs := make([]*v1reflectionpb.ServiceResponse, len(mr.ListServicesResponse.GetService()))
   412  			for i, svc := range mr.ListServicesResponse.GetService() {
   413  				svcs[i] = &v1reflectionpb.ServiceResponse{
   414  					Name: svc.GetName(),
   415  				}
   416  			}
   417  			v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ListServicesResponse{
   418  				ListServicesResponse: &v1reflectionpb.ListServiceResponse{
   419  					Service: svcs,
   420  				},
   421  			}
   422  		}
   423  	case *v1alphareflectionpb.ServerReflectionResponse_ErrorResponse:
   424  		if mr != nil {
   425  			v1.MessageResponse = &v1reflectionpb.ServerReflectionResponse_ErrorResponse{
   426  				ErrorResponse: &v1reflectionpb.ErrorResponse{
   427  					ErrorCode:    mr.ErrorResponse.GetErrorCode(),
   428  					ErrorMessage: mr.ErrorResponse.GetErrorMessage(),
   429  				},
   430  			}
   431  		}
   432  	default:
   433  		// no value set
   434  	}
   435  	return &v1
   436  }
   437  

View as plain text