...

Source file src/github.com/letsencrypt/boulder/grpc/internal/resolver/dns/dns_resolver.go

Documentation: github.com/letsencrypt/boulder/grpc/internal/resolver/dns

     1  /*
     2   *
     3   * Copyright 2018 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Forked from the default internal DNS resolver in the grpc-go package. The
    20  // original source can be found at:
    21  // https://github.com/grpc/grpc-go/blob/v1.49.0/internal/resolver/dns/dns_resolver.go
    22  
    23  package dns
    24  
    25  import (
    26  	"context"
    27  	"errors"
    28  	"fmt"
    29  	"net"
    30  	"strconv"
    31  	"strings"
    32  	"sync"
    33  	"time"
    34  
    35  	"github.com/letsencrypt/boulder/bdns"
    36  	"github.com/letsencrypt/boulder/grpc/internal/backoff"
    37  	"github.com/letsencrypt/boulder/grpc/noncebalancer"
    38  	"google.golang.org/grpc/grpclog"
    39  	"google.golang.org/grpc/resolver"
    40  	"google.golang.org/grpc/serviceconfig"
    41  )
    42  
    43  var logger = grpclog.Component("srv")
    44  
    45  // Globals to stub out in tests. TODO: Perhaps these two can be combined into a
    46  // single variable for testing the resolver?
    47  var (
    48  	newTimer           = time.NewTimer
    49  	newTimerDNSResRate = time.NewTimer
    50  )
    51  
    52  func init() {
    53  	resolver.Register(NewDefaultSRVBuilder())
    54  	resolver.Register(NewNonceSRVBuilder())
    55  }
    56  
    57  const defaultDNSSvrPort = "53"
    58  
    59  var defaultResolver netResolver = net.DefaultResolver
    60  
    61  var (
    62  	// To prevent excessive re-resolution, we enforce a rate limit on DNS
    63  	// resolution requests.
    64  	minDNSResRate = 30 * time.Second
    65  )
    66  
    67  var customAuthorityDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
    68  	return func(ctx context.Context, network, address string) (net.Conn, error) {
    69  		var dialer net.Dialer
    70  		return dialer.DialContext(ctx, network, authority)
    71  	}
    72  }
    73  
    74  var customAuthorityResolver = func(authority string) (*net.Resolver, error) {
    75  	host, port, err := bdns.ParseTarget(authority, defaultDNSSvrPort)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return &net.Resolver{
    80  		PreferGo: true,
    81  		Dial:     customAuthorityDialer(net.JoinHostPort(host, port)),
    82  	}, nil
    83  }
    84  
    85  // NewDefaultSRVBuilder creates a srvBuilder which is used to factory SRV DNS
    86  // resolvers.
    87  func NewDefaultSRVBuilder() resolver.Builder {
    88  	return &srvBuilder{scheme: "srv"}
    89  }
    90  
    91  // NewNonceSRVBuilder creates a srvBuilder which is used to factory SRV DNS
    92  // resolvers with a custom grpc.Balancer used by nonce-service clients.
    93  func NewNonceSRVBuilder() resolver.Builder {
    94  	return &srvBuilder{scheme: noncebalancer.SRVResolverScheme, balancer: noncebalancer.Name}
    95  }
    96  
    97  type srvBuilder struct {
    98  	scheme   string
    99  	balancer string
   100  }
   101  
   102  // Build creates and starts a DNS resolver that watches the name resolution of the target.
   103  func (b *srvBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
   104  	var names []name
   105  	for _, i := range strings.Split(target.Endpoint(), ",") {
   106  		service, domain, err := parseServiceDomain(i)
   107  		if err != nil {
   108  			return nil, err
   109  		}
   110  		names = append(names, name{service: service, domain: domain})
   111  	}
   112  
   113  	ctx, cancel := context.WithCancel(context.Background())
   114  	d := &dnsResolver{
   115  		names:  names,
   116  		ctx:    ctx,
   117  		cancel: cancel,
   118  		cc:     cc,
   119  		rn:     make(chan struct{}, 1),
   120  	}
   121  
   122  	if target.Authority == "" {
   123  		d.resolver = defaultResolver
   124  	} else {
   125  		var err error
   126  		d.resolver, err = customAuthorityResolver(target.Authority)
   127  		if err != nil {
   128  			return nil, err
   129  		}
   130  	}
   131  
   132  	if b.balancer != "" {
   133  		d.serviceConfig = cc.ParseServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, b.balancer))
   134  	}
   135  
   136  	d.wg.Add(1)
   137  	go d.watcher()
   138  	return d, nil
   139  }
   140  
   141  // Scheme returns the naming scheme of this resolver builder.
   142  func (b *srvBuilder) Scheme() string {
   143  	return b.scheme
   144  }
   145  
   146  type netResolver interface {
   147  	LookupHost(ctx context.Context, host string) (addrs []string, err error)
   148  	LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
   149  }
   150  
   151  type name struct {
   152  	service string
   153  	domain  string
   154  }
   155  
   156  // dnsResolver watches for the name resolution update for a non-IP target.
   157  type dnsResolver struct {
   158  	names    []name
   159  	resolver netResolver
   160  	ctx      context.Context
   161  	cancel   context.CancelFunc
   162  	cc       resolver.ClientConn
   163  	// rn channel is used by ResolveNow() to force an immediate resolution of the target.
   164  	rn chan struct{}
   165  	// wg is used to enforce Close() to return after the watcher() goroutine has finished.
   166  	// Otherwise, data race will be possible. [Race Example] in dns_resolver_test we
   167  	// replace the real lookup functions with mocked ones to facilitate testing.
   168  	// If Close() doesn't wait for watcher() goroutine finishes, race detector sometimes
   169  	// will warns lookup (READ the lookup function pointers) inside watcher() goroutine
   170  	// has data race with replaceNetFunc (WRITE the lookup function pointers).
   171  	wg            sync.WaitGroup
   172  	serviceConfig *serviceconfig.ParseResult
   173  }
   174  
   175  // ResolveNow invoke an immediate resolution of the target that this dnsResolver watches.
   176  func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
   177  	select {
   178  	case d.rn <- struct{}{}:
   179  	default:
   180  	}
   181  }
   182  
   183  // Close closes the dnsResolver.
   184  func (d *dnsResolver) Close() {
   185  	d.cancel()
   186  	d.wg.Wait()
   187  }
   188  
   189  func (d *dnsResolver) watcher() {
   190  	defer d.wg.Done()
   191  	backoffIndex := 1
   192  	for {
   193  		state, err := d.lookup()
   194  		if err != nil {
   195  			// Report error to the underlying grpc.ClientConn.
   196  			d.cc.ReportError(err)
   197  		} else {
   198  			if d.serviceConfig != nil {
   199  				state.ServiceConfig = d.serviceConfig
   200  			}
   201  			err = d.cc.UpdateState(*state)
   202  		}
   203  
   204  		var timer *time.Timer
   205  		if err == nil {
   206  			// Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least
   207  			// to prevent constantly re-resolving.
   208  			backoffIndex = 1
   209  			timer = newTimerDNSResRate(minDNSResRate)
   210  			select {
   211  			case <-d.ctx.Done():
   212  				timer.Stop()
   213  				return
   214  			case <-d.rn:
   215  			}
   216  		} else {
   217  			// Poll on an error found in DNS Resolver or an error received from ClientConn.
   218  			timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex))
   219  			backoffIndex++
   220  		}
   221  		select {
   222  		case <-d.ctx.Done():
   223  			timer.Stop()
   224  			return
   225  		case <-timer.C:
   226  		}
   227  	}
   228  }
   229  
   230  func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
   231  	var newAddrs []resolver.Address
   232  	var errs []error
   233  	for _, n := range d.names {
   234  		_, srvs, err := d.resolver.LookupSRV(d.ctx, n.service, "tcp", n.domain)
   235  		if err != nil {
   236  			err = handleDNSError(err, "SRV") // may become nil
   237  			if err != nil {
   238  				errs = append(errs, err)
   239  				continue
   240  			}
   241  		}
   242  		for _, s := range srvs {
   243  			backendAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
   244  			if err != nil {
   245  				err = handleDNSError(err, "A") // may become nil
   246  				if err != nil {
   247  					errs = append(errs, err)
   248  					continue
   249  				}
   250  			}
   251  			for _, a := range backendAddrs {
   252  				ip, ok := formatIP(a)
   253  				if !ok {
   254  					errs = append(errs, fmt.Errorf("srv: error parsing A record IP address %v", a))
   255  					continue
   256  				}
   257  				addr := ip + ":" + strconv.Itoa(int(s.Port))
   258  				newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
   259  			}
   260  		}
   261  	}
   262  	// Only return an error if all lookups failed.
   263  	if len(errs) > 0 && len(newAddrs) == 0 {
   264  		return nil, errors.Join(errs...)
   265  	}
   266  	return newAddrs, nil
   267  }
   268  
   269  func handleDNSError(err error, lookupType string) error {
   270  	if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
   271  		// Timeouts and temporary errors should be communicated to gRPC to
   272  		// attempt another DNS query (with backoff).  Other errors should be
   273  		// suppressed (they may represent the absence of a TXT record).
   274  		return nil
   275  	}
   276  	if err != nil {
   277  		err = fmt.Errorf("srv: %v record lookup error: %v", lookupType, err)
   278  		logger.Info(err)
   279  	}
   280  	return err
   281  }
   282  
   283  func (d *dnsResolver) lookup() (*resolver.State, error) {
   284  	addrs, err := d.lookupSRV()
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  	return &resolver.State{Addresses: addrs}, nil
   289  }
   290  
   291  // formatIP returns ok = false if addr is not a valid textual representation of an IP address.
   292  // If addr is an IPv4 address, return the addr and ok = true.
   293  // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true.
   294  func formatIP(addr string) (addrIP string, ok bool) {
   295  	ip := net.ParseIP(addr)
   296  	if ip == nil {
   297  		return "", false
   298  	}
   299  	if ip.To4() != nil {
   300  		return addr, true
   301  	}
   302  	return "[" + addr + "]", true
   303  }
   304  
   305  // parseServiceDomain takes the user input target string and parses the service domain
   306  // names for SRV lookup. Input is expected to be a hostname containing at least
   307  // two labels (e.g. "foo.bar", "foo.bar.baz"). The first label is the service
   308  // name and the rest is the domain name. If the target is not in the expected
   309  // format, an error is returned.
   310  func parseServiceDomain(target string) (string, string, error) {
   311  	sd := strings.SplitN(target, ".", 2)
   312  	if len(sd) < 2 || sd[0] == "" || sd[1] == "" {
   313  		return "", "", fmt.Errorf("srv: hostname %q contains < 2 labels", target)
   314  	}
   315  	return sd[0], sd[1], nil
   316  }
   317  

View as plain text