1
18
19
20
21 package dns
22
23 import (
24 "context"
25 "encoding/json"
26 "fmt"
27 "net"
28 "os"
29 "strconv"
30 "strings"
31 "sync"
32 "time"
33
34 grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
35 "google.golang.org/grpc/grpclog"
36 "google.golang.org/grpc/internal/backoff"
37 "google.golang.org/grpc/internal/envconfig"
38 "google.golang.org/grpc/internal/grpcrand"
39 "google.golang.org/grpc/internal/resolver/dns/internal"
40 "google.golang.org/grpc/resolver"
41 "google.golang.org/grpc/serviceconfig"
42 )
43
44 var (
45
46
47 EnableSRVLookups = false
48
49
50
51 MinResolutionInterval = 30 * time.Second
52
53
54
55
56
57
58 ResolvingTimeout = 30 * time.Second
59
60 logger = grpclog.Component("dns")
61 )
62
63 func init() {
64 resolver.Register(NewBuilder())
65 internal.TimeAfterFunc = time.After
66 internal.NewNetResolver = newNetResolver
67 internal.AddressDialer = addressDialer
68 }
69
70 const (
71 defaultPort = "443"
72 defaultDNSSvrPort = "53"
73 golang = "GO"
74
75
76 txtPrefix = "_grpc_config."
77
78
79 txtAttribute = "grpc_config="
80 )
81
82 var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
83 return func(ctx context.Context, network, _ string) (net.Conn, error) {
84 var dialer net.Dialer
85 return dialer.DialContext(ctx, network, address)
86 }
87 }
88
89 var newNetResolver = func(authority string) (internal.NetResolver, error) {
90 if authority == "" {
91 return net.DefaultResolver, nil
92 }
93
94 host, port, err := parseTarget(authority, defaultDNSSvrPort)
95 if err != nil {
96 return nil, err
97 }
98
99 authorityWithPort := net.JoinHostPort(host, port)
100
101 return &net.Resolver{
102 PreferGo: true,
103 Dial: internal.AddressDialer(authorityWithPort),
104 }, nil
105 }
106
107
108 func NewBuilder() resolver.Builder {
109 return &dnsBuilder{}
110 }
111
112 type dnsBuilder struct{}
113
114
115
116 func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
117 host, port, err := parseTarget(target.Endpoint(), defaultPort)
118 if err != nil {
119 return nil, err
120 }
121
122
123 if ipAddr, ok := formatIP(host); ok {
124 addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
125 cc.UpdateState(resolver.State{Addresses: addr})
126 return deadResolver{}, nil
127 }
128
129
130 ctx, cancel := context.WithCancel(context.Background())
131 d := &dnsResolver{
132 host: host,
133 port: port,
134 ctx: ctx,
135 cancel: cancel,
136 cc: cc,
137 rn: make(chan struct{}, 1),
138 disableServiceConfig: opts.DisableServiceConfig,
139 }
140
141 d.resolver, err = internal.NewNetResolver(target.URL.Host)
142 if err != nil {
143 return nil, err
144 }
145
146 d.wg.Add(1)
147 go d.watcher()
148 return d, nil
149 }
150
151
152 func (b *dnsBuilder) Scheme() string {
153 return "dns"
154 }
155
156
157 type deadResolver struct{}
158
159 func (deadResolver) ResolveNow(resolver.ResolveNowOptions) {}
160
161 func (deadResolver) Close() {}
162
163
164 type dnsResolver struct {
165 host string
166 port string
167 resolver internal.NetResolver
168 ctx context.Context
169 cancel context.CancelFunc
170 cc resolver.ClientConn
171
172
173 rn chan struct{}
174
175
176
177
178
179
180
181 wg sync.WaitGroup
182 disableServiceConfig bool
183 }
184
185
186
187 func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
188 select {
189 case d.rn <- struct{}{}:
190 default:
191 }
192 }
193
194
195 func (d *dnsResolver) Close() {
196 d.cancel()
197 d.wg.Wait()
198 }
199
200 func (d *dnsResolver) watcher() {
201 defer d.wg.Done()
202 backoffIndex := 1
203 for {
204 state, err := d.lookup()
205 if err != nil {
206
207 d.cc.ReportError(err)
208 } else {
209 err = d.cc.UpdateState(*state)
210 }
211
212 var waitTime time.Duration
213 if err == nil {
214
215
216 backoffIndex = 1
217 waitTime = MinResolutionInterval
218 select {
219 case <-d.ctx.Done():
220 return
221 case <-d.rn:
222 }
223 } else {
224
225
226 waitTime = backoff.DefaultExponential.Backoff(backoffIndex)
227 backoffIndex++
228 }
229 select {
230 case <-d.ctx.Done():
231 return
232 case <-internal.TimeAfterFunc(waitTime):
233 }
234 }
235 }
236
237 func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
238 if !EnableSRVLookups {
239 return nil, nil
240 }
241 var newAddrs []resolver.Address
242 _, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
243 if err != nil {
244 err = handleDNSError(err, "SRV")
245 return nil, err
246 }
247 for _, s := range srvs {
248 lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
249 if err != nil {
250 err = handleDNSError(err, "A")
251 if err == nil {
252
253
254 continue
255 }
256 return nil, err
257 }
258 for _, a := range lbAddrs {
259 ip, ok := formatIP(a)
260 if !ok {
261 return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
262 }
263 addr := ip + ":" + strconv.Itoa(int(s.Port))
264 newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
265 }
266 }
267 return newAddrs, nil
268 }
269
270 func handleDNSError(err error, lookupType string) error {
271 dnsErr, ok := err.(*net.DNSError)
272 if ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
273
274
275
276 return nil
277 }
278 if err != nil {
279 err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err)
280 logger.Info(err)
281 }
282 return err
283 }
284
285 func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
286 ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
287 if err != nil {
288 if envconfig.TXTErrIgnore {
289 return nil
290 }
291 if err = handleDNSError(err, "TXT"); err != nil {
292 return &serviceconfig.ParseResult{Err: err}
293 }
294 return nil
295 }
296 var res string
297 for _, s := range ss {
298 res += s
299 }
300
301
302
303 if !strings.HasPrefix(res, txtAttribute) {
304 logger.Warningf("dns: TXT record %v missing %v attribute", res, txtAttribute)
305
306
307 return nil
308 }
309 sc := canaryingSC(strings.TrimPrefix(res, txtAttribute))
310 return d.cc.ParseServiceConfig(sc)
311 }
312
313 func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
314 addrs, err := d.resolver.LookupHost(ctx, d.host)
315 if err != nil {
316 err = handleDNSError(err, "A")
317 return nil, err
318 }
319 newAddrs := make([]resolver.Address, 0, len(addrs))
320 for _, a := range addrs {
321 ip, ok := formatIP(a)
322 if !ok {
323 return nil, fmt.Errorf("dns: error parsing A record IP address %v", a)
324 }
325 addr := ip + ":" + d.port
326 newAddrs = append(newAddrs, resolver.Address{Addr: addr})
327 }
328 return newAddrs, nil
329 }
330
331 func (d *dnsResolver) lookup() (*resolver.State, error) {
332 ctx, cancel := context.WithTimeout(d.ctx, ResolvingTimeout)
333 defer cancel()
334 srv, srvErr := d.lookupSRV(ctx)
335 addrs, hostErr := d.lookupHost(ctx)
336 if hostErr != nil && (srvErr != nil || len(srv) == 0) {
337 return nil, hostErr
338 }
339
340 state := resolver.State{Addresses: addrs}
341 if len(srv) > 0 {
342 state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
343 }
344 if !d.disableServiceConfig {
345 state.ServiceConfig = d.lookupTXT(ctx)
346 }
347 return &state, nil
348 }
349
350
351
352
353
354 func formatIP(addr string) (addrIP string, ok bool) {
355 ip := net.ParseIP(addr)
356 if ip == nil {
357 return "", false
358 }
359 if ip.To4() != nil {
360 return addr, true
361 }
362 return "[" + addr + "]", true
363 }
364
365
366
367
368
369
370
371
372
373
374 func parseTarget(target, defaultPort string) (host, port string, err error) {
375 if target == "" {
376 return "", "", internal.ErrMissingAddr
377 }
378 if ip := net.ParseIP(target); ip != nil {
379
380 return target, defaultPort, nil
381 }
382 if host, port, err = net.SplitHostPort(target); err == nil {
383 if port == "" {
384
385
386 return "", "", internal.ErrEndsWithColon
387 }
388
389 if host == "" {
390
391
392 host = "localhost"
393 }
394 return host, port, nil
395 }
396 if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil {
397
398 return host, port, nil
399 }
400 return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
401 }
402
403 type rawChoice struct {
404 ClientLanguage *[]string `json:"clientLanguage,omitempty"`
405 Percentage *int `json:"percentage,omitempty"`
406 ClientHostName *[]string `json:"clientHostName,omitempty"`
407 ServiceConfig *json.RawMessage `json:"serviceConfig,omitempty"`
408 }
409
410 func containsString(a *[]string, b string) bool {
411 if a == nil {
412 return true
413 }
414 for _, c := range *a {
415 if c == b {
416 return true
417 }
418 }
419 return false
420 }
421
422 func chosenByPercentage(a *int) bool {
423 if a == nil {
424 return true
425 }
426 return grpcrand.Intn(100)+1 <= *a
427 }
428
429 func canaryingSC(js string) string {
430 if js == "" {
431 return ""
432 }
433 var rcs []rawChoice
434 err := json.Unmarshal([]byte(js), &rcs)
435 if err != nil {
436 logger.Warningf("dns: error parsing service config json: %v", err)
437 return ""
438 }
439 cliHostname, err := os.Hostname()
440 if err != nil {
441 logger.Warningf("dns: error getting client hostname: %v", err)
442 return ""
443 }
444 var sc string
445 for _, c := range rcs {
446 if !containsString(c.ClientLanguage, golang) ||
447 !chosenByPercentage(c.Percentage) ||
448 !containsString(c.ClientHostName, cliHostname) ||
449 c.ServiceConfig == nil {
450 continue
451 }
452 sc = string(*c.ServiceConfig)
453 break
454 }
455 return sc
456 }
457
View as plain text