1
18
19
20
21
22
23 package dns
24
25 import (
26 "context"
27 "errors"
28 "fmt"
29 "net"
30 "strconv"
31 "strings"
32 "sync"
33 "time"
34
35 "github.com/letsencrypt/boulder/bdns"
36 "github.com/letsencrypt/boulder/grpc/internal/backoff"
37 "github.com/letsencrypt/boulder/grpc/noncebalancer"
38 "google.golang.org/grpc/grpclog"
39 "google.golang.org/grpc/resolver"
40 "google.golang.org/grpc/serviceconfig"
41 )
42
43 var logger = grpclog.Component("srv")
44
45
46
47 var (
48 newTimer = time.NewTimer
49 newTimerDNSResRate = time.NewTimer
50 )
51
52 func init() {
53 resolver.Register(NewDefaultSRVBuilder())
54 resolver.Register(NewNonceSRVBuilder())
55 }
56
57 const defaultDNSSvrPort = "53"
58
59 var defaultResolver netResolver = net.DefaultResolver
60
61 var (
62
63
64 minDNSResRate = 30 * time.Second
65 )
66
67 var customAuthorityDialer = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
68 return func(ctx context.Context, network, address string) (net.Conn, error) {
69 var dialer net.Dialer
70 return dialer.DialContext(ctx, network, authority)
71 }
72 }
73
74 var customAuthorityResolver = func(authority string) (*net.Resolver, error) {
75 host, port, err := bdns.ParseTarget(authority, defaultDNSSvrPort)
76 if err != nil {
77 return nil, err
78 }
79 return &net.Resolver{
80 PreferGo: true,
81 Dial: customAuthorityDialer(net.JoinHostPort(host, port)),
82 }, nil
83 }
84
85
86
87 func NewDefaultSRVBuilder() resolver.Builder {
88 return &srvBuilder{scheme: "srv"}
89 }
90
91
92
93 func NewNonceSRVBuilder() resolver.Builder {
94 return &srvBuilder{scheme: noncebalancer.SRVResolverScheme, balancer: noncebalancer.Name}
95 }
96
97 type srvBuilder struct {
98 scheme string
99 balancer string
100 }
101
102
103 func (b *srvBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
104 var names []name
105 for _, i := range strings.Split(target.Endpoint(), ",") {
106 service, domain, err := parseServiceDomain(i)
107 if err != nil {
108 return nil, err
109 }
110 names = append(names, name{service: service, domain: domain})
111 }
112
113 ctx, cancel := context.WithCancel(context.Background())
114 d := &dnsResolver{
115 names: names,
116 ctx: ctx,
117 cancel: cancel,
118 cc: cc,
119 rn: make(chan struct{}, 1),
120 }
121
122 if target.Authority == "" {
123 d.resolver = defaultResolver
124 } else {
125 var err error
126 d.resolver, err = customAuthorityResolver(target.Authority)
127 if err != nil {
128 return nil, err
129 }
130 }
131
132 if b.balancer != "" {
133 d.serviceConfig = cc.ParseServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, b.balancer))
134 }
135
136 d.wg.Add(1)
137 go d.watcher()
138 return d, nil
139 }
140
141
142 func (b *srvBuilder) Scheme() string {
143 return b.scheme
144 }
145
146 type netResolver interface {
147 LookupHost(ctx context.Context, host string) (addrs []string, err error)
148 LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
149 }
150
151 type name struct {
152 service string
153 domain string
154 }
155
156
157 type dnsResolver struct {
158 names []name
159 resolver netResolver
160 ctx context.Context
161 cancel context.CancelFunc
162 cc resolver.ClientConn
163
164 rn chan struct{}
165
166
167
168
169
170
171 wg sync.WaitGroup
172 serviceConfig *serviceconfig.ParseResult
173 }
174
175
176 func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
177 select {
178 case d.rn <- struct{}{}:
179 default:
180 }
181 }
182
183
184 func (d *dnsResolver) Close() {
185 d.cancel()
186 d.wg.Wait()
187 }
188
189 func (d *dnsResolver) watcher() {
190 defer d.wg.Done()
191 backoffIndex := 1
192 for {
193 state, err := d.lookup()
194 if err != nil {
195
196 d.cc.ReportError(err)
197 } else {
198 if d.serviceConfig != nil {
199 state.ServiceConfig = d.serviceConfig
200 }
201 err = d.cc.UpdateState(*state)
202 }
203
204 var timer *time.Timer
205 if err == nil {
206
207
208 backoffIndex = 1
209 timer = newTimerDNSResRate(minDNSResRate)
210 select {
211 case <-d.ctx.Done():
212 timer.Stop()
213 return
214 case <-d.rn:
215 }
216 } else {
217
218 timer = newTimer(backoff.DefaultExponential.Backoff(backoffIndex))
219 backoffIndex++
220 }
221 select {
222 case <-d.ctx.Done():
223 timer.Stop()
224 return
225 case <-timer.C:
226 }
227 }
228 }
229
230 func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
231 var newAddrs []resolver.Address
232 var errs []error
233 for _, n := range d.names {
234 _, srvs, err := d.resolver.LookupSRV(d.ctx, n.service, "tcp", n.domain)
235 if err != nil {
236 err = handleDNSError(err, "SRV")
237 if err != nil {
238 errs = append(errs, err)
239 continue
240 }
241 }
242 for _, s := range srvs {
243 backendAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
244 if err != nil {
245 err = handleDNSError(err, "A")
246 if err != nil {
247 errs = append(errs, err)
248 continue
249 }
250 }
251 for _, a := range backendAddrs {
252 ip, ok := formatIP(a)
253 if !ok {
254 errs = append(errs, fmt.Errorf("srv: error parsing A record IP address %v", a))
255 continue
256 }
257 addr := ip + ":" + strconv.Itoa(int(s.Port))
258 newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
259 }
260 }
261 }
262
263 if len(errs) > 0 && len(newAddrs) == 0 {
264 return nil, errors.Join(errs...)
265 }
266 return newAddrs, nil
267 }
268
269 func handleDNSError(err error, lookupType string) error {
270 if dnsErr, ok := err.(*net.DNSError); ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
271
272
273
274 return nil
275 }
276 if err != nil {
277 err = fmt.Errorf("srv: %v record lookup error: %v", lookupType, err)
278 logger.Info(err)
279 }
280 return err
281 }
282
283 func (d *dnsResolver) lookup() (*resolver.State, error) {
284 addrs, err := d.lookupSRV()
285 if err != nil {
286 return nil, err
287 }
288 return &resolver.State{Addresses: addrs}, nil
289 }
290
291
292
293
294 func formatIP(addr string) (addrIP string, ok bool) {
295 ip := net.ParseIP(addr)
296 if ip == nil {
297 return "", false
298 }
299 if ip.To4() != nil {
300 return addr, true
301 }
302 return "[" + addr + "]", true
303 }
304
305
306
307
308
309
310 func parseServiceDomain(target string) (string, string, error) {
311 sd := strings.SplitN(target, ".", 2)
312 if len(sd) < 2 || sd[0] == "" || sd[1] == "" {
313 return "", "", fmt.Errorf("srv: hostname %q contains < 2 labels", target)
314 }
315 return sd[0], sd[1], nil
316 }
317
View as plain text