1
2
3 package dns
4
5 import (
6 "context"
7 "crypto/tls"
8 "encoding/binary"
9 "errors"
10 "io"
11 "net"
12 "strings"
13 "sync"
14 "time"
15 )
16
17
18 const maxTCPQueries = 128
19
20
21
22 var aLongTimeAgo = time.Unix(1, 0)
23
24
25 type Handler interface {
26 ServeDNS(w ResponseWriter, r *Msg)
27 }
28
29
30
31
32
33 type HandlerFunc func(ResponseWriter, *Msg)
34
35
36 func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
37 f(w, r)
38 }
39
40
41
42 type ResponseWriter interface {
43
44 LocalAddr() net.Addr
45
46 RemoteAddr() net.Addr
47
48 WriteMsg(*Msg) error
49
50 Write([]byte) (int, error)
51
52 Close() error
53
54 TsigStatus() error
55
56 TsigTimersOnly(bool)
57
58
59 Hijack()
60 }
61
62
63
64 type ConnectionStater interface {
65 ConnectionState() *tls.ConnectionState
66 }
67
68 type response struct {
69 closed bool
70 hijacked bool
71 tsigTimersOnly bool
72 tsigStatus error
73 tsigRequestMAC string
74 tsigProvider TsigProvider
75 udp net.PacketConn
76 tcp net.Conn
77 udpSession *SessionUDP
78 pcSession net.Addr
79 writer Writer
80 }
81
82
83 func handleRefused(w ResponseWriter, r *Msg) {
84 m := new(Msg)
85 m.SetRcode(r, RcodeRefused)
86 w.WriteMsg(m)
87 }
88
89
90
91 func HandleFailed(w ResponseWriter, r *Msg) {
92 m := new(Msg)
93 m.SetRcode(r, RcodeServerFailure)
94
95 w.WriteMsg(m)
96 }
97
98
99
100 func ListenAndServe(addr string, network string, handler Handler) error {
101 server := &Server{Addr: addr, Net: network, Handler: handler}
102 return server.ListenAndServe()
103 }
104
105
106
107 func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
108 cert, err := tls.LoadX509KeyPair(certFile, keyFile)
109 if err != nil {
110 return err
111 }
112
113 config := tls.Config{
114 Certificates: []tls.Certificate{cert},
115 }
116
117 server := &Server{
118 Addr: addr,
119 Net: "tcp-tls",
120 TLSConfig: &config,
121 Handler: handler,
122 }
123
124 return server.ListenAndServe()
125 }
126
127
128
129
130
131 func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
132 server := &Server{Listener: l, PacketConn: p, Handler: handler}
133 return server.ActivateAndServe()
134 }
135
136
137 type Writer interface {
138 io.Writer
139 }
140
141
142 type Reader interface {
143
144
145 ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
146
147
148 ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
149 }
150
151
152 type PacketConnReader interface {
153 Reader
154
155
156
157 ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
158 }
159
160
161
162
163 type defaultReader struct {
164 *Server
165 }
166
167 var _ PacketConnReader = defaultReader{}
168
169 func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
170 return dr.readTCP(conn, timeout)
171 }
172
173 func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
174 return dr.readUDP(conn, timeout)
175 }
176
177 func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
178 return dr.readPacketConn(conn, timeout)
179 }
180
181
182
183
184
185 type DecorateReader func(Reader) Reader
186
187
188
189 type DecorateWriter func(Writer) Writer
190
191
192 type Server struct {
193
194 Addr string
195
196 Net string
197
198 Listener net.Listener
199
200 TLSConfig *tls.Config
201
202 PacketConn net.PacketConn
203
204 Handler Handler
205
206
207 UDPSize int
208
209 ReadTimeout time.Duration
210
211 WriteTimeout time.Duration
212
213 IdleTimeout func() time.Duration
214
215 TsigProvider TsigProvider
216
217 TsigSecret map[string]string
218
219 NotifyStartedFunc func()
220
221 DecorateReader DecorateReader
222
223 DecorateWriter DecorateWriter
224
225 MaxTCPQueries int
226
227
228 ReusePort bool
229
230
231
232 ReuseAddr bool
233
234
235 MsgAcceptFunc MsgAcceptFunc
236
237
238 lock sync.RWMutex
239 started bool
240 shutdown chan struct{}
241 conns map[net.Conn]struct{}
242
243
244 udpPool sync.Pool
245 }
246
247 func (srv *Server) tsigProvider() TsigProvider {
248 if srv.TsigProvider != nil {
249 return srv.TsigProvider
250 }
251 if srv.TsigSecret != nil {
252 return tsigSecretProvider(srv.TsigSecret)
253 }
254 return nil
255 }
256
257 func (srv *Server) isStarted() bool {
258 srv.lock.RLock()
259 started := srv.started
260 srv.lock.RUnlock()
261 return started
262 }
263
264 func makeUDPBuffer(size int) func() interface{} {
265 return func() interface{} {
266 return make([]byte, size)
267 }
268 }
269
270 func (srv *Server) init() {
271 srv.shutdown = make(chan struct{})
272 srv.conns = make(map[net.Conn]struct{})
273
274 if srv.UDPSize == 0 {
275 srv.UDPSize = MinMsgSize
276 }
277 if srv.MsgAcceptFunc == nil {
278 srv.MsgAcceptFunc = DefaultMsgAcceptFunc
279 }
280 if srv.Handler == nil {
281 srv.Handler = DefaultServeMux
282 }
283
284 srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
285 }
286
287 func unlockOnce(l sync.Locker) func() {
288 var once sync.Once
289 return func() { once.Do(l.Unlock) }
290 }
291
292
293 func (srv *Server) ListenAndServe() error {
294 unlock := unlockOnce(&srv.lock)
295 srv.lock.Lock()
296 defer unlock()
297
298 if srv.started {
299 return &Error{err: "server already started"}
300 }
301
302 addr := srv.Addr
303 if addr == "" {
304 addr = ":domain"
305 }
306
307 srv.init()
308
309 switch srv.Net {
310 case "tcp", "tcp4", "tcp6":
311 l, err := listenTCP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
312 if err != nil {
313 return err
314 }
315 srv.Listener = l
316 srv.started = true
317 unlock()
318 return srv.serveTCP(l)
319 case "tcp-tls", "tcp4-tls", "tcp6-tls":
320 if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
321 return errors.New("dns: neither Certificates nor GetCertificate set in Config")
322 }
323 network := strings.TrimSuffix(srv.Net, "-tls")
324 l, err := listenTCP(network, addr, srv.ReusePort, srv.ReuseAddr)
325 if err != nil {
326 return err
327 }
328 l = tls.NewListener(l, srv.TLSConfig)
329 srv.Listener = l
330 srv.started = true
331 unlock()
332 return srv.serveTCP(l)
333 case "udp", "udp4", "udp6":
334 l, err := listenUDP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
335 if err != nil {
336 return err
337 }
338 u := l.(*net.UDPConn)
339 if e := setUDPSocketOptions(u); e != nil {
340 u.Close()
341 return e
342 }
343 srv.PacketConn = l
344 srv.started = true
345 unlock()
346 return srv.serveUDP(u)
347 }
348 return &Error{err: "bad network"}
349 }
350
351
352
353 func (srv *Server) ActivateAndServe() error {
354 unlock := unlockOnce(&srv.lock)
355 srv.lock.Lock()
356 defer unlock()
357
358 if srv.started {
359 return &Error{err: "server already started"}
360 }
361
362 srv.init()
363
364 if srv.PacketConn != nil {
365
366
367 if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
368 if e := setUDPSocketOptions(t); e != nil {
369 return e
370 }
371 }
372 srv.started = true
373 unlock()
374 return srv.serveUDP(srv.PacketConn)
375 }
376 if srv.Listener != nil {
377 srv.started = true
378 unlock()
379 return srv.serveTCP(srv.Listener)
380 }
381 return &Error{err: "bad listeners"}
382 }
383
384
385
386 func (srv *Server) Shutdown() error {
387 return srv.ShutdownContext(context.Background())
388 }
389
390
391
392
393
394
395 func (srv *Server) ShutdownContext(ctx context.Context) error {
396 srv.lock.Lock()
397 if !srv.started {
398 srv.lock.Unlock()
399 return &Error{err: "server not started"}
400 }
401
402 srv.started = false
403
404 if srv.PacketConn != nil {
405 srv.PacketConn.SetReadDeadline(aLongTimeAgo)
406 }
407
408 if srv.Listener != nil {
409 srv.Listener.Close()
410 }
411
412 for rw := range srv.conns {
413 rw.SetReadDeadline(aLongTimeAgo)
414 }
415
416 srv.lock.Unlock()
417
418 if testShutdownNotify != nil {
419 testShutdownNotify.Broadcast()
420 }
421
422 var ctxErr error
423 select {
424 case <-srv.shutdown:
425 case <-ctx.Done():
426 ctxErr = ctx.Err()
427 }
428
429 if srv.PacketConn != nil {
430 srv.PacketConn.Close()
431 }
432
433 return ctxErr
434 }
435
436 var testShutdownNotify *sync.Cond
437
438
439 func (srv *Server) getReadTimeout() time.Duration {
440 if srv.ReadTimeout != 0 {
441 return srv.ReadTimeout
442 }
443 return dnsTimeout
444 }
445
446
447 func (srv *Server) serveTCP(l net.Listener) error {
448 defer l.Close()
449
450 if srv.NotifyStartedFunc != nil {
451 srv.NotifyStartedFunc()
452 }
453
454 var wg sync.WaitGroup
455 defer func() {
456 wg.Wait()
457 close(srv.shutdown)
458 }()
459
460 for srv.isStarted() {
461 rw, err := l.Accept()
462 if err != nil {
463 if !srv.isStarted() {
464 return nil
465 }
466 if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
467 continue
468 }
469 return err
470 }
471 srv.lock.Lock()
472
473 srv.conns[rw] = struct{}{}
474 srv.lock.Unlock()
475 wg.Add(1)
476 go srv.serveTCPConn(&wg, rw)
477 }
478
479 return nil
480 }
481
482
483 func (srv *Server) serveUDP(l net.PacketConn) error {
484 defer l.Close()
485
486 reader := Reader(defaultReader{srv})
487 if srv.DecorateReader != nil {
488 reader = srv.DecorateReader(reader)
489 }
490
491 lUDP, isUDP := l.(*net.UDPConn)
492 readerPC, canPacketConn := reader.(PacketConnReader)
493 if !isUDP && !canPacketConn {
494 return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
495 }
496
497 if srv.NotifyStartedFunc != nil {
498 srv.NotifyStartedFunc()
499 }
500
501 var wg sync.WaitGroup
502 defer func() {
503 wg.Wait()
504 close(srv.shutdown)
505 }()
506
507 rtimeout := srv.getReadTimeout()
508
509 for srv.isStarted() {
510 var (
511 m []byte
512 sPC net.Addr
513 sUDP *SessionUDP
514 err error
515 )
516 if isUDP {
517 m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
518 } else {
519 m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
520 }
521 if err != nil {
522 if !srv.isStarted() {
523 return nil
524 }
525 if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
526 continue
527 }
528 return err
529 }
530 if len(m) < headerSize {
531 if cap(m) == srv.UDPSize {
532 srv.udpPool.Put(m[:srv.UDPSize])
533 }
534 continue
535 }
536 wg.Add(1)
537 go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
538 }
539
540 return nil
541 }
542
543
544 func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
545 w := &response{tsigProvider: srv.tsigProvider(), tcp: rw}
546 if srv.DecorateWriter != nil {
547 w.writer = srv.DecorateWriter(w)
548 } else {
549 w.writer = w
550 }
551
552 reader := Reader(defaultReader{srv})
553 if srv.DecorateReader != nil {
554 reader = srv.DecorateReader(reader)
555 }
556
557 idleTimeout := tcpIdleTimeout
558 if srv.IdleTimeout != nil {
559 idleTimeout = srv.IdleTimeout()
560 }
561
562 timeout := srv.getReadTimeout()
563
564 limit := srv.MaxTCPQueries
565 if limit == 0 {
566 limit = maxTCPQueries
567 }
568
569 for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
570 m, err := reader.ReadTCP(w.tcp, timeout)
571 if err != nil {
572
573 break
574 }
575 srv.serveDNS(m, w)
576 if w.closed {
577 break
578 }
579 if w.hijacked {
580 break
581 }
582
583
584 timeout = idleTimeout
585 }
586
587 if !w.hijacked {
588 w.Close()
589 }
590
591 srv.lock.Lock()
592 delete(srv.conns, w.tcp)
593 srv.lock.Unlock()
594
595 wg.Done()
596 }
597
598
599 func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
600 w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
601 if srv.DecorateWriter != nil {
602 w.writer = srv.DecorateWriter(w)
603 } else {
604 w.writer = w
605 }
606
607 srv.serveDNS(m, w)
608 wg.Done()
609 }
610
611 func (srv *Server) serveDNS(m []byte, w *response) {
612 dh, off, err := unpackMsgHdr(m, 0)
613 if err != nil {
614
615 return
616 }
617
618 req := new(Msg)
619 req.setHdr(dh)
620
621 switch action := srv.MsgAcceptFunc(dh); action {
622 case MsgAccept:
623 if req.unpack(dh, m, off) == nil {
624 break
625 }
626
627 fallthrough
628 case MsgReject, MsgRejectNotImplemented:
629 opcode := req.Opcode
630 req.SetRcodeFormatError(req)
631 req.Zero = false
632 if action == MsgRejectNotImplemented {
633 req.Opcode = opcode
634 req.Rcode = RcodeNotImplemented
635 }
636
637
638 req.Ns, req.Answer, req.Extra = nil, nil, nil
639
640 w.WriteMsg(req)
641 fallthrough
642 case MsgIgnore:
643 if w.udp != nil && cap(m) == srv.UDPSize {
644 srv.udpPool.Put(m[:srv.UDPSize])
645 }
646
647 return
648 }
649
650 w.tsigStatus = nil
651 if w.tsigProvider != nil {
652 if t := req.IsTsig(); t != nil {
653 w.tsigStatus = TsigVerifyWithProvider(m, w.tsigProvider, "", false)
654 w.tsigTimersOnly = false
655 w.tsigRequestMAC = t.MAC
656 }
657 }
658
659 if w.udp != nil && cap(m) == srv.UDPSize {
660 srv.udpPool.Put(m[:srv.UDPSize])
661 }
662
663 srv.Handler.ServeDNS(w, req)
664 }
665
666 func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
667
668
669
670
671 srv.lock.RLock()
672 if srv.started {
673 conn.SetReadDeadline(time.Now().Add(timeout))
674 }
675 srv.lock.RUnlock()
676
677 var length uint16
678 if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
679 return nil, err
680 }
681
682 m := make([]byte, length)
683 if _, err := io.ReadFull(conn, m); err != nil {
684 return nil, err
685 }
686
687 return m, nil
688 }
689
690 func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
691 srv.lock.RLock()
692 if srv.started {
693
694 conn.SetReadDeadline(time.Now().Add(timeout))
695 }
696 srv.lock.RUnlock()
697
698 m := srv.udpPool.Get().([]byte)
699 n, s, err := ReadFromSessionUDP(conn, m)
700 if err != nil {
701 srv.udpPool.Put(m)
702 return nil, nil, err
703 }
704 m = m[:n]
705 return m, s, nil
706 }
707
708 func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
709 srv.lock.RLock()
710 if srv.started {
711
712 conn.SetReadDeadline(time.Now().Add(timeout))
713 }
714 srv.lock.RUnlock()
715
716 m := srv.udpPool.Get().([]byte)
717 n, addr, err := conn.ReadFrom(m)
718 if err != nil {
719 srv.udpPool.Put(m)
720 return nil, nil, err
721 }
722 m = m[:n]
723 return m, addr, nil
724 }
725
726
727 func (w *response) WriteMsg(m *Msg) (err error) {
728 if w.closed {
729 return &Error{err: "WriteMsg called after Close"}
730 }
731
732 var data []byte
733 if w.tsigProvider != nil {
734 if t := m.IsTsig(); t != nil {
735 data, w.tsigRequestMAC, err = TsigGenerateWithProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
736 if err != nil {
737 return err
738 }
739 _, err = w.writer.Write(data)
740 return err
741 }
742 }
743 data, err = m.Pack()
744 if err != nil {
745 return err
746 }
747 _, err = w.writer.Write(data)
748 return err
749 }
750
751
752 func (w *response) Write(m []byte) (int, error) {
753 if w.closed {
754 return 0, &Error{err: "Write called after Close"}
755 }
756
757 switch {
758 case w.udp != nil:
759 if u, ok := w.udp.(*net.UDPConn); ok {
760 return WriteToSessionUDP(u, m, w.udpSession)
761 }
762 return w.udp.WriteTo(m, w.pcSession)
763 case w.tcp != nil:
764 if len(m) > MaxMsgSize {
765 return 0, &Error{err: "message too large"}
766 }
767
768 msg := make([]byte, 2+len(m))
769 binary.BigEndian.PutUint16(msg, uint16(len(m)))
770 copy(msg[2:], m)
771 return w.tcp.Write(msg)
772 default:
773 panic("dns: internal error: udp and tcp both nil")
774 }
775 }
776
777
778 func (w *response) LocalAddr() net.Addr {
779 switch {
780 case w.udp != nil:
781 return w.udp.LocalAddr()
782 case w.tcp != nil:
783 return w.tcp.LocalAddr()
784 default:
785 panic("dns: internal error: udp and tcp both nil")
786 }
787 }
788
789
790 func (w *response) RemoteAddr() net.Addr {
791 switch {
792 case w.udpSession != nil:
793 return w.udpSession.RemoteAddr()
794 case w.pcSession != nil:
795 return w.pcSession
796 case w.tcp != nil:
797 return w.tcp.RemoteAddr()
798 default:
799 panic("dns: internal error: udpSession, pcSession and tcp are all nil")
800 }
801 }
802
803
804 func (w *response) TsigStatus() error { return w.tsigStatus }
805
806
807 func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
808
809
810 func (w *response) Hijack() { w.hijacked = true }
811
812
813 func (w *response) Close() error {
814 if w.closed {
815 return &Error{err: "connection already closed"}
816 }
817 w.closed = true
818
819 switch {
820 case w.udp != nil:
821
822 return nil
823 case w.tcp != nil:
824 return w.tcp.Close()
825 default:
826 panic("dns: internal error: udp and tcp both nil")
827 }
828 }
829
830
831 func (w *response) ConnectionState() *tls.ConnectionState {
832 type tlsConnectionStater interface {
833 ConnectionState() tls.ConnectionState
834 }
835 if v, ok := w.tcp.(tlsConnectionStater); ok {
836 t := v.ConnectionState()
837 return &t
838 }
839 return nil
840 }
841
View as plain text