...

Source file src/github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway/internal/gengateway/template_test.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-grpc-gateway/internal/gengateway

     1  package gengateway
     2  
     3  import (
     4  	"strings"
     5  	"testing"
     6  
     7  	"github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor"
     8  	"github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
     9  	"google.golang.org/protobuf/proto"
    10  	"google.golang.org/protobuf/types/descriptorpb"
    11  )
    12  
    13  func crossLinkFixture(f *descriptor.File) *descriptor.File {
    14  	for _, m := range f.Messages {
    15  		m.File = f
    16  	}
    17  	for _, svc := range f.Services {
    18  		svc.File = f
    19  		for _, m := range svc.Methods {
    20  			m.Service = svc
    21  			for _, b := range m.Bindings {
    22  				b.Method = m
    23  				for _, param := range b.PathParams {
    24  					param.Method = m
    25  				}
    26  			}
    27  		}
    28  	}
    29  	return f
    30  }
    31  
    32  func compilePath(t *testing.T, path string) httprule.Template {
    33  	parsed, err := httprule.Parse(path)
    34  	if err != nil {
    35  		t.Fatalf("httprule.Parse(%q) failed with %v; want success", path, err)
    36  	}
    37  	return parsed.Compile()
    38  }
    39  
    40  func TestApplyTemplateHeader(t *testing.T) {
    41  	msgdesc := &descriptorpb.DescriptorProto{
    42  		Name: proto.String("ExampleMessage"),
    43  	}
    44  	meth := &descriptorpb.MethodDescriptorProto{
    45  		Name:       proto.String("Example"),
    46  		InputType:  proto.String("ExampleMessage"),
    47  		OutputType: proto.String("ExampleMessage"),
    48  	}
    49  	svc := &descriptorpb.ServiceDescriptorProto{
    50  		Name:   proto.String("ExampleService"),
    51  		Method: []*descriptorpb.MethodDescriptorProto{meth},
    52  	}
    53  	msg := &descriptor.Message{
    54  		DescriptorProto: msgdesc,
    55  	}
    56  	file := descriptor.File{
    57  		FileDescriptorProto: &descriptorpb.FileDescriptorProto{
    58  			Name:        proto.String("example.proto"),
    59  			Package:     proto.String("example"),
    60  			Dependency:  []string{"a.example/b/c.proto", "a.example/d/e.proto"},
    61  			MessageType: []*descriptorpb.DescriptorProto{msgdesc},
    62  			Service:     []*descriptorpb.ServiceDescriptorProto{svc},
    63  		},
    64  		GoPkg: descriptor.GoPackage{
    65  			Path: "example.com/path/to/example/example.pb",
    66  			Name: "example_pb",
    67  		},
    68  		Messages: []*descriptor.Message{msg},
    69  		Services: []*descriptor.Service{
    70  			{
    71  				ServiceDescriptorProto: svc,
    72  				Methods: []*descriptor.Method{
    73  					{
    74  						MethodDescriptorProto: meth,
    75  						RequestType:           msg,
    76  						ResponseType:          msg,
    77  						Bindings: []*descriptor.Binding{
    78  							{
    79  								HTTPMethod: "GET",
    80  								Body:       &descriptor.Body{FieldPath: nil},
    81  							},
    82  						},
    83  					},
    84  				},
    85  			},
    86  		},
    87  	}
    88  	got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
    89  	if err != nil {
    90  		t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
    91  		return
    92  	}
    93  	if want := "package example_pb\n"; !strings.Contains(got, want) {
    94  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
    95  	}
    96  }
    97  
    98  func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) {
    99  	msgdesc := &descriptorpb.DescriptorProto{
   100  		Name: proto.String("ExampleMessage"),
   101  		Field: []*descriptorpb.FieldDescriptorProto{
   102  			{
   103  				Name:     proto.String("nested"),
   104  				Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   105  				Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   106  				TypeName: proto.String("NestedMessage"),
   107  				Number:   proto.Int32(1),
   108  			},
   109  		},
   110  	}
   111  	nesteddesc := &descriptorpb.DescriptorProto{
   112  		Name: proto.String("NestedMessage"),
   113  		Field: []*descriptorpb.FieldDescriptorProto{
   114  			{
   115  				Name:   proto.String("int32"),
   116  				Label:  descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   117  				Type:   descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(),
   118  				Number: proto.Int32(1),
   119  			},
   120  			{
   121  				Name:   proto.String("bool"),
   122  				Label:  descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   123  				Type:   descriptorpb.FieldDescriptorProto_TYPE_BOOL.Enum(),
   124  				Number: proto.Int32(2),
   125  			},
   126  		},
   127  	}
   128  	meth := &descriptorpb.MethodDescriptorProto{
   129  		Name:            proto.String("Echo"),
   130  		InputType:       proto.String("ExampleMessage"),
   131  		OutputType:      proto.String("ExampleMessage"),
   132  		ClientStreaming: proto.Bool(false),
   133  	}
   134  	svc := &descriptorpb.ServiceDescriptorProto{
   135  		Name:   proto.String("ExampleService"),
   136  		Method: []*descriptorpb.MethodDescriptorProto{meth},
   137  	}
   138  	for _, spec := range []struct {
   139  		serverStreaming bool
   140  		sigWant         string
   141  	}{
   142  		{
   143  			serverStreaming: false,
   144  			sigWant:         `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`,
   145  		},
   146  		{
   147  			serverStreaming: true,
   148  			sigWant:         `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`,
   149  		},
   150  	} {
   151  		meth.ServerStreaming = proto.Bool(spec.serverStreaming)
   152  
   153  		msg := &descriptor.Message{
   154  			DescriptorProto: msgdesc,
   155  		}
   156  		nested := &descriptor.Message{
   157  			DescriptorProto: nesteddesc,
   158  		}
   159  
   160  		nestedField := &descriptor.Field{
   161  			Message:              msg,
   162  			FieldDescriptorProto: msg.GetField()[0],
   163  		}
   164  		intField := &descriptor.Field{
   165  			Message:              nested,
   166  			FieldDescriptorProto: nested.GetField()[0],
   167  		}
   168  		boolField := &descriptor.Field{
   169  			Message:              nested,
   170  			FieldDescriptorProto: nested.GetField()[1],
   171  		}
   172  		file := descriptor.File{
   173  			FileDescriptorProto: &descriptorpb.FileDescriptorProto{
   174  				Name:        proto.String("example.proto"),
   175  				Package:     proto.String("example"),
   176  				MessageType: []*descriptorpb.DescriptorProto{msgdesc, nesteddesc},
   177  				Service:     []*descriptorpb.ServiceDescriptorProto{svc},
   178  			},
   179  			GoPkg: descriptor.GoPackage{
   180  				Path: "example.com/path/to/example/example.pb",
   181  				Name: "example_pb",
   182  			},
   183  			Messages: []*descriptor.Message{msg, nested},
   184  			Services: []*descriptor.Service{
   185  				{
   186  					ServiceDescriptorProto: svc,
   187  					Methods: []*descriptor.Method{
   188  						{
   189  							MethodDescriptorProto: meth,
   190  							RequestType:           msg,
   191  							ResponseType:          msg,
   192  							Bindings: []*descriptor.Binding{
   193  								{
   194  									HTTPMethod: "POST",
   195  									PathTmpl: httprule.Template{
   196  										Version:  1,
   197  										OpCodes:  []int{0, 0},
   198  										Template: "/v1",
   199  									},
   200  									PathParams: []descriptor.Parameter{
   201  										{
   202  											FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   203  												{
   204  													Name:   "nested",
   205  													Target: nestedField,
   206  												},
   207  												{
   208  													Name:   "int32",
   209  													Target: intField,
   210  												},
   211  											}),
   212  											Target: intField,
   213  										},
   214  									},
   215  									Body: &descriptor.Body{
   216  										FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   217  											{
   218  												Name:   "nested",
   219  												Target: nestedField,
   220  											},
   221  											{
   222  												Name:   "bool",
   223  												Target: boolField,
   224  											},
   225  										}),
   226  									},
   227  								},
   228  							},
   229  						},
   230  					},
   231  				},
   232  			},
   233  		}
   234  		got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   235  		if err != nil {
   236  			t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   237  			return
   238  		}
   239  		if want := spec.sigWant; !strings.Contains(got, want) {
   240  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   241  		}
   242  		if want := `marshaler.NewDecoder(req.Body).Decode(&protoReq.GetNested().Bool)`; !strings.Contains(got, want) {
   243  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   244  		}
   245  		if want := `val, ok = pathParams["nested.int32"]`; !strings.Contains(got, want) {
   246  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   247  		}
   248  		if want := `protoReq.GetNested().Int32, err = runtime.Int32P(val)`; !strings.Contains(got, want) {
   249  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   250  		}
   251  		if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) {
   252  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   253  		}
   254  		if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) {
   255  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   256  		}
   257  		if want := `annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/example.ExampleService/Echo", runtime.WithHTTPPathPattern("/v1"))`; !strings.Contains(got, want) {
   258  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   259  		}
   260  	}
   261  }
   262  
   263  func TestApplyTemplateRequestWithClientStreaming(t *testing.T) {
   264  	msgdesc := &descriptorpb.DescriptorProto{
   265  		Name: proto.String("ExampleMessage"),
   266  		Field: []*descriptorpb.FieldDescriptorProto{
   267  			{
   268  				Name:     proto.String("nested"),
   269  				Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   270  				Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   271  				TypeName: proto.String("NestedMessage"),
   272  				Number:   proto.Int32(1),
   273  			},
   274  		},
   275  	}
   276  	nesteddesc := &descriptorpb.DescriptorProto{
   277  		Name: proto.String("NestedMessage"),
   278  		Field: []*descriptorpb.FieldDescriptorProto{
   279  			{
   280  				Name:   proto.String("int32"),
   281  				Label:  descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   282  				Type:   descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(),
   283  				Number: proto.Int32(1),
   284  			},
   285  			{
   286  				Name:   proto.String("bool"),
   287  				Label:  descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   288  				Type:   descriptorpb.FieldDescriptorProto_TYPE_BOOL.Enum(),
   289  				Number: proto.Int32(2),
   290  			},
   291  		},
   292  	}
   293  	meth := &descriptorpb.MethodDescriptorProto{
   294  		Name:            proto.String("Echo"),
   295  		InputType:       proto.String("ExampleMessage"),
   296  		OutputType:      proto.String("ExampleMessage"),
   297  		ClientStreaming: proto.Bool(true),
   298  	}
   299  	svc := &descriptorpb.ServiceDescriptorProto{
   300  		Name:   proto.String("ExampleService"),
   301  		Method: []*descriptorpb.MethodDescriptorProto{meth},
   302  	}
   303  	for _, spec := range []struct {
   304  		serverStreaming bool
   305  		sigWant         string
   306  	}{
   307  		{
   308  			serverStreaming: false,
   309  			sigWant:         `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`,
   310  		},
   311  		{
   312  			serverStreaming: true,
   313  			sigWant:         `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`,
   314  		},
   315  	} {
   316  		meth.ServerStreaming = proto.Bool(spec.serverStreaming)
   317  
   318  		msg := &descriptor.Message{
   319  			DescriptorProto: msgdesc,
   320  		}
   321  		nested := &descriptor.Message{
   322  			DescriptorProto: nesteddesc,
   323  		}
   324  
   325  		nestedField := &descriptor.Field{
   326  			Message:              msg,
   327  			FieldDescriptorProto: msg.GetField()[0],
   328  		}
   329  		intField := &descriptor.Field{
   330  			Message:              nested,
   331  			FieldDescriptorProto: nested.GetField()[0],
   332  		}
   333  		boolField := &descriptor.Field{
   334  			Message:              nested,
   335  			FieldDescriptorProto: nested.GetField()[1],
   336  		}
   337  		file := descriptor.File{
   338  			FileDescriptorProto: &descriptorpb.FileDescriptorProto{
   339  				Name:        proto.String("example.proto"),
   340  				Package:     proto.String("example"),
   341  				MessageType: []*descriptorpb.DescriptorProto{msgdesc, nesteddesc},
   342  				Service:     []*descriptorpb.ServiceDescriptorProto{svc},
   343  			},
   344  			GoPkg: descriptor.GoPackage{
   345  				Path: "example.com/path/to/example/example.pb",
   346  				Name: "example_pb",
   347  			},
   348  			Messages: []*descriptor.Message{msg, nested},
   349  			Services: []*descriptor.Service{
   350  				{
   351  					ServiceDescriptorProto: svc,
   352  					Methods: []*descriptor.Method{
   353  						{
   354  							MethodDescriptorProto: meth,
   355  							RequestType:           msg,
   356  							ResponseType:          msg,
   357  							Bindings: []*descriptor.Binding{
   358  								{
   359  									HTTPMethod: "POST",
   360  									PathTmpl: httprule.Template{
   361  										Version: 1,
   362  										OpCodes: []int{0, 0},
   363  									},
   364  									PathParams: []descriptor.Parameter{
   365  										{
   366  											FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   367  												{
   368  													Name:   "nested",
   369  													Target: nestedField,
   370  												},
   371  												{
   372  													Name:   "int32",
   373  													Target: intField,
   374  												},
   375  											}),
   376  											Target: intField,
   377  										},
   378  									},
   379  									Body: &descriptor.Body{
   380  										FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   381  											{
   382  												Name:   "nested",
   383  												Target: nestedField,
   384  											},
   385  											{
   386  												Name:   "bool",
   387  												Target: boolField,
   388  											},
   389  										}),
   390  									},
   391  								},
   392  							},
   393  						},
   394  					},
   395  				},
   396  			},
   397  		}
   398  		got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   399  		if err != nil {
   400  			t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   401  			return
   402  		}
   403  		if want := spec.sigWant; !strings.Contains(got, want) {
   404  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   405  		}
   406  		if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) {
   407  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   408  		}
   409  		if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) {
   410  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   411  		}
   412  	}
   413  }
   414  
   415  func TestApplyTemplateInProcess(t *testing.T) {
   416  	msgdesc := &descriptorpb.DescriptorProto{
   417  		Name: proto.String("ExampleMessage"),
   418  		Field: []*descriptorpb.FieldDescriptorProto{
   419  			{
   420  				Name:     proto.String("nested"),
   421  				Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   422  				Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   423  				TypeName: proto.String("NestedMessage"),
   424  				Number:   proto.Int32(1),
   425  			},
   426  		},
   427  	}
   428  	nesteddesc := &descriptorpb.DescriptorProto{
   429  		Name: proto.String("NestedMessage"),
   430  		Field: []*descriptorpb.FieldDescriptorProto{
   431  			{
   432  				Name:   proto.String("int32"),
   433  				Label:  descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   434  				Type:   descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(),
   435  				Number: proto.Int32(1),
   436  			},
   437  			{
   438  				Name:   proto.String("bool"),
   439  				Label:  descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   440  				Type:   descriptorpb.FieldDescriptorProto_TYPE_BOOL.Enum(),
   441  				Number: proto.Int32(2),
   442  			},
   443  		},
   444  	}
   445  	meth := &descriptorpb.MethodDescriptorProto{
   446  		Name:            proto.String("Echo"),
   447  		InputType:       proto.String("ExampleMessage"),
   448  		OutputType:      proto.String("ExampleMessage"),
   449  		ClientStreaming: proto.Bool(true),
   450  	}
   451  	svc := &descriptorpb.ServiceDescriptorProto{
   452  		Name:   proto.String("ExampleService"),
   453  		Method: []*descriptorpb.MethodDescriptorProto{meth},
   454  	}
   455  	for _, spec := range []struct {
   456  		clientStreaming bool
   457  		serverStreaming bool
   458  		sigWant         []string
   459  	}{
   460  		{
   461  			clientStreaming: false,
   462  			serverStreaming: false,
   463  			sigWant: []string{
   464  				`func local_request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, server ExampleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`,
   465  				`resp, md, err := local_request_ExampleService_Echo_0(annotatedContext, inboundMarshaler, server, req, pathParams)`,
   466  			},
   467  		},
   468  		{
   469  			clientStreaming: true,
   470  			serverStreaming: true,
   471  			sigWant: []string{
   472  				`err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")`,
   473  			},
   474  		},
   475  		{
   476  			clientStreaming: true,
   477  			serverStreaming: false,
   478  			sigWant: []string{
   479  				`err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")`,
   480  			},
   481  		},
   482  		{
   483  			clientStreaming: false,
   484  			serverStreaming: true,
   485  			sigWant: []string{
   486  				`err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")`,
   487  			},
   488  		},
   489  	} {
   490  		meth.ClientStreaming = proto.Bool(spec.clientStreaming)
   491  		meth.ServerStreaming = proto.Bool(spec.serverStreaming)
   492  
   493  		msg := &descriptor.Message{
   494  			DescriptorProto: msgdesc,
   495  		}
   496  		nested := &descriptor.Message{
   497  			DescriptorProto: nesteddesc,
   498  		}
   499  
   500  		nestedField := &descriptor.Field{
   501  			Message:              msg,
   502  			FieldDescriptorProto: msg.GetField()[0],
   503  		}
   504  		intField := &descriptor.Field{
   505  			Message:              nested,
   506  			FieldDescriptorProto: nested.GetField()[0],
   507  		}
   508  		boolField := &descriptor.Field{
   509  			Message:              nested,
   510  			FieldDescriptorProto: nested.GetField()[1],
   511  		}
   512  		file := descriptor.File{
   513  			FileDescriptorProto: &descriptorpb.FileDescriptorProto{
   514  				Name:        proto.String("example.proto"),
   515  				Package:     proto.String("example"),
   516  				MessageType: []*descriptorpb.DescriptorProto{msgdesc, nesteddesc},
   517  				Service:     []*descriptorpb.ServiceDescriptorProto{svc},
   518  			},
   519  			GoPkg: descriptor.GoPackage{
   520  				Path: "example.com/path/to/example/example.pb",
   521  				Name: "example_pb",
   522  			},
   523  			Messages: []*descriptor.Message{msg, nested},
   524  			Services: []*descriptor.Service{
   525  				{
   526  					ServiceDescriptorProto: svc,
   527  					Methods: []*descriptor.Method{
   528  						{
   529  							MethodDescriptorProto: meth,
   530  							RequestType:           msg,
   531  							ResponseType:          msg,
   532  							Bindings: []*descriptor.Binding{
   533  								{
   534  									HTTPMethod: "POST",
   535  									PathTmpl: httprule.Template{
   536  										Version: 1,
   537  										OpCodes: []int{0, 0},
   538  									},
   539  									PathParams: []descriptor.Parameter{
   540  										{
   541  											FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   542  												{
   543  													Name:   "nested",
   544  													Target: nestedField,
   545  												},
   546  												{
   547  													Name:   "int32",
   548  													Target: intField,
   549  												},
   550  											}),
   551  											Target: intField,
   552  										},
   553  									},
   554  									Body: &descriptor.Body{
   555  										FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   556  											{
   557  												Name:   "nested",
   558  												Target: nestedField,
   559  											},
   560  											{
   561  												Name:   "bool",
   562  												Target: boolField,
   563  											},
   564  										}),
   565  									},
   566  								},
   567  							},
   568  						},
   569  					},
   570  				},
   571  			},
   572  		}
   573  		got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   574  		if err != nil {
   575  			t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   576  			return
   577  		}
   578  
   579  		for _, want := range spec.sigWant {
   580  			if !strings.Contains(got, want) {
   581  				t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   582  			}
   583  		}
   584  
   585  		if want := `func RegisterExampleServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server ExampleServiceServer) error {`; !strings.Contains(got, want) {
   586  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   587  		}
   588  	}
   589  }
   590  
   591  func TestAllowPatchFeature(t *testing.T) {
   592  	updateMaskDesc := &descriptorpb.FieldDescriptorProto{
   593  		Name:     proto.String("UpdateMask"),
   594  		Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   595  		Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   596  		TypeName: proto.String(".google.protobuf.FieldMask"),
   597  		Number:   proto.Int32(1),
   598  	}
   599  	msgdesc := &descriptorpb.DescriptorProto{
   600  		Name:  proto.String("ExampleMessage"),
   601  		Field: []*descriptorpb.FieldDescriptorProto{updateMaskDesc},
   602  	}
   603  	meth := &descriptorpb.MethodDescriptorProto{
   604  		Name:       proto.String("Example"),
   605  		InputType:  proto.String("ExampleMessage"),
   606  		OutputType: proto.String("ExampleMessage"),
   607  	}
   608  	svc := &descriptorpb.ServiceDescriptorProto{
   609  		Name:   proto.String("ExampleService"),
   610  		Method: []*descriptorpb.MethodDescriptorProto{meth},
   611  	}
   612  	msg := &descriptor.Message{
   613  		DescriptorProto: msgdesc,
   614  	}
   615  	updateMaskField := &descriptor.Field{
   616  		Message:              msg,
   617  		FieldDescriptorProto: updateMaskDesc,
   618  	}
   619  	msg.Fields = append(msg.Fields, updateMaskField)
   620  	file := descriptor.File{
   621  		FileDescriptorProto: &descriptorpb.FileDescriptorProto{
   622  			Name:        proto.String("example.proto"),
   623  			Package:     proto.String("example"),
   624  			MessageType: []*descriptorpb.DescriptorProto{msgdesc},
   625  			Service:     []*descriptorpb.ServiceDescriptorProto{svc},
   626  		},
   627  		GoPkg: descriptor.GoPackage{
   628  			Path: "example.com/path/to/example/example.pb",
   629  			Name: "example_pb",
   630  		},
   631  		Messages: []*descriptor.Message{msg},
   632  		Services: []*descriptor.Service{
   633  			{
   634  				ServiceDescriptorProto: svc,
   635  				Methods: []*descriptor.Method{
   636  					{
   637  						MethodDescriptorProto: meth,
   638  						RequestType:           msg,
   639  						ResponseType:          msg,
   640  						Bindings: []*descriptor.Binding{
   641  							{
   642  								HTTPMethod: "PATCH",
   643  								Body: &descriptor.Body{FieldPath: descriptor.FieldPath{descriptor.FieldPathComponent{
   644  									Name:   "abe",
   645  									Target: msg.Fields[0],
   646  								}}},
   647  							},
   648  						},
   649  					},
   650  				},
   651  			},
   652  		},
   653  	}
   654  	want := "if protoReq.UpdateMask == nil || len(protoReq.UpdateMask.GetPaths()) == 0 {\n"
   655  	for _, allowPatchFeature := range []bool{true, false} {
   656  		got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: allowPatchFeature}, descriptor.NewRegistry())
   657  		if err != nil {
   658  			t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   659  			return
   660  		}
   661  		if allowPatchFeature {
   662  			if want := `marshaler.NewDecoder(newReader()).Decode(&protoReq.Abe)`; !strings.Contains(got, want) {
   663  				t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   664  			}
   665  			if !strings.Contains(got, want) {
   666  				t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   667  			}
   668  		} else {
   669  			if strings.Contains(got, want) {
   670  				t.Errorf("applyTemplate(%#v) = %s; want to _not_ contain %s", file, got, want)
   671  			}
   672  		}
   673  	}
   674  }
   675  
   676  func TestIdentifierCapitalization(t *testing.T) {
   677  	msgdesc1 := &descriptorpb.DescriptorProto{
   678  		Name: proto.String("Exam_pleRequest"),
   679  	}
   680  	msgdesc2 := &descriptorpb.DescriptorProto{
   681  		Name: proto.String("example_response"),
   682  	}
   683  	meth1 := &descriptorpb.MethodDescriptorProto{
   684  		Name:       proto.String("ExampleGe2t"),
   685  		InputType:  proto.String("Exam_pleRequest"),
   686  		OutputType: proto.String("example_response"),
   687  	}
   688  	meth2 := &descriptorpb.MethodDescriptorProto{
   689  		Name:       proto.String("Exampl_ePost"),
   690  		InputType:  proto.String("Exam_pleRequest"),
   691  		OutputType: proto.String("example_response"),
   692  	}
   693  	svc := &descriptorpb.ServiceDescriptorProto{
   694  		Name:   proto.String("Example"),
   695  		Method: []*descriptorpb.MethodDescriptorProto{meth1, meth2},
   696  	}
   697  	msg1 := &descriptor.Message{
   698  		DescriptorProto: msgdesc1,
   699  	}
   700  	msg2 := &descriptor.Message{
   701  		DescriptorProto: msgdesc2,
   702  	}
   703  	file := descriptor.File{
   704  		FileDescriptorProto: &descriptorpb.FileDescriptorProto{
   705  			Name:        proto.String("example.proto"),
   706  			Package:     proto.String("example"),
   707  			Dependency:  []string{"a.example/b/c.proto", "a.example/d/e.proto"},
   708  			MessageType: []*descriptorpb.DescriptorProto{msgdesc1, msgdesc2},
   709  			Service:     []*descriptorpb.ServiceDescriptorProto{svc},
   710  		},
   711  		GoPkg: descriptor.GoPackage{
   712  			Path: "example.com/path/to/example/example.pb",
   713  			Name: "example_pb",
   714  		},
   715  		Messages: []*descriptor.Message{msg1, msg2},
   716  		Services: []*descriptor.Service{
   717  			{
   718  				ServiceDescriptorProto: svc,
   719  				Methods: []*descriptor.Method{
   720  					{
   721  						MethodDescriptorProto: meth1,
   722  						RequestType:           msg1,
   723  						ResponseType:          msg1,
   724  						Bindings: []*descriptor.Binding{
   725  							{
   726  								HTTPMethod: "GET",
   727  								Body:       &descriptor.Body{FieldPath: nil},
   728  							},
   729  						},
   730  					},
   731  				},
   732  			},
   733  			{
   734  				ServiceDescriptorProto: svc,
   735  				Methods: []*descriptor.Method{
   736  					{
   737  						MethodDescriptorProto: meth2,
   738  						RequestType:           msg2,
   739  						ResponseType:          msg2,
   740  						Bindings: []*descriptor.Binding{
   741  							{
   742  								HTTPMethod: "POST",
   743  								Body:       &descriptor.Body{FieldPath: nil},
   744  							},
   745  						},
   746  					},
   747  				},
   748  			},
   749  		},
   750  	}
   751  
   752  	got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   753  	if err != nil {
   754  		t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   755  		return
   756  	}
   757  	if want := `msg, err := client.ExampleGe2T(ctx, &protoReq, grpc.Header(&metadata.HeaderMD)`; !strings.Contains(got, want) {
   758  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   759  	}
   760  	if want := `msg, err := client.ExamplEPost(ctx, &protoReq, grpc.Header(&metadata.HeaderMD)`; !strings.Contains(got, want) {
   761  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   762  	}
   763  	if want := `var protoReq ExamPleRequest`; !strings.Contains(got, want) {
   764  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   765  	}
   766  	if want := `var protoReq ExampleResponse`; !strings.Contains(got, want) {
   767  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   768  	}
   769  }
   770  
   771  func TestDuplicatePathsInSameService(t *testing.T) {
   772  	msgdesc := &descriptorpb.DescriptorProto{
   773  		Name: proto.String("ExampleMessage"),
   774  		Field: []*descriptorpb.FieldDescriptorProto{
   775  			{
   776  				Name:     proto.String("nested"),
   777  				Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   778  				Type:     descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
   779  				TypeName: proto.String(".google.protobuf.StringValue"),
   780  				Number:   proto.Int32(1),
   781  			},
   782  		},
   783  	}
   784  	meth1 := &descriptorpb.MethodDescriptorProto{
   785  		Name:       proto.String("Echo"),
   786  		InputType:  proto.String("ExampleMessage"),
   787  		OutputType: proto.String("ExampleMessage"),
   788  	}
   789  	meth2 := &descriptorpb.MethodDescriptorProto{
   790  		Name:       proto.String("Echo2"),
   791  		InputType:  proto.String("ExampleMessage"),
   792  		OutputType: proto.String("ExampleMessage"),
   793  	}
   794  	svc := &descriptorpb.ServiceDescriptorProto{
   795  		Name:   proto.String("ExampleService"),
   796  		Method: []*descriptorpb.MethodDescriptorProto{meth1, meth2},
   797  	}
   798  	msg := &descriptor.Message{
   799  		DescriptorProto: msgdesc,
   800  	}
   801  	binding := &descriptor.Binding{
   802  		Index:      1,
   803  		PathTmpl:   compilePath(t, "/v1/example/echo"),
   804  		HTTPMethod: "GET",
   805  		PathParams: nil,
   806  		Body:       nil,
   807  	}
   808  	file := descriptor.File{
   809  		FileDescriptorProto: &descriptorpb.FileDescriptorProto{
   810  			Name:        proto.String("example.proto"),
   811  			Package:     proto.String("example"),
   812  			MessageType: []*descriptorpb.DescriptorProto{msgdesc},
   813  			Service:     []*descriptorpb.ServiceDescriptorProto{svc},
   814  		},
   815  		GoPkg: descriptor.GoPackage{
   816  			Path: "example.com/path/to/example/example.pb",
   817  			Name: "example_pb",
   818  		},
   819  		Messages: []*descriptor.Message{msg},
   820  		Services: []*descriptor.Service{
   821  			{
   822  				ServiceDescriptorProto: svc,
   823  				Methods: []*descriptor.Method{
   824  					{
   825  						MethodDescriptorProto: meth1,
   826  						RequestType:           msg,
   827  						ResponseType:          msg,
   828  						Bindings:              []*descriptor.Binding{binding},
   829  					},
   830  					{
   831  						MethodDescriptorProto: meth2,
   832  						RequestType:           msg,
   833  						ResponseType:          msg,
   834  						Bindings:              []*descriptor.Binding{binding},
   835  					},
   836  				},
   837  			},
   838  		},
   839  	}
   840  	_, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   841  	if err == nil {
   842  		t.Errorf("applyTemplate(%#v) succeeded; want an error", file)
   843  		return
   844  	}
   845  }
   846  
   847  func TestDuplicatePathsInDifferentService(t *testing.T) {
   848  	msgdesc := &descriptorpb.DescriptorProto{
   849  		Name: proto.String("ExampleMessage"),
   850  		Field: []*descriptorpb.FieldDescriptorProto{
   851  			{
   852  				Name:     proto.String("nested"),
   853  				Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   854  				Type:     descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
   855  				TypeName: proto.String(".google.protobuf.StringValue"),
   856  				Number:   proto.Int32(1),
   857  			},
   858  		},
   859  	}
   860  	meth1 := &descriptorpb.MethodDescriptorProto{
   861  		Name:       proto.String("Echo"),
   862  		InputType:  proto.String("ExampleMessage"),
   863  		OutputType: proto.String("ExampleMessage"),
   864  	}
   865  	meth2 := &descriptorpb.MethodDescriptorProto{
   866  		Name:       proto.String("Echo2"),
   867  		InputType:  proto.String("ExampleMessage"),
   868  		OutputType: proto.String("ExampleMessage"),
   869  	}
   870  	svc1 := &descriptorpb.ServiceDescriptorProto{
   871  		Name:   proto.String("ExampleServiceNumberOne"),
   872  		Method: []*descriptorpb.MethodDescriptorProto{meth1, meth2},
   873  	}
   874  	svc2 := &descriptorpb.ServiceDescriptorProto{
   875  		Name:   proto.String("ExampleServiceNumberTwo"),
   876  		Method: []*descriptorpb.MethodDescriptorProto{meth1, meth2},
   877  	}
   878  	msg := &descriptor.Message{
   879  		DescriptorProto: msgdesc,
   880  	}
   881  	binding := &descriptor.Binding{
   882  		Index:      1,
   883  		PathTmpl:   compilePath(t, "/v1/example/echo"),
   884  		HTTPMethod: "GET",
   885  		PathParams: nil,
   886  		Body:       nil,
   887  	}
   888  	file := descriptor.File{
   889  		FileDescriptorProto: &descriptorpb.FileDescriptorProto{
   890  			Name:        proto.String("example.proto"),
   891  			Package:     proto.String("example"),
   892  			MessageType: []*descriptorpb.DescriptorProto{msgdesc},
   893  			Service:     []*descriptorpb.ServiceDescriptorProto{svc1, svc2},
   894  		},
   895  		GoPkg: descriptor.GoPackage{
   896  			Path: "example.com/path/to/example/example.pb",
   897  			Name: "example_pb",
   898  		},
   899  		Messages: []*descriptor.Message{msg},
   900  		Services: []*descriptor.Service{
   901  			{
   902  				ServiceDescriptorProto: svc1,
   903  				Methods: []*descriptor.Method{
   904  					{
   905  						MethodDescriptorProto: meth1,
   906  						RequestType:           msg,
   907  						ResponseType:          msg,
   908  						Bindings:              []*descriptor.Binding{binding},
   909  					},
   910  				},
   911  			},
   912  			{
   913  				ServiceDescriptorProto: svc2,
   914  				Methods: []*descriptor.Method{
   915  					{
   916  						MethodDescriptorProto: meth2,
   917  						RequestType:           msg,
   918  						ResponseType:          msg,
   919  						Bindings:              []*descriptor.Binding{binding},
   920  					},
   921  				},
   922  			},
   923  		},
   924  	}
   925  	_, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   926  	if err != nil {
   927  		t.Errorf("applyTemplate(%#v) failed; want success", file)
   928  		return
   929  	}
   930  }
   931  

View as plain text