...

Source file src/google.golang.org/grpc/test/authority_test.go

Documentation: google.golang.org/grpc/test

     1  //go:build linux
     2  // +build linux
     3  
     4  /*
     5   *
     6   * Copyright 2020 gRPC authors.
     7   *
     8   * Licensed under the Apache License, Version 2.0 (the "License");
     9   * you may not use this file except in compliance with the License.
    10   * You may obtain a copy of the License at
    11   *
    12   *     https://www.apache.org/licenses/LICENSE-2.0
    13   *
    14   * Unless required by applicable law or agreed to in writing, software
    15   * distributed under the License is distributed on an "AS IS" BASIS,
    16   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    17   * See the License for the specific language governing permissions and
    18   * limitations under the License.
    19   *
    20   */
    21  
    22  package test
    23  
    24  import (
    25  	"context"
    26  	"fmt"
    27  	"net"
    28  	"os"
    29  	"strings"
    30  	"sync"
    31  	"testing"
    32  
    33  	"google.golang.org/grpc"
    34  	"google.golang.org/grpc/codes"
    35  	"google.golang.org/grpc/credentials/insecure"
    36  	"google.golang.org/grpc/internal/stubserver"
    37  	"google.golang.org/grpc/metadata"
    38  	"google.golang.org/grpc/resolver"
    39  	"google.golang.org/grpc/resolver/manual"
    40  	"google.golang.org/grpc/status"
    41  
    42  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    43  	testpb "google.golang.org/grpc/interop/grpc_testing"
    44  )
    45  
    46  func authorityChecker(ctx context.Context, expectedAuthority string) (*testpb.Empty, error) {
    47  	md, ok := metadata.FromIncomingContext(ctx)
    48  	if !ok {
    49  		return nil, status.Error(codes.InvalidArgument, "failed to parse metadata")
    50  	}
    51  	auths, ok := md[":authority"]
    52  	if !ok {
    53  		return nil, status.Error(codes.InvalidArgument, "no authority header")
    54  	}
    55  	if len(auths) != 1 {
    56  		return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths))
    57  	}
    58  	if auths[0] != expectedAuthority {
    59  		return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority))
    60  	}
    61  	return &testpb.Empty{}, nil
    62  }
    63  
    64  func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer func(context.Context, string) (net.Conn, error)) {
    65  	if !strings.HasPrefix(target, "unix-abstract:") {
    66  		if err := os.RemoveAll(address); err != nil {
    67  			t.Fatalf("Error removing socket file %v: %v\n", address, err)
    68  		}
    69  	}
    70  	ss := &stubserver.StubServer{
    71  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
    72  			return authorityChecker(ctx, expectedAuthority)
    73  		},
    74  		Network: "unix",
    75  		Address: address,
    76  		Target:  target,
    77  	}
    78  	opts := []grpc.DialOption{}
    79  	if dialer != nil {
    80  		opts = append(opts, grpc.WithContextDialer(dialer))
    81  	}
    82  	if err := ss.Start(nil, opts...); err != nil {
    83  		t.Fatalf("Error starting endpoint server: %v", err)
    84  	}
    85  	defer ss.Stop()
    86  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
    87  	defer cancel()
    88  	_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
    89  	if err != nil {
    90  		t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
    91  	}
    92  }
    93  
    94  type authorityTest struct {
    95  	name           string
    96  	address        string
    97  	target         string
    98  	authority      string
    99  	dialTargetWant string
   100  }
   101  
   102  var authorityTests = []authorityTest{
   103  	{
   104  		name:           "UnixRelative",
   105  		address:        "sock.sock",
   106  		target:         "unix:sock.sock",
   107  		authority:      "localhost",
   108  		dialTargetWant: "unix:sock.sock",
   109  	},
   110  	{
   111  		name:           "UnixAbsolute",
   112  		address:        "/tmp/sock.sock",
   113  		target:         "unix:/tmp/sock.sock",
   114  		authority:      "localhost",
   115  		dialTargetWant: "unix:///tmp/sock.sock",
   116  	},
   117  	{
   118  		name:           "UnixAbsoluteAlternate",
   119  		address:        "/tmp/sock.sock",
   120  		target:         "unix:///tmp/sock.sock",
   121  		authority:      "localhost",
   122  		dialTargetWant: "unix:///tmp/sock.sock",
   123  	},
   124  	{
   125  		name:           "UnixPassthrough",
   126  		address:        "/tmp/sock.sock",
   127  		target:         "passthrough:///unix:///tmp/sock.sock",
   128  		authority:      "unix:%2F%2F%2Ftmp%2Fsock.sock",
   129  		dialTargetWant: "unix:///tmp/sock.sock",
   130  	},
   131  	{
   132  		name:           "UnixAbstract",
   133  		address:        "@abc efg",
   134  		target:         "unix-abstract:abc efg",
   135  		authority:      "localhost",
   136  		dialTargetWant: "unix:@abc efg",
   137  	},
   138  }
   139  
   140  // TestUnix does end to end tests with the various supported unix target
   141  // formats, ensuring that the authority is set as expected.
   142  func (s) TestUnix(t *testing.T) {
   143  	for _, test := range authorityTests {
   144  		t.Run(test.name, func(t *testing.T) {
   145  			runUnixTest(t, test.address, test.target, test.authority, nil)
   146  		})
   147  	}
   148  }
   149  
   150  // TestUnixCustomDialer does end to end tests with various supported unix target
   151  // formats, ensuring that the target sent to the dialer does NOT have the
   152  // "unix:" prefix stripped.
   153  func (s) TestUnixCustomDialer(t *testing.T) {
   154  	for _, test := range authorityTests {
   155  		t.Run(test.name+"WithDialer", func(t *testing.T) {
   156  			dialer := func(ctx context.Context, address string) (net.Conn, error) {
   157  				if address != test.dialTargetWant {
   158  					return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.dialTargetWant, address)
   159  				}
   160  				address = address[len("unix:"):]
   161  				return (&net.Dialer{}).DialContext(ctx, "unix", address)
   162  			}
   163  			runUnixTest(t, test.address, test.target, test.authority, dialer)
   164  		})
   165  	}
   166  }
   167  
   168  // TestColonPortAuthority does an end to end test with the target for grpc.Dial
   169  // being ":[port]". Ensures authority is "localhost:[port]".
   170  func (s) TestColonPortAuthority(t *testing.T) {
   171  	expectedAuthority := ""
   172  	var authorityMu sync.Mutex
   173  	ss := &stubserver.StubServer{
   174  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   175  			authorityMu.Lock()
   176  			defer authorityMu.Unlock()
   177  			return authorityChecker(ctx, expectedAuthority)
   178  		},
   179  		Network: "tcp",
   180  	}
   181  	if err := ss.Start(nil); err != nil {
   182  		t.Fatalf("Error starting endpoint server: %v", err)
   183  	}
   184  	defer ss.Stop()
   185  	_, port, err := net.SplitHostPort(ss.Address)
   186  	if err != nil {
   187  		t.Fatalf("Failed splitting host from post: %v", err)
   188  	}
   189  	authorityMu.Lock()
   190  	expectedAuthority = "localhost:" + port
   191  	authorityMu.Unlock()
   192  	// ss.Start dials, but not the ":[port]" target that is being tested here.
   193  	// Dial again, with ":[port]" as the target.
   194  	//
   195  	// Append "localhost" before calling net.Dial, in case net.Dial on certain
   196  	// platforms doesn't work well for address without the IP.
   197  	cc, err := grpc.Dial(":"+port, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
   198  		return (&net.Dialer{}).DialContext(ctx, "tcp", "localhost"+addr)
   199  	}))
   200  	if err != nil {
   201  		t.Fatalf("grpc.Dial(%q) = %v", ss.Target, err)
   202  	}
   203  	defer cc.Close()
   204  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   205  	defer cancel()
   206  	_, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{})
   207  	if err != nil {
   208  		t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
   209  	}
   210  }
   211  
   212  // TestAuthorityReplacedWithResolverAddress tests the scenario where the resolver
   213  // returned address contains a ServerName override. The test verifies that the
   214  // :authority header value sent to the server as part of the http/2 HEADERS frame
   215  // is set to the value specified in the resolver returned address.
   216  func (s) TestAuthorityReplacedWithResolverAddress(t *testing.T) {
   217  	const expectedAuthority = "test.server.name"
   218  
   219  	ss := &stubserver.StubServer{
   220  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   221  			return authorityChecker(ctx, expectedAuthority)
   222  		},
   223  	}
   224  	if err := ss.Start(nil); err != nil {
   225  		t.Fatalf("Error starting endpoint server: %v", err)
   226  	}
   227  	defer ss.Stop()
   228  
   229  	r := manual.NewBuilderWithScheme("whatever")
   230  	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address, ServerName: expectedAuthority}}})
   231  	cc, err := grpc.NewClient(r.Scheme()+":///whatever", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
   232  	if err != nil {
   233  		t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
   234  	}
   235  	defer cc.Close()
   236  
   237  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   238  	defer cancel()
   239  	if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}); err != nil {
   240  		t.Fatalf("EmptyCall() rpc failed: %v", err)
   241  	}
   242  }
   243  

View as plain text