...

Source file src/github.com/letsencrypt/boulder/grpc/internal/resolver/dns/dns_resolver_test.go

Documentation: github.com/letsencrypt/boulder/grpc/internal/resolver/dns

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package dns
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"os"
    27  	"slices"
    28  	"strings"
    29  	"sync"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/letsencrypt/boulder/grpc/internal/leakcheck"
    34  	"github.com/letsencrypt/boulder/grpc/internal/testutils"
    35  	"github.com/letsencrypt/boulder/test"
    36  	"google.golang.org/grpc/balancer"
    37  	"google.golang.org/grpc/resolver"
    38  )
    39  
    40  func TestMain(m *testing.M) {
    41  	// Set a non-zero duration only for tests which are actually testing that
    42  	// feature.
    43  	replaceDNSResRate(time.Duration(0)) // No need to clean up since we os.Exit
    44  	overrideDefaultResolver(false)      // No need to clean up since we os.Exit
    45  	code := m.Run()
    46  	os.Exit(code)
    47  }
    48  
    49  const (
    50  	txtBytesLimit           = 255
    51  	defaultTestTimeout      = 10 * time.Second
    52  	defaultTestShortTimeout = 10 * time.Millisecond
    53  )
    54  
    55  type testClientConn struct {
    56  	resolver.ClientConn // For unimplemented functions
    57  	target              string
    58  	m1                  sync.Mutex
    59  	state               resolver.State
    60  	updateStateCalls    int
    61  	errChan             chan error
    62  	updateStateErr      error
    63  }
    64  
    65  func (t *testClientConn) UpdateState(s resolver.State) error {
    66  	t.m1.Lock()
    67  	defer t.m1.Unlock()
    68  	t.state = s
    69  	t.updateStateCalls++
    70  	// This error determines whether DNS Resolver actually decides to exponentially backoff or not.
    71  	// This can be any error.
    72  	return t.updateStateErr
    73  }
    74  
    75  func (t *testClientConn) getState() (resolver.State, int) {
    76  	t.m1.Lock()
    77  	defer t.m1.Unlock()
    78  	return t.state, t.updateStateCalls
    79  }
    80  
    81  func (t *testClientConn) ReportError(err error) {
    82  	t.errChan <- err
    83  }
    84  
    85  type testResolver struct {
    86  	// A write to this channel is made when this resolver receives a resolution
    87  	// request. Tests can rely on reading from this channel to be notified about
    88  	// resolution requests instead of sleeping for a predefined period of time.
    89  	lookupHostCh *testutils.Channel
    90  }
    91  
    92  func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
    93  	if tr.lookupHostCh != nil {
    94  		tr.lookupHostCh.Send(nil)
    95  	}
    96  	return hostLookup(host)
    97  }
    98  
    99  func (*testResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
   100  	return srvLookup(service, proto, name)
   101  }
   102  
   103  // overrideDefaultResolver overrides the defaultResolver used by the code with
   104  // an instance of the testResolver. pushOnLookup controls whether the
   105  // testResolver created here pushes lookupHost events on its channel.
   106  func overrideDefaultResolver(pushOnLookup bool) func() {
   107  	oldResolver := defaultResolver
   108  
   109  	var lookupHostCh *testutils.Channel
   110  	if pushOnLookup {
   111  		lookupHostCh = testutils.NewChannel()
   112  	}
   113  	defaultResolver = &testResolver{lookupHostCh: lookupHostCh}
   114  
   115  	return func() {
   116  		defaultResolver = oldResolver
   117  	}
   118  }
   119  
   120  func replaceDNSResRate(d time.Duration) func() {
   121  	oldMinDNSResRate := minDNSResRate
   122  	minDNSResRate = d
   123  
   124  	return func() {
   125  		minDNSResRate = oldMinDNSResRate
   126  	}
   127  }
   128  
   129  var hostLookupTbl = struct {
   130  	sync.Mutex
   131  	tbl map[string][]string
   132  }{
   133  	tbl: map[string][]string{
   134  		"ipv4.single.fake": {"2.4.6.8"},
   135  		"ipv4.multi.fake":  {"1.2.3.4", "5.6.7.8", "9.10.11.12"},
   136  		"ipv6.single.fake": {"2607:f8b0:400a:801::1001"},
   137  		"ipv6.multi.fake":  {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"},
   138  	},
   139  }
   140  
   141  func hostLookup(host string) ([]string, error) {
   142  	hostLookupTbl.Lock()
   143  	defer hostLookupTbl.Unlock()
   144  	if addrs, ok := hostLookupTbl.tbl[host]; ok {
   145  		return addrs, nil
   146  	}
   147  	return nil, &net.DNSError{
   148  		Err:         "hostLookup error",
   149  		Name:        host,
   150  		Server:      "fake",
   151  		IsTemporary: true,
   152  	}
   153  }
   154  
   155  var srvLookupTbl = struct {
   156  	sync.Mutex
   157  	tbl map[string][]*net.SRV
   158  }{
   159  	tbl: map[string][]*net.SRV{
   160  		"_foo._tcp.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}},
   161  		"_foo._tcp.ipv4.multi.fake":  {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}},
   162  		"_foo._tcp.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}},
   163  		"_foo._tcp.ipv6.multi.fake":  {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}},
   164  	},
   165  }
   166  
   167  func srvLookup(service, proto, name string) (string, []*net.SRV, error) {
   168  	cname := "_" + service + "._" + proto + "." + name
   169  	srvLookupTbl.Lock()
   170  	defer srvLookupTbl.Unlock()
   171  	if srvs, cnt := srvLookupTbl.tbl[cname]; cnt {
   172  		return cname, srvs, nil
   173  	}
   174  	return "", nil, &net.DNSError{
   175  		Err:         "srvLookup error",
   176  		Name:        cname,
   177  		Server:      "fake",
   178  		IsTemporary: true,
   179  	}
   180  }
   181  
   182  func TestResolve(t *testing.T) {
   183  	testDNSResolver(t)
   184  	testDNSResolveNow(t)
   185  }
   186  
   187  func testDNSResolver(t *testing.T) {
   188  	defer func(nt func(d time.Duration) *time.Timer) {
   189  		newTimer = nt
   190  	}(newTimer)
   191  	newTimer = func(_ time.Duration) *time.Timer {
   192  		// Will never fire on its own, will protect from triggering exponential backoff.
   193  		return time.NewTimer(time.Hour)
   194  	}
   195  	tests := []struct {
   196  		target   string
   197  		addrWant []resolver.Address
   198  	}{
   199  		{
   200  			"foo.ipv4.single.fake",
   201  			[]resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}},
   202  		},
   203  		{
   204  			"foo.ipv4.multi.fake",
   205  			[]resolver.Address{
   206  				{Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"},
   207  				{Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"},
   208  				{Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"},
   209  			},
   210  		},
   211  		{
   212  			"foo.ipv6.single.fake",
   213  			[]resolver.Address{{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.single.fake"}},
   214  		},
   215  		{
   216  			"foo.ipv6.multi.fake",
   217  			[]resolver.Address{
   218  				{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.multi.fake"},
   219  				{Addr: "[2607:f8b0:400a:801::1002]:1234", ServerName: "ipv6.multi.fake"},
   220  				{Addr: "[2607:f8b0:400a:801::1003]:1234", ServerName: "ipv6.multi.fake"},
   221  			},
   222  		},
   223  	}
   224  
   225  	for _, a := range tests {
   226  		b := NewDefaultSRVBuilder()
   227  		cc := &testClientConn{target: a.target}
   228  		r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{})
   229  		if err != nil {
   230  			t.Fatalf("%v\n", err)
   231  		}
   232  		var state resolver.State
   233  		var cnt int
   234  		for i := 0; i < 2000; i++ {
   235  			state, cnt = cc.getState()
   236  			if cnt > 0 {
   237  				break
   238  			}
   239  			time.Sleep(time.Millisecond)
   240  		}
   241  		if cnt == 0 {
   242  			t.Fatalf("UpdateState not called after 2s; aborting")
   243  		}
   244  
   245  		if !slices.Equal(a.addrWant, state.Addresses) {
   246  			t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant)
   247  		}
   248  		r.Close()
   249  	}
   250  }
   251  
   252  // DNS Resolver immediately starts polling on an error from grpc. This should continue until the ClientConn doesn't
   253  // send back an error from updating the DNS Resolver's state.
   254  func TestDNSResolverExponentialBackoff(t *testing.T) {
   255  	defer leakcheck.Check(t)
   256  	defer func(nt func(d time.Duration) *time.Timer) {
   257  		newTimer = nt
   258  	}(newTimer)
   259  	timerChan := testutils.NewChannel()
   260  	newTimer = func(d time.Duration) *time.Timer {
   261  		// Will never fire on its own, allows this test to call timer immediately.
   262  		t := time.NewTimer(time.Hour)
   263  		timerChan.Send(t)
   264  		return t
   265  	}
   266  	target := "foo.ipv4.single.fake"
   267  	wantAddr := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}}
   268  
   269  	b := NewDefaultSRVBuilder()
   270  	cc := &testClientConn{target: target}
   271  	// Cause ClientConn to return an error.
   272  	cc.updateStateErr = balancer.ErrBadResolverState
   273  	r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{})
   274  	if err != nil {
   275  		t.Fatalf("Error building resolver for target %v: %v", target, err)
   276  	}
   277  	defer r.Close()
   278  	var state resolver.State
   279  	var cnt int
   280  	for i := 0; i < 2000; i++ {
   281  		state, cnt = cc.getState()
   282  		if cnt > 0 {
   283  			break
   284  		}
   285  		time.Sleep(time.Millisecond)
   286  	}
   287  	if cnt == 0 {
   288  		t.Fatalf("UpdateState not called after 2s; aborting")
   289  	}
   290  	if !slices.Equal(wantAddr, state.Addresses) {
   291  		t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, target)
   292  	}
   293  	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   294  	defer ctxCancel()
   295  	// Cause timer to go off 10 times, and see if it calls updateState() correctly.
   296  	for i := 0; i < 10; i++ {
   297  		timer, err := timerChan.Receive(ctx)
   298  		if err != nil {
   299  			t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   300  		}
   301  		timerPointer := timer.(*time.Timer)
   302  		timerPointer.Reset(0)
   303  	}
   304  	// Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call
   305  	// ClientConn update state.
   306  	deadline := time.Now().Add(defaultTestTimeout)
   307  	for {
   308  		cc.m1.Lock()
   309  		got := cc.updateStateCalls
   310  		cc.m1.Unlock()
   311  		if got == 11 {
   312  			break
   313  		}
   314  
   315  		if time.Now().After(deadline) {
   316  			t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got)
   317  		}
   318  
   319  		time.Sleep(time.Millisecond)
   320  	}
   321  
   322  	// Update resolver.ClientConn to not return an error anymore - this should stop it from backing off.
   323  	cc.updateStateErr = nil
   324  	timer, err := timerChan.Receive(ctx)
   325  	if err != nil {
   326  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   327  	}
   328  	timerPointer := timer.(*time.Timer)
   329  	timerPointer.Reset(0)
   330  	// Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call
   331  	// ClientConn update state the final time. The DNS Resolver should then stop polling.
   332  	deadline = time.Now().Add(defaultTestTimeout)
   333  	for {
   334  		cc.m1.Lock()
   335  		got := cc.updateStateCalls
   336  		cc.m1.Unlock()
   337  		if got == 12 {
   338  			break
   339  		}
   340  
   341  		if time.Now().After(deadline) {
   342  			t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got)
   343  		}
   344  
   345  		_, err := timerChan.ReceiveOrFail()
   346  		if err {
   347  			t.Fatalf("Should not poll again after Client Conn stops returning error.")
   348  		}
   349  
   350  		time.Sleep(time.Millisecond)
   351  	}
   352  }
   353  
   354  func mutateTbl(target string) func() {
   355  	hostLookupTbl.Lock()
   356  	oldHostTblEntry := hostLookupTbl.tbl[target]
   357  
   358  	// Remove the last address from the target's entry.
   359  	hostLookupTbl.tbl[target] = hostLookupTbl.tbl[target][:len(oldHostTblEntry)-1]
   360  	hostLookupTbl.Unlock()
   361  
   362  	return func() {
   363  		hostLookupTbl.Lock()
   364  		hostLookupTbl.tbl[target] = oldHostTblEntry
   365  		hostLookupTbl.Unlock()
   366  	}
   367  }
   368  
   369  func testDNSResolveNow(t *testing.T) {
   370  	defer leakcheck.Check(t)
   371  	defer func(nt func(d time.Duration) *time.Timer) {
   372  		newTimer = nt
   373  	}(newTimer)
   374  	newTimer = func(_ time.Duration) *time.Timer {
   375  		// Will never fire on its own, will protect from triggering exponential backoff.
   376  		return time.NewTimer(time.Hour)
   377  	}
   378  	tests := []struct {
   379  		target   string
   380  		addrWant []resolver.Address
   381  		addrNext []resolver.Address
   382  	}{
   383  		{
   384  			"foo.ipv4.multi.fake",
   385  			[]resolver.Address{
   386  				{Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"},
   387  				{Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"},
   388  				{Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"},
   389  			},
   390  			[]resolver.Address{
   391  				{Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"},
   392  				{Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"},
   393  			},
   394  		},
   395  	}
   396  
   397  	for _, a := range tests {
   398  		b := NewDefaultSRVBuilder()
   399  		cc := &testClientConn{target: a.target}
   400  		r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{})
   401  		if err != nil {
   402  			t.Fatalf("%v\n", err)
   403  		}
   404  		defer r.Close()
   405  		var state resolver.State
   406  		var cnt int
   407  		for i := 0; i < 2000; i++ {
   408  			state, cnt = cc.getState()
   409  			if cnt > 0 {
   410  				break
   411  			}
   412  			time.Sleep(time.Millisecond)
   413  		}
   414  		if cnt == 0 {
   415  			t.Fatalf("UpdateState not called after 2s; aborting.  state=%v", state)
   416  		}
   417  		if !slices.Equal(a.addrWant, state.Addresses) {
   418  			t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant)
   419  		}
   420  
   421  		revertTbl := mutateTbl(strings.TrimPrefix(a.target, "foo."))
   422  		r.ResolveNow(resolver.ResolveNowOptions{})
   423  		for i := 0; i < 2000; i++ {
   424  			state, cnt = cc.getState()
   425  			if cnt == 2 {
   426  				break
   427  			}
   428  			time.Sleep(time.Millisecond)
   429  		}
   430  		if cnt != 2 {
   431  			t.Fatalf("UpdateState not called after 2s; aborting.  state=%v", state)
   432  		}
   433  		if !slices.Equal(a.addrNext, state.Addresses) {
   434  			t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrNext)
   435  		}
   436  		revertTbl()
   437  	}
   438  }
   439  
   440  func TestDNSResolverRetry(t *testing.T) {
   441  	defer func(nt func(d time.Duration) *time.Timer) {
   442  		newTimer = nt
   443  	}(newTimer)
   444  	newTimer = func(d time.Duration) *time.Timer {
   445  		// Will never fire on its own, will protect from triggering exponential backoff.
   446  		return time.NewTimer(time.Hour)
   447  	}
   448  	b := NewDefaultSRVBuilder()
   449  	target := "foo.ipv4.single.fake"
   450  	cc := &testClientConn{target: target}
   451  	r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{})
   452  	if err != nil {
   453  		t.Fatalf("%v\n", err)
   454  	}
   455  	defer r.Close()
   456  	var state resolver.State
   457  	for i := 0; i < 2000; i++ {
   458  		state, _ = cc.getState()
   459  		if len(state.Addresses) == 1 {
   460  			break
   461  		}
   462  		time.Sleep(time.Millisecond)
   463  	}
   464  	if len(state.Addresses) != 1 {
   465  		t.Fatalf("UpdateState not called with 1 address after 2s; aborting.  state=%v", state)
   466  	}
   467  	want := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}}
   468  	if !slices.Equal(want, state.Addresses) {
   469  		t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want)
   470  	}
   471  	// mutate the host lookup table so the target has 0 address returned.
   472  	revertTbl := mutateTbl(strings.TrimPrefix(target, "foo."))
   473  	// trigger a resolve that will get empty address list
   474  	r.ResolveNow(resolver.ResolveNowOptions{})
   475  	for i := 0; i < 2000; i++ {
   476  		state, _ = cc.getState()
   477  		if len(state.Addresses) == 0 {
   478  			break
   479  		}
   480  		time.Sleep(time.Millisecond)
   481  	}
   482  	if len(state.Addresses) != 0 {
   483  		t.Fatalf("UpdateState not called with 0 address after 2s; aborting.  state=%v", state)
   484  	}
   485  	revertTbl()
   486  	// wait for the retry to happen in two seconds.
   487  	r.ResolveNow(resolver.ResolveNowOptions{})
   488  	for i := 0; i < 2000; i++ {
   489  		state, _ = cc.getState()
   490  		if len(state.Addresses) == 1 {
   491  			break
   492  		}
   493  		time.Sleep(time.Millisecond)
   494  	}
   495  	if !slices.Equal(want, state.Addresses) {
   496  		t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want)
   497  	}
   498  }
   499  
   500  func TestCustomAuthority(t *testing.T) {
   501  	defer leakcheck.Check(t)
   502  	defer func(nt func(d time.Duration) *time.Timer) {
   503  		newTimer = nt
   504  	}(newTimer)
   505  	newTimer = func(d time.Duration) *time.Timer {
   506  		// Will never fire on its own, will protect from triggering exponential backoff.
   507  		return time.NewTimer(time.Hour)
   508  	}
   509  
   510  	tests := []struct {
   511  		authority     string
   512  		authorityWant string
   513  		expectError   bool
   514  	}{
   515  		{
   516  			"4.3.2.1:" + defaultDNSSvrPort,
   517  			"4.3.2.1:" + defaultDNSSvrPort,
   518  			false,
   519  		},
   520  		{
   521  			"4.3.2.1:123",
   522  			"4.3.2.1:123",
   523  			false,
   524  		},
   525  		{
   526  			"4.3.2.1",
   527  			"4.3.2.1:" + defaultDNSSvrPort,
   528  			false,
   529  		},
   530  		{
   531  			"::1",
   532  			"[::1]:" + defaultDNSSvrPort,
   533  			false,
   534  		},
   535  		{
   536  			"[::1]",
   537  			"[::1]:" + defaultDNSSvrPort,
   538  			false,
   539  		},
   540  		{
   541  			"[::1]:123",
   542  			"[::1]:123",
   543  			false,
   544  		},
   545  		{
   546  			"dnsserver.com",
   547  			"dnsserver.com:" + defaultDNSSvrPort,
   548  			false,
   549  		},
   550  		{
   551  			":123",
   552  			"localhost:123",
   553  			false,
   554  		},
   555  		{
   556  			":",
   557  			"",
   558  			true,
   559  		},
   560  		{
   561  			"[::1]:",
   562  			"",
   563  			true,
   564  		},
   565  		{
   566  			"dnsserver.com:",
   567  			"",
   568  			true,
   569  		},
   570  	}
   571  	oldcustomAuthorityDialer := customAuthorityDialer
   572  	defer func() {
   573  		customAuthorityDialer = oldcustomAuthorityDialer
   574  	}()
   575  
   576  	for _, a := range tests {
   577  		errChan := make(chan error, 1)
   578  		customAuthorityDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
   579  			if authority != a.authorityWant {
   580  				errChan <- fmt.Errorf("wrong custom authority passed to resolver. input: %s expected: %s actual: %s", a.authority, a.authorityWant, authority)
   581  			} else {
   582  				errChan <- nil
   583  			}
   584  			return func(ctx context.Context, network, address string) (net.Conn, error) {
   585  				return nil, errors.New("no need to dial")
   586  			}
   587  		}
   588  
   589  		mockEndpointTarget := "foo.bar.com"
   590  		b := NewDefaultSRVBuilder()
   591  		cc := &testClientConn{target: mockEndpointTarget, errChan: make(chan error, 1)}
   592  		target := resolver.Target{
   593  			Authority: a.authority,
   594  			URL:       *testutils.MustParseURL(fmt.Sprintf("scheme://%s/%s", a.authority, mockEndpointTarget)),
   595  		}
   596  		r, err := b.Build(target, cc, resolver.BuildOptions{})
   597  
   598  		if err == nil {
   599  			r.Close()
   600  
   601  			err = <-errChan
   602  			if err != nil {
   603  				t.Errorf(err.Error())
   604  			}
   605  
   606  			if a.expectError {
   607  				t.Errorf("custom authority should have caused an error: %s", a.authority)
   608  			}
   609  		} else if !a.expectError {
   610  			t.Errorf("unexpected error using custom authority %s: %s", a.authority, err)
   611  		}
   612  	}
   613  }
   614  
   615  // TestRateLimitedResolve exercises the rate limit enforced on re-resolution
   616  // requests. It sets the re-resolution rate to a small value and repeatedly
   617  // calls ResolveNow() and ensures only the expected number of resolution
   618  // requests are made.
   619  func TestRateLimitedResolve(t *testing.T) {
   620  	defer leakcheck.Check(t)
   621  	defer func(nt func(d time.Duration) *time.Timer) {
   622  		newTimer = nt
   623  	}(newTimer)
   624  	newTimer = func(d time.Duration) *time.Timer {
   625  		// Will never fire on its own, will protect from triggering exponential
   626  		// backoff.
   627  		return time.NewTimer(time.Hour)
   628  	}
   629  	defer func(nt func(d time.Duration) *time.Timer) {
   630  		newTimerDNSResRate = nt
   631  	}(newTimerDNSResRate)
   632  
   633  	timerChan := testutils.NewChannel()
   634  	newTimerDNSResRate = func(d time.Duration) *time.Timer {
   635  		// Will never fire on its own, allows this test to call timer
   636  		// immediately.
   637  		t := time.NewTimer(time.Hour)
   638  		timerChan.Send(t)
   639  		return t
   640  	}
   641  
   642  	// Create a new testResolver{} for this test because we want the exact count
   643  	// of the number of times the resolver was invoked.
   644  	nc := overrideDefaultResolver(true)
   645  	defer nc()
   646  
   647  	target := "foo.ipv4.single.fake"
   648  	b := NewDefaultSRVBuilder()
   649  	cc := &testClientConn{target: target}
   650  
   651  	r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{})
   652  	if err != nil {
   653  		t.Fatalf("resolver.Build() returned error: %v\n", err)
   654  	}
   655  	defer r.Close()
   656  
   657  	dnsR, ok := r.(*dnsResolver)
   658  	if !ok {
   659  		t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR)
   660  	}
   661  
   662  	tr, ok := dnsR.resolver.(*testResolver)
   663  	if !ok {
   664  		t.Fatalf("delegate resolver returned unexpected type: %T\n", tr)
   665  	}
   666  
   667  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   668  	defer cancel()
   669  
   670  	// Wait for the first resolution request to be done. This happens as part
   671  	// of the first iteration of the for loop in watcher().
   672  	if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
   673  		t.Fatalf("Timed out waiting for lookup() call.")
   674  	}
   675  
   676  	// Call Resolve Now 100 times, shouldn't continue onto next iteration of
   677  	// watcher, thus shouldn't lookup again.
   678  	for i := 0; i <= 100; i++ {
   679  		r.ResolveNow(resolver.ResolveNowOptions{})
   680  	}
   681  
   682  	continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   683  	defer continueCancel()
   684  
   685  	if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil {
   686  		t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
   687  	}
   688  
   689  	// Make the DNSMinResRate timer fire immediately (by receiving it, then
   690  	// resetting to 0), this will unblock the resolver which is currently
   691  	// blocked on the DNS Min Res Rate timer going off, which will allow it to
   692  	// continue to the next iteration of the watcher loop.
   693  	timer, err := timerChan.Receive(ctx)
   694  	if err != nil {
   695  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   696  	}
   697  	timerPointer := timer.(*time.Timer)
   698  	timerPointer.Reset(0)
   699  
   700  	// Now that DNS Min Res Rate timer has gone off, it should lookup again.
   701  	if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
   702  		t.Fatalf("Timed out waiting for lookup() call.")
   703  	}
   704  
   705  	// Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate
   706  	// timer has not gone off.
   707  	for i := 0; i < 1000; i++ {
   708  		r.ResolveNow(resolver.ResolveNowOptions{})
   709  	}
   710  
   711  	if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil {
   712  		t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
   713  	}
   714  
   715  	// Make the DNSMinResRate timer fire immediately again.
   716  	timer, err = timerChan.Receive(ctx)
   717  	if err != nil {
   718  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   719  	}
   720  	timerPointer = timer.(*time.Timer)
   721  	timerPointer.Reset(0)
   722  
   723  	// Now that DNS Min Res Rate timer has gone off, it should lookup again.
   724  	if _, err = tr.lookupHostCh.Receive(ctx); err != nil {
   725  		t.Fatalf("Timed out waiting for lookup() call.")
   726  	}
   727  
   728  	wantAddrs := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}}
   729  	var state resolver.State
   730  	for {
   731  		var cnt int
   732  		state, cnt = cc.getState()
   733  		if cnt > 0 {
   734  			break
   735  		}
   736  		time.Sleep(time.Millisecond)
   737  	}
   738  	if !slices.Equal(state.Addresses, wantAddrs) {
   739  		t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, wantAddrs)
   740  	}
   741  }
   742  
   743  // DNS Resolver immediately starts polling on an error. This will cause the re-resolution to return another error.
   744  // Thus, test that it constantly sends errors to the grpc.ClientConn.
   745  func TestReportError(t *testing.T) {
   746  	const target = "not.found"
   747  	defer func(nt func(d time.Duration) *time.Timer) {
   748  		newTimer = nt
   749  	}(newTimer)
   750  	timerChan := testutils.NewChannel()
   751  	newTimer = func(d time.Duration) *time.Timer {
   752  		// Will never fire on its own, allows this test to call timer immediately.
   753  		t := time.NewTimer(time.Hour)
   754  		timerChan.Send(t)
   755  		return t
   756  	}
   757  	cc := &testClientConn{target: target, errChan: make(chan error)}
   758  	totalTimesCalledError := 0
   759  	b := NewDefaultSRVBuilder()
   760  	r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{})
   761  	if err != nil {
   762  		t.Fatalf("Error building resolver for target %v: %v", target, err)
   763  	}
   764  	// Should receive first error.
   765  	err = <-cc.errChan
   766  	if !strings.Contains(err.Error(), "srvLookup error") {
   767  		t.Fatalf(`ReportError(err=%v) called; want err contains "srvLookupError"`, err)
   768  	}
   769  	totalTimesCalledError++
   770  	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   771  	defer ctxCancel()
   772  	timer, err := timerChan.Receive(ctx)
   773  	if err != nil {
   774  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   775  	}
   776  	timerPointer := timer.(*time.Timer)
   777  	timerPointer.Reset(0)
   778  	defer r.Close()
   779  
   780  	// Cause timer to go off 10 times, and see if it matches DNS Resolver updating Error.
   781  	for i := 0; i < 10; i++ {
   782  		// Should call ReportError().
   783  		err = <-cc.errChan
   784  		if !strings.Contains(err.Error(), "srvLookup error") {
   785  			t.Fatalf(`ReportError(err=%v) called; want err contains "srvLookupError"`, err)
   786  		}
   787  		totalTimesCalledError++
   788  		timer, err := timerChan.Receive(ctx)
   789  		if err != nil {
   790  			t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   791  		}
   792  		timerPointer := timer.(*time.Timer)
   793  		timerPointer.Reset(0)
   794  	}
   795  
   796  	if totalTimesCalledError != 11 {
   797  		t.Errorf("ReportError() not called 11 times, instead called %d times.", totalTimesCalledError)
   798  	}
   799  	// Clean up final watcher iteration.
   800  	<-cc.errChan
   801  	_, err = timerChan.Receive(ctx)
   802  	if err != nil {
   803  		t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
   804  	}
   805  }
   806  
   807  func Test_parseServiceDomain(t *testing.T) {
   808  	tests := []struct {
   809  		target        string
   810  		expectService string
   811  		expectDomain  string
   812  		wantErr       bool
   813  	}{
   814  		// valid
   815  		{"foo.bar", "foo", "bar", false},
   816  		{"foo.bar.baz", "foo", "bar.baz", false},
   817  		{"foo.bar.baz.", "foo", "bar.baz.", false},
   818  
   819  		// invalid
   820  		{"", "", "", true},
   821  		{".", "", "", true},
   822  		{"foo", "", "", true},
   823  		{".foo", "", "", true},
   824  		{"foo.", "", "", true},
   825  		{".foo.bar.baz", "", "", true},
   826  		{".foo.bar.baz.", "", "", true},
   827  	}
   828  	for _, tt := range tests {
   829  		t.Run(tt.target, func(t *testing.T) {
   830  			gotService, gotDomain, err := parseServiceDomain(tt.target)
   831  			if tt.wantErr {
   832  				test.AssertError(t, err, "expect err got nil")
   833  			} else {
   834  				test.AssertNotError(t, err, "expect nil err")
   835  				test.AssertEquals(t, gotService, tt.expectService)
   836  				test.AssertEquals(t, gotDomain, tt.expectDomain)
   837  			}
   838  		})
   839  	}
   840  }
   841  

View as plain text