...

Source file src/github.com/miekg/dns/msg.go

Documentation: github.com/miekg/dns

     1  // DNS packet assembly, see RFC 1035. Converting from - Unpack() -
     2  // and to - Pack() - wire format.
     3  // All the packers and unpackers take a (msg []byte, off int)
     4  // and return (off1 int, ok bool).  If they return ok==false, they
     5  // also return off1==len(msg), so that the next unpacker will
     6  // also fail.  This lets us avoid checks of ok until the end of a
     7  // packing sequence.
     8  
     9  package dns
    10  
    11  //go:generate go run msg_generate.go
    12  
    13  import (
    14  	"crypto/rand"
    15  	"encoding/binary"
    16  	"fmt"
    17  	"math/big"
    18  	"strconv"
    19  	"strings"
    20  )
    21  
    22  const (
    23  	maxCompressionOffset    = 2 << 13 // We have 14 bits for the compression pointer
    24  	maxDomainNameWireOctets = 255     // See RFC 1035 section 2.3.4
    25  
    26  	// This is the maximum number of compression pointers that should occur in a
    27  	// semantically valid message. Each label in a domain name must be at least one
    28  	// octet and is separated by a period. The root label won't be represented by a
    29  	// compression pointer to a compression pointer, hence the -2 to exclude the
    30  	// smallest valid root label.
    31  	//
    32  	// It is possible to construct a valid message that has more compression pointers
    33  	// than this, and still doesn't loop, by pointing to a previous pointer. This is
    34  	// not something a well written implementation should ever do, so we leave them
    35  	// to trip the maximum compression pointer check.
    36  	maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2
    37  
    38  	// This is the maximum length of a domain name in presentation format. The
    39  	// maximum wire length of a domain name is 255 octets (see above), with the
    40  	// maximum label length being 63. The wire format requires one extra byte over
    41  	// the presentation format, reducing the number of octets by 1. Each label in
    42  	// the name will be separated by a single period, with each octet in the label
    43  	// expanding to at most 4 bytes (\DDD). If all other labels are of the maximum
    44  	// length, then the final label can only be 61 octets long to not exceed the
    45  	// maximum allowed wire length.
    46  	maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1
    47  )
    48  
    49  // Errors defined in this package.
    50  var (
    51  	ErrAlg           error = &Error{err: "bad algorithm"}                  // ErrAlg indicates an error with the (DNSSEC) algorithm.
    52  	ErrAuth          error = &Error{err: "bad authentication"}             // ErrAuth indicates an error in the TSIG authentication.
    53  	ErrBuf           error = &Error{err: "buffer size too small"}          // ErrBuf indicates that the buffer used is too small for the message.
    54  	ErrConnEmpty     error = &Error{err: "conn has no connection"}         // ErrConnEmpty indicates a connection is being used before it is initialized.
    55  	ErrExtendedRcode error = &Error{err: "bad extended rcode"}             // ErrExtendedRcode ...
    56  	ErrFqdn          error = &Error{err: "domain must be fully qualified"} // ErrFqdn indicates that a domain name does not have a closing dot.
    57  	ErrId            error = &Error{err: "id mismatch"}                    // ErrId indicates there is a mismatch with the message's ID.
    58  	ErrKeyAlg        error = &Error{err: "bad key algorithm"}              // ErrKeyAlg indicates that the algorithm in the key is not valid.
    59  	ErrKey           error = &Error{err: "bad key"}
    60  	ErrKeySize       error = &Error{err: "bad key size"}
    61  	ErrLongDomain    error = &Error{err: fmt.Sprintf("domain name exceeded %d wire-format octets", maxDomainNameWireOctets)}
    62  	ErrNoSig         error = &Error{err: "no signature found"}
    63  	ErrPrivKey       error = &Error{err: "bad private key"}
    64  	ErrRcode         error = &Error{err: "bad rcode"}
    65  	ErrRdata         error = &Error{err: "bad rdata"}
    66  	ErrRRset         error = &Error{err: "bad rrset"}
    67  	ErrSecret        error = &Error{err: "no secrets defined"}
    68  	ErrShortRead     error = &Error{err: "short read"}
    69  	ErrSig           error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
    70  	ErrSoa           error = &Error{err: "no SOA"}        // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
    71  	ErrTime          error = &Error{err: "bad time"}      // ErrTime indicates a timing error in TSIG authentication.
    72  )
    73  
    74  // Id by default returns a 16-bit random number to be used as a message id. The
    75  // number is drawn from a cryptographically secure random number generator.
    76  // This being a variable the function can be reassigned to a custom function.
    77  // For instance, to make it return a static value for testing:
    78  //
    79  //	dns.Id = func() uint16 { return 3 }
    80  var Id = id
    81  
    82  // id returns a 16 bits random number to be used as a
    83  // message id. The random provided should be good enough.
    84  func id() uint16 {
    85  	var output uint16
    86  	err := binary.Read(rand.Reader, binary.BigEndian, &output)
    87  	if err != nil {
    88  		panic("dns: reading random id failed: " + err.Error())
    89  	}
    90  	return output
    91  }
    92  
    93  // MsgHdr is a a manually-unpacked version of (id, bits).
    94  type MsgHdr struct {
    95  	Id                 uint16
    96  	Response           bool
    97  	Opcode             int
    98  	Authoritative      bool
    99  	Truncated          bool
   100  	RecursionDesired   bool
   101  	RecursionAvailable bool
   102  	Zero               bool
   103  	AuthenticatedData  bool
   104  	CheckingDisabled   bool
   105  	Rcode              int
   106  }
   107  
   108  // Msg contains the layout of a DNS message.
   109  type Msg struct {
   110  	MsgHdr
   111  	Compress bool       `json:"-"` // If true, the message will be compressed when converted to wire format.
   112  	Question []Question // Holds the RR(s) of the question section.
   113  	Answer   []RR       // Holds the RR(s) of the answer section.
   114  	Ns       []RR       // Holds the RR(s) of the authority section.
   115  	Extra    []RR       // Holds the RR(s) of the additional section.
   116  }
   117  
   118  // ClassToString is a maps Classes to strings for each CLASS wire type.
   119  var ClassToString = map[uint16]string{
   120  	ClassINET:   "IN",
   121  	ClassCSNET:  "CS",
   122  	ClassCHAOS:  "CH",
   123  	ClassHESIOD: "HS",
   124  	ClassNONE:   "NONE",
   125  	ClassANY:    "ANY",
   126  }
   127  
   128  // OpcodeToString maps Opcodes to strings.
   129  var OpcodeToString = map[int]string{
   130  	OpcodeQuery:  "QUERY",
   131  	OpcodeIQuery: "IQUERY",
   132  	OpcodeStatus: "STATUS",
   133  	OpcodeNotify: "NOTIFY",
   134  	OpcodeUpdate: "UPDATE",
   135  }
   136  
   137  // RcodeToString maps Rcodes to strings.
   138  var RcodeToString = map[int]string{
   139  	RcodeSuccess:        "NOERROR",
   140  	RcodeFormatError:    "FORMERR",
   141  	RcodeServerFailure:  "SERVFAIL",
   142  	RcodeNameError:      "NXDOMAIN",
   143  	RcodeNotImplemented: "NOTIMP",
   144  	RcodeRefused:        "REFUSED",
   145  	RcodeYXDomain:       "YXDOMAIN", // See RFC 2136
   146  	RcodeYXRrset:        "YXRRSET",
   147  	RcodeNXRrset:        "NXRRSET",
   148  	RcodeNotAuth:        "NOTAUTH",
   149  	RcodeNotZone:        "NOTZONE",
   150  	RcodeBadSig:         "BADSIG", // Also known as RcodeBadVers, see RFC 6891
   151  	//	RcodeBadVers:        "BADVERS",
   152  	RcodeBadKey:    "BADKEY",
   153  	RcodeBadTime:   "BADTIME",
   154  	RcodeBadMode:   "BADMODE",
   155  	RcodeBadName:   "BADNAME",
   156  	RcodeBadAlg:    "BADALG",
   157  	RcodeBadTrunc:  "BADTRUNC",
   158  	RcodeBadCookie: "BADCOOKIE",
   159  }
   160  
   161  // compressionMap is used to allow a more efficient compression map
   162  // to be used for internal packDomainName calls without changing the
   163  // signature or functionality of public API.
   164  //
   165  // In particular, map[string]uint16 uses 25% less per-entry memory
   166  // than does map[string]int.
   167  type compressionMap struct {
   168  	ext map[string]int    // external callers
   169  	int map[string]uint16 // internal callers
   170  }
   171  
   172  func (m compressionMap) valid() bool {
   173  	return m.int != nil || m.ext != nil
   174  }
   175  
   176  func (m compressionMap) insert(s string, pos int) {
   177  	if m.ext != nil {
   178  		m.ext[s] = pos
   179  	} else {
   180  		m.int[s] = uint16(pos)
   181  	}
   182  }
   183  
   184  func (m compressionMap) find(s string) (int, bool) {
   185  	if m.ext != nil {
   186  		pos, ok := m.ext[s]
   187  		return pos, ok
   188  	}
   189  
   190  	pos, ok := m.int[s]
   191  	return int(pos), ok
   192  }
   193  
   194  // Domain names are a sequence of counted strings
   195  // split at the dots. They end with a zero-length string.
   196  
   197  // PackDomainName packs a domain name s into msg[off:].
   198  // If compression is wanted compress must be true and the compression
   199  // map needs to hold a mapping between domain names and offsets
   200  // pointing into msg.
   201  func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
   202  	return packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
   203  }
   204  
   205  func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
   206  	// XXX: A logical copy of this function exists in IsDomainName and
   207  	// should be kept in sync with this function.
   208  
   209  	ls := len(s)
   210  	if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
   211  		return off, nil
   212  	}
   213  
   214  	// If not fully qualified, error out.
   215  	if !IsFqdn(s) {
   216  		return len(msg), ErrFqdn
   217  	}
   218  
   219  	// Each dot ends a segment of the name.
   220  	// We trade each dot byte for a length byte.
   221  	// Except for escaped dots (\.), which are normal dots.
   222  	// There is also a trailing zero.
   223  
   224  	// Compression
   225  	pointer := -1
   226  
   227  	// Emit sequence of counted strings, chopping at dots.
   228  	var (
   229  		begin     int
   230  		compBegin int
   231  		compOff   int
   232  		bs        []byte
   233  		wasDot    bool
   234  	)
   235  loop:
   236  	for i := 0; i < ls; i++ {
   237  		var c byte
   238  		if bs == nil {
   239  			c = s[i]
   240  		} else {
   241  			c = bs[i]
   242  		}
   243  
   244  		switch c {
   245  		case '\\':
   246  			if off+1 > len(msg) {
   247  				return len(msg), ErrBuf
   248  			}
   249  
   250  			if bs == nil {
   251  				bs = []byte(s)
   252  			}
   253  
   254  			// check for \DDD
   255  			if isDDD(bs[i+1:]) {
   256  				bs[i] = dddToByte(bs[i+1:])
   257  				copy(bs[i+1:ls-3], bs[i+4:])
   258  				ls -= 3
   259  				compOff += 3
   260  			} else {
   261  				copy(bs[i:ls-1], bs[i+1:])
   262  				ls--
   263  				compOff++
   264  			}
   265  
   266  			wasDot = false
   267  		case '.':
   268  			if i == 0 && len(s) > 1 {
   269  				// leading dots are not legal except for the root zone
   270  				return len(msg), ErrRdata
   271  			}
   272  
   273  			if wasDot {
   274  				// two dots back to back is not legal
   275  				return len(msg), ErrRdata
   276  			}
   277  			wasDot = true
   278  
   279  			labelLen := i - begin
   280  			if labelLen >= 1<<6 { // top two bits of length must be clear
   281  				return len(msg), ErrRdata
   282  			}
   283  
   284  			// off can already (we're in a loop) be bigger than len(msg)
   285  			// this happens when a name isn't fully qualified
   286  			if off+1+labelLen > len(msg) {
   287  				return len(msg), ErrBuf
   288  			}
   289  
   290  			// Don't try to compress '.'
   291  			// We should only compress when compress is true, but we should also still pick
   292  			// up names that can be used for *future* compression(s).
   293  			if compression.valid() && !isRootLabel(s, bs, begin, ls) {
   294  				if p, ok := compression.find(s[compBegin:]); ok {
   295  					// The first hit is the longest matching dname
   296  					// keep the pointer offset we get back and store
   297  					// the offset of the current name, because that's
   298  					// where we need to insert the pointer later
   299  
   300  					// If compress is true, we're allowed to compress this dname
   301  					if compress {
   302  						pointer = p // Where to point to
   303  						break loop
   304  					}
   305  				} else if off < maxCompressionOffset {
   306  					// Only offsets smaller than maxCompressionOffset can be used.
   307  					compression.insert(s[compBegin:], off)
   308  				}
   309  			}
   310  
   311  			// The following is covered by the length check above.
   312  			msg[off] = byte(labelLen)
   313  
   314  			if bs == nil {
   315  				copy(msg[off+1:], s[begin:i])
   316  			} else {
   317  				copy(msg[off+1:], bs[begin:i])
   318  			}
   319  			off += 1 + labelLen
   320  
   321  			begin = i + 1
   322  			compBegin = begin + compOff
   323  		default:
   324  			wasDot = false
   325  		}
   326  	}
   327  
   328  	// Root label is special
   329  	if isRootLabel(s, bs, 0, ls) {
   330  		return off, nil
   331  	}
   332  
   333  	// If we did compression and we find something add the pointer here
   334  	if pointer != -1 {
   335  		// We have two bytes (14 bits) to put the pointer in
   336  		binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
   337  		return off + 2, nil
   338  	}
   339  
   340  	if off < len(msg) {
   341  		msg[off] = 0
   342  	}
   343  
   344  	return off + 1, nil
   345  }
   346  
   347  // isRootLabel returns whether s or bs, from off to end, is the root
   348  // label ".".
   349  //
   350  // If bs is nil, s will be checked, otherwise bs will be checked.
   351  func isRootLabel(s string, bs []byte, off, end int) bool {
   352  	if bs == nil {
   353  		return s[off:end] == "."
   354  	}
   355  
   356  	return end-off == 1 && bs[off] == '.'
   357  }
   358  
   359  // Unpack a domain name.
   360  // In addition to the simple sequences of counted strings above,
   361  // domain names are allowed to refer to strings elsewhere in the
   362  // packet, to avoid repeating common suffixes when returning
   363  // many entries in a single domain.  The pointers are marked
   364  // by a length byte with the top two bits set.  Ignoring those
   365  // two bits, that byte and the next give a 14 bit offset from msg[0]
   366  // where we should pick up the trail.
   367  // Note that if we jump elsewhere in the packet,
   368  // we return off1 == the offset after the first pointer we found,
   369  // which is where the next record will start.
   370  // In theory, the pointers are only allowed to jump backward.
   371  // We let them jump anywhere and stop jumping after a while.
   372  
   373  // UnpackDomainName unpacks a domain name into a string. It returns
   374  // the name, the new offset into msg and any error that occurred.
   375  //
   376  // When an error is encountered, the unpacked name will be discarded
   377  // and len(msg) will be returned as the offset.
   378  func UnpackDomainName(msg []byte, off int) (string, int, error) {
   379  	s := make([]byte, 0, maxDomainNamePresentationLength)
   380  	off1 := 0
   381  	lenmsg := len(msg)
   382  	budget := maxDomainNameWireOctets
   383  	ptr := 0 // number of pointers followed
   384  Loop:
   385  	for {
   386  		if off >= lenmsg {
   387  			return "", lenmsg, ErrBuf
   388  		}
   389  		c := int(msg[off])
   390  		off++
   391  		switch c & 0xC0 {
   392  		case 0x00:
   393  			if c == 0x00 {
   394  				// end of name
   395  				break Loop
   396  			}
   397  			// literal string
   398  			if off+c > lenmsg {
   399  				return "", lenmsg, ErrBuf
   400  			}
   401  			budget -= c + 1 // +1 for the label separator
   402  			if budget <= 0 {
   403  				return "", lenmsg, ErrLongDomain
   404  			}
   405  			for _, b := range msg[off : off+c] {
   406  				if isDomainNameLabelSpecial(b) {
   407  					s = append(s, '\\', b)
   408  				} else if b < ' ' || b > '~' {
   409  					s = append(s, escapeByte(b)...)
   410  				} else {
   411  					s = append(s, b)
   412  				}
   413  			}
   414  			s = append(s, '.')
   415  			off += c
   416  		case 0xC0:
   417  			// pointer to somewhere else in msg.
   418  			// remember location after first ptr,
   419  			// since that's how many bytes we consumed.
   420  			// also, don't follow too many pointers --
   421  			// maybe there's a loop.
   422  			if off >= lenmsg {
   423  				return "", lenmsg, ErrBuf
   424  			}
   425  			c1 := msg[off]
   426  			off++
   427  			if ptr == 0 {
   428  				off1 = off
   429  			}
   430  			if ptr++; ptr > maxCompressionPointers {
   431  				return "", lenmsg, &Error{err: "too many compression pointers"}
   432  			}
   433  			// pointer should guarantee that it advances and points forwards at least
   434  			// but the condition on previous three lines guarantees that it's
   435  			// at least loop-free
   436  			off = (c^0xC0)<<8 | int(c1)
   437  		default:
   438  			// 0x80 and 0x40 are reserved
   439  			return "", lenmsg, ErrRdata
   440  		}
   441  	}
   442  	if ptr == 0 {
   443  		off1 = off
   444  	}
   445  	if len(s) == 0 {
   446  		return ".", off1, nil
   447  	}
   448  	return string(s), off1, nil
   449  }
   450  
   451  func packTxt(txt []string, msg []byte, offset int) (int, error) {
   452  	if len(txt) == 0 {
   453  		if offset >= len(msg) {
   454  			return offset, ErrBuf
   455  		}
   456  		msg[offset] = 0
   457  		return offset, nil
   458  	}
   459  	var err error
   460  	for _, s := range txt {
   461  		offset, err = packTxtString(s, msg, offset)
   462  		if err != nil {
   463  			return offset, err
   464  		}
   465  	}
   466  	return offset, nil
   467  }
   468  
   469  func packTxtString(s string, msg []byte, offset int) (int, error) {
   470  	lenByteOffset := offset
   471  	if offset >= len(msg) || len(s) > 256*4+1 /* If all \DDD */ {
   472  		return offset, ErrBuf
   473  	}
   474  	offset++
   475  	for i := 0; i < len(s); i++ {
   476  		if len(msg) <= offset {
   477  			return offset, ErrBuf
   478  		}
   479  		if s[i] == '\\' {
   480  			i++
   481  			if i == len(s) {
   482  				break
   483  			}
   484  			// check for \DDD
   485  			if isDDD(s[i:]) {
   486  				msg[offset] = dddToByte(s[i:])
   487  				i += 2
   488  			} else {
   489  				msg[offset] = s[i]
   490  			}
   491  		} else {
   492  			msg[offset] = s[i]
   493  		}
   494  		offset++
   495  	}
   496  	l := offset - lenByteOffset - 1
   497  	if l > 255 {
   498  		return offset, &Error{err: "string exceeded 255 bytes in txt"}
   499  	}
   500  	msg[lenByteOffset] = byte(l)
   501  	return offset, nil
   502  }
   503  
   504  func packOctetString(s string, msg []byte, offset int) (int, error) {
   505  	if offset >= len(msg) || len(s) > 256*4+1 {
   506  		return offset, ErrBuf
   507  	}
   508  	for i := 0; i < len(s); i++ {
   509  		if len(msg) <= offset {
   510  			return offset, ErrBuf
   511  		}
   512  		if s[i] == '\\' {
   513  			i++
   514  			if i == len(s) {
   515  				break
   516  			}
   517  			// check for \DDD
   518  			if isDDD(s[i:]) {
   519  				msg[offset] = dddToByte(s[i:])
   520  				i += 2
   521  			} else {
   522  				msg[offset] = s[i]
   523  			}
   524  		} else {
   525  			msg[offset] = s[i]
   526  		}
   527  		offset++
   528  	}
   529  	return offset, nil
   530  }
   531  
   532  func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
   533  	off = off0
   534  	var s string
   535  	for off < len(msg) && err == nil {
   536  		s, off, err = unpackString(msg, off)
   537  		if err == nil {
   538  			ss = append(ss, s)
   539  		}
   540  	}
   541  	return
   542  }
   543  
   544  // Helpers for dealing with escaped bytes
   545  func isDigit(b byte) bool { return b >= '0' && b <= '9' }
   546  
   547  func isDDD[T ~[]byte | ~string](s T) bool {
   548  	return len(s) >= 3 && isDigit(s[0]) && isDigit(s[1]) && isDigit(s[2])
   549  }
   550  
   551  func dddToByte[T ~[]byte | ~string](s T) byte {
   552  	_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
   553  	return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
   554  }
   555  
   556  // Helper function for packing and unpacking
   557  func intToBytes(i *big.Int, length int) []byte {
   558  	buf := i.Bytes()
   559  	if len(buf) < length {
   560  		b := make([]byte, length)
   561  		copy(b[length-len(buf):], buf)
   562  		return b
   563  	}
   564  	return buf
   565  }
   566  
   567  // PackRR packs a resource record rr into msg[off:].
   568  // See PackDomainName for documentation about the compression.
   569  func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
   570  	headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
   571  	if err == nil {
   572  		// packRR no longer sets the Rdlength field on the rr, but
   573  		// callers might be expecting it so we set it here.
   574  		rr.Header().Rdlength = uint16(off1 - headerEnd)
   575  	}
   576  	return off1, err
   577  }
   578  
   579  func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
   580  	if rr == nil {
   581  		return len(msg), len(msg), &Error{err: "nil rr"}
   582  	}
   583  
   584  	headerEnd, err = rr.Header().packHeader(msg, off, compression, compress)
   585  	if err != nil {
   586  		return headerEnd, len(msg), err
   587  	}
   588  
   589  	off1, err = rr.pack(msg, headerEnd, compression, compress)
   590  	if err != nil {
   591  		return headerEnd, len(msg), err
   592  	}
   593  
   594  	rdlength := off1 - headerEnd
   595  	if int(uint16(rdlength)) != rdlength { // overflow
   596  		return headerEnd, len(msg), ErrRdata
   597  	}
   598  
   599  	// The RDLENGTH field is the last field in the header and we set it here.
   600  	binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
   601  	return headerEnd, off1, nil
   602  }
   603  
   604  // UnpackRR unpacks msg[off:] into an RR.
   605  func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
   606  	h, off, msg, err := unpackHeader(msg, off)
   607  	if err != nil {
   608  		return nil, len(msg), err
   609  	}
   610  
   611  	return UnpackRRWithHeader(h, msg, off)
   612  }
   613  
   614  // UnpackRRWithHeader unpacks the record type specific payload given an existing
   615  // RR_Header.
   616  func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
   617  	if newFn, ok := TypeToRR[h.Rrtype]; ok {
   618  		rr = newFn()
   619  		*rr.Header() = h
   620  	} else {
   621  		rr = &RFC3597{Hdr: h}
   622  	}
   623  
   624  	if off < 0 || off > len(msg) {
   625  		return &h, off, &Error{err: "bad off"}
   626  	}
   627  
   628  	end := off + int(h.Rdlength)
   629  	if end < off || end > len(msg) {
   630  		return &h, end, &Error{err: "bad rdlength"}
   631  	}
   632  
   633  	if noRdata(h) {
   634  		return rr, off, nil
   635  	}
   636  
   637  	off, err = rr.unpack(msg, off)
   638  	if err != nil {
   639  		return nil, end, err
   640  	}
   641  	if off != end {
   642  		return &h, end, &Error{err: "bad rdlength"}
   643  	}
   644  
   645  	return rr, off, nil
   646  }
   647  
   648  // unpackRRslice unpacks msg[off:] into an []RR.
   649  // If we cannot unpack the whole array, then it will return nil
   650  func unpackRRslice(l int, msg []byte, off int) (dst1 []RR, off1 int, err error) {
   651  	var r RR
   652  	// Don't pre-allocate, l may be under attacker control
   653  	var dst []RR
   654  	for i := 0; i < l; i++ {
   655  		off1 := off
   656  		r, off, err = UnpackRR(msg, off)
   657  		if err != nil {
   658  			off = len(msg)
   659  			break
   660  		}
   661  		// If offset does not increase anymore, l is a lie
   662  		if off1 == off {
   663  			break
   664  		}
   665  		dst = append(dst, r)
   666  	}
   667  	if err != nil && off == len(msg) {
   668  		dst = nil
   669  	}
   670  	return dst, off, err
   671  }
   672  
   673  // Convert a MsgHdr to a string, with dig-like headers:
   674  //
   675  // ;; opcode: QUERY, status: NOERROR, id: 48404
   676  //
   677  // ;; flags: qr aa rd ra;
   678  func (h *MsgHdr) String() string {
   679  	if h == nil {
   680  		return "<nil> MsgHdr"
   681  	}
   682  
   683  	s := ";; opcode: " + OpcodeToString[h.Opcode]
   684  	s += ", status: " + RcodeToString[h.Rcode]
   685  	s += ", id: " + strconv.Itoa(int(h.Id)) + "\n"
   686  
   687  	s += ";; flags:"
   688  	if h.Response {
   689  		s += " qr"
   690  	}
   691  	if h.Authoritative {
   692  		s += " aa"
   693  	}
   694  	if h.Truncated {
   695  		s += " tc"
   696  	}
   697  	if h.RecursionDesired {
   698  		s += " rd"
   699  	}
   700  	if h.RecursionAvailable {
   701  		s += " ra"
   702  	}
   703  	if h.Zero { // Hmm
   704  		s += " z"
   705  	}
   706  	if h.AuthenticatedData {
   707  		s += " ad"
   708  	}
   709  	if h.CheckingDisabled {
   710  		s += " cd"
   711  	}
   712  
   713  	s += ";"
   714  	return s
   715  }
   716  
   717  // Pack packs a Msg: it is converted to to wire format.
   718  // If the dns.Compress is true the message will be in compressed wire format.
   719  func (dns *Msg) Pack() (msg []byte, err error) {
   720  	return dns.PackBuffer(nil)
   721  }
   722  
   723  // PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
   724  func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
   725  	// If this message can't be compressed, avoid filling the
   726  	// compression map and creating garbage.
   727  	if dns.Compress && dns.isCompressible() {
   728  		compression := make(map[string]uint16) // Compression pointer mappings.
   729  		return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
   730  	}
   731  
   732  	return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
   733  }
   734  
   735  // packBufferWithCompressionMap packs a Msg, using the given buffer buf.
   736  func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) {
   737  	if dns.Rcode < 0 || dns.Rcode > 0xFFF {
   738  		return nil, ErrRcode
   739  	}
   740  
   741  	// Set extended rcode unconditionally if we have an opt, this will allow
   742  	// resetting the extended rcode bits if they need to.
   743  	if opt := dns.IsEdns0(); opt != nil {
   744  		opt.SetExtendedRcode(uint16(dns.Rcode))
   745  	} else if dns.Rcode > 0xF {
   746  		// If Rcode is an extended one and opt is nil, error out.
   747  		return nil, ErrExtendedRcode
   748  	}
   749  
   750  	// Convert convenient Msg into wire-like Header.
   751  	var dh Header
   752  	dh.Id = dns.Id
   753  	dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
   754  	if dns.Response {
   755  		dh.Bits |= _QR
   756  	}
   757  	if dns.Authoritative {
   758  		dh.Bits |= _AA
   759  	}
   760  	if dns.Truncated {
   761  		dh.Bits |= _TC
   762  	}
   763  	if dns.RecursionDesired {
   764  		dh.Bits |= _RD
   765  	}
   766  	if dns.RecursionAvailable {
   767  		dh.Bits |= _RA
   768  	}
   769  	if dns.Zero {
   770  		dh.Bits |= _Z
   771  	}
   772  	if dns.AuthenticatedData {
   773  		dh.Bits |= _AD
   774  	}
   775  	if dns.CheckingDisabled {
   776  		dh.Bits |= _CD
   777  	}
   778  
   779  	dh.Qdcount = uint16(len(dns.Question))
   780  	dh.Ancount = uint16(len(dns.Answer))
   781  	dh.Nscount = uint16(len(dns.Ns))
   782  	dh.Arcount = uint16(len(dns.Extra))
   783  
   784  	// We need the uncompressed length here, because we first pack it and then compress it.
   785  	msg = buf
   786  	uncompressedLen := msgLenWithCompressionMap(dns, nil)
   787  	if packLen := uncompressedLen + 1; len(msg) < packLen {
   788  		msg = make([]byte, packLen)
   789  	}
   790  
   791  	// Pack it in: header and then the pieces.
   792  	off := 0
   793  	off, err = dh.pack(msg, off, compression, compress)
   794  	if err != nil {
   795  		return nil, err
   796  	}
   797  	for _, r := range dns.Question {
   798  		off, err = r.pack(msg, off, compression, compress)
   799  		if err != nil {
   800  			return nil, err
   801  		}
   802  	}
   803  	for _, r := range dns.Answer {
   804  		_, off, err = packRR(r, msg, off, compression, compress)
   805  		if err != nil {
   806  			return nil, err
   807  		}
   808  	}
   809  	for _, r := range dns.Ns {
   810  		_, off, err = packRR(r, msg, off, compression, compress)
   811  		if err != nil {
   812  			return nil, err
   813  		}
   814  	}
   815  	for _, r := range dns.Extra {
   816  		_, off, err = packRR(r, msg, off, compression, compress)
   817  		if err != nil {
   818  			return nil, err
   819  		}
   820  	}
   821  	return msg[:off], nil
   822  }
   823  
   824  func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
   825  	// If we are at the end of the message we should return *just* the
   826  	// header. This can still be useful to the caller. 9.9.9.9 sends these
   827  	// when responding with REFUSED for instance.
   828  	if off == len(msg) {
   829  		// reset sections before returning
   830  		dns.Question, dns.Answer, dns.Ns, dns.Extra = nil, nil, nil, nil
   831  		return nil
   832  	}
   833  
   834  	// Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are
   835  	// attacker controlled. This means we can't use them to pre-allocate
   836  	// slices.
   837  	dns.Question = nil
   838  	for i := 0; i < int(dh.Qdcount); i++ {
   839  		off1 := off
   840  		var q Question
   841  		q, off, err = unpackQuestion(msg, off)
   842  		if err != nil {
   843  			return err
   844  		}
   845  		if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
   846  			dh.Qdcount = uint16(i)
   847  			break
   848  		}
   849  		dns.Question = append(dns.Question, q)
   850  	}
   851  
   852  	dns.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off)
   853  	// The header counts might have been wrong so we need to update it
   854  	dh.Ancount = uint16(len(dns.Answer))
   855  	if err == nil {
   856  		dns.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off)
   857  	}
   858  	// The header counts might have been wrong so we need to update it
   859  	dh.Nscount = uint16(len(dns.Ns))
   860  	if err == nil {
   861  		dns.Extra, _, err = unpackRRslice(int(dh.Arcount), msg, off)
   862  	}
   863  	// The header counts might have been wrong so we need to update it
   864  	dh.Arcount = uint16(len(dns.Extra))
   865  
   866  	// Set extended Rcode
   867  	if opt := dns.IsEdns0(); opt != nil {
   868  		dns.Rcode |= opt.ExtendedRcode()
   869  	}
   870  
   871  	// TODO(miek) make this an error?
   872  	// use PackOpt to let people tell how detailed the error reporting should be?
   873  	// if off != len(msg) {
   874  	// 	// println("dns: extra bytes in dns packet", off, "<", len(msg))
   875  	// }
   876  	return err
   877  
   878  }
   879  
   880  // Unpack unpacks a binary message to a Msg structure.
   881  func (dns *Msg) Unpack(msg []byte) (err error) {
   882  	dh, off, err := unpackMsgHdr(msg, 0)
   883  	if err != nil {
   884  		return err
   885  	}
   886  
   887  	dns.setHdr(dh)
   888  	return dns.unpack(dh, msg, off)
   889  }
   890  
   891  // Convert a complete message to a string with dig-like output.
   892  func (dns *Msg) String() string {
   893  	if dns == nil {
   894  		return "<nil> MsgHdr"
   895  	}
   896  	s := dns.MsgHdr.String() + " "
   897  	if dns.MsgHdr.Opcode == OpcodeUpdate {
   898  		s += "ZONE: " + strconv.Itoa(len(dns.Question)) + ", "
   899  		s += "PREREQ: " + strconv.Itoa(len(dns.Answer)) + ", "
   900  		s += "UPDATE: " + strconv.Itoa(len(dns.Ns)) + ", "
   901  		s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
   902  	} else {
   903  		s += "QUERY: " + strconv.Itoa(len(dns.Question)) + ", "
   904  		s += "ANSWER: " + strconv.Itoa(len(dns.Answer)) + ", "
   905  		s += "AUTHORITY: " + strconv.Itoa(len(dns.Ns)) + ", "
   906  		s += "ADDITIONAL: " + strconv.Itoa(len(dns.Extra)) + "\n"
   907  	}
   908  	opt := dns.IsEdns0()
   909  	if opt != nil {
   910  		// OPT PSEUDOSECTION
   911  		s += opt.String() + "\n"
   912  	}
   913  	if len(dns.Question) > 0 {
   914  		if dns.MsgHdr.Opcode == OpcodeUpdate {
   915  			s += "\n;; ZONE SECTION:\n"
   916  		} else {
   917  			s += "\n;; QUESTION SECTION:\n"
   918  		}
   919  		for _, r := range dns.Question {
   920  			s += r.String() + "\n"
   921  		}
   922  	}
   923  	if len(dns.Answer) > 0 {
   924  		if dns.MsgHdr.Opcode == OpcodeUpdate {
   925  			s += "\n;; PREREQUISITE SECTION:\n"
   926  		} else {
   927  			s += "\n;; ANSWER SECTION:\n"
   928  		}
   929  		for _, r := range dns.Answer {
   930  			if r != nil {
   931  				s += r.String() + "\n"
   932  			}
   933  		}
   934  	}
   935  	if len(dns.Ns) > 0 {
   936  		if dns.MsgHdr.Opcode == OpcodeUpdate {
   937  			s += "\n;; UPDATE SECTION:\n"
   938  		} else {
   939  			s += "\n;; AUTHORITY SECTION:\n"
   940  		}
   941  		for _, r := range dns.Ns {
   942  			if r != nil {
   943  				s += r.String() + "\n"
   944  			}
   945  		}
   946  	}
   947  	if len(dns.Extra) > 0 && (opt == nil || len(dns.Extra) > 1) {
   948  		s += "\n;; ADDITIONAL SECTION:\n"
   949  		for _, r := range dns.Extra {
   950  			if r != nil && r.Header().Rrtype != TypeOPT {
   951  				s += r.String() + "\n"
   952  			}
   953  		}
   954  	}
   955  	return s
   956  }
   957  
   958  // isCompressible returns whether the msg may be compressible.
   959  func (dns *Msg) isCompressible() bool {
   960  	// If we only have one question, there is nothing we can ever compress.
   961  	return len(dns.Question) > 1 || len(dns.Answer) > 0 ||
   962  		len(dns.Ns) > 0 || len(dns.Extra) > 0
   963  }
   964  
   965  // Len returns the message length when in (un)compressed wire format.
   966  // If dns.Compress is true compression it is taken into account. Len()
   967  // is provided to be a faster way to get the size of the resulting packet,
   968  // than packing it, measuring the size and discarding the buffer.
   969  func (dns *Msg) Len() int {
   970  	// If this message can't be compressed, avoid filling the
   971  	// compression map and creating garbage.
   972  	if dns.Compress && dns.isCompressible() {
   973  		compression := make(map[string]struct{})
   974  		return msgLenWithCompressionMap(dns, compression)
   975  	}
   976  
   977  	return msgLenWithCompressionMap(dns, nil)
   978  }
   979  
   980  func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
   981  	l := headerSize
   982  
   983  	for _, r := range dns.Question {
   984  		l += r.len(l, compression)
   985  	}
   986  	for _, r := range dns.Answer {
   987  		if r != nil {
   988  			l += r.len(l, compression)
   989  		}
   990  	}
   991  	for _, r := range dns.Ns {
   992  		if r != nil {
   993  			l += r.len(l, compression)
   994  		}
   995  	}
   996  	for _, r := range dns.Extra {
   997  		if r != nil {
   998  			l += r.len(l, compression)
   999  		}
  1000  	}
  1001  
  1002  	return l
  1003  }
  1004  
  1005  func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int {
  1006  	if s == "" || s == "." {
  1007  		return 1
  1008  	}
  1009  
  1010  	escaped := strings.Contains(s, "\\")
  1011  
  1012  	if compression != nil && (compress || off < maxCompressionOffset) {
  1013  		// compressionLenSearch will insert the entry into the compression
  1014  		// map if it doesn't contain it.
  1015  		if l, ok := compressionLenSearch(compression, s, off); ok && compress {
  1016  			if escaped {
  1017  				return escapedNameLen(s[:l]) + 2
  1018  			}
  1019  
  1020  			return l + 2
  1021  		}
  1022  	}
  1023  
  1024  	if escaped {
  1025  		return escapedNameLen(s) + 1
  1026  	}
  1027  
  1028  	return len(s) + 1
  1029  }
  1030  
  1031  func escapedNameLen(s string) int {
  1032  	nameLen := len(s)
  1033  	for i := 0; i < len(s); i++ {
  1034  		if s[i] != '\\' {
  1035  			continue
  1036  		}
  1037  
  1038  		if isDDD(s[i+1:]) {
  1039  			nameLen -= 3
  1040  			i += 3
  1041  		} else {
  1042  			nameLen--
  1043  			i++
  1044  		}
  1045  	}
  1046  
  1047  	return nameLen
  1048  }
  1049  
  1050  func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) {
  1051  	for off, end := 0, false; !end; off, end = NextLabel(s, off) {
  1052  		if _, ok := c[s[off:]]; ok {
  1053  			return off, true
  1054  		}
  1055  
  1056  		if msgOff+off < maxCompressionOffset {
  1057  			c[s[off:]] = struct{}{}
  1058  		}
  1059  	}
  1060  
  1061  	return 0, false
  1062  }
  1063  
  1064  // Copy returns a new RR which is a deep-copy of r.
  1065  func Copy(r RR) RR { return r.copy() }
  1066  
  1067  // Len returns the length (in octets) of the uncompressed RR in wire format.
  1068  func Len(r RR) int { return r.len(0, nil) }
  1069  
  1070  // Copy returns a new *Msg which is a deep-copy of dns.
  1071  func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
  1072  
  1073  // CopyTo copies the contents to the provided message using a deep-copy and returns the copy.
  1074  func (dns *Msg) CopyTo(r1 *Msg) *Msg {
  1075  	r1.MsgHdr = dns.MsgHdr
  1076  	r1.Compress = dns.Compress
  1077  
  1078  	if len(dns.Question) > 0 {
  1079  		// TODO(miek): Question is an immutable value, ok to do a shallow-copy
  1080  		r1.Question = cloneSlice(dns.Question)
  1081  	}
  1082  
  1083  	rrArr := make([]RR, len(dns.Answer)+len(dns.Ns)+len(dns.Extra))
  1084  	r1.Answer, rrArr = rrArr[:0:len(dns.Answer)], rrArr[len(dns.Answer):]
  1085  	r1.Ns, rrArr = rrArr[:0:len(dns.Ns)], rrArr[len(dns.Ns):]
  1086  	r1.Extra = rrArr[:0:len(dns.Extra)]
  1087  
  1088  	for _, r := range dns.Answer {
  1089  		r1.Answer = append(r1.Answer, r.copy())
  1090  	}
  1091  
  1092  	for _, r := range dns.Ns {
  1093  		r1.Ns = append(r1.Ns, r.copy())
  1094  	}
  1095  
  1096  	for _, r := range dns.Extra {
  1097  		r1.Extra = append(r1.Extra, r.copy())
  1098  	}
  1099  
  1100  	return r1
  1101  }
  1102  
  1103  func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
  1104  	off, err := packDomainName(q.Name, msg, off, compression, compress)
  1105  	if err != nil {
  1106  		return off, err
  1107  	}
  1108  	off, err = packUint16(q.Qtype, msg, off)
  1109  	if err != nil {
  1110  		return off, err
  1111  	}
  1112  	off, err = packUint16(q.Qclass, msg, off)
  1113  	if err != nil {
  1114  		return off, err
  1115  	}
  1116  	return off, nil
  1117  }
  1118  
  1119  func unpackQuestion(msg []byte, off int) (Question, int, error) {
  1120  	var (
  1121  		q   Question
  1122  		err error
  1123  	)
  1124  	q.Name, off, err = UnpackDomainName(msg, off)
  1125  	if err != nil {
  1126  		return q, off, err
  1127  	}
  1128  	if off == len(msg) {
  1129  		return q, off, nil
  1130  	}
  1131  	q.Qtype, off, err = unpackUint16(msg, off)
  1132  	if err != nil {
  1133  		return q, off, err
  1134  	}
  1135  	if off == len(msg) {
  1136  		return q, off, nil
  1137  	}
  1138  	q.Qclass, off, err = unpackUint16(msg, off)
  1139  	if off == len(msg) {
  1140  		return q, off, nil
  1141  	}
  1142  	return q, off, err
  1143  }
  1144  
  1145  func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
  1146  	off, err := packUint16(dh.Id, msg, off)
  1147  	if err != nil {
  1148  		return off, err
  1149  	}
  1150  	off, err = packUint16(dh.Bits, msg, off)
  1151  	if err != nil {
  1152  		return off, err
  1153  	}
  1154  	off, err = packUint16(dh.Qdcount, msg, off)
  1155  	if err != nil {
  1156  		return off, err
  1157  	}
  1158  	off, err = packUint16(dh.Ancount, msg, off)
  1159  	if err != nil {
  1160  		return off, err
  1161  	}
  1162  	off, err = packUint16(dh.Nscount, msg, off)
  1163  	if err != nil {
  1164  		return off, err
  1165  	}
  1166  	off, err = packUint16(dh.Arcount, msg, off)
  1167  	if err != nil {
  1168  		return off, err
  1169  	}
  1170  	return off, nil
  1171  }
  1172  
  1173  func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
  1174  	var (
  1175  		dh  Header
  1176  		err error
  1177  	)
  1178  	dh.Id, off, err = unpackUint16(msg, off)
  1179  	if err != nil {
  1180  		return dh, off, err
  1181  	}
  1182  	dh.Bits, off, err = unpackUint16(msg, off)
  1183  	if err != nil {
  1184  		return dh, off, err
  1185  	}
  1186  	dh.Qdcount, off, err = unpackUint16(msg, off)
  1187  	if err != nil {
  1188  		return dh, off, err
  1189  	}
  1190  	dh.Ancount, off, err = unpackUint16(msg, off)
  1191  	if err != nil {
  1192  		return dh, off, err
  1193  	}
  1194  	dh.Nscount, off, err = unpackUint16(msg, off)
  1195  	if err != nil {
  1196  		return dh, off, err
  1197  	}
  1198  	dh.Arcount, off, err = unpackUint16(msg, off)
  1199  	if err != nil {
  1200  		return dh, off, err
  1201  	}
  1202  	return dh, off, nil
  1203  }
  1204  
  1205  // setHdr set the header in the dns using the binary data in dh.
  1206  func (dns *Msg) setHdr(dh Header) {
  1207  	dns.Id = dh.Id
  1208  	dns.Response = dh.Bits&_QR != 0
  1209  	dns.Opcode = int(dh.Bits>>11) & 0xF
  1210  	dns.Authoritative = dh.Bits&_AA != 0
  1211  	dns.Truncated = dh.Bits&_TC != 0
  1212  	dns.RecursionDesired = dh.Bits&_RD != 0
  1213  	dns.RecursionAvailable = dh.Bits&_RA != 0
  1214  	dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
  1215  	dns.AuthenticatedData = dh.Bits&_AD != 0
  1216  	dns.CheckingDisabled = dh.Bits&_CD != 0
  1217  	dns.Rcode = int(dh.Bits & 0xF)
  1218  }
  1219  

View as plain text