...

Source file src/github.com/miekg/dns/server.go

Documentation: github.com/miekg/dns

     1  // DNS server implementation.
     2  
     3  package dns
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"encoding/binary"
     9  	"errors"
    10  	"io"
    11  	"net"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  )
    16  
    17  // Default maximum number of TCP queries before we close the socket.
    18  const maxTCPQueries = 128
    19  
    20  // aLongTimeAgo is a non-zero time, far in the past, used for
    21  // immediate cancellation of network operations.
    22  var aLongTimeAgo = time.Unix(1, 0)
    23  
    24  // Handler is implemented by any value that implements ServeDNS.
    25  type Handler interface {
    26  	ServeDNS(w ResponseWriter, r *Msg)
    27  }
    28  
    29  // The HandlerFunc type is an adapter to allow the use of
    30  // ordinary functions as DNS handlers.  If f is a function
    31  // with the appropriate signature, HandlerFunc(f) is a
    32  // Handler object that calls f.
    33  type HandlerFunc func(ResponseWriter, *Msg)
    34  
    35  // ServeDNS calls f(w, r).
    36  func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
    37  	f(w, r)
    38  }
    39  
    40  // A ResponseWriter interface is used by an DNS handler to
    41  // construct an DNS response.
    42  type ResponseWriter interface {
    43  	// LocalAddr returns the net.Addr of the server
    44  	LocalAddr() net.Addr
    45  	// RemoteAddr returns the net.Addr of the client that sent the current request.
    46  	RemoteAddr() net.Addr
    47  	// WriteMsg writes a reply back to the client.
    48  	WriteMsg(*Msg) error
    49  	// Write writes a raw buffer back to the client.
    50  	Write([]byte) (int, error)
    51  	// Close closes the connection.
    52  	Close() error
    53  	// TsigStatus returns the status of the Tsig.
    54  	TsigStatus() error
    55  	// TsigTimersOnly sets the tsig timers only boolean.
    56  	TsigTimersOnly(bool)
    57  	// Hijack lets the caller take over the connection.
    58  	// After a call to Hijack(), the DNS package will not do anything with the connection.
    59  	Hijack()
    60  }
    61  
    62  // A ConnectionStater interface is used by a DNS Handler to access TLS connection state
    63  // when available.
    64  type ConnectionStater interface {
    65  	ConnectionState() *tls.ConnectionState
    66  }
    67  
    68  type response struct {
    69  	closed         bool // connection has been closed
    70  	hijacked       bool // connection has been hijacked by handler
    71  	tsigTimersOnly bool
    72  	tsigStatus     error
    73  	tsigRequestMAC string
    74  	tsigProvider   TsigProvider
    75  	udp            net.PacketConn // i/o connection if UDP was used
    76  	tcp            net.Conn       // i/o connection if TCP was used
    77  	udpSession     *SessionUDP    // oob data to get egress interface right
    78  	pcSession      net.Addr       // address to use when writing to a generic net.PacketConn
    79  	writer         Writer         // writer to output the raw DNS bits
    80  }
    81  
    82  // handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
    83  func handleRefused(w ResponseWriter, r *Msg) {
    84  	m := new(Msg)
    85  	m.SetRcode(r, RcodeRefused)
    86  	w.WriteMsg(m)
    87  }
    88  
    89  // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
    90  // Deprecated: This function is going away.
    91  func HandleFailed(w ResponseWriter, r *Msg) {
    92  	m := new(Msg)
    93  	m.SetRcode(r, RcodeServerFailure)
    94  	// does not matter if this write fails
    95  	w.WriteMsg(m)
    96  }
    97  
    98  // ListenAndServe Starts a server on address and network specified Invoke handler
    99  // for incoming queries.
   100  func ListenAndServe(addr string, network string, handler Handler) error {
   101  	server := &Server{Addr: addr, Net: network, Handler: handler}
   102  	return server.ListenAndServe()
   103  }
   104  
   105  // ListenAndServeTLS acts like http.ListenAndServeTLS, more information in
   106  // http://golang.org/pkg/net/http/#ListenAndServeTLS
   107  func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
   108  	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
   109  	if err != nil {
   110  		return err
   111  	}
   112  
   113  	config := tls.Config{
   114  		Certificates: []tls.Certificate{cert},
   115  	}
   116  
   117  	server := &Server{
   118  		Addr:      addr,
   119  		Net:       "tcp-tls",
   120  		TLSConfig: &config,
   121  		Handler:   handler,
   122  	}
   123  
   124  	return server.ListenAndServe()
   125  }
   126  
   127  // ActivateAndServe activates a server with a listener from systemd,
   128  // l and p should not both be non-nil.
   129  // If both l and p are not nil only p will be used.
   130  // Invoke handler for incoming queries.
   131  func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
   132  	server := &Server{Listener: l, PacketConn: p, Handler: handler}
   133  	return server.ActivateAndServe()
   134  }
   135  
   136  // Writer writes raw DNS messages; each call to Write should send an entire message.
   137  type Writer interface {
   138  	io.Writer
   139  }
   140  
   141  // Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
   142  type Reader interface {
   143  	// ReadTCP reads a raw message from a TCP connection. Implementations may alter
   144  	// connection properties, for example the read-deadline.
   145  	ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
   146  	// ReadUDP reads a raw message from a UDP connection. Implementations may alter
   147  	// connection properties, for example the read-deadline.
   148  	ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
   149  }
   150  
   151  // PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
   152  type PacketConnReader interface {
   153  	Reader
   154  
   155  	// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
   156  	// alter connection properties, for example the read-deadline.
   157  	ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
   158  }
   159  
   160  // defaultReader is an adapter for the Server struct that implements the Reader and
   161  // PacketConnReader interfaces using the readTCP, readUDP and readPacketConn funcs
   162  // of the embedded Server.
   163  type defaultReader struct {
   164  	*Server
   165  }
   166  
   167  var _ PacketConnReader = defaultReader{}
   168  
   169  func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
   170  	return dr.readTCP(conn, timeout)
   171  }
   172  
   173  func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
   174  	return dr.readUDP(conn, timeout)
   175  }
   176  
   177  func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
   178  	return dr.readPacketConn(conn, timeout)
   179  }
   180  
   181  // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
   182  // Implementations should never return a nil Reader.
   183  // Readers should also implement the optional PacketConnReader interface.
   184  // PacketConnReader is required to use a generic net.PacketConn.
   185  type DecorateReader func(Reader) Reader
   186  
   187  // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
   188  // Implementations should never return a nil Writer.
   189  type DecorateWriter func(Writer) Writer
   190  
   191  // A Server defines parameters for running an DNS server.
   192  type Server struct {
   193  	// Address to listen on, ":dns" if empty.
   194  	Addr string
   195  	// if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one
   196  	Net string
   197  	// TCP Listener to use, this is to aid in systemd's socket activation.
   198  	Listener net.Listener
   199  	// TLS connection configuration
   200  	TLSConfig *tls.Config
   201  	// UDP "Listener" to use, this is to aid in systemd's socket activation.
   202  	PacketConn net.PacketConn
   203  	// Handler to invoke, dns.DefaultServeMux if nil.
   204  	Handler Handler
   205  	// Default buffer size to use to read incoming UDP messages. If not set
   206  	// it defaults to MinMsgSize (512 B).
   207  	UDPSize int
   208  	// The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second.
   209  	ReadTimeout time.Duration
   210  	// The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second.
   211  	WriteTimeout time.Duration
   212  	// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
   213  	IdleTimeout func() time.Duration
   214  	// An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
   215  	TsigProvider TsigProvider
   216  	// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
   217  	TsigSecret map[string]string
   218  	// If NotifyStartedFunc is set it is called once the server has started listening.
   219  	NotifyStartedFunc func()
   220  	// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
   221  	DecorateReader DecorateReader
   222  	// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
   223  	DecorateWriter DecorateWriter
   224  	// Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
   225  	MaxTCPQueries int
   226  	// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
   227  	// It is only supported on certain GOOSes and when using ListenAndServe.
   228  	ReusePort bool
   229  	// Whether to set the SO_REUSEADDR socket option, allowing multiple listeners to be bound to a single address.
   230  	// Crucially this allows binding when an existing server is listening on `0.0.0.0` or `::`.
   231  	// It is only supported on certain GOOSes and when using ListenAndServe.
   232  	ReuseAddr bool
   233  	// AcceptMsgFunc will check the incoming message and will reject it early in the process.
   234  	// By default DefaultMsgAcceptFunc will be used.
   235  	MsgAcceptFunc MsgAcceptFunc
   236  
   237  	// Shutdown handling
   238  	lock     sync.RWMutex
   239  	started  bool
   240  	shutdown chan struct{}
   241  	conns    map[net.Conn]struct{}
   242  
   243  	// A pool for UDP message buffers.
   244  	udpPool sync.Pool
   245  }
   246  
   247  func (srv *Server) tsigProvider() TsigProvider {
   248  	if srv.TsigProvider != nil {
   249  		return srv.TsigProvider
   250  	}
   251  	if srv.TsigSecret != nil {
   252  		return tsigSecretProvider(srv.TsigSecret)
   253  	}
   254  	return nil
   255  }
   256  
   257  func (srv *Server) isStarted() bool {
   258  	srv.lock.RLock()
   259  	started := srv.started
   260  	srv.lock.RUnlock()
   261  	return started
   262  }
   263  
   264  func makeUDPBuffer(size int) func() interface{} {
   265  	return func() interface{} {
   266  		return make([]byte, size)
   267  	}
   268  }
   269  
   270  func (srv *Server) init() {
   271  	srv.shutdown = make(chan struct{})
   272  	srv.conns = make(map[net.Conn]struct{})
   273  
   274  	if srv.UDPSize == 0 {
   275  		srv.UDPSize = MinMsgSize
   276  	}
   277  	if srv.MsgAcceptFunc == nil {
   278  		srv.MsgAcceptFunc = DefaultMsgAcceptFunc
   279  	}
   280  	if srv.Handler == nil {
   281  		srv.Handler = DefaultServeMux
   282  	}
   283  
   284  	srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
   285  }
   286  
   287  func unlockOnce(l sync.Locker) func() {
   288  	var once sync.Once
   289  	return func() { once.Do(l.Unlock) }
   290  }
   291  
   292  // ListenAndServe starts a nameserver on the configured address in *Server.
   293  func (srv *Server) ListenAndServe() error {
   294  	unlock := unlockOnce(&srv.lock)
   295  	srv.lock.Lock()
   296  	defer unlock()
   297  
   298  	if srv.started {
   299  		return &Error{err: "server already started"}
   300  	}
   301  
   302  	addr := srv.Addr
   303  	if addr == "" {
   304  		addr = ":domain"
   305  	}
   306  
   307  	srv.init()
   308  
   309  	switch srv.Net {
   310  	case "tcp", "tcp4", "tcp6":
   311  		l, err := listenTCP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
   312  		if err != nil {
   313  			return err
   314  		}
   315  		srv.Listener = l
   316  		srv.started = true
   317  		unlock()
   318  		return srv.serveTCP(l)
   319  	case "tcp-tls", "tcp4-tls", "tcp6-tls":
   320  		if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
   321  			return errors.New("dns: neither Certificates nor GetCertificate set in Config")
   322  		}
   323  		network := strings.TrimSuffix(srv.Net, "-tls")
   324  		l, err := listenTCP(network, addr, srv.ReusePort, srv.ReuseAddr)
   325  		if err != nil {
   326  			return err
   327  		}
   328  		l = tls.NewListener(l, srv.TLSConfig)
   329  		srv.Listener = l
   330  		srv.started = true
   331  		unlock()
   332  		return srv.serveTCP(l)
   333  	case "udp", "udp4", "udp6":
   334  		l, err := listenUDP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
   335  		if err != nil {
   336  			return err
   337  		}
   338  		u := l.(*net.UDPConn)
   339  		if e := setUDPSocketOptions(u); e != nil {
   340  			u.Close()
   341  			return e
   342  		}
   343  		srv.PacketConn = l
   344  		srv.started = true
   345  		unlock()
   346  		return srv.serveUDP(u)
   347  	}
   348  	return &Error{err: "bad network"}
   349  }
   350  
   351  // ActivateAndServe starts a nameserver with the PacketConn or Listener
   352  // configured in *Server. Its main use is to start a server from systemd.
   353  func (srv *Server) ActivateAndServe() error {
   354  	unlock := unlockOnce(&srv.lock)
   355  	srv.lock.Lock()
   356  	defer unlock()
   357  
   358  	if srv.started {
   359  		return &Error{err: "server already started"}
   360  	}
   361  
   362  	srv.init()
   363  
   364  	if srv.PacketConn != nil {
   365  		// Check PacketConn interface's type is valid and value
   366  		// is not nil
   367  		if t, ok := srv.PacketConn.(*net.UDPConn); ok && t != nil {
   368  			if e := setUDPSocketOptions(t); e != nil {
   369  				return e
   370  			}
   371  		}
   372  		srv.started = true
   373  		unlock()
   374  		return srv.serveUDP(srv.PacketConn)
   375  	}
   376  	if srv.Listener != nil {
   377  		srv.started = true
   378  		unlock()
   379  		return srv.serveTCP(srv.Listener)
   380  	}
   381  	return &Error{err: "bad listeners"}
   382  }
   383  
   384  // Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and
   385  // ActivateAndServe will return.
   386  func (srv *Server) Shutdown() error {
   387  	return srv.ShutdownContext(context.Background())
   388  }
   389  
   390  // ShutdownContext shuts down a server. After a call to ShutdownContext,
   391  // ListenAndServe and ActivateAndServe will return.
   392  //
   393  // A context.Context may be passed to limit how long to wait for connections
   394  // to terminate.
   395  func (srv *Server) ShutdownContext(ctx context.Context) error {
   396  	srv.lock.Lock()
   397  	if !srv.started {
   398  		srv.lock.Unlock()
   399  		return &Error{err: "server not started"}
   400  	}
   401  
   402  	srv.started = false
   403  
   404  	if srv.PacketConn != nil {
   405  		srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
   406  	}
   407  
   408  	if srv.Listener != nil {
   409  		srv.Listener.Close()
   410  	}
   411  
   412  	for rw := range srv.conns {
   413  		rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
   414  	}
   415  
   416  	srv.lock.Unlock()
   417  
   418  	if testShutdownNotify != nil {
   419  		testShutdownNotify.Broadcast()
   420  	}
   421  
   422  	var ctxErr error
   423  	select {
   424  	case <-srv.shutdown:
   425  	case <-ctx.Done():
   426  		ctxErr = ctx.Err()
   427  	}
   428  
   429  	if srv.PacketConn != nil {
   430  		srv.PacketConn.Close()
   431  	}
   432  
   433  	return ctxErr
   434  }
   435  
   436  var testShutdownNotify *sync.Cond
   437  
   438  // getReadTimeout is a helper func to use system timeout if server did not intend to change it.
   439  func (srv *Server) getReadTimeout() time.Duration {
   440  	if srv.ReadTimeout != 0 {
   441  		return srv.ReadTimeout
   442  	}
   443  	return dnsTimeout
   444  }
   445  
   446  // serveTCP starts a TCP listener for the server.
   447  func (srv *Server) serveTCP(l net.Listener) error {
   448  	defer l.Close()
   449  
   450  	if srv.NotifyStartedFunc != nil {
   451  		srv.NotifyStartedFunc()
   452  	}
   453  
   454  	var wg sync.WaitGroup
   455  	defer func() {
   456  		wg.Wait()
   457  		close(srv.shutdown)
   458  	}()
   459  
   460  	for srv.isStarted() {
   461  		rw, err := l.Accept()
   462  		if err != nil {
   463  			if !srv.isStarted() {
   464  				return nil
   465  			}
   466  			if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
   467  				continue
   468  			}
   469  			return err
   470  		}
   471  		srv.lock.Lock()
   472  		// Track the connection to allow unblocking reads on shutdown.
   473  		srv.conns[rw] = struct{}{}
   474  		srv.lock.Unlock()
   475  		wg.Add(1)
   476  		go srv.serveTCPConn(&wg, rw)
   477  	}
   478  
   479  	return nil
   480  }
   481  
   482  // serveUDP starts a UDP listener for the server.
   483  func (srv *Server) serveUDP(l net.PacketConn) error {
   484  	defer l.Close()
   485  
   486  	reader := Reader(defaultReader{srv})
   487  	if srv.DecorateReader != nil {
   488  		reader = srv.DecorateReader(reader)
   489  	}
   490  
   491  	lUDP, isUDP := l.(*net.UDPConn)
   492  	readerPC, canPacketConn := reader.(PacketConnReader)
   493  	if !isUDP && !canPacketConn {
   494  		return &Error{err: "PacketConnReader was not implemented on Reader returned from DecorateReader but is required for net.PacketConn"}
   495  	}
   496  
   497  	if srv.NotifyStartedFunc != nil {
   498  		srv.NotifyStartedFunc()
   499  	}
   500  
   501  	var wg sync.WaitGroup
   502  	defer func() {
   503  		wg.Wait()
   504  		close(srv.shutdown)
   505  	}()
   506  
   507  	rtimeout := srv.getReadTimeout()
   508  	// deadline is not used here
   509  	for srv.isStarted() {
   510  		var (
   511  			m    []byte
   512  			sPC  net.Addr
   513  			sUDP *SessionUDP
   514  			err  error
   515  		)
   516  		if isUDP {
   517  			m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
   518  		} else {
   519  			m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
   520  		}
   521  		if err != nil {
   522  			if !srv.isStarted() {
   523  				return nil
   524  			}
   525  			if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
   526  				continue
   527  			}
   528  			return err
   529  		}
   530  		if len(m) < headerSize {
   531  			if cap(m) == srv.UDPSize {
   532  				srv.udpPool.Put(m[:srv.UDPSize])
   533  			}
   534  			continue
   535  		}
   536  		wg.Add(1)
   537  		go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
   538  	}
   539  
   540  	return nil
   541  }
   542  
   543  // Serve a new TCP connection.
   544  func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
   545  	w := &response{tsigProvider: srv.tsigProvider(), tcp: rw}
   546  	if srv.DecorateWriter != nil {
   547  		w.writer = srv.DecorateWriter(w)
   548  	} else {
   549  		w.writer = w
   550  	}
   551  
   552  	reader := Reader(defaultReader{srv})
   553  	if srv.DecorateReader != nil {
   554  		reader = srv.DecorateReader(reader)
   555  	}
   556  
   557  	idleTimeout := tcpIdleTimeout
   558  	if srv.IdleTimeout != nil {
   559  		idleTimeout = srv.IdleTimeout()
   560  	}
   561  
   562  	timeout := srv.getReadTimeout()
   563  
   564  	limit := srv.MaxTCPQueries
   565  	if limit == 0 {
   566  		limit = maxTCPQueries
   567  	}
   568  
   569  	for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
   570  		m, err := reader.ReadTCP(w.tcp, timeout)
   571  		if err != nil {
   572  			// TODO(tmthrgd): handle error
   573  			break
   574  		}
   575  		srv.serveDNS(m, w)
   576  		if w.closed {
   577  			break // Close() was called
   578  		}
   579  		if w.hijacked {
   580  			break // client will call Close() themselves
   581  		}
   582  		// The first read uses the read timeout, the rest use the
   583  		// idle timeout.
   584  		timeout = idleTimeout
   585  	}
   586  
   587  	if !w.hijacked {
   588  		w.Close()
   589  	}
   590  
   591  	srv.lock.Lock()
   592  	delete(srv.conns, w.tcp)
   593  	srv.lock.Unlock()
   594  
   595  	wg.Done()
   596  }
   597  
   598  // Serve a new UDP request.
   599  func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
   600  	w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
   601  	if srv.DecorateWriter != nil {
   602  		w.writer = srv.DecorateWriter(w)
   603  	} else {
   604  		w.writer = w
   605  	}
   606  
   607  	srv.serveDNS(m, w)
   608  	wg.Done()
   609  }
   610  
   611  func (srv *Server) serveDNS(m []byte, w *response) {
   612  	dh, off, err := unpackMsgHdr(m, 0)
   613  	if err != nil {
   614  		// Let client hang, they are sending crap; any reply can be used to amplify.
   615  		return
   616  	}
   617  
   618  	req := new(Msg)
   619  	req.setHdr(dh)
   620  
   621  	switch action := srv.MsgAcceptFunc(dh); action {
   622  	case MsgAccept:
   623  		if req.unpack(dh, m, off) == nil {
   624  			break
   625  		}
   626  
   627  		fallthrough
   628  	case MsgReject, MsgRejectNotImplemented:
   629  		opcode := req.Opcode
   630  		req.SetRcodeFormatError(req)
   631  		req.Zero = false
   632  		if action == MsgRejectNotImplemented {
   633  			req.Opcode = opcode
   634  			req.Rcode = RcodeNotImplemented
   635  		}
   636  
   637  		// Are we allowed to delete any OPT records here?
   638  		req.Ns, req.Answer, req.Extra = nil, nil, nil
   639  
   640  		w.WriteMsg(req)
   641  		fallthrough
   642  	case MsgIgnore:
   643  		if w.udp != nil && cap(m) == srv.UDPSize {
   644  			srv.udpPool.Put(m[:srv.UDPSize])
   645  		}
   646  
   647  		return
   648  	}
   649  
   650  	w.tsigStatus = nil
   651  	if w.tsigProvider != nil {
   652  		if t := req.IsTsig(); t != nil {
   653  			w.tsigStatus = TsigVerifyWithProvider(m, w.tsigProvider, "", false)
   654  			w.tsigTimersOnly = false
   655  			w.tsigRequestMAC = t.MAC
   656  		}
   657  	}
   658  
   659  	if w.udp != nil && cap(m) == srv.UDPSize {
   660  		srv.udpPool.Put(m[:srv.UDPSize])
   661  	}
   662  
   663  	srv.Handler.ServeDNS(w, req) // Writes back to the client
   664  }
   665  
   666  func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
   667  	// If we race with ShutdownContext, the read deadline may
   668  	// have been set in the distant past to unblock the read
   669  	// below. We must not override it, otherwise we may block
   670  	// ShutdownContext.
   671  	srv.lock.RLock()
   672  	if srv.started {
   673  		conn.SetReadDeadline(time.Now().Add(timeout))
   674  	}
   675  	srv.lock.RUnlock()
   676  
   677  	var length uint16
   678  	if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
   679  		return nil, err
   680  	}
   681  
   682  	m := make([]byte, length)
   683  	if _, err := io.ReadFull(conn, m); err != nil {
   684  		return nil, err
   685  	}
   686  
   687  	return m, nil
   688  }
   689  
   690  func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
   691  	srv.lock.RLock()
   692  	if srv.started {
   693  		// See the comment in readTCP above.
   694  		conn.SetReadDeadline(time.Now().Add(timeout))
   695  	}
   696  	srv.lock.RUnlock()
   697  
   698  	m := srv.udpPool.Get().([]byte)
   699  	n, s, err := ReadFromSessionUDP(conn, m)
   700  	if err != nil {
   701  		srv.udpPool.Put(m)
   702  		return nil, nil, err
   703  	}
   704  	m = m[:n]
   705  	return m, s, nil
   706  }
   707  
   708  func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
   709  	srv.lock.RLock()
   710  	if srv.started {
   711  		// See the comment in readTCP above.
   712  		conn.SetReadDeadline(time.Now().Add(timeout))
   713  	}
   714  	srv.lock.RUnlock()
   715  
   716  	m := srv.udpPool.Get().([]byte)
   717  	n, addr, err := conn.ReadFrom(m)
   718  	if err != nil {
   719  		srv.udpPool.Put(m)
   720  		return nil, nil, err
   721  	}
   722  	m = m[:n]
   723  	return m, addr, nil
   724  }
   725  
   726  // WriteMsg implements the ResponseWriter.WriteMsg method.
   727  func (w *response) WriteMsg(m *Msg) (err error) {
   728  	if w.closed {
   729  		return &Error{err: "WriteMsg called after Close"}
   730  	}
   731  
   732  	var data []byte
   733  	if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
   734  		if t := m.IsTsig(); t != nil {
   735  			data, w.tsigRequestMAC, err = TsigGenerateWithProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
   736  			if err != nil {
   737  				return err
   738  			}
   739  			_, err = w.writer.Write(data)
   740  			return err
   741  		}
   742  	}
   743  	data, err = m.Pack()
   744  	if err != nil {
   745  		return err
   746  	}
   747  	_, err = w.writer.Write(data)
   748  	return err
   749  }
   750  
   751  // Write implements the ResponseWriter.Write method.
   752  func (w *response) Write(m []byte) (int, error) {
   753  	if w.closed {
   754  		return 0, &Error{err: "Write called after Close"}
   755  	}
   756  
   757  	switch {
   758  	case w.udp != nil:
   759  		if u, ok := w.udp.(*net.UDPConn); ok {
   760  			return WriteToSessionUDP(u, m, w.udpSession)
   761  		}
   762  		return w.udp.WriteTo(m, w.pcSession)
   763  	case w.tcp != nil:
   764  		if len(m) > MaxMsgSize {
   765  			return 0, &Error{err: "message too large"}
   766  		}
   767  
   768  		msg := make([]byte, 2+len(m))
   769  		binary.BigEndian.PutUint16(msg, uint16(len(m)))
   770  		copy(msg[2:], m)
   771  		return w.tcp.Write(msg)
   772  	default:
   773  		panic("dns: internal error: udp and tcp both nil")
   774  	}
   775  }
   776  
   777  // LocalAddr implements the ResponseWriter.LocalAddr method.
   778  func (w *response) LocalAddr() net.Addr {
   779  	switch {
   780  	case w.udp != nil:
   781  		return w.udp.LocalAddr()
   782  	case w.tcp != nil:
   783  		return w.tcp.LocalAddr()
   784  	default:
   785  		panic("dns: internal error: udp and tcp both nil")
   786  	}
   787  }
   788  
   789  // RemoteAddr implements the ResponseWriter.RemoteAddr method.
   790  func (w *response) RemoteAddr() net.Addr {
   791  	switch {
   792  	case w.udpSession != nil:
   793  		return w.udpSession.RemoteAddr()
   794  	case w.pcSession != nil:
   795  		return w.pcSession
   796  	case w.tcp != nil:
   797  		return w.tcp.RemoteAddr()
   798  	default:
   799  		panic("dns: internal error: udpSession, pcSession and tcp are all nil")
   800  	}
   801  }
   802  
   803  // TsigStatus implements the ResponseWriter.TsigStatus method.
   804  func (w *response) TsigStatus() error { return w.tsigStatus }
   805  
   806  // TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method.
   807  func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
   808  
   809  // Hijack implements the ResponseWriter.Hijack method.
   810  func (w *response) Hijack() { w.hijacked = true }
   811  
   812  // Close implements the ResponseWriter.Close method
   813  func (w *response) Close() error {
   814  	if w.closed {
   815  		return &Error{err: "connection already closed"}
   816  	}
   817  	w.closed = true
   818  
   819  	switch {
   820  	case w.udp != nil:
   821  		// Can't close the udp conn, as that is actually the listener.
   822  		return nil
   823  	case w.tcp != nil:
   824  		return w.tcp.Close()
   825  	default:
   826  		panic("dns: internal error: udp and tcp both nil")
   827  	}
   828  }
   829  
   830  // ConnectionState() implements the ConnectionStater.ConnectionState() interface.
   831  func (w *response) ConnectionState() *tls.ConnectionState {
   832  	type tlsConnectionStater interface {
   833  		ConnectionState() tls.ConnectionState
   834  	}
   835  	if v, ok := w.tcp.(tlsConnectionStater); ok {
   836  		t := v.ConnectionState()
   837  		return &t
   838  	}
   839  	return nil
   840  }
   841  

View as plain text