1 package dns
2
3
4
5 import (
6 "context"
7 "crypto/tls"
8 "encoding/binary"
9 "io"
10 "net"
11 "strings"
12 "time"
13 )
14
15 const (
16 dnsTimeout time.Duration = 2 * time.Second
17 tcpIdleTimeout time.Duration = 8 * time.Second
18 )
19
20 func isPacketConn(c net.Conn) bool {
21 if _, ok := c.(net.PacketConn); !ok {
22 return false
23 }
24
25 if ua, ok := c.LocalAddr().(*net.UnixAddr); ok {
26 return ua.Net == "unixgram" || ua.Net == "unixpacket"
27 }
28
29 return true
30 }
31
32
33 type Conn struct {
34 net.Conn
35 UDPSize uint16
36 TsigSecret map[string]string
37 TsigProvider TsigProvider
38 tsigRequestMAC string
39 }
40
41 func (co *Conn) tsigProvider() TsigProvider {
42 if co.TsigProvider != nil {
43 return co.TsigProvider
44 }
45
46 return tsigSecretProvider(co.TsigSecret)
47 }
48
49
50 type Client struct {
51 Net string
52 UDPSize uint16
53 TLSConfig *tls.Config
54 Dialer *net.Dialer
55
56
57
58 Timeout time.Duration
59 DialTimeout time.Duration
60 ReadTimeout time.Duration
61 WriteTimeout time.Duration
62 TsigSecret map[string]string
63 TsigProvider TsigProvider
64
65
66
67
68
69
70
71 SingleInflight bool
72 }
73
74
75
76
77
78 func Exchange(m *Msg, a string) (r *Msg, err error) {
79 client := Client{Net: "udp"}
80 r, _, err = client.Exchange(m, a)
81 return r, err
82 }
83
84 func (c *Client) dialTimeout() time.Duration {
85 if c.Timeout != 0 {
86 return c.Timeout
87 }
88 if c.DialTimeout != 0 {
89 return c.DialTimeout
90 }
91 return dnsTimeout
92 }
93
94 func (c *Client) readTimeout() time.Duration {
95 if c.ReadTimeout != 0 {
96 return c.ReadTimeout
97 }
98 return dnsTimeout
99 }
100
101 func (c *Client) writeTimeout() time.Duration {
102 if c.WriteTimeout != 0 {
103 return c.WriteTimeout
104 }
105 return dnsTimeout
106 }
107
108
109 func (c *Client) Dial(address string) (conn *Conn, err error) {
110 return c.DialContext(context.Background(), address)
111 }
112
113
114 func (c *Client) DialContext(ctx context.Context, address string) (conn *Conn, err error) {
115
116 var d net.Dialer
117 if c.Dialer == nil {
118 d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())}
119 } else {
120 d = *c.Dialer
121 }
122
123 network := c.Net
124 if network == "" {
125 network = "udp"
126 }
127
128 useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls")
129
130 conn = new(Conn)
131 if useTLS {
132 network = strings.TrimSuffix(network, "-tls")
133
134 tlsDialer := tls.Dialer{
135 NetDialer: &d,
136 Config: c.TLSConfig,
137 }
138 conn.Conn, err = tlsDialer.DialContext(ctx, network, address)
139 } else {
140 conn.Conn, err = d.DialContext(ctx, network, address)
141 }
142 if err != nil {
143 return nil, err
144 }
145 conn.UDPSize = c.UDPSize
146 return conn, nil
147 }
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163 func (c *Client) Exchange(m *Msg, address string) (r *Msg, rtt time.Duration, err error) {
164 co, err := c.Dial(address)
165
166 if err != nil {
167 return nil, 0, err
168 }
169 defer co.Close()
170 return c.ExchangeWithConn(m, co)
171 }
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186 func (c *Client) ExchangeWithConn(m *Msg, conn *Conn) (r *Msg, rtt time.Duration, err error) {
187 return c.ExchangeWithConnContext(context.Background(), m, conn)
188 }
189
190
191
192 func (c *Client) ExchangeWithConnContext(ctx context.Context, m *Msg, co *Conn) (r *Msg, rtt time.Duration, err error) {
193 opt := m.IsEdns0()
194
195 if opt != nil && opt.UDPSize() >= MinMsgSize {
196 co.UDPSize = opt.UDPSize()
197 }
198
199 if opt == nil && c.UDPSize >= MinMsgSize {
200 co.UDPSize = c.UDPSize
201 }
202
203
204 t := time.Now()
205 writeDeadline := t.Add(c.getTimeoutForRequest(c.writeTimeout()))
206 readDeadline := t.Add(c.getTimeoutForRequest(c.readTimeout()))
207 if deadline, ok := ctx.Deadline(); ok {
208 if deadline.Before(writeDeadline) {
209 writeDeadline = deadline
210 }
211 if deadline.Before(readDeadline) {
212 readDeadline = deadline
213 }
214 }
215 co.SetWriteDeadline(writeDeadline)
216 co.SetReadDeadline(readDeadline)
217
218 co.TsigSecret, co.TsigProvider = c.TsigSecret, c.TsigProvider
219
220 if err = co.WriteMsg(m); err != nil {
221 return nil, 0, err
222 }
223
224 if isPacketConn(co.Conn) {
225 for {
226 r, err = co.ReadMsg()
227
228
229 if err != nil || r.Id == m.Id {
230 break
231 }
232 }
233 } else {
234 r, err = co.ReadMsg()
235 if err == nil && r.Id != m.Id {
236 err = ErrId
237 }
238 }
239 rtt = time.Since(t)
240 return r, rtt, err
241 }
242
243
244
245
246
247
248 func (co *Conn) ReadMsg() (*Msg, error) {
249 p, err := co.ReadMsgHeader(nil)
250 if err != nil {
251 return nil, err
252 }
253
254 m := new(Msg)
255 if err := m.Unpack(p); err != nil {
256
257
258
259 return m, err
260 }
261 if t := m.IsTsig(); t != nil {
262
263 err = TsigVerifyWithProvider(p, co.tsigProvider(), co.tsigRequestMAC, false)
264 }
265 return m, err
266 }
267
268
269
270
271 func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
272 var (
273 p []byte
274 n int
275 err error
276 )
277
278 if isPacketConn(co.Conn) {
279 if co.UDPSize > MinMsgSize {
280 p = make([]byte, co.UDPSize)
281 } else {
282 p = make([]byte, MinMsgSize)
283 }
284 n, err = co.Read(p)
285 } else {
286 var length uint16
287 if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
288 return nil, err
289 }
290
291 p = make([]byte, length)
292 n, err = io.ReadFull(co.Conn, p)
293 }
294
295 if err != nil {
296 return nil, err
297 } else if n < headerSize {
298 return nil, ErrShortRead
299 }
300
301 p = p[:n]
302 if hdr != nil {
303 dh, _, err := unpackMsgHdr(p, 0)
304 if err != nil {
305 return nil, err
306 }
307 *hdr = dh
308 }
309 return p, err
310 }
311
312
313 func (co *Conn) Read(p []byte) (n int, err error) {
314 if co.Conn == nil {
315 return 0, ErrConnEmpty
316 }
317
318 if isPacketConn(co.Conn) {
319
320 return co.Conn.Read(p)
321 }
322
323 var length uint16
324 if err := binary.Read(co.Conn, binary.BigEndian, &length); err != nil {
325 return 0, err
326 }
327 if int(length) > len(p) {
328 return 0, io.ErrShortBuffer
329 }
330
331 return io.ReadFull(co.Conn, p[:length])
332 }
333
334
335
336
337 func (co *Conn) WriteMsg(m *Msg) (err error) {
338 var out []byte
339 if t := m.IsTsig(); t != nil {
340
341 out, co.tsigRequestMAC, err = TsigGenerateWithProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
342 } else {
343 out, err = m.Pack()
344 }
345 if err != nil {
346 return err
347 }
348 _, err = co.Write(out)
349 return err
350 }
351
352
353 func (co *Conn) Write(p []byte) (int, error) {
354 if len(p) > MaxMsgSize {
355 return 0, &Error{err: "message too large"}
356 }
357
358 if isPacketConn(co.Conn) {
359 return co.Conn.Write(p)
360 }
361
362 msg := make([]byte, 2+len(p))
363 binary.BigEndian.PutUint16(msg, uint16(len(p)))
364 copy(msg[2:], p)
365 return co.Conn.Write(msg)
366 }
367
368
369 func (c *Client) getTimeoutForRequest(timeout time.Duration) time.Duration {
370 var requestTimeout time.Duration
371 if c.Timeout != 0 {
372 requestTimeout = c.Timeout
373 } else {
374 requestTimeout = timeout
375 }
376
377
378 if c.Dialer != nil && c.Dialer.Timeout != 0 {
379 if c.Dialer.Timeout < requestTimeout {
380 requestTimeout = c.Dialer.Timeout
381 }
382 }
383 return requestTimeout
384 }
385
386
387 func Dial(network, address string) (conn *Conn, err error) {
388 conn = new(Conn)
389 conn.Conn, err = net.Dial(network, address)
390 if err != nil {
391 return nil, err
392 }
393 return conn, nil
394 }
395
396
397
398 func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
399 client := Client{Net: "udp"}
400 r, _, err = client.ExchangeContext(ctx, m, a)
401
402
403 return r, err
404 }
405
406
407
408
409
410
411
412
413
414 func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
415 println("dns: ExchangeConn: this function is deprecated")
416 co := new(Conn)
417 co.Conn = c
418 if err = co.WriteMsg(m); err != nil {
419 return nil, err
420 }
421 r, err = co.ReadMsg()
422 if err == nil && r.Id != m.Id {
423 err = ErrId
424 }
425 return r, err
426 }
427
428
429 func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) {
430 client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}}
431 return client.Dial(address)
432 }
433
434
435 func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, err error) {
436 if !strings.HasSuffix(network, "-tls") {
437 network += "-tls"
438 }
439 client := Client{Net: network, TLSConfig: tlsConfig}
440 return client.Dial(address)
441 }
442
443
444 func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout time.Duration) (conn *Conn, err error) {
445 if !strings.HasSuffix(network, "-tls") {
446 network += "-tls"
447 }
448 client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig}
449 return client.Dial(address)
450 }
451
452
453
454
455 func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
456 conn, err := c.DialContext(ctx, a)
457 if err != nil {
458 return nil, 0, err
459 }
460 defer conn.Close()
461
462 return c.ExchangeWithConnContext(ctx, m, conn)
463 }
464
View as plain text