...

Source file src/google.golang.org/grpc/test/balancer_test.go

Documentation: google.golang.org/grpc/test

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package test
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"reflect"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/attributes"
    34  	"google.golang.org/grpc/balancer"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/connectivity"
    37  	"google.golang.org/grpc/credentials"
    38  	"google.golang.org/grpc/credentials/insecure"
    39  	"google.golang.org/grpc/internal"
    40  	"google.golang.org/grpc/internal/balancer/stub"
    41  	"google.golang.org/grpc/internal/balancerload"
    42  	"google.golang.org/grpc/internal/grpcsync"
    43  	"google.golang.org/grpc/internal/grpcutil"
    44  	imetadata "google.golang.org/grpc/internal/metadata"
    45  	"google.golang.org/grpc/internal/stubserver"
    46  	"google.golang.org/grpc/internal/testutils"
    47  	"google.golang.org/grpc/metadata"
    48  	"google.golang.org/grpc/resolver"
    49  	"google.golang.org/grpc/resolver/manual"
    50  	"google.golang.org/grpc/status"
    51  	"google.golang.org/grpc/testdata"
    52  
    53  	testgrpc "google.golang.org/grpc/interop/grpc_testing"
    54  	testpb "google.golang.org/grpc/interop/grpc_testing"
    55  )
    56  
    57  const testBalancerName = "testbalancer"
    58  
    59  // testBalancer creates one subconn with the first address from resolved
    60  // addresses.
    61  //
    62  // It's used to test whether options for NewSubConn are applied correctly.
    63  type testBalancer struct {
    64  	cc balancer.ClientConn
    65  	sc balancer.SubConn
    66  
    67  	newSubConnOptions balancer.NewSubConnOptions
    68  	pickInfos         []balancer.PickInfo
    69  	pickExtraMDs      []metadata.MD
    70  	doneInfo          []balancer.DoneInfo
    71  }
    72  
    73  func (b *testBalancer) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
    74  	b.cc = cc
    75  	return b
    76  }
    77  
    78  func (*testBalancer) Name() string {
    79  	return testBalancerName
    80  }
    81  
    82  func (*testBalancer) ResolverError(err error) {
    83  	panic("not implemented")
    84  }
    85  
    86  func (b *testBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
    87  	// Only create a subconn at the first time.
    88  	if b.sc == nil {
    89  		var err error
    90  		b.newSubConnOptions.StateListener = b.updateSubConnState
    91  		b.sc, err = b.cc.NewSubConn(state.ResolverState.Addresses, b.newSubConnOptions)
    92  		if err != nil {
    93  			logger.Errorf("testBalancer: failed to NewSubConn: %v", err)
    94  			return nil
    95  		}
    96  		b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}})
    97  		b.sc.Connect()
    98  	}
    99  	return nil
   100  }
   101  
   102  func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) {
   103  	panic(fmt.Sprintf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, s))
   104  }
   105  
   106  func (b *testBalancer) updateSubConnState(s balancer.SubConnState) {
   107  	logger.Infof("testBalancer: updateSubConnState: %v", s)
   108  
   109  	switch s.ConnectivityState {
   110  	case connectivity.Ready:
   111  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{bal: b}})
   112  	case connectivity.Idle:
   113  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{bal: b, idle: true}})
   114  	case connectivity.Connecting:
   115  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}})
   116  	case connectivity.TransientFailure:
   117  		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrTransientFailure, bal: b}})
   118  	}
   119  }
   120  
   121  func (b *testBalancer) Close() {}
   122  
   123  func (b *testBalancer) ExitIdle() {}
   124  
   125  type picker struct {
   126  	err  error
   127  	bal  *testBalancer
   128  	idle bool
   129  }
   130  
   131  func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
   132  	if p.err != nil {
   133  		return balancer.PickResult{}, p.err
   134  	}
   135  	if p.idle {
   136  		p.bal.sc.Connect()
   137  		return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
   138  	}
   139  	extraMD, _ := grpcutil.ExtraMetadata(info.Ctx)
   140  	info.Ctx = nil // Do not validate context.
   141  	p.bal.pickInfos = append(p.bal.pickInfos, info)
   142  	p.bal.pickExtraMDs = append(p.bal.pickExtraMDs, extraMD)
   143  	return balancer.PickResult{SubConn: p.bal.sc, Done: func(d balancer.DoneInfo) { p.bal.doneInfo = append(p.bal.doneInfo, d) }}, nil
   144  }
   145  
   146  func (s) TestCredsBundleFromBalancer(t *testing.T) {
   147  	balancer.Register(&testBalancer{
   148  		newSubConnOptions: balancer.NewSubConnOptions{
   149  			CredsBundle: &testCredsBundle{},
   150  		},
   151  	})
   152  	te := newTest(t, env{name: "creds-bundle", network: "tcp", balancer: ""})
   153  	te.tapHandle = authHandle
   154  	te.customDialOptions = []grpc.DialOption{
   155  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)),
   156  	}
   157  	creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   158  	if err != nil {
   159  		t.Fatalf("Failed to generate credentials %v", err)
   160  	}
   161  	te.customServerOptions = []grpc.ServerOption{
   162  		grpc.Creds(creds),
   163  	}
   164  	te.startServer(&testServer{})
   165  	defer te.tearDown()
   166  
   167  	cc := te.clientConn()
   168  	tc := testgrpc.NewTestServiceClient(cc)
   169  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   170  	defer cancel()
   171  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   172  		t.Fatalf("Test failed. Reason: %v", err)
   173  	}
   174  }
   175  
   176  func (s) TestPickExtraMetadata(t *testing.T) {
   177  	for _, e := range listTestEnv() {
   178  		testPickExtraMetadata(t, e)
   179  	}
   180  }
   181  
   182  func testPickExtraMetadata(t *testing.T, e env) {
   183  	te := newTest(t, e)
   184  	b := &testBalancer{}
   185  	balancer.Register(b)
   186  	const (
   187  		testUserAgent      = "test-user-agent"
   188  		testSubContentType = "proto"
   189  	)
   190  
   191  	te.customDialOptions = []grpc.DialOption{
   192  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)),
   193  		grpc.WithUserAgent(testUserAgent),
   194  	}
   195  	te.startServer(&testServer{security: e.security})
   196  	defer te.tearDown()
   197  
   198  	// Trigger the extra-metadata-adding code path.
   199  	defer func(old string) { internal.GRPCResolverSchemeExtraMetadata = old }(internal.GRPCResolverSchemeExtraMetadata)
   200  	internal.GRPCResolverSchemeExtraMetadata = "passthrough"
   201  
   202  	cc := te.clientConn()
   203  	tc := testgrpc.NewTestServiceClient(cc)
   204  
   205  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   206  	defer cancel()
   207  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
   208  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   209  	}
   210  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)); err != nil {
   211  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   212  	}
   213  
   214  	want := []metadata.MD{
   215  		// First RPC doesn't have sub-content-type.
   216  		{"content-type": []string{"application/grpc"}},
   217  		// Second RPC has sub-content-type "proto".
   218  		{"content-type": []string{"application/grpc+proto"}},
   219  	}
   220  	if diff := cmp.Diff(want, b.pickExtraMDs); diff != "" {
   221  		t.Fatalf("unexpected diff in metadata (-want, +got): %s", diff)
   222  	}
   223  }
   224  
   225  func (s) TestDoneInfo(t *testing.T) {
   226  	for _, e := range listTestEnv() {
   227  		testDoneInfo(t, e)
   228  	}
   229  }
   230  
   231  func testDoneInfo(t *testing.T, e env) {
   232  	te := newTest(t, e)
   233  	b := &testBalancer{}
   234  	balancer.Register(b)
   235  	te.customDialOptions = []grpc.DialOption{
   236  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName)),
   237  	}
   238  	te.userAgent = failAppUA
   239  	te.startServer(&testServer{security: e.security})
   240  	defer te.tearDown()
   241  
   242  	cc := te.clientConn()
   243  	tc := testgrpc.NewTestServiceClient(cc)
   244  
   245  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   246  	defer cancel()
   247  	wantErr := detailedError
   248  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) {
   249  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", status.Convert(err).Proto(), status.Convert(wantErr).Proto())
   250  	}
   251  	if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   252  		t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
   253  	}
   254  
   255  	if len(b.doneInfo) < 1 || !testutils.StatusErrEqual(b.doneInfo[0].Err, wantErr) {
   256  		t.Fatalf("b.doneInfo = %v; want b.doneInfo[0].Err = %v", b.doneInfo, wantErr)
   257  	}
   258  	if len(b.doneInfo) < 2 || !reflect.DeepEqual(b.doneInfo[1].Trailer, testTrailerMetadata) {
   259  		t.Fatalf("b.doneInfo = %v; want b.doneInfo[1].Trailer = %v", b.doneInfo, testTrailerMetadata)
   260  	}
   261  	if len(b.pickInfos) != len(b.doneInfo) {
   262  		t.Fatalf("Got %d picks, but %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo))
   263  	}
   264  	// To test done() is always called, even if it's returned with a non-Ready
   265  	// SubConn.
   266  	//
   267  	// Stop server and at the same time send RPCs. There are chances that picker
   268  	// is not updated in time, causing a non-Ready SubConn to be returned.
   269  	finished := make(chan struct{})
   270  	go func() {
   271  		for i := 0; i < 20; i++ {
   272  			tc.UnaryCall(ctx, &testpb.SimpleRequest{})
   273  		}
   274  		close(finished)
   275  	}()
   276  	te.srv.Stop()
   277  	<-finished
   278  	if len(b.pickInfos) != len(b.doneInfo) {
   279  		t.Fatalf("Got %d picks, %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo))
   280  	}
   281  }
   282  
   283  const loadMDKey = "X-Endpoint-Load-Metrics-Bin"
   284  
   285  type testLoadParser struct{}
   286  
   287  func (*testLoadParser) Parse(md metadata.MD) any {
   288  	vs := md.Get(loadMDKey)
   289  	if len(vs) == 0 {
   290  		return nil
   291  	}
   292  	return vs[0]
   293  }
   294  
   295  func init() {
   296  	balancerload.SetParser(&testLoadParser{})
   297  }
   298  
   299  func (s) TestDoneLoads(t *testing.T) {
   300  	testDoneLoads(t)
   301  }
   302  
   303  func testDoneLoads(t *testing.T) {
   304  	b := &testBalancer{}
   305  	balancer.Register(b)
   306  
   307  	const testLoad = "test-load-,-should-be-orca"
   308  
   309  	ss := &stubserver.StubServer{
   310  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
   311  			grpc.SetTrailer(ctx, metadata.Pairs(loadMDKey, testLoad))
   312  			return &testpb.Empty{}, nil
   313  		},
   314  	}
   315  	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, testBalancerName))); err != nil {
   316  		t.Fatalf("error starting testing server: %v", err)
   317  	}
   318  	defer ss.Stop()
   319  
   320  	tc := testgrpc.NewTestServiceClient(ss.CC)
   321  
   322  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   323  	defer cancel()
   324  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   325  		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
   326  	}
   327  
   328  	piWant := []balancer.PickInfo{
   329  		{FullMethodName: "/grpc.testing.TestService/EmptyCall"},
   330  	}
   331  	if !reflect.DeepEqual(b.pickInfos, piWant) {
   332  		t.Fatalf("b.pickInfos = %v; want %v", b.pickInfos, piWant)
   333  	}
   334  
   335  	if len(b.doneInfo) < 1 {
   336  		t.Fatalf("b.doneInfo = %v, want length 1", b.doneInfo)
   337  	}
   338  	gotLoad, _ := b.doneInfo[0].ServerLoad.(string)
   339  	if gotLoad != testLoad {
   340  		t.Fatalf("b.doneInfo[0].ServerLoad = %v; want = %v", b.doneInfo[0].ServerLoad, testLoad)
   341  	}
   342  }
   343  
   344  type aiPicker struct {
   345  	result balancer.PickResult
   346  	err    error
   347  }
   348  
   349  func (aip *aiPicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) {
   350  	return aip.result, aip.err
   351  }
   352  
   353  // attrTransportCreds is a transport credential implementation which stores
   354  // Attributes from the ClientHandshakeInfo struct passed in the context locally
   355  // for the test to inspect.
   356  type attrTransportCreds struct {
   357  	credentials.TransportCredentials
   358  	attr *attributes.Attributes
   359  }
   360  
   361  func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   362  	ai := credentials.ClientHandshakeInfoFromContext(ctx)
   363  	ac.attr = ai.Attributes
   364  	return rawConn, nil, nil
   365  }
   366  func (ac *attrTransportCreds) Info() credentials.ProtocolInfo {
   367  	return credentials.ProtocolInfo{}
   368  }
   369  func (ac *attrTransportCreds) Clone() credentials.TransportCredentials {
   370  	return nil
   371  }
   372  
   373  // TestAddressAttributesInNewSubConn verifies that the Attributes passed from a
   374  // balancer in the resolver.Address that is passes to NewSubConn reaches all the
   375  // way to the ClientHandshake method of the credentials configured on the parent
   376  // channel.
   377  func (s) TestAddressAttributesInNewSubConn(t *testing.T) {
   378  	const (
   379  		testAttrKey      = "foo"
   380  		testAttrVal      = "bar"
   381  		attrBalancerName = "attribute-balancer"
   382  	)
   383  
   384  	// Register a stub balancer which adds attributes to the first address that
   385  	// it receives and then calls NewSubConn on it.
   386  	bf := stub.BalancerFuncs{
   387  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   388  			addrs := ccs.ResolverState.Addresses
   389  			if len(addrs) == 0 {
   390  				return nil
   391  			}
   392  
   393  			// Only use the first address.
   394  			attr := attributes.New(testAttrKey, testAttrVal)
   395  			addrs[0].Attributes = attr
   396  			var sc balancer.SubConn
   397  			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{
   398  				StateListener: func(state balancer.SubConnState) {
   399  					bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   400  				},
   401  			})
   402  			if err != nil {
   403  				return err
   404  			}
   405  			sc.Connect()
   406  			return nil
   407  		},
   408  	}
   409  	stub.Register(attrBalancerName, bf)
   410  	t.Logf("Registered balancer %s...", attrBalancerName)
   411  
   412  	r := manual.NewBuilderWithScheme("whatever")
   413  	t.Logf("Registered manual resolver with scheme %s...", r.Scheme())
   414  
   415  	lis, err := net.Listen("tcp", "localhost:0")
   416  	if err != nil {
   417  		t.Fatal(err)
   418  	}
   419  
   420  	s := grpc.NewServer()
   421  	testgrpc.RegisterTestServiceServer(s, &testServer{})
   422  	go s.Serve(lis)
   423  	defer s.Stop()
   424  	t.Logf("Started gRPC server at %s...", lis.Addr().String())
   425  
   426  	creds := &attrTransportCreds{}
   427  	dopts := []grpc.DialOption{
   428  		grpc.WithTransportCredentials(creds),
   429  		grpc.WithResolvers(r),
   430  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, attrBalancerName)),
   431  	}
   432  	cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
   433  	if err != nil {
   434  		t.Fatal(err)
   435  	}
   436  	defer cc.Close()
   437  	tc := testgrpc.NewTestServiceClient(cc)
   438  	t.Log("Created a ClientConn...")
   439  
   440  	// The first RPC should fail because there's no address.
   441  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   442  	defer cancel()
   443  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
   444  		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
   445  	}
   446  	t.Log("Made an RPC which was expected to fail...")
   447  
   448  	state := resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}
   449  	r.UpdateState(state)
   450  	t.Logf("Pushing resolver state update: %v through the manual resolver", state)
   451  
   452  	// The second RPC should succeed.
   453  	ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
   454  	defer cancel()
   455  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   456  		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   457  	}
   458  	t.Log("Made an RPC which succeeded...")
   459  
   460  	wantAttr := attributes.New(testAttrKey, testAttrVal)
   461  	if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
   462  		t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr)
   463  	}
   464  }
   465  
   466  // TestMetadataInAddressAttributes verifies that the metadata added to
   467  // address.Attributes will be sent with the RPCs.
   468  func (s) TestMetadataInAddressAttributes(t *testing.T) {
   469  	const (
   470  		testMDKey      = "test-md"
   471  		testMDValue    = "test-md-value"
   472  		mdBalancerName = "metadata-balancer"
   473  	)
   474  
   475  	// Register a stub balancer which adds metadata to the first address that it
   476  	// receives and then calls NewSubConn on it.
   477  	bf := stub.BalancerFuncs{
   478  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   479  			addrs := ccs.ResolverState.Addresses
   480  			if len(addrs) == 0 {
   481  				return nil
   482  			}
   483  			// Only use the first address.
   484  			var sc balancer.SubConn
   485  			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{
   486  				imetadata.Set(addrs[0], metadata.Pairs(testMDKey, testMDValue)),
   487  			}, balancer.NewSubConnOptions{
   488  				StateListener: func(state balancer.SubConnState) {
   489  					bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   490  				},
   491  			})
   492  			if err != nil {
   493  				return err
   494  			}
   495  			sc.Connect()
   496  			return nil
   497  		},
   498  	}
   499  	stub.Register(mdBalancerName, bf)
   500  	t.Logf("Registered balancer %s...", mdBalancerName)
   501  
   502  	testMDChan := make(chan []string, 1)
   503  	ss := &stubserver.StubServer{
   504  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   505  			md, ok := metadata.FromIncomingContext(ctx)
   506  			if ok {
   507  				select {
   508  				case testMDChan <- md[testMDKey]:
   509  				case <-ctx.Done():
   510  					return nil, ctx.Err()
   511  				}
   512  			}
   513  			return &testpb.Empty{}, nil
   514  		},
   515  	}
   516  	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(
   517  		fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, mdBalancerName),
   518  	)); err != nil {
   519  		t.Fatalf("Error starting endpoint server: %v", err)
   520  	}
   521  	defer ss.Stop()
   522  
   523  	// The RPC should succeed with the expected md.
   524  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   525  	defer cancel()
   526  	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   527  		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   528  	}
   529  	t.Log("Made an RPC which succeeded...")
   530  
   531  	// The server should receive the test metadata.
   532  	md1 := <-testMDChan
   533  	if len(md1) == 0 || md1[0] != testMDValue {
   534  		t.Fatalf("got md: %v, want %v", md1, []string{testMDValue})
   535  	}
   536  }
   537  
   538  // TestServersSwap creates two servers and verifies the client switches between
   539  // them when the name resolver reports the first and then the second.
   540  func (s) TestServersSwap(t *testing.T) {
   541  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   542  	defer cancel()
   543  
   544  	// Initialize servers
   545  	reg := func(username string) (addr string, cleanup func()) {
   546  		lis, err := net.Listen("tcp", "localhost:0")
   547  		if err != nil {
   548  			t.Fatalf("Error while listening. Err: %v", err)
   549  		}
   550  		s := grpc.NewServer()
   551  		ts := &funcServer{
   552  			unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   553  				return &testpb.SimpleResponse{Username: username}, nil
   554  			},
   555  		}
   556  		testgrpc.RegisterTestServiceServer(s, ts)
   557  		go s.Serve(lis)
   558  		return lis.Addr().String(), s.Stop
   559  	}
   560  	const one = "1"
   561  	addr1, cleanup := reg(one)
   562  	defer cleanup()
   563  	const two = "2"
   564  	addr2, cleanup := reg(two)
   565  	defer cleanup()
   566  
   567  	// Initialize client
   568  	r := manual.NewBuilderWithScheme("whatever")
   569  	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: addr1}}})
   570  	cc, err := grpc.DialContext(ctx, r.Scheme()+":///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
   571  	if err != nil {
   572  		t.Fatalf("Error creating client: %v", err)
   573  	}
   574  	defer cc.Close()
   575  	client := testgrpc.NewTestServiceClient(cc)
   576  
   577  	// Confirm we are connected to the first server
   578  	if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one {
   579  		t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
   580  	}
   581  
   582  	// Update resolver to report only the second server
   583  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: addr2}}})
   584  
   585  	// Loop until new RPCs talk to server two.
   586  	for i := 0; i < 2000; i++ {
   587  		if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
   588  			t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
   589  		} else if res.Username == two {
   590  			break // pass
   591  		}
   592  		time.Sleep(5 * time.Millisecond)
   593  	}
   594  }
   595  
   596  func (s) TestWaitForReady(t *testing.T) {
   597  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   598  	defer cancel()
   599  
   600  	// Initialize server
   601  	lis, err := net.Listen("tcp", "localhost:0")
   602  	if err != nil {
   603  		t.Fatalf("Error while listening. Err: %v", err)
   604  	}
   605  	s := grpc.NewServer()
   606  	defer s.Stop()
   607  	const one = "1"
   608  	ts := &funcServer{
   609  		unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
   610  			return &testpb.SimpleResponse{Username: one}, nil
   611  		},
   612  	}
   613  	testgrpc.RegisterTestServiceServer(s, ts)
   614  	go s.Serve(lis)
   615  
   616  	// Initialize client
   617  	r := manual.NewBuilderWithScheme("whatever")
   618  
   619  	cc, err := grpc.DialContext(ctx, r.Scheme()+":///", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
   620  	if err != nil {
   621  		t.Fatalf("Error creating client: %v", err)
   622  	}
   623  	defer cc.Close()
   624  	client := testgrpc.NewTestServiceClient(cc)
   625  
   626  	// Report an error so non-WFR RPCs will give up early.
   627  	r.CC.ReportError(errors.New("fake resolver error"))
   628  
   629  	// Ensure the client is not connected to anything and fails non-WFR RPCs.
   630  	if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable {
   631  		t.Fatalf("UnaryCall(_) = %v, %v; want _, Code()=%v", res, err, codes.Unavailable)
   632  	}
   633  
   634  	errChan := make(chan error, 1)
   635  	go func() {
   636  		if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.WaitForReady(true)); err != nil || res.Username != one {
   637  			errChan <- fmt.Errorf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
   638  		}
   639  		close(errChan)
   640  	}()
   641  
   642  	select {
   643  	case err := <-errChan:
   644  		t.Errorf("unexpected receive from errChan before addresses provided")
   645  		t.Fatal(err.Error())
   646  	case <-time.After(5 * time.Millisecond):
   647  	}
   648  
   649  	// Resolve the server.  The WFR RPC should unblock and use it.
   650  	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
   651  
   652  	if err := <-errChan; err != nil {
   653  		t.Fatal(err.Error())
   654  	}
   655  }
   656  
   657  // authorityOverrideTransportCreds returns the configured authority value in its
   658  // Info() method.
   659  type authorityOverrideTransportCreds struct {
   660  	credentials.TransportCredentials
   661  	authorityOverride string
   662  }
   663  
   664  func (ao *authorityOverrideTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
   665  	return rawConn, nil, nil
   666  }
   667  func (ao *authorityOverrideTransportCreds) Info() credentials.ProtocolInfo {
   668  	return credentials.ProtocolInfo{ServerName: ao.authorityOverride}
   669  }
   670  func (ao *authorityOverrideTransportCreds) Clone() credentials.TransportCredentials {
   671  	return &authorityOverrideTransportCreds{authorityOverride: ao.authorityOverride}
   672  }
   673  
   674  // TestAuthorityInBuildOptions tests that the Authority field in
   675  // balancer.BuildOptions is setup correctly from gRPC.
   676  func (s) TestAuthorityInBuildOptions(t *testing.T) {
   677  	const dialTarget = "test.server"
   678  
   679  	tests := []struct {
   680  		name          string
   681  		dopts         []grpc.DialOption
   682  		wantAuthority string
   683  	}{
   684  		{
   685  			name:          "authority from dial target",
   686  			dopts:         []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())},
   687  			wantAuthority: dialTarget,
   688  		},
   689  		{
   690  			name: "authority from dial option",
   691  			dopts: []grpc.DialOption{
   692  				grpc.WithTransportCredentials(insecure.NewCredentials()),
   693  				grpc.WithAuthority("authority-override"),
   694  			},
   695  			wantAuthority: "authority-override",
   696  		},
   697  		{
   698  			name:          "authority from transport creds",
   699  			dopts:         []grpc.DialOption{grpc.WithTransportCredentials(&authorityOverrideTransportCreds{authorityOverride: "authority-override-from-transport-creds"})},
   700  			wantAuthority: "authority-override-from-transport-creds",
   701  		},
   702  	}
   703  
   704  	for _, test := range tests {
   705  		t.Run(test.name, func(t *testing.T) {
   706  			authorityCh := make(chan string, 1)
   707  			bf := stub.BalancerFuncs{
   708  				UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   709  					select {
   710  					case authorityCh <- bd.BuildOptions.Authority:
   711  					default:
   712  					}
   713  
   714  					addrs := ccs.ResolverState.Addresses
   715  					if len(addrs) == 0 {
   716  						return nil
   717  					}
   718  
   719  					// Only use the first address.
   720  					var sc balancer.SubConn
   721  					sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{
   722  						StateListener: func(state balancer.SubConnState) {
   723  							bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
   724  						},
   725  					})
   726  					if err != nil {
   727  						return err
   728  					}
   729  					sc.Connect()
   730  					return nil
   731  				},
   732  			}
   733  			balancerName := "stub-balancer-" + test.name
   734  			stub.Register(balancerName, bf)
   735  			t.Logf("Registered balancer %s...", balancerName)
   736  
   737  			lis, err := testutils.LocalTCPListener()
   738  			if err != nil {
   739  				t.Fatal(err)
   740  			}
   741  
   742  			s := grpc.NewServer()
   743  			testgrpc.RegisterTestServiceServer(s, &testServer{})
   744  			go s.Serve(lis)
   745  			defer s.Stop()
   746  			t.Logf("Started gRPC server at %s...", lis.Addr().String())
   747  
   748  			r := manual.NewBuilderWithScheme("whatever")
   749  			t.Logf("Registered manual resolver with scheme %s...", r.Scheme())
   750  			r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
   751  
   752  			dopts := append([]grpc.DialOption{
   753  				grpc.WithResolvers(r),
   754  				grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, balancerName)),
   755  			}, test.dopts...)
   756  			cc, err := grpc.NewClient(r.Scheme()+":///"+dialTarget, dopts...)
   757  			if err != nil {
   758  				t.Fatal(err)
   759  			}
   760  			defer cc.Close()
   761  			tc := testgrpc.NewTestServiceClient(cc)
   762  			t.Log("Created a ClientConn...")
   763  
   764  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   765  			defer cancel()
   766  			if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   767  				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
   768  			}
   769  			t.Log("Made an RPC which succeeded...")
   770  
   771  			select {
   772  			case <-ctx.Done():
   773  				t.Fatal("timeout when waiting for Authority in balancer.BuildOptions")
   774  			case gotAuthority := <-authorityCh:
   775  				if gotAuthority != test.wantAuthority {
   776  					t.Fatalf("Authority in balancer.BuildOptions is %s, want %s", gotAuthority, test.wantAuthority)
   777  				}
   778  			}
   779  		})
   780  	}
   781  }
   782  
   783  // testCCWrapper wraps a balancer.ClientConn and intercepts UpdateState and
   784  // returns a custom picker which injects arbitrary metadata on a per-call basis.
   785  type testCCWrapper struct {
   786  	balancer.ClientConn
   787  }
   788  
   789  func (t *testCCWrapper) UpdateState(state balancer.State) {
   790  	state.Picker = &wrappedPicker{p: state.Picker}
   791  	t.ClientConn.UpdateState(state)
   792  }
   793  
   794  const (
   795  	metadataHeaderInjectedByBalancer    = "metadata-header-injected-by-balancer"
   796  	metadataHeaderInjectedByApplication = "metadata-header-injected-by-application"
   797  	metadataValueInjectedByBalancer     = "metadata-value-injected-by-balancer"
   798  	metadataValueInjectedByApplication  = "metadata-value-injected-by-application"
   799  )
   800  
   801  // wrappedPicker wraps the picker returned by the pick_first
   802  type wrappedPicker struct {
   803  	p balancer.Picker
   804  }
   805  
   806  func (wp *wrappedPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
   807  	res, err := wp.p.Pick(info)
   808  	if err != nil {
   809  		return balancer.PickResult{}, err
   810  	}
   811  
   812  	if res.Metadata == nil {
   813  		res.Metadata = metadata.Pairs(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
   814  	} else {
   815  		res.Metadata.Append(metadataHeaderInjectedByBalancer, metadataValueInjectedByBalancer)
   816  	}
   817  	return res, nil
   818  }
   819  
   820  // TestMetadataInPickResult tests the scenario where an LB policy inject
   821  // arbitrary metadata on a per-call basis and verifies that the injected
   822  // metadata makes it all the way to the server RPC handler.
   823  func (s) TestMetadataInPickResult(t *testing.T) {
   824  	t.Log("Starting test backend...")
   825  	mdChan := make(chan metadata.MD, 1)
   826  	ss := &stubserver.StubServer{
   827  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   828  			md, _ := metadata.FromIncomingContext(ctx)
   829  			select {
   830  			case mdChan <- md:
   831  			case <-ctx.Done():
   832  				return nil, ctx.Err()
   833  			}
   834  			return &testpb.Empty{}, nil
   835  		},
   836  	}
   837  	if err := ss.StartServer(); err != nil {
   838  		t.Fatalf("Starting test backend: %v", err)
   839  	}
   840  	defer ss.Stop()
   841  	t.Logf("Started test backend at %q", ss.Address)
   842  
   843  	// Register a test balancer that contains a pick_first balancer and forwards
   844  	// all calls from the ClientConn to it. For state updates from the
   845  	// pick_first balancer, it creates a custom picker which injects arbitrary
   846  	// metadata on a per-call basis.
   847  	stub.Register(t.Name(), stub.BalancerFuncs{
   848  		Init: func(bd *stub.BalancerData) {
   849  			cc := &testCCWrapper{ClientConn: bd.ClientConn}
   850  			bd.Data = balancer.Get(grpc.PickFirstBalancerName).Build(cc, bd.BuildOptions)
   851  		},
   852  		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
   853  			bal := bd.Data.(balancer.Balancer)
   854  			return bal.UpdateClientConnState(ccs)
   855  		},
   856  	})
   857  
   858  	t.Log("Creating ClientConn to test backend...")
   859  	r := manual.NewBuilderWithScheme("whatever")
   860  	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address}}})
   861  	dopts := []grpc.DialOption{
   862  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   863  		grpc.WithResolvers(r),
   864  		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, t.Name())),
   865  	}
   866  	cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
   867  	if err != nil {
   868  		t.Fatalf("grpc.NewClient(): %v", err)
   869  	}
   870  	defer cc.Close()
   871  	tc := testgrpc.NewTestServiceClient(cc)
   872  
   873  	t.Log("Making EmptyCall() RPC with custom metadata...")
   874  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   875  	defer cancel()
   876  	md := metadata.Pairs(metadataHeaderInjectedByApplication, metadataValueInjectedByApplication)
   877  	ctx = metadata.NewOutgoingContext(ctx, md)
   878  	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
   879  		t.Fatalf("EmptyCall() RPC: %v", err)
   880  	}
   881  	t.Log("EmptyCall() RPC succeeded")
   882  
   883  	t.Log("Waiting for custom metadata to be received at the test backend...")
   884  	var gotMD metadata.MD
   885  	select {
   886  	case gotMD = <-mdChan:
   887  	case <-ctx.Done():
   888  		t.Fatalf("Timed out waiting for custom metadata to be received at the test backend")
   889  	}
   890  
   891  	t.Log("Verifying custom metadata added by the client application is received at the test backend...")
   892  	wantMDVal := []string{metadataValueInjectedByApplication}
   893  	gotMDVal := gotMD.Get(metadataHeaderInjectedByApplication)
   894  	if !cmp.Equal(gotMDVal, wantMDVal) {
   895  		t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
   896  	}
   897  
   898  	t.Log("Verifying custom metadata added by the LB policy is received at the test backend...")
   899  	wantMDVal = []string{metadataValueInjectedByBalancer}
   900  	gotMDVal = gotMD.Get(metadataHeaderInjectedByBalancer)
   901  	if !cmp.Equal(gotMDVal, wantMDVal) {
   902  		t.Fatalf("Mismatch in custom metadata received at test backend, got: %v, want %v", gotMDVal, wantMDVal)
   903  	}
   904  }
   905  
   906  // producerTestBalancerBuilder and producerTestBalancer start a producer which
   907  // makes an RPC before the subconn is READY, then connects the subconn, and
   908  // pushes the resulting error (expected to be nil) to rpcErrChan.
   909  type producerTestBalancerBuilder struct {
   910  	rpcErrChan chan error
   911  	ctxChan    chan context.Context
   912  	connect    bool
   913  }
   914  
   915  func (bb *producerTestBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
   916  	return &producerTestBalancer{cc: cc, rpcErrChan: bb.rpcErrChan, ctxChan: bb.ctxChan, connect: bb.connect}
   917  }
   918  
   919  const producerTestBalancerName = "producer_test_balancer"
   920  
   921  func (bb *producerTestBalancerBuilder) Name() string { return producerTestBalancerName }
   922  
   923  type producerTestBalancer struct {
   924  	cc         balancer.ClientConn
   925  	rpcErrChan chan error
   926  	ctxChan    chan context.Context
   927  	connect    bool
   928  }
   929  
   930  func (b *producerTestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
   931  	// Create the subconn, but don't connect it.
   932  	sc, err := b.cc.NewSubConn(ccs.ResolverState.Addresses, balancer.NewSubConnOptions{})
   933  	if err != nil {
   934  		return fmt.Errorf("error creating subconn: %v", err)
   935  	}
   936  
   937  	// Create the producer.  This will call the producer builder's Build
   938  	// method, which will try to start an RPC in a goroutine.
   939  	p := &testProducerBuilder{start: grpcsync.NewEvent(), rpcErrChan: b.rpcErrChan, ctxChan: b.ctxChan}
   940  	sc.GetOrBuildProducer(p)
   941  
   942  	// Wait here until the producer is about to perform the RPC, which should
   943  	// block until connected.
   944  	<-p.start.Done()
   945  
   946  	// Ensure the error chan doesn't get anything on it before we connect the
   947  	// subconn.
   948  	select {
   949  	case err := <-b.rpcErrChan:
   950  		go func() { b.rpcErrChan <- fmt.Errorf("Got unexpected data on rpcErrChan: %v", err) }()
   951  	default:
   952  	}
   953  
   954  	if b.connect {
   955  		// Now we can connect, which will unblock the RPC above.
   956  		sc.Connect()
   957  	}
   958  
   959  	// The stub server requires a READY picker to be reported, to unblock its
   960  	// Start method.  We won't make RPCs in our test, so a nil picker is okay.
   961  	b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Ready, Picker: nil})
   962  	return nil
   963  }
   964  
   965  func (b *producerTestBalancer) ResolverError(err error) {
   966  	panic(fmt.Sprintf("Unexpected resolver error: %v", err))
   967  }
   968  
   969  func (b *producerTestBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) {}
   970  func (b *producerTestBalancer) Close()                                                     {}
   971  
   972  type testProducerBuilder struct {
   973  	start      *grpcsync.Event
   974  	rpcErrChan chan error
   975  	ctxChan    chan context.Context
   976  }
   977  
   978  func (b *testProducerBuilder) Build(cci any) (balancer.Producer, func()) {
   979  	c := testgrpc.NewTestServiceClient(cci.(grpc.ClientConnInterface))
   980  	// Perform the RPC in a goroutine instead of during build because the
   981  	// subchannel's mutex is held here.
   982  	go func() {
   983  		ctx := <-b.ctxChan
   984  		b.start.Fire()
   985  		_, err := c.EmptyCall(ctx, &testpb.Empty{})
   986  		b.rpcErrChan <- err
   987  	}()
   988  	return nil, func() {}
   989  }
   990  
   991  // TestBalancerProducerBlockUntilReady tests that we get no RPC errors from
   992  // producers when subchannels aren't ready.
   993  func (s) TestBalancerProducerBlockUntilReady(t *testing.T) {
   994  	// rpcErrChan is given to the LB policy to report the status of the
   995  	// producer's one RPC.
   996  	ctxChan := make(chan context.Context, 1)
   997  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   998  	defer cancel()
   999  	ctxChan <- ctx
  1000  
  1001  	rpcErrChan := make(chan error)
  1002  	balancer.Register(&producerTestBalancerBuilder{rpcErrChan: rpcErrChan, ctxChan: ctxChan, connect: true})
  1003  
  1004  	ss := &stubserver.StubServer{
  1005  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
  1006  			return &testpb.Empty{}, nil
  1007  		},
  1008  	}
  1009  
  1010  	// Start the server & client with the test producer LB policy.
  1011  	svcCfg := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, producerTestBalancerName)
  1012  	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(svcCfg)); err != nil {
  1013  		t.Fatalf("Error starting testing server: %v", err)
  1014  	}
  1015  	defer ss.Stop()
  1016  
  1017  	// Receive the error from the producer's RPC, which should be nil.
  1018  	if err := <-rpcErrChan; err != nil {
  1019  		t.Fatalf("Received unexpected error from producer RPC: %v", err)
  1020  	}
  1021  }
  1022  
  1023  // TestBalancerProducerHonorsContext tests that producers that perform RPC get
  1024  // context errors correctly.
  1025  func (s) TestBalancerProducerHonorsContext(t *testing.T) {
  1026  	// rpcErrChan is given to the LB policy to report the status of the
  1027  	// producer's one RPC.
  1028  	ctxChan := make(chan context.Context, 1)
  1029  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1030  	ctxChan <- ctx
  1031  
  1032  	rpcErrChan := make(chan error)
  1033  	balancer.Register(&producerTestBalancerBuilder{rpcErrChan: rpcErrChan, ctxChan: ctxChan, connect: false})
  1034  
  1035  	ss := &stubserver.StubServer{
  1036  		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
  1037  			return &testpb.Empty{}, nil
  1038  		},
  1039  	}
  1040  
  1041  	// Start the server & client with the test producer LB policy.
  1042  	svcCfg := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, producerTestBalancerName)
  1043  	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(svcCfg)); err != nil {
  1044  		t.Fatalf("Error starting testing server: %v", err)
  1045  	}
  1046  	defer ss.Stop()
  1047  
  1048  	cancel()
  1049  
  1050  	// Receive the error from the producer's RPC, which should be canceled.
  1051  	if err := <-rpcErrChan; status.Code(err) != codes.Canceled {
  1052  		t.Fatalf("RPC error: %v; want status.Code(err)=%v", err, codes.Canceled)
  1053  	}
  1054  }
  1055  
  1056  // TestSubConnShutdown confirms that the Shutdown method on subconns and
  1057  // RemoveSubConn method on ClientConn properly initiates subconn shutdown.
  1058  func (s) TestSubConnShutdown(t *testing.T) {
  1059  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
  1060  	defer cancel()
  1061  
  1062  	testCases := []struct {
  1063  		name     string
  1064  		shutdown func(cc balancer.ClientConn, sc balancer.SubConn)
  1065  	}{{
  1066  		name: "ClientConn.RemoveSubConn",
  1067  		shutdown: func(cc balancer.ClientConn, sc balancer.SubConn) {
  1068  			cc.RemoveSubConn(sc)
  1069  		},
  1070  	}, {
  1071  		name: "SubConn.Shutdown",
  1072  		shutdown: func(_ balancer.ClientConn, sc balancer.SubConn) {
  1073  			sc.Shutdown()
  1074  		},
  1075  	}}
  1076  
  1077  	for _, tc := range testCases {
  1078  		t.Run(tc.name, func(t *testing.T) {
  1079  			gotShutdown := grpcsync.NewEvent()
  1080  
  1081  			bf := stub.BalancerFuncs{
  1082  				UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
  1083  					var sc balancer.SubConn
  1084  					opts := balancer.NewSubConnOptions{
  1085  						StateListener: func(scs balancer.SubConnState) {
  1086  							switch scs.ConnectivityState {
  1087  							case connectivity.Connecting:
  1088  								// Ignored.
  1089  							case connectivity.Ready:
  1090  								tc.shutdown(bd.ClientConn, sc)
  1091  							case connectivity.Shutdown:
  1092  								gotShutdown.Fire()
  1093  							default:
  1094  								t.Errorf("got unexpected state %q in listener", scs.ConnectivityState)
  1095  							}
  1096  						},
  1097  					}
  1098  					sc, err := bd.ClientConn.NewSubConn(ccs.ResolverState.Addresses, opts)
  1099  					if err != nil {
  1100  						return err
  1101  					}
  1102  					sc.Connect()
  1103  					// Report the state as READY to unblock ss.Start(), which waits for ready.
  1104  					bd.ClientConn.UpdateState(balancer.State{ConnectivityState: connectivity.Ready})
  1105  					return nil
  1106  				},
  1107  			}
  1108  
  1109  			testBalName := "shutdown-test-balancer-" + tc.name
  1110  			stub.Register(testBalName, bf)
  1111  			t.Logf("Registered balancer %s...", testBalName)
  1112  
  1113  			ss := &stubserver.StubServer{}
  1114  			if err := ss.Start(nil, grpc.WithDefaultServiceConfig(
  1115  				fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, testBalName),
  1116  			)); err != nil {
  1117  				t.Fatalf("Error starting endpoint server: %v", err)
  1118  			}
  1119  			defer ss.Stop()
  1120  
  1121  			select {
  1122  			case <-gotShutdown.Done():
  1123  				// Success
  1124  			case <-ctx.Done():
  1125  				t.Fatalf("Timed out waiting for gotShutdown to be fired.")
  1126  			}
  1127  		})
  1128  	}
  1129  }
  1130  

View as plain text