...

Source file src/github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/internal/gengateway/template_test.go

Documentation: github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/internal/gengateway

     1  package gengateway
     2  
     3  import (
     4  	"strings"
     5  	"testing"
     6  
     7  	"github.com/golang/protobuf/proto"
     8  	protodescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
     9  	"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
    10  	"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
    11  )
    12  
    13  func crossLinkFixture(f *descriptor.File) *descriptor.File {
    14  	for _, m := range f.Messages {
    15  		m.File = f
    16  	}
    17  	for _, svc := range f.Services {
    18  		svc.File = f
    19  		for _, m := range svc.Methods {
    20  			m.Service = svc
    21  			for _, b := range m.Bindings {
    22  				b.Method = m
    23  				for _, param := range b.PathParams {
    24  					param.Method = m
    25  				}
    26  			}
    27  		}
    28  	}
    29  	return f
    30  }
    31  
    32  func TestApplyTemplateHeader(t *testing.T) {
    33  	msgdesc := &protodescriptor.DescriptorProto{
    34  		Name: proto.String("ExampleMessage"),
    35  	}
    36  	meth := &protodescriptor.MethodDescriptorProto{
    37  		Name:       proto.String("Example"),
    38  		InputType:  proto.String("ExampleMessage"),
    39  		OutputType: proto.String("ExampleMessage"),
    40  	}
    41  	svc := &protodescriptor.ServiceDescriptorProto{
    42  		Name:   proto.String("ExampleService"),
    43  		Method: []*protodescriptor.MethodDescriptorProto{meth},
    44  	}
    45  	msg := &descriptor.Message{
    46  		DescriptorProto: msgdesc,
    47  	}
    48  	file := descriptor.File{
    49  		FileDescriptorProto: &protodescriptor.FileDescriptorProto{
    50  			Name:        proto.String("example.proto"),
    51  			Package:     proto.String("example"),
    52  			Dependency:  []string{"a.example/b/c.proto", "a.example/d/e.proto"},
    53  			MessageType: []*protodescriptor.DescriptorProto{msgdesc},
    54  			Service:     []*protodescriptor.ServiceDescriptorProto{svc},
    55  		},
    56  		GoPkg: descriptor.GoPackage{
    57  			Path: "example.com/path/to/example/example.pb",
    58  			Name: "example_pb",
    59  		},
    60  		Messages: []*descriptor.Message{msg},
    61  		Services: []*descriptor.Service{
    62  			{
    63  				ServiceDescriptorProto: svc,
    64  				Methods: []*descriptor.Method{
    65  					{
    66  						MethodDescriptorProto: meth,
    67  						RequestType:           msg,
    68  						ResponseType:          msg,
    69  						Bindings: []*descriptor.Binding{
    70  							{
    71  								HTTPMethod: "GET",
    72  								Body:       &descriptor.Body{FieldPath: nil},
    73  							},
    74  						},
    75  					},
    76  				},
    77  			},
    78  		},
    79  	}
    80  	got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
    81  	if err != nil {
    82  		t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
    83  		return
    84  	}
    85  	if want := "package example_pb\n"; !strings.Contains(got, want) {
    86  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
    87  	}
    88  }
    89  
    90  func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) {
    91  	msgdesc := &protodescriptor.DescriptorProto{
    92  		Name: proto.String("ExampleMessage"),
    93  		Field: []*protodescriptor.FieldDescriptorProto{
    94  			{
    95  				Name:     proto.String("nested"),
    96  				Label:    protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
    97  				Type:     protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
    98  				TypeName: proto.String("NestedMessage"),
    99  				Number:   proto.Int32(1),
   100  			},
   101  		},
   102  	}
   103  	nesteddesc := &protodescriptor.DescriptorProto{
   104  		Name: proto.String("NestedMessage"),
   105  		Field: []*protodescriptor.FieldDescriptorProto{
   106  			{
   107  				Name:   proto.String("int32"),
   108  				Label:  protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   109  				Type:   protodescriptor.FieldDescriptorProto_TYPE_INT32.Enum(),
   110  				Number: proto.Int32(1),
   111  			},
   112  			{
   113  				Name:   proto.String("bool"),
   114  				Label:  protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   115  				Type:   protodescriptor.FieldDescriptorProto_TYPE_BOOL.Enum(),
   116  				Number: proto.Int32(2),
   117  			},
   118  		},
   119  	}
   120  	meth := &protodescriptor.MethodDescriptorProto{
   121  		Name:            proto.String("Echo"),
   122  		InputType:       proto.String("ExampleMessage"),
   123  		OutputType:      proto.String("ExampleMessage"),
   124  		ClientStreaming: proto.Bool(false),
   125  	}
   126  	svc := &protodescriptor.ServiceDescriptorProto{
   127  		Name:   proto.String("ExampleService"),
   128  		Method: []*protodescriptor.MethodDescriptorProto{meth},
   129  	}
   130  	for _, spec := range []struct {
   131  		serverStreaming bool
   132  		sigWant         string
   133  	}{
   134  		{
   135  			serverStreaming: false,
   136  			sigWant:         `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`,
   137  		},
   138  		{
   139  			serverStreaming: true,
   140  			sigWant:         `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`,
   141  		},
   142  	} {
   143  		meth.ServerStreaming = proto.Bool(spec.serverStreaming)
   144  
   145  		msg := &descriptor.Message{
   146  			DescriptorProto: msgdesc,
   147  		}
   148  		nested := &descriptor.Message{
   149  			DescriptorProto: nesteddesc,
   150  		}
   151  
   152  		nestedField := &descriptor.Field{
   153  			Message:              msg,
   154  			FieldDescriptorProto: msg.GetField()[0],
   155  		}
   156  		intField := &descriptor.Field{
   157  			Message:              nested,
   158  			FieldDescriptorProto: nested.GetField()[0],
   159  		}
   160  		boolField := &descriptor.Field{
   161  			Message:              nested,
   162  			FieldDescriptorProto: nested.GetField()[1],
   163  		}
   164  		file := descriptor.File{
   165  			FileDescriptorProto: &protodescriptor.FileDescriptorProto{
   166  				Name:        proto.String("example.proto"),
   167  				Package:     proto.String("example"),
   168  				MessageType: []*protodescriptor.DescriptorProto{msgdesc, nesteddesc},
   169  				Service:     []*protodescriptor.ServiceDescriptorProto{svc},
   170  			},
   171  			GoPkg: descriptor.GoPackage{
   172  				Path: "example.com/path/to/example/example.pb",
   173  				Name: "example_pb",
   174  			},
   175  			Messages: []*descriptor.Message{msg, nested},
   176  			Services: []*descriptor.Service{
   177  				{
   178  					ServiceDescriptorProto: svc,
   179  					Methods: []*descriptor.Method{
   180  						{
   181  							MethodDescriptorProto: meth,
   182  							RequestType:           msg,
   183  							ResponseType:          msg,
   184  							Bindings: []*descriptor.Binding{
   185  								{
   186  									HTTPMethod: "POST",
   187  									PathTmpl: httprule.Template{
   188  										Version: 1,
   189  										OpCodes: []int{0, 0},
   190  									},
   191  									PathParams: []descriptor.Parameter{
   192  										{
   193  											FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   194  												{
   195  													Name:   "nested",
   196  													Target: nestedField,
   197  												},
   198  												{
   199  													Name:   "int32",
   200  													Target: intField,
   201  												},
   202  											}),
   203  											Target: intField,
   204  										},
   205  									},
   206  									Body: &descriptor.Body{
   207  										FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   208  											{
   209  												Name:   "nested",
   210  												Target: nestedField,
   211  											},
   212  											{
   213  												Name:   "bool",
   214  												Target: boolField,
   215  											},
   216  										}),
   217  									},
   218  								},
   219  							},
   220  						},
   221  					},
   222  				},
   223  			},
   224  		}
   225  		got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   226  		if err != nil {
   227  			t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   228  			return
   229  		}
   230  		if want := spec.sigWant; !strings.Contains(got, want) {
   231  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   232  		}
   233  		if want := `marshaler.NewDecoder(newReader()).Decode(&protoReq.GetNested().Bool)`; !strings.Contains(got, want) {
   234  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   235  		}
   236  		if want := `val, ok = pathParams["nested.int32"]`; !strings.Contains(got, want) {
   237  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   238  		}
   239  		if want := `protoReq.GetNested().Int32, err = runtime.Int32P(val)`; !strings.Contains(got, want) {
   240  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   241  		}
   242  		if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) {
   243  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   244  		}
   245  		if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), "", runtime.AssumeColonVerbOpt(true)))`; !strings.Contains(got, want) {
   246  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   247  		}
   248  	}
   249  }
   250  
   251  func TestApplyTemplateRequestWithClientStreaming(t *testing.T) {
   252  	msgdesc := &protodescriptor.DescriptorProto{
   253  		Name: proto.String("ExampleMessage"),
   254  		Field: []*protodescriptor.FieldDescriptorProto{
   255  			{
   256  				Name:     proto.String("nested"),
   257  				Label:    protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   258  				Type:     protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   259  				TypeName: proto.String("NestedMessage"),
   260  				Number:   proto.Int32(1),
   261  			},
   262  		},
   263  	}
   264  	nesteddesc := &protodescriptor.DescriptorProto{
   265  		Name: proto.String("NestedMessage"),
   266  		Field: []*protodescriptor.FieldDescriptorProto{
   267  			{
   268  				Name:   proto.String("int32"),
   269  				Label:  protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   270  				Type:   protodescriptor.FieldDescriptorProto_TYPE_INT32.Enum(),
   271  				Number: proto.Int32(1),
   272  			},
   273  			{
   274  				Name:   proto.String("bool"),
   275  				Label:  protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   276  				Type:   protodescriptor.FieldDescriptorProto_TYPE_BOOL.Enum(),
   277  				Number: proto.Int32(2),
   278  			},
   279  		},
   280  	}
   281  	meth := &protodescriptor.MethodDescriptorProto{
   282  		Name:            proto.String("Echo"),
   283  		InputType:       proto.String("ExampleMessage"),
   284  		OutputType:      proto.String("ExampleMessage"),
   285  		ClientStreaming: proto.Bool(true),
   286  	}
   287  	svc := &protodescriptor.ServiceDescriptorProto{
   288  		Name:   proto.String("ExampleService"),
   289  		Method: []*protodescriptor.MethodDescriptorProto{meth},
   290  	}
   291  	for _, spec := range []struct {
   292  		serverStreaming bool
   293  		sigWant         string
   294  	}{
   295  		{
   296  			serverStreaming: false,
   297  			sigWant:         `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`,
   298  		},
   299  		{
   300  			serverStreaming: true,
   301  			sigWant:         `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`,
   302  		},
   303  	} {
   304  		meth.ServerStreaming = proto.Bool(spec.serverStreaming)
   305  
   306  		msg := &descriptor.Message{
   307  			DescriptorProto: msgdesc,
   308  		}
   309  		nested := &descriptor.Message{
   310  			DescriptorProto: nesteddesc,
   311  		}
   312  
   313  		nestedField := &descriptor.Field{
   314  			Message:              msg,
   315  			FieldDescriptorProto: msg.GetField()[0],
   316  		}
   317  		intField := &descriptor.Field{
   318  			Message:              nested,
   319  			FieldDescriptorProto: nested.GetField()[0],
   320  		}
   321  		boolField := &descriptor.Field{
   322  			Message:              nested,
   323  			FieldDescriptorProto: nested.GetField()[1],
   324  		}
   325  		file := descriptor.File{
   326  			FileDescriptorProto: &protodescriptor.FileDescriptorProto{
   327  				Name:        proto.String("example.proto"),
   328  				Package:     proto.String("example"),
   329  				MessageType: []*protodescriptor.DescriptorProto{msgdesc, nesteddesc},
   330  				Service:     []*protodescriptor.ServiceDescriptorProto{svc},
   331  			},
   332  			GoPkg: descriptor.GoPackage{
   333  				Path: "example.com/path/to/example/example.pb",
   334  				Name: "example_pb",
   335  			},
   336  			Messages: []*descriptor.Message{msg, nested},
   337  			Services: []*descriptor.Service{
   338  				{
   339  					ServiceDescriptorProto: svc,
   340  					Methods: []*descriptor.Method{
   341  						{
   342  							MethodDescriptorProto: meth,
   343  							RequestType:           msg,
   344  							ResponseType:          msg,
   345  							Bindings: []*descriptor.Binding{
   346  								{
   347  									HTTPMethod: "POST",
   348  									PathTmpl: httprule.Template{
   349  										Version: 1,
   350  										OpCodes: []int{0, 0},
   351  									},
   352  									PathParams: []descriptor.Parameter{
   353  										{
   354  											FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   355  												{
   356  													Name:   "nested",
   357  													Target: nestedField,
   358  												},
   359  												{
   360  													Name:   "int32",
   361  													Target: intField,
   362  												},
   363  											}),
   364  											Target: intField,
   365  										},
   366  									},
   367  									Body: &descriptor.Body{
   368  										FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   369  											{
   370  												Name:   "nested",
   371  												Target: nestedField,
   372  											},
   373  											{
   374  												Name:   "bool",
   375  												Target: boolField,
   376  											},
   377  										}),
   378  									},
   379  								},
   380  							},
   381  						},
   382  					},
   383  				},
   384  			},
   385  		}
   386  		got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   387  		if err != nil {
   388  			t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   389  			return
   390  		}
   391  		if want := spec.sigWant; !strings.Contains(got, want) {
   392  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   393  		}
   394  		if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) {
   395  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   396  		}
   397  		if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), "", runtime.AssumeColonVerbOpt(true)))`; !strings.Contains(got, want) {
   398  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   399  		}
   400  	}
   401  }
   402  
   403  func TestApplyTemplateInProcess(t *testing.T) {
   404  	msgdesc := &protodescriptor.DescriptorProto{
   405  		Name: proto.String("ExampleMessage"),
   406  		Field: []*protodescriptor.FieldDescriptorProto{
   407  			{
   408  				Name:     proto.String("nested"),
   409  				Label:    protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   410  				Type:     protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   411  				TypeName: proto.String("NestedMessage"),
   412  				Number:   proto.Int32(1),
   413  			},
   414  		},
   415  	}
   416  	nesteddesc := &protodescriptor.DescriptorProto{
   417  		Name: proto.String("NestedMessage"),
   418  		Field: []*protodescriptor.FieldDescriptorProto{
   419  			{
   420  				Name:   proto.String("int32"),
   421  				Label:  protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   422  				Type:   protodescriptor.FieldDescriptorProto_TYPE_INT32.Enum(),
   423  				Number: proto.Int32(1),
   424  			},
   425  			{
   426  				Name:   proto.String("bool"),
   427  				Label:  protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   428  				Type:   protodescriptor.FieldDescriptorProto_TYPE_BOOL.Enum(),
   429  				Number: proto.Int32(2),
   430  			},
   431  		},
   432  	}
   433  	meth := &protodescriptor.MethodDescriptorProto{
   434  		Name:            proto.String("Echo"),
   435  		InputType:       proto.String("ExampleMessage"),
   436  		OutputType:      proto.String("ExampleMessage"),
   437  		ClientStreaming: proto.Bool(true),
   438  	}
   439  	svc := &protodescriptor.ServiceDescriptorProto{
   440  		Name:   proto.String("ExampleService"),
   441  		Method: []*protodescriptor.MethodDescriptorProto{meth},
   442  	}
   443  	for _, spec := range []struct {
   444  		clientStreaming bool
   445  		serverStreaming bool
   446  		sigWant         []string
   447  	}{
   448  		{
   449  			clientStreaming: false,
   450  			serverStreaming: false,
   451  			sigWant: []string{
   452  				`func local_request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, server ExampleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`,
   453  				`resp, md, err := local_request_ExampleService_Echo_0(rctx, inboundMarshaler, server, req, pathParams)`,
   454  			},
   455  		},
   456  		{
   457  			clientStreaming: true,
   458  			serverStreaming: true,
   459  			sigWant: []string{
   460  				`err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")`,
   461  			},
   462  		},
   463  		{
   464  			clientStreaming: true,
   465  			serverStreaming: false,
   466  			sigWant: []string{
   467  				`err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")`,
   468  			},
   469  		},
   470  		{
   471  			clientStreaming: false,
   472  			serverStreaming: true,
   473  			sigWant: []string{
   474  				`err := status.Error(codes.Unimplemented, "streaming calls are not yet supported in the in-process transport")`,
   475  			},
   476  		},
   477  	} {
   478  		meth.ClientStreaming = proto.Bool(spec.clientStreaming)
   479  		meth.ServerStreaming = proto.Bool(spec.serverStreaming)
   480  
   481  		msg := &descriptor.Message{
   482  			DescriptorProto: msgdesc,
   483  		}
   484  		nested := &descriptor.Message{
   485  			DescriptorProto: nesteddesc,
   486  		}
   487  
   488  		nestedField := &descriptor.Field{
   489  			Message:              msg,
   490  			FieldDescriptorProto: msg.GetField()[0],
   491  		}
   492  		intField := &descriptor.Field{
   493  			Message:              nested,
   494  			FieldDescriptorProto: nested.GetField()[0],
   495  		}
   496  		boolField := &descriptor.Field{
   497  			Message:              nested,
   498  			FieldDescriptorProto: nested.GetField()[1],
   499  		}
   500  		file := descriptor.File{
   501  			FileDescriptorProto: &protodescriptor.FileDescriptorProto{
   502  				Name:        proto.String("example.proto"),
   503  				Package:     proto.String("example"),
   504  				MessageType: []*protodescriptor.DescriptorProto{msgdesc, nesteddesc},
   505  				Service:     []*protodescriptor.ServiceDescriptorProto{svc},
   506  			},
   507  			GoPkg: descriptor.GoPackage{
   508  				Path: "example.com/path/to/example/example.pb",
   509  				Name: "example_pb",
   510  			},
   511  			Messages: []*descriptor.Message{msg, nested},
   512  			Services: []*descriptor.Service{
   513  				{
   514  					ServiceDescriptorProto: svc,
   515  					Methods: []*descriptor.Method{
   516  						{
   517  							MethodDescriptorProto: meth,
   518  							RequestType:           msg,
   519  							ResponseType:          msg,
   520  							Bindings: []*descriptor.Binding{
   521  								{
   522  									HTTPMethod: "POST",
   523  									PathTmpl: httprule.Template{
   524  										Version: 1,
   525  										OpCodes: []int{0, 0},
   526  									},
   527  									PathParams: []descriptor.Parameter{
   528  										{
   529  											FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   530  												{
   531  													Name:   "nested",
   532  													Target: nestedField,
   533  												},
   534  												{
   535  													Name:   "int32",
   536  													Target: intField,
   537  												},
   538  											}),
   539  											Target: intField,
   540  										},
   541  									},
   542  									Body: &descriptor.Body{
   543  										FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
   544  											{
   545  												Name:   "nested",
   546  												Target: nestedField,
   547  											},
   548  											{
   549  												Name:   "bool",
   550  												Target: boolField,
   551  											},
   552  										}),
   553  									},
   554  								},
   555  							},
   556  						},
   557  					},
   558  				},
   559  			},
   560  		}
   561  		got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   562  		if err != nil {
   563  			t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   564  			return
   565  		}
   566  
   567  		for _, want := range spec.sigWant {
   568  			if !strings.Contains(got, want) {
   569  				t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   570  			}
   571  		}
   572  
   573  		if want := `func RegisterExampleServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server ExampleServiceServer) error {`; !strings.Contains(got, want) {
   574  			t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   575  		}
   576  	}
   577  }
   578  
   579  func TestAllowPatchFeature(t *testing.T) {
   580  	updateMaskDesc := &protodescriptor.FieldDescriptorProto{
   581  		Name:     proto.String("UpdateMask"),
   582  		Label:    protodescriptor.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   583  		Type:     protodescriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   584  		TypeName: proto.String(".google.protobuf.FieldMask"),
   585  		Number:   proto.Int32(1),
   586  	}
   587  	msgdesc := &protodescriptor.DescriptorProto{
   588  		Name:  proto.String("ExampleMessage"),
   589  		Field: []*protodescriptor.FieldDescriptorProto{updateMaskDesc},
   590  	}
   591  	meth := &protodescriptor.MethodDescriptorProto{
   592  		Name:       proto.String("Example"),
   593  		InputType:  proto.String("ExampleMessage"),
   594  		OutputType: proto.String("ExampleMessage"),
   595  	}
   596  	svc := &protodescriptor.ServiceDescriptorProto{
   597  		Name:   proto.String("ExampleService"),
   598  		Method: []*protodescriptor.MethodDescriptorProto{meth},
   599  	}
   600  	msg := &descriptor.Message{
   601  		DescriptorProto: msgdesc,
   602  	}
   603  	updateMaskField := &descriptor.Field{
   604  		Message:              msg,
   605  		FieldDescriptorProto: updateMaskDesc,
   606  	}
   607  	msg.Fields = append(msg.Fields, updateMaskField)
   608  	file := descriptor.File{
   609  		FileDescriptorProto: &protodescriptor.FileDescriptorProto{
   610  			Name:        proto.String("example.proto"),
   611  			Package:     proto.String("example"),
   612  			MessageType: []*protodescriptor.DescriptorProto{msgdesc},
   613  			Service:     []*protodescriptor.ServiceDescriptorProto{svc},
   614  		},
   615  		GoPkg: descriptor.GoPackage{
   616  			Path: "example.com/path/to/example/example.pb",
   617  			Name: "example_pb",
   618  		},
   619  		Messages: []*descriptor.Message{msg},
   620  		Services: []*descriptor.Service{
   621  			{
   622  				ServiceDescriptorProto: svc,
   623  				Methods: []*descriptor.Method{
   624  					{
   625  						MethodDescriptorProto: meth,
   626  						RequestType:           msg,
   627  						ResponseType:          msg,
   628  						Bindings: []*descriptor.Binding{
   629  							{
   630  								HTTPMethod: "PATCH",
   631  								Body: &descriptor.Body{FieldPath: descriptor.FieldPath{descriptor.FieldPathComponent{
   632  									Name:   "abe",
   633  									Target: msg.Fields[0],
   634  								}}},
   635  							},
   636  						},
   637  					},
   638  				},
   639  			},
   640  		},
   641  	}
   642  	want := "if protoReq.UpdateMask == nil || len(protoReq.UpdateMask.GetPaths()) == 0 {\n"
   643  	for _, allowPatchFeature := range []bool{true, false} {
   644  		got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: allowPatchFeature}, descriptor.NewRegistry())
   645  		if err != nil {
   646  			t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   647  			return
   648  		}
   649  		if allowPatchFeature {
   650  			if !strings.Contains(got, want) {
   651  				t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   652  			}
   653  		} else {
   654  			if strings.Contains(got, want) {
   655  				t.Errorf("applyTemplate(%#v) = %s; want to _not_ contain %s", file, got, want)
   656  			}
   657  		}
   658  	}
   659  }
   660  
   661  func TestIdentifierCapitalization(t *testing.T) {
   662  	msgdesc1 := &protodescriptor.DescriptorProto{
   663  		Name: proto.String("Exam_pleRequest"),
   664  	}
   665  	msgdesc2 := &protodescriptor.DescriptorProto{
   666  		Name: proto.String("example_response"),
   667  	}
   668  	meth1 := &protodescriptor.MethodDescriptorProto{
   669  		Name:       proto.String("ExampleGe2t"),
   670  		InputType:  proto.String("Exam_pleRequest"),
   671  		OutputType: proto.String("example_response"),
   672  	}
   673  	meth2 := &protodescriptor.MethodDescriptorProto{
   674  		Name:       proto.String("Exampl_eGet"),
   675  		InputType:  proto.String("Exam_pleRequest"),
   676  		OutputType: proto.String("example_response"),
   677  	}
   678  	svc := &protodescriptor.ServiceDescriptorProto{
   679  		Name:   proto.String("Example"),
   680  		Method: []*protodescriptor.MethodDescriptorProto{meth1, meth2},
   681  	}
   682  	msg1 := &descriptor.Message{
   683  		DescriptorProto: msgdesc1,
   684  	}
   685  	msg2 := &descriptor.Message{
   686  		DescriptorProto: msgdesc2,
   687  	}
   688  	file := descriptor.File{
   689  		FileDescriptorProto: &protodescriptor.FileDescriptorProto{
   690  			Name:        proto.String("example.proto"),
   691  			Package:     proto.String("example"),
   692  			Dependency:  []string{"a.example/b/c.proto", "a.example/d/e.proto"},
   693  			MessageType: []*protodescriptor.DescriptorProto{msgdesc1, msgdesc2},
   694  			Service:     []*protodescriptor.ServiceDescriptorProto{svc},
   695  		},
   696  		GoPkg: descriptor.GoPackage{
   697  			Path: "example.com/path/to/example/example.pb",
   698  			Name: "example_pb",
   699  		},
   700  		Messages: []*descriptor.Message{msg1, msg2},
   701  		Services: []*descriptor.Service{
   702  			{
   703  				ServiceDescriptorProto: svc,
   704  				Methods: []*descriptor.Method{
   705  					{
   706  						MethodDescriptorProto: meth1,
   707  						RequestType:           msg1,
   708  						ResponseType:          msg1,
   709  						Bindings: []*descriptor.Binding{
   710  							{
   711  								HTTPMethod: "GET",
   712  								Body:       &descriptor.Body{FieldPath: nil},
   713  							},
   714  						},
   715  					},
   716  				},
   717  			},
   718  			{
   719  				ServiceDescriptorProto: svc,
   720  				Methods: []*descriptor.Method{
   721  					{
   722  						MethodDescriptorProto: meth2,
   723  						RequestType:           msg2,
   724  						ResponseType:          msg2,
   725  						Bindings: []*descriptor.Binding{
   726  							{
   727  								HTTPMethod: "GET",
   728  								Body:       &descriptor.Body{FieldPath: nil},
   729  							},
   730  						},
   731  					},
   732  				},
   733  			},
   734  		},
   735  	}
   736  
   737  	got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler", AllowPatchFeature: true}, descriptor.NewRegistry())
   738  	if err != nil {
   739  		t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err)
   740  		return
   741  	}
   742  	if want := `msg, err := client.ExampleGe2T(ctx, &protoReq, grpc.Header(&metadata.HeaderMD)`; !strings.Contains(got, want) {
   743  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   744  	}
   745  	if want := `msg, err := client.ExamplEGet(ctx, &protoReq, grpc.Header(&metadata.HeaderMD)`; !strings.Contains(got, want) {
   746  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   747  	}
   748  	if want := `var protoReq ExamPleRequest`; !strings.Contains(got, want) {
   749  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   750  	}
   751  	if want := `var protoReq ExampleResponse`; !strings.Contains(got, want) {
   752  		t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want)
   753  	}
   754  }
   755  

View as plain text