1
18
19 package dns
20
21 import (
22 "context"
23 "errors"
24 "fmt"
25 "net"
26 "os"
27 "slices"
28 "strings"
29 "sync"
30 "testing"
31 "time"
32
33 "github.com/letsencrypt/boulder/grpc/internal/leakcheck"
34 "github.com/letsencrypt/boulder/grpc/internal/testutils"
35 "github.com/letsencrypt/boulder/test"
36 "google.golang.org/grpc/balancer"
37 "google.golang.org/grpc/resolver"
38 )
39
40 func TestMain(m *testing.M) {
41
42
43 replaceDNSResRate(time.Duration(0))
44 overrideDefaultResolver(false)
45 code := m.Run()
46 os.Exit(code)
47 }
48
49 const (
50 txtBytesLimit = 255
51 defaultTestTimeout = 10 * time.Second
52 defaultTestShortTimeout = 10 * time.Millisecond
53 )
54
55 type testClientConn struct {
56 resolver.ClientConn
57 target string
58 m1 sync.Mutex
59 state resolver.State
60 updateStateCalls int
61 errChan chan error
62 updateStateErr error
63 }
64
65 func (t *testClientConn) UpdateState(s resolver.State) error {
66 t.m1.Lock()
67 defer t.m1.Unlock()
68 t.state = s
69 t.updateStateCalls++
70
71
72 return t.updateStateErr
73 }
74
75 func (t *testClientConn) getState() (resolver.State, int) {
76 t.m1.Lock()
77 defer t.m1.Unlock()
78 return t.state, t.updateStateCalls
79 }
80
81 func (t *testClientConn) ReportError(err error) {
82 t.errChan <- err
83 }
84
85 type testResolver struct {
86
87
88
89 lookupHostCh *testutils.Channel
90 }
91
92 func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
93 if tr.lookupHostCh != nil {
94 tr.lookupHostCh.Send(nil)
95 }
96 return hostLookup(host)
97 }
98
99 func (*testResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
100 return srvLookup(service, proto, name)
101 }
102
103
104
105
106 func overrideDefaultResolver(pushOnLookup bool) func() {
107 oldResolver := defaultResolver
108
109 var lookupHostCh *testutils.Channel
110 if pushOnLookup {
111 lookupHostCh = testutils.NewChannel()
112 }
113 defaultResolver = &testResolver{lookupHostCh: lookupHostCh}
114
115 return func() {
116 defaultResolver = oldResolver
117 }
118 }
119
120 func replaceDNSResRate(d time.Duration) func() {
121 oldMinDNSResRate := minDNSResRate
122 minDNSResRate = d
123
124 return func() {
125 minDNSResRate = oldMinDNSResRate
126 }
127 }
128
129 var hostLookupTbl = struct {
130 sync.Mutex
131 tbl map[string][]string
132 }{
133 tbl: map[string][]string{
134 "ipv4.single.fake": {"2.4.6.8"},
135 "ipv4.multi.fake": {"1.2.3.4", "5.6.7.8", "9.10.11.12"},
136 "ipv6.single.fake": {"2607:f8b0:400a:801::1001"},
137 "ipv6.multi.fake": {"2607:f8b0:400a:801::1001", "2607:f8b0:400a:801::1002", "2607:f8b0:400a:801::1003"},
138 },
139 }
140
141 func hostLookup(host string) ([]string, error) {
142 hostLookupTbl.Lock()
143 defer hostLookupTbl.Unlock()
144 if addrs, ok := hostLookupTbl.tbl[host]; ok {
145 return addrs, nil
146 }
147 return nil, &net.DNSError{
148 Err: "hostLookup error",
149 Name: host,
150 Server: "fake",
151 IsTemporary: true,
152 }
153 }
154
155 var srvLookupTbl = struct {
156 sync.Mutex
157 tbl map[string][]*net.SRV
158 }{
159 tbl: map[string][]*net.SRV{
160 "_foo._tcp.ipv4.single.fake": {&net.SRV{Target: "ipv4.single.fake", Port: 1234}},
161 "_foo._tcp.ipv4.multi.fake": {&net.SRV{Target: "ipv4.multi.fake", Port: 1234}},
162 "_foo._tcp.ipv6.single.fake": {&net.SRV{Target: "ipv6.single.fake", Port: 1234}},
163 "_foo._tcp.ipv6.multi.fake": {&net.SRV{Target: "ipv6.multi.fake", Port: 1234}},
164 },
165 }
166
167 func srvLookup(service, proto, name string) (string, []*net.SRV, error) {
168 cname := "_" + service + "._" + proto + "." + name
169 srvLookupTbl.Lock()
170 defer srvLookupTbl.Unlock()
171 if srvs, cnt := srvLookupTbl.tbl[cname]; cnt {
172 return cname, srvs, nil
173 }
174 return "", nil, &net.DNSError{
175 Err: "srvLookup error",
176 Name: cname,
177 Server: "fake",
178 IsTemporary: true,
179 }
180 }
181
182 func TestResolve(t *testing.T) {
183 testDNSResolver(t)
184 testDNSResolveNow(t)
185 }
186
187 func testDNSResolver(t *testing.T) {
188 defer func(nt func(d time.Duration) *time.Timer) {
189 newTimer = nt
190 }(newTimer)
191 newTimer = func(_ time.Duration) *time.Timer {
192
193 return time.NewTimer(time.Hour)
194 }
195 tests := []struct {
196 target string
197 addrWant []resolver.Address
198 }{
199 {
200 "foo.ipv4.single.fake",
201 []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}},
202 },
203 {
204 "foo.ipv4.multi.fake",
205 []resolver.Address{
206 {Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"},
207 {Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"},
208 {Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"},
209 },
210 },
211 {
212 "foo.ipv6.single.fake",
213 []resolver.Address{{Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.single.fake"}},
214 },
215 {
216 "foo.ipv6.multi.fake",
217 []resolver.Address{
218 {Addr: "[2607:f8b0:400a:801::1001]:1234", ServerName: "ipv6.multi.fake"},
219 {Addr: "[2607:f8b0:400a:801::1002]:1234", ServerName: "ipv6.multi.fake"},
220 {Addr: "[2607:f8b0:400a:801::1003]:1234", ServerName: "ipv6.multi.fake"},
221 },
222 },
223 }
224
225 for _, a := range tests {
226 b := NewDefaultSRVBuilder()
227 cc := &testClientConn{target: a.target}
228 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{})
229 if err != nil {
230 t.Fatalf("%v\n", err)
231 }
232 var state resolver.State
233 var cnt int
234 for i := 0; i < 2000; i++ {
235 state, cnt = cc.getState()
236 if cnt > 0 {
237 break
238 }
239 time.Sleep(time.Millisecond)
240 }
241 if cnt == 0 {
242 t.Fatalf("UpdateState not called after 2s; aborting")
243 }
244
245 if !slices.Equal(a.addrWant, state.Addresses) {
246 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant)
247 }
248 r.Close()
249 }
250 }
251
252
253
254 func TestDNSResolverExponentialBackoff(t *testing.T) {
255 defer leakcheck.Check(t)
256 defer func(nt func(d time.Duration) *time.Timer) {
257 newTimer = nt
258 }(newTimer)
259 timerChan := testutils.NewChannel()
260 newTimer = func(d time.Duration) *time.Timer {
261
262 t := time.NewTimer(time.Hour)
263 timerChan.Send(t)
264 return t
265 }
266 target := "foo.ipv4.single.fake"
267 wantAddr := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}}
268
269 b := NewDefaultSRVBuilder()
270 cc := &testClientConn{target: target}
271
272 cc.updateStateErr = balancer.ErrBadResolverState
273 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{})
274 if err != nil {
275 t.Fatalf("Error building resolver for target %v: %v", target, err)
276 }
277 defer r.Close()
278 var state resolver.State
279 var cnt int
280 for i := 0; i < 2000; i++ {
281 state, cnt = cc.getState()
282 if cnt > 0 {
283 break
284 }
285 time.Sleep(time.Millisecond)
286 }
287 if cnt == 0 {
288 t.Fatalf("UpdateState not called after 2s; aborting")
289 }
290 if !slices.Equal(wantAddr, state.Addresses) {
291 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, target)
292 }
293 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
294 defer ctxCancel()
295
296 for i := 0; i < 10; i++ {
297 timer, err := timerChan.Receive(ctx)
298 if err != nil {
299 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
300 }
301 timerPointer := timer.(*time.Timer)
302 timerPointer.Reset(0)
303 }
304
305
306 deadline := time.Now().Add(defaultTestTimeout)
307 for {
308 cc.m1.Lock()
309 got := cc.updateStateCalls
310 cc.m1.Unlock()
311 if got == 11 {
312 break
313 }
314
315 if time.Now().After(deadline) {
316 t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got)
317 }
318
319 time.Sleep(time.Millisecond)
320 }
321
322
323 cc.updateStateErr = nil
324 timer, err := timerChan.Receive(ctx)
325 if err != nil {
326 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
327 }
328 timerPointer := timer.(*time.Timer)
329 timerPointer.Reset(0)
330
331
332 deadline = time.Now().Add(defaultTestTimeout)
333 for {
334 cc.m1.Lock()
335 got := cc.updateStateCalls
336 cc.m1.Unlock()
337 if got == 12 {
338 break
339 }
340
341 if time.Now().After(deadline) {
342 t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got)
343 }
344
345 _, err := timerChan.ReceiveOrFail()
346 if err {
347 t.Fatalf("Should not poll again after Client Conn stops returning error.")
348 }
349
350 time.Sleep(time.Millisecond)
351 }
352 }
353
354 func mutateTbl(target string) func() {
355 hostLookupTbl.Lock()
356 oldHostTblEntry := hostLookupTbl.tbl[target]
357
358
359 hostLookupTbl.tbl[target] = hostLookupTbl.tbl[target][:len(oldHostTblEntry)-1]
360 hostLookupTbl.Unlock()
361
362 return func() {
363 hostLookupTbl.Lock()
364 hostLookupTbl.tbl[target] = oldHostTblEntry
365 hostLookupTbl.Unlock()
366 }
367 }
368
369 func testDNSResolveNow(t *testing.T) {
370 defer leakcheck.Check(t)
371 defer func(nt func(d time.Duration) *time.Timer) {
372 newTimer = nt
373 }(newTimer)
374 newTimer = func(_ time.Duration) *time.Timer {
375
376 return time.NewTimer(time.Hour)
377 }
378 tests := []struct {
379 target string
380 addrWant []resolver.Address
381 addrNext []resolver.Address
382 }{
383 {
384 "foo.ipv4.multi.fake",
385 []resolver.Address{
386 {Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"},
387 {Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"},
388 {Addr: "9.10.11.12:1234", ServerName: "ipv4.multi.fake"},
389 },
390 []resolver.Address{
391 {Addr: "1.2.3.4:1234", ServerName: "ipv4.multi.fake"},
392 {Addr: "5.6.7.8:1234", ServerName: "ipv4.multi.fake"},
393 },
394 },
395 }
396
397 for _, a := range tests {
398 b := NewDefaultSRVBuilder()
399 cc := &testClientConn{target: a.target}
400 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", a.target))}, cc, resolver.BuildOptions{})
401 if err != nil {
402 t.Fatalf("%v\n", err)
403 }
404 defer r.Close()
405 var state resolver.State
406 var cnt int
407 for i := 0; i < 2000; i++ {
408 state, cnt = cc.getState()
409 if cnt > 0 {
410 break
411 }
412 time.Sleep(time.Millisecond)
413 }
414 if cnt == 0 {
415 t.Fatalf("UpdateState not called after 2s; aborting. state=%v", state)
416 }
417 if !slices.Equal(a.addrWant, state.Addresses) {
418 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrWant)
419 }
420
421 revertTbl := mutateTbl(strings.TrimPrefix(a.target, "foo."))
422 r.ResolveNow(resolver.ResolveNowOptions{})
423 for i := 0; i < 2000; i++ {
424 state, cnt = cc.getState()
425 if cnt == 2 {
426 break
427 }
428 time.Sleep(time.Millisecond)
429 }
430 if cnt != 2 {
431 t.Fatalf("UpdateState not called after 2s; aborting. state=%v", state)
432 }
433 if !slices.Equal(a.addrNext, state.Addresses) {
434 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", a.target, state.Addresses, a.addrNext)
435 }
436 revertTbl()
437 }
438 }
439
440 func TestDNSResolverRetry(t *testing.T) {
441 defer func(nt func(d time.Duration) *time.Timer) {
442 newTimer = nt
443 }(newTimer)
444 newTimer = func(d time.Duration) *time.Timer {
445
446 return time.NewTimer(time.Hour)
447 }
448 b := NewDefaultSRVBuilder()
449 target := "foo.ipv4.single.fake"
450 cc := &testClientConn{target: target}
451 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{})
452 if err != nil {
453 t.Fatalf("%v\n", err)
454 }
455 defer r.Close()
456 var state resolver.State
457 for i := 0; i < 2000; i++ {
458 state, _ = cc.getState()
459 if len(state.Addresses) == 1 {
460 break
461 }
462 time.Sleep(time.Millisecond)
463 }
464 if len(state.Addresses) != 1 {
465 t.Fatalf("UpdateState not called with 1 address after 2s; aborting. state=%v", state)
466 }
467 want := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}}
468 if !slices.Equal(want, state.Addresses) {
469 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want)
470 }
471
472 revertTbl := mutateTbl(strings.TrimPrefix(target, "foo."))
473
474 r.ResolveNow(resolver.ResolveNowOptions{})
475 for i := 0; i < 2000; i++ {
476 state, _ = cc.getState()
477 if len(state.Addresses) == 0 {
478 break
479 }
480 time.Sleep(time.Millisecond)
481 }
482 if len(state.Addresses) != 0 {
483 t.Fatalf("UpdateState not called with 0 address after 2s; aborting. state=%v", state)
484 }
485 revertTbl()
486
487 r.ResolveNow(resolver.ResolveNowOptions{})
488 for i := 0; i < 2000; i++ {
489 state, _ = cc.getState()
490 if len(state.Addresses) == 1 {
491 break
492 }
493 time.Sleep(time.Millisecond)
494 }
495 if !slices.Equal(want, state.Addresses) {
496 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, want)
497 }
498 }
499
500 func TestCustomAuthority(t *testing.T) {
501 defer leakcheck.Check(t)
502 defer func(nt func(d time.Duration) *time.Timer) {
503 newTimer = nt
504 }(newTimer)
505 newTimer = func(d time.Duration) *time.Timer {
506
507 return time.NewTimer(time.Hour)
508 }
509
510 tests := []struct {
511 authority string
512 authorityWant string
513 expectError bool
514 }{
515 {
516 "4.3.2.1:" + defaultDNSSvrPort,
517 "4.3.2.1:" + defaultDNSSvrPort,
518 false,
519 },
520 {
521 "4.3.2.1:123",
522 "4.3.2.1:123",
523 false,
524 },
525 {
526 "4.3.2.1",
527 "4.3.2.1:" + defaultDNSSvrPort,
528 false,
529 },
530 {
531 "::1",
532 "[::1]:" + defaultDNSSvrPort,
533 false,
534 },
535 {
536 "[::1]",
537 "[::1]:" + defaultDNSSvrPort,
538 false,
539 },
540 {
541 "[::1]:123",
542 "[::1]:123",
543 false,
544 },
545 {
546 "dnsserver.com",
547 "dnsserver.com:" + defaultDNSSvrPort,
548 false,
549 },
550 {
551 ":123",
552 "localhost:123",
553 false,
554 },
555 {
556 ":",
557 "",
558 true,
559 },
560 {
561 "[::1]:",
562 "",
563 true,
564 },
565 {
566 "dnsserver.com:",
567 "",
568 true,
569 },
570 }
571 oldcustomAuthorityDialer := customAuthorityDialer
572 defer func() {
573 customAuthorityDialer = oldcustomAuthorityDialer
574 }()
575
576 for _, a := range tests {
577 errChan := make(chan error, 1)
578 customAuthorityDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
579 if authority != a.authorityWant {
580 errChan <- fmt.Errorf("wrong custom authority passed to resolver. input: %s expected: %s actual: %s", a.authority, a.authorityWant, authority)
581 } else {
582 errChan <- nil
583 }
584 return func(ctx context.Context, network, address string) (net.Conn, error) {
585 return nil, errors.New("no need to dial")
586 }
587 }
588
589 mockEndpointTarget := "foo.bar.com"
590 b := NewDefaultSRVBuilder()
591 cc := &testClientConn{target: mockEndpointTarget, errChan: make(chan error, 1)}
592 target := resolver.Target{
593 Authority: a.authority,
594 URL: *testutils.MustParseURL(fmt.Sprintf("scheme://%s/%s", a.authority, mockEndpointTarget)),
595 }
596 r, err := b.Build(target, cc, resolver.BuildOptions{})
597
598 if err == nil {
599 r.Close()
600
601 err = <-errChan
602 if err != nil {
603 t.Errorf(err.Error())
604 }
605
606 if a.expectError {
607 t.Errorf("custom authority should have caused an error: %s", a.authority)
608 }
609 } else if !a.expectError {
610 t.Errorf("unexpected error using custom authority %s: %s", a.authority, err)
611 }
612 }
613 }
614
615
616
617
618
619 func TestRateLimitedResolve(t *testing.T) {
620 defer leakcheck.Check(t)
621 defer func(nt func(d time.Duration) *time.Timer) {
622 newTimer = nt
623 }(newTimer)
624 newTimer = func(d time.Duration) *time.Timer {
625
626
627 return time.NewTimer(time.Hour)
628 }
629 defer func(nt func(d time.Duration) *time.Timer) {
630 newTimerDNSResRate = nt
631 }(newTimerDNSResRate)
632
633 timerChan := testutils.NewChannel()
634 newTimerDNSResRate = func(d time.Duration) *time.Timer {
635
636
637 t := time.NewTimer(time.Hour)
638 timerChan.Send(t)
639 return t
640 }
641
642
643
644 nc := overrideDefaultResolver(true)
645 defer nc()
646
647 target := "foo.ipv4.single.fake"
648 b := NewDefaultSRVBuilder()
649 cc := &testClientConn{target: target}
650
651 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{})
652 if err != nil {
653 t.Fatalf("resolver.Build() returned error: %v\n", err)
654 }
655 defer r.Close()
656
657 dnsR, ok := r.(*dnsResolver)
658 if !ok {
659 t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR)
660 }
661
662 tr, ok := dnsR.resolver.(*testResolver)
663 if !ok {
664 t.Fatalf("delegate resolver returned unexpected type: %T\n", tr)
665 }
666
667 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
668 defer cancel()
669
670
671
672 if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
673 t.Fatalf("Timed out waiting for lookup() call.")
674 }
675
676
677
678 for i := 0; i <= 100; i++ {
679 r.ResolveNow(resolver.ResolveNowOptions{})
680 }
681
682 continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
683 defer continueCancel()
684
685 if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil {
686 t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
687 }
688
689
690
691
692
693 timer, err := timerChan.Receive(ctx)
694 if err != nil {
695 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
696 }
697 timerPointer := timer.(*time.Timer)
698 timerPointer.Reset(0)
699
700
701 if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
702 t.Fatalf("Timed out waiting for lookup() call.")
703 }
704
705
706
707 for i := 0; i < 1000; i++ {
708 r.ResolveNow(resolver.ResolveNowOptions{})
709 }
710
711 if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil {
712 t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
713 }
714
715
716 timer, err = timerChan.Receive(ctx)
717 if err != nil {
718 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
719 }
720 timerPointer = timer.(*time.Timer)
721 timerPointer.Reset(0)
722
723
724 if _, err = tr.lookupHostCh.Receive(ctx); err != nil {
725 t.Fatalf("Timed out waiting for lookup() call.")
726 }
727
728 wantAddrs := []resolver.Address{{Addr: "2.4.6.8:1234", ServerName: "ipv4.single.fake"}}
729 var state resolver.State
730 for {
731 var cnt int
732 state, cnt = cc.getState()
733 if cnt > 0 {
734 break
735 }
736 time.Sleep(time.Millisecond)
737 }
738 if !slices.Equal(state.Addresses, wantAddrs) {
739 t.Errorf("Resolved addresses of target: %q = %+v, want %+v", target, state.Addresses, wantAddrs)
740 }
741 }
742
743
744
745 func TestReportError(t *testing.T) {
746 const target = "not.found"
747 defer func(nt func(d time.Duration) *time.Timer) {
748 newTimer = nt
749 }(newTimer)
750 timerChan := testutils.NewChannel()
751 newTimer = func(d time.Duration) *time.Timer {
752
753 t := time.NewTimer(time.Hour)
754 timerChan.Send(t)
755 return t
756 }
757 cc := &testClientConn{target: target, errChan: make(chan error)}
758 totalTimesCalledError := 0
759 b := NewDefaultSRVBuilder()
760 r, err := b.Build(resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("scheme:///%s", target))}, cc, resolver.BuildOptions{})
761 if err != nil {
762 t.Fatalf("Error building resolver for target %v: %v", target, err)
763 }
764
765 err = <-cc.errChan
766 if !strings.Contains(err.Error(), "srvLookup error") {
767 t.Fatalf(`ReportError(err=%v) called; want err contains "srvLookupError"`, err)
768 }
769 totalTimesCalledError++
770 ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
771 defer ctxCancel()
772 timer, err := timerChan.Receive(ctx)
773 if err != nil {
774 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
775 }
776 timerPointer := timer.(*time.Timer)
777 timerPointer.Reset(0)
778 defer r.Close()
779
780
781 for i := 0; i < 10; i++ {
782
783 err = <-cc.errChan
784 if !strings.Contains(err.Error(), "srvLookup error") {
785 t.Fatalf(`ReportError(err=%v) called; want err contains "srvLookupError"`, err)
786 }
787 totalTimesCalledError++
788 timer, err := timerChan.Receive(ctx)
789 if err != nil {
790 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
791 }
792 timerPointer := timer.(*time.Timer)
793 timerPointer.Reset(0)
794 }
795
796 if totalTimesCalledError != 11 {
797 t.Errorf("ReportError() not called 11 times, instead called %d times.", totalTimesCalledError)
798 }
799
800 <-cc.errChan
801 _, err = timerChan.Receive(ctx)
802 if err != nil {
803 t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
804 }
805 }
806
807 func Test_parseServiceDomain(t *testing.T) {
808 tests := []struct {
809 target string
810 expectService string
811 expectDomain string
812 wantErr bool
813 }{
814
815 {"foo.bar", "foo", "bar", false},
816 {"foo.bar.baz", "foo", "bar.baz", false},
817 {"foo.bar.baz.", "foo", "bar.baz.", false},
818
819
820 {"", "", "", true},
821 {".", "", "", true},
822 {"foo", "", "", true},
823 {".foo", "", "", true},
824 {"foo.", "", "", true},
825 {".foo.bar.baz", "", "", true},
826 {".foo.bar.baz.", "", "", true},
827 }
828 for _, tt := range tests {
829 t.Run(tt.target, func(t *testing.T) {
830 gotService, gotDomain, err := parseServiceDomain(tt.target)
831 if tt.wantErr {
832 test.AssertError(t, err, "expect err got nil")
833 } else {
834 test.AssertNotError(t, err, "expect nil err")
835 test.AssertEquals(t, gotService, tt.expectService)
836 test.AssertEquals(t, gotDomain, tt.expectDomain)
837 }
838 })
839 }
840 }
841
View as plain text