...

Source file src/github.com/miekg/dns/msg_helpers.go

Documentation: github.com/miekg/dns

     1  package dns
     2  
     3  import (
     4  	"encoding/base32"
     5  	"encoding/base64"
     6  	"encoding/binary"
     7  	"encoding/hex"
     8  	"net"
     9  	"sort"
    10  	"strings"
    11  )
    12  
    13  // helper functions called from the generated zmsg.go
    14  
    15  // These function are named after the tag to help pack/unpack, if there is no tag it is the name
    16  // of the type they pack/unpack (string, int, etc). We prefix all with unpackData or packData, so packDataA or
    17  // packDataDomainName.
    18  
    19  func unpackDataA(msg []byte, off int) (net.IP, int, error) {
    20  	if off+net.IPv4len > len(msg) {
    21  		return nil, len(msg), &Error{err: "overflow unpacking a"}
    22  	}
    23  	return cloneSlice(msg[off : off+net.IPv4len]), off + net.IPv4len, nil
    24  }
    25  
    26  func packDataA(a net.IP, msg []byte, off int) (int, error) {
    27  	switch len(a) {
    28  	case net.IPv4len, net.IPv6len:
    29  		// It must be a slice of 4, even if it is 16, we encode only the first 4
    30  		if off+net.IPv4len > len(msg) {
    31  			return len(msg), &Error{err: "overflow packing a"}
    32  		}
    33  
    34  		copy(msg[off:], a.To4())
    35  		off += net.IPv4len
    36  	case 0:
    37  		// Allowed, for dynamic updates.
    38  	default:
    39  		return len(msg), &Error{err: "overflow packing a"}
    40  	}
    41  	return off, nil
    42  }
    43  
    44  func unpackDataAAAA(msg []byte, off int) (net.IP, int, error) {
    45  	if off+net.IPv6len > len(msg) {
    46  		return nil, len(msg), &Error{err: "overflow unpacking aaaa"}
    47  	}
    48  	return cloneSlice(msg[off : off+net.IPv6len]), off + net.IPv6len, nil
    49  }
    50  
    51  func packDataAAAA(aaaa net.IP, msg []byte, off int) (int, error) {
    52  	switch len(aaaa) {
    53  	case net.IPv6len:
    54  		if off+net.IPv6len > len(msg) {
    55  			return len(msg), &Error{err: "overflow packing aaaa"}
    56  		}
    57  
    58  		copy(msg[off:], aaaa)
    59  		off += net.IPv6len
    60  	case 0:
    61  		// Allowed, dynamic updates.
    62  	default:
    63  		return len(msg), &Error{err: "overflow packing aaaa"}
    64  	}
    65  	return off, nil
    66  }
    67  
    68  // unpackHeader unpacks an RR header, returning the offset to the end of the header and a
    69  // re-sliced msg according to the expected length of the RR.
    70  func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, err error) {
    71  	hdr := RR_Header{}
    72  	if off == len(msg) {
    73  		return hdr, off, msg, nil
    74  	}
    75  
    76  	hdr.Name, off, err = UnpackDomainName(msg, off)
    77  	if err != nil {
    78  		return hdr, len(msg), msg, err
    79  	}
    80  	hdr.Rrtype, off, err = unpackUint16(msg, off)
    81  	if err != nil {
    82  		return hdr, len(msg), msg, err
    83  	}
    84  	hdr.Class, off, err = unpackUint16(msg, off)
    85  	if err != nil {
    86  		return hdr, len(msg), msg, err
    87  	}
    88  	hdr.Ttl, off, err = unpackUint32(msg, off)
    89  	if err != nil {
    90  		return hdr, len(msg), msg, err
    91  	}
    92  	hdr.Rdlength, off, err = unpackUint16(msg, off)
    93  	if err != nil {
    94  		return hdr, len(msg), msg, err
    95  	}
    96  	msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength)
    97  	return hdr, off, msg, err
    98  }
    99  
   100  // packHeader packs an RR header, returning the offset to the end of the header.
   101  // See PackDomainName for documentation about the compression.
   102  func (hdr RR_Header) packHeader(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
   103  	if off == len(msg) {
   104  		return off, nil
   105  	}
   106  
   107  	off, err := packDomainName(hdr.Name, msg, off, compression, compress)
   108  	if err != nil {
   109  		return len(msg), err
   110  	}
   111  	off, err = packUint16(hdr.Rrtype, msg, off)
   112  	if err != nil {
   113  		return len(msg), err
   114  	}
   115  	off, err = packUint16(hdr.Class, msg, off)
   116  	if err != nil {
   117  		return len(msg), err
   118  	}
   119  	off, err = packUint32(hdr.Ttl, msg, off)
   120  	if err != nil {
   121  		return len(msg), err
   122  	}
   123  	off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR.
   124  	if err != nil {
   125  		return len(msg), err
   126  	}
   127  	return off, nil
   128  }
   129  
   130  // helper helper functions.
   131  
   132  // truncateMsgFromRdLength truncates msg to match the expected length of the RR.
   133  // Returns an error if msg is smaller than the expected size.
   134  func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []byte, err error) {
   135  	lenrd := off + int(rdlength)
   136  	if lenrd > len(msg) {
   137  		return msg, &Error{err: "overflowing header size"}
   138  	}
   139  	return msg[:lenrd], nil
   140  }
   141  
   142  var base32HexNoPadEncoding = base32.HexEncoding.WithPadding(base32.NoPadding)
   143  
   144  func fromBase32(s []byte) (buf []byte, err error) {
   145  	for i, b := range s {
   146  		if b >= 'a' && b <= 'z' {
   147  			s[i] = b - 32
   148  		}
   149  	}
   150  	buflen := base32HexNoPadEncoding.DecodedLen(len(s))
   151  	buf = make([]byte, buflen)
   152  	n, err := base32HexNoPadEncoding.Decode(buf, s)
   153  	buf = buf[:n]
   154  	return
   155  }
   156  
   157  func toBase32(b []byte) string {
   158  	return base32HexNoPadEncoding.EncodeToString(b)
   159  }
   160  
   161  func fromBase64(s []byte) (buf []byte, err error) {
   162  	buflen := base64.StdEncoding.DecodedLen(len(s))
   163  	buf = make([]byte, buflen)
   164  	n, err := base64.StdEncoding.Decode(buf, s)
   165  	buf = buf[:n]
   166  	return
   167  }
   168  
   169  func toBase64(b []byte) string { return base64.StdEncoding.EncodeToString(b) }
   170  
   171  // dynamicUpdate returns true if the Rdlength is zero.
   172  func noRdata(h RR_Header) bool { return h.Rdlength == 0 }
   173  
   174  func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) {
   175  	if off+1 > len(msg) {
   176  		return 0, len(msg), &Error{err: "overflow unpacking uint8"}
   177  	}
   178  	return msg[off], off + 1, nil
   179  }
   180  
   181  func packUint8(i uint8, msg []byte, off int) (off1 int, err error) {
   182  	if off+1 > len(msg) {
   183  		return len(msg), &Error{err: "overflow packing uint8"}
   184  	}
   185  	msg[off] = i
   186  	return off + 1, nil
   187  }
   188  
   189  func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) {
   190  	if off+2 > len(msg) {
   191  		return 0, len(msg), &Error{err: "overflow unpacking uint16"}
   192  	}
   193  	return binary.BigEndian.Uint16(msg[off:]), off + 2, nil
   194  }
   195  
   196  func packUint16(i uint16, msg []byte, off int) (off1 int, err error) {
   197  	if off+2 > len(msg) {
   198  		return len(msg), &Error{err: "overflow packing uint16"}
   199  	}
   200  	binary.BigEndian.PutUint16(msg[off:], i)
   201  	return off + 2, nil
   202  }
   203  
   204  func unpackUint32(msg []byte, off int) (i uint32, off1 int, err error) {
   205  	if off+4 > len(msg) {
   206  		return 0, len(msg), &Error{err: "overflow unpacking uint32"}
   207  	}
   208  	return binary.BigEndian.Uint32(msg[off:]), off + 4, nil
   209  }
   210  
   211  func packUint32(i uint32, msg []byte, off int) (off1 int, err error) {
   212  	if off+4 > len(msg) {
   213  		return len(msg), &Error{err: "overflow packing uint32"}
   214  	}
   215  	binary.BigEndian.PutUint32(msg[off:], i)
   216  	return off + 4, nil
   217  }
   218  
   219  func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) {
   220  	if off+6 > len(msg) {
   221  		return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"}
   222  	}
   223  	// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
   224  	i = uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
   225  		uint64(msg[off+4])<<8 | uint64(msg[off+5])
   226  	off += 6
   227  	return i, off, nil
   228  }
   229  
   230  func packUint48(i uint64, msg []byte, off int) (off1 int, err error) {
   231  	if off+6 > len(msg) {
   232  		return len(msg), &Error{err: "overflow packing uint64 as uint48"}
   233  	}
   234  	msg[off] = byte(i >> 40)
   235  	msg[off+1] = byte(i >> 32)
   236  	msg[off+2] = byte(i >> 24)
   237  	msg[off+3] = byte(i >> 16)
   238  	msg[off+4] = byte(i >> 8)
   239  	msg[off+5] = byte(i)
   240  	off += 6
   241  	return off, nil
   242  }
   243  
   244  func unpackUint64(msg []byte, off int) (i uint64, off1 int, err error) {
   245  	if off+8 > len(msg) {
   246  		return 0, len(msg), &Error{err: "overflow unpacking uint64"}
   247  	}
   248  	return binary.BigEndian.Uint64(msg[off:]), off + 8, nil
   249  }
   250  
   251  func packUint64(i uint64, msg []byte, off int) (off1 int, err error) {
   252  	if off+8 > len(msg) {
   253  		return len(msg), &Error{err: "overflow packing uint64"}
   254  	}
   255  	binary.BigEndian.PutUint64(msg[off:], i)
   256  	off += 8
   257  	return off, nil
   258  }
   259  
   260  func unpackString(msg []byte, off int) (string, int, error) {
   261  	if off+1 > len(msg) {
   262  		return "", off, &Error{err: "overflow unpacking txt"}
   263  	}
   264  	l := int(msg[off])
   265  	off++
   266  	if off+l > len(msg) {
   267  		return "", off, &Error{err: "overflow unpacking txt"}
   268  	}
   269  	var s strings.Builder
   270  	consumed := 0
   271  	for i, b := range msg[off : off+l] {
   272  		switch {
   273  		case b == '"' || b == '\\':
   274  			if consumed == 0 {
   275  				s.Grow(l * 2)
   276  			}
   277  			s.Write(msg[off+consumed : off+i])
   278  			s.WriteByte('\\')
   279  			s.WriteByte(b)
   280  			consumed = i + 1
   281  		case b < ' ' || b > '~': // unprintable
   282  			if consumed == 0 {
   283  				s.Grow(l * 2)
   284  			}
   285  			s.Write(msg[off+consumed : off+i])
   286  			s.WriteString(escapeByte(b))
   287  			consumed = i + 1
   288  		}
   289  	}
   290  	if consumed == 0 { // no escaping needed
   291  		return string(msg[off : off+l]), off + l, nil
   292  	}
   293  	s.Write(msg[off+consumed : off+l])
   294  	return s.String(), off + l, nil
   295  }
   296  
   297  func packString(s string, msg []byte, off int) (int, error) {
   298  	off, err := packTxtString(s, msg, off)
   299  	if err != nil {
   300  		return len(msg), err
   301  	}
   302  	return off, nil
   303  }
   304  
   305  func unpackStringBase32(msg []byte, off, end int) (string, int, error) {
   306  	if end > len(msg) {
   307  		return "", len(msg), &Error{err: "overflow unpacking base32"}
   308  	}
   309  	s := toBase32(msg[off:end])
   310  	return s, end, nil
   311  }
   312  
   313  func packStringBase32(s string, msg []byte, off int) (int, error) {
   314  	b32, err := fromBase32([]byte(s))
   315  	if err != nil {
   316  		return len(msg), err
   317  	}
   318  	if off+len(b32) > len(msg) {
   319  		return len(msg), &Error{err: "overflow packing base32"}
   320  	}
   321  	copy(msg[off:off+len(b32)], b32)
   322  	off += len(b32)
   323  	return off, nil
   324  }
   325  
   326  func unpackStringBase64(msg []byte, off, end int) (string, int, error) {
   327  	// Rest of the RR is base64 encoded value, so we don't need an explicit length
   328  	// to be set. Thus far all RR's that have base64 encoded fields have those as their
   329  	// last one. What we do need is the end of the RR!
   330  	if end > len(msg) {
   331  		return "", len(msg), &Error{err: "overflow unpacking base64"}
   332  	}
   333  	s := toBase64(msg[off:end])
   334  	return s, end, nil
   335  }
   336  
   337  func packStringBase64(s string, msg []byte, off int) (int, error) {
   338  	b64, err := fromBase64([]byte(s))
   339  	if err != nil {
   340  		return len(msg), err
   341  	}
   342  	if off+len(b64) > len(msg) {
   343  		return len(msg), &Error{err: "overflow packing base64"}
   344  	}
   345  	copy(msg[off:off+len(b64)], b64)
   346  	off += len(b64)
   347  	return off, nil
   348  }
   349  
   350  func unpackStringHex(msg []byte, off, end int) (string, int, error) {
   351  	// Rest of the RR is hex encoded value, so we don't need an explicit length
   352  	// to be set. NSEC and TSIG have hex fields with a length field.
   353  	// What we do need is the end of the RR!
   354  	if end > len(msg) {
   355  		return "", len(msg), &Error{err: "overflow unpacking hex"}
   356  	}
   357  
   358  	s := hex.EncodeToString(msg[off:end])
   359  	return s, end, nil
   360  }
   361  
   362  func packStringHex(s string, msg []byte, off int) (int, error) {
   363  	h, err := hex.DecodeString(s)
   364  	if err != nil {
   365  		return len(msg), err
   366  	}
   367  	if off+len(h) > len(msg) {
   368  		return len(msg), &Error{err: "overflow packing hex"}
   369  	}
   370  	copy(msg[off:off+len(h)], h)
   371  	off += len(h)
   372  	return off, nil
   373  }
   374  
   375  func unpackStringAny(msg []byte, off, end int) (string, int, error) {
   376  	if end > len(msg) {
   377  		return "", len(msg), &Error{err: "overflow unpacking anything"}
   378  	}
   379  	return string(msg[off:end]), end, nil
   380  }
   381  
   382  func packStringAny(s string, msg []byte, off int) (int, error) {
   383  	if off+len(s) > len(msg) {
   384  		return len(msg), &Error{err: "overflow packing anything"}
   385  	}
   386  	copy(msg[off:off+len(s)], s)
   387  	off += len(s)
   388  	return off, nil
   389  }
   390  
   391  func unpackStringTxt(msg []byte, off int) ([]string, int, error) {
   392  	txt, off, err := unpackTxt(msg, off)
   393  	if err != nil {
   394  		return nil, len(msg), err
   395  	}
   396  	return txt, off, nil
   397  }
   398  
   399  func packStringTxt(s []string, msg []byte, off int) (int, error) {
   400  	off, err := packTxt(s, msg, off)
   401  	if err != nil {
   402  		return len(msg), err
   403  	}
   404  	return off, nil
   405  }
   406  
   407  func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) {
   408  	var edns []EDNS0
   409  	for off < len(msg) {
   410  		if off+4 > len(msg) {
   411  			return nil, len(msg), &Error{err: "overflow unpacking opt"}
   412  		}
   413  		code := binary.BigEndian.Uint16(msg[off:])
   414  		off += 2
   415  		optlen := binary.BigEndian.Uint16(msg[off:])
   416  		off += 2
   417  		if off+int(optlen) > len(msg) {
   418  			return nil, len(msg), &Error{err: "overflow unpacking opt"}
   419  		}
   420  		opt := makeDataOpt(code)
   421  		if err := opt.unpack(msg[off : off+int(optlen)]); err != nil {
   422  			return nil, len(msg), err
   423  		}
   424  		edns = append(edns, opt)
   425  		off += int(optlen)
   426  	}
   427  	return edns, off, nil
   428  }
   429  
   430  func packDataOpt(options []EDNS0, msg []byte, off int) (int, error) {
   431  	for _, el := range options {
   432  		b, err := el.pack()
   433  		if err != nil || off+4 > len(msg) {
   434  			return len(msg), &Error{err: "overflow packing opt"}
   435  		}
   436  		binary.BigEndian.PutUint16(msg[off:], el.Option())      // Option code
   437  		binary.BigEndian.PutUint16(msg[off+2:], uint16(len(b))) // Length
   438  		off += 4
   439  		if off+len(b) > len(msg) {
   440  			return len(msg), &Error{err: "overflow packing opt"}
   441  		}
   442  		// Actual data
   443  		copy(msg[off:off+len(b)], b)
   444  		off += len(b)
   445  	}
   446  	return off, nil
   447  }
   448  
   449  func unpackStringOctet(msg []byte, off int) (string, int, error) {
   450  	s := string(msg[off:])
   451  	return s, len(msg), nil
   452  }
   453  
   454  func packStringOctet(s string, msg []byte, off int) (int, error) {
   455  	off, err := packOctetString(s, msg, off)
   456  	if err != nil {
   457  		return len(msg), err
   458  	}
   459  	return off, nil
   460  }
   461  
   462  func unpackDataNsec(msg []byte, off int) ([]uint16, int, error) {
   463  	var nsec []uint16
   464  	length, window, lastwindow := 0, 0, -1
   465  	for off < len(msg) {
   466  		if off+2 > len(msg) {
   467  			return nsec, len(msg), &Error{err: "overflow unpacking NSEC(3)"}
   468  		}
   469  		window = int(msg[off])
   470  		length = int(msg[off+1])
   471  		off += 2
   472  		if window <= lastwindow {
   473  			// RFC 4034: Blocks are present in the NSEC RR RDATA in
   474  			// increasing numerical order.
   475  			return nsec, len(msg), &Error{err: "out of order NSEC(3) block in type bitmap"}
   476  		}
   477  		if length == 0 {
   478  			// RFC 4034: Blocks with no types present MUST NOT be included.
   479  			return nsec, len(msg), &Error{err: "empty NSEC(3) block in type bitmap"}
   480  		}
   481  		if length > 32 {
   482  			return nsec, len(msg), &Error{err: "NSEC(3) block too long in type bitmap"}
   483  		}
   484  		if off+length > len(msg) {
   485  			return nsec, len(msg), &Error{err: "overflowing NSEC(3) block in type bitmap"}
   486  		}
   487  
   488  		// Walk the bytes in the window and extract the type bits
   489  		for j, b := range msg[off : off+length] {
   490  			// Check the bits one by one, and set the type
   491  			if b&0x80 == 0x80 {
   492  				nsec = append(nsec, uint16(window*256+j*8+0))
   493  			}
   494  			if b&0x40 == 0x40 {
   495  				nsec = append(nsec, uint16(window*256+j*8+1))
   496  			}
   497  			if b&0x20 == 0x20 {
   498  				nsec = append(nsec, uint16(window*256+j*8+2))
   499  			}
   500  			if b&0x10 == 0x10 {
   501  				nsec = append(nsec, uint16(window*256+j*8+3))
   502  			}
   503  			if b&0x8 == 0x8 {
   504  				nsec = append(nsec, uint16(window*256+j*8+4))
   505  			}
   506  			if b&0x4 == 0x4 {
   507  				nsec = append(nsec, uint16(window*256+j*8+5))
   508  			}
   509  			if b&0x2 == 0x2 {
   510  				nsec = append(nsec, uint16(window*256+j*8+6))
   511  			}
   512  			if b&0x1 == 0x1 {
   513  				nsec = append(nsec, uint16(window*256+j*8+7))
   514  			}
   515  		}
   516  		off += length
   517  		lastwindow = window
   518  	}
   519  	return nsec, off, nil
   520  }
   521  
   522  // typeBitMapLen is a helper function which computes the "maximum" length of
   523  // a the NSEC Type BitMap field.
   524  func typeBitMapLen(bitmap []uint16) int {
   525  	var l int
   526  	var lastwindow, lastlength uint16
   527  	for _, t := range bitmap {
   528  		window := t / 256
   529  		length := (t-window*256)/8 + 1
   530  		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
   531  			l += int(lastlength) + 2
   532  			lastlength = 0
   533  		}
   534  		if window < lastwindow || length < lastlength {
   535  			// packDataNsec would return Error{err: "nsec bits out of order"} here, but
   536  			// when computing the length, we want do be liberal.
   537  			continue
   538  		}
   539  		lastwindow, lastlength = window, length
   540  	}
   541  	l += int(lastlength) + 2
   542  	return l
   543  }
   544  
   545  func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
   546  	if len(bitmap) == 0 {
   547  		return off, nil
   548  	}
   549  	if off > len(msg) {
   550  		return off, &Error{err: "overflow packing nsec"}
   551  	}
   552  	toZero := msg[off:]
   553  	if maxLen := typeBitMapLen(bitmap); maxLen < len(toZero) {
   554  		toZero = toZero[:maxLen]
   555  	}
   556  	for i := range toZero {
   557  		toZero[i] = 0
   558  	}
   559  	var lastwindow, lastlength uint16
   560  	for _, t := range bitmap {
   561  		window := t / 256
   562  		length := (t-window*256)/8 + 1
   563  		if window > lastwindow && lastlength != 0 { // New window, jump to the new offset
   564  			off += int(lastlength) + 2
   565  			lastlength = 0
   566  		}
   567  		if window < lastwindow || length < lastlength {
   568  			return len(msg), &Error{err: "nsec bits out of order"}
   569  		}
   570  		if off+2+int(length) > len(msg) {
   571  			return len(msg), &Error{err: "overflow packing nsec"}
   572  		}
   573  		// Setting the window #
   574  		msg[off] = byte(window)
   575  		// Setting the octets length
   576  		msg[off+1] = byte(length)
   577  		// Setting the bit value for the type in the right octet
   578  		msg[off+1+int(length)] |= byte(1 << (7 - t%8))
   579  		lastwindow, lastlength = window, length
   580  	}
   581  	off += int(lastlength) + 2
   582  	return off, nil
   583  }
   584  
   585  func unpackDataSVCB(msg []byte, off int) ([]SVCBKeyValue, int, error) {
   586  	var xs []SVCBKeyValue
   587  	var code uint16
   588  	var length uint16
   589  	var err error
   590  	for off < len(msg) {
   591  		code, off, err = unpackUint16(msg, off)
   592  		if err != nil {
   593  			return nil, len(msg), &Error{err: "overflow unpacking SVCB"}
   594  		}
   595  		length, off, err = unpackUint16(msg, off)
   596  		if err != nil || off+int(length) > len(msg) {
   597  			return nil, len(msg), &Error{err: "overflow unpacking SVCB"}
   598  		}
   599  		e := makeSVCBKeyValue(SVCBKey(code))
   600  		if e == nil {
   601  			return nil, len(msg), &Error{err: "bad SVCB key"}
   602  		}
   603  		if err := e.unpack(msg[off : off+int(length)]); err != nil {
   604  			return nil, len(msg), err
   605  		}
   606  		if len(xs) > 0 && e.Key() <= xs[len(xs)-1].Key() {
   607  			return nil, len(msg), &Error{err: "SVCB keys not in strictly increasing order"}
   608  		}
   609  		xs = append(xs, e)
   610  		off += int(length)
   611  	}
   612  	return xs, off, nil
   613  }
   614  
   615  func packDataSVCB(pairs []SVCBKeyValue, msg []byte, off int) (int, error) {
   616  	pairs = cloneSlice(pairs)
   617  	sort.Slice(pairs, func(i, j int) bool {
   618  		return pairs[i].Key() < pairs[j].Key()
   619  	})
   620  	prev := svcb_RESERVED
   621  	for _, el := range pairs {
   622  		if el.Key() == prev {
   623  			return len(msg), &Error{err: "repeated SVCB keys are not allowed"}
   624  		}
   625  		prev = el.Key()
   626  		packed, err := el.pack()
   627  		if err != nil {
   628  			return len(msg), err
   629  		}
   630  		off, err = packUint16(uint16(el.Key()), msg, off)
   631  		if err != nil {
   632  			return len(msg), &Error{err: "overflow packing SVCB"}
   633  		}
   634  		off, err = packUint16(uint16(len(packed)), msg, off)
   635  		if err != nil || off+len(packed) > len(msg) {
   636  			return len(msg), &Error{err: "overflow packing SVCB"}
   637  		}
   638  		copy(msg[off:off+len(packed)], packed)
   639  		off += len(packed)
   640  	}
   641  	return off, nil
   642  }
   643  
   644  func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
   645  	var (
   646  		servers []string
   647  		s       string
   648  		err     error
   649  	)
   650  	if end > len(msg) {
   651  		return nil, len(msg), &Error{err: "overflow unpacking domain names"}
   652  	}
   653  	for off < end {
   654  		s, off, err = UnpackDomainName(msg, off)
   655  		if err != nil {
   656  			return servers, len(msg), err
   657  		}
   658  		servers = append(servers, s)
   659  	}
   660  	return servers, off, nil
   661  }
   662  
   663  func packDataDomainNames(names []string, msg []byte, off int, compression compressionMap, compress bool) (int, error) {
   664  	var err error
   665  	for _, name := range names {
   666  		off, err = packDomainName(name, msg, off, compression, compress)
   667  		if err != nil {
   668  			return len(msg), err
   669  		}
   670  	}
   671  	return off, nil
   672  }
   673  
   674  func packDataApl(data []APLPrefix, msg []byte, off int) (int, error) {
   675  	var err error
   676  	for i := range data {
   677  		off, err = packDataAplPrefix(&data[i], msg, off)
   678  		if err != nil {
   679  			return len(msg), err
   680  		}
   681  	}
   682  	return off, nil
   683  }
   684  
   685  func packDataAplPrefix(p *APLPrefix, msg []byte, off int) (int, error) {
   686  	if len(p.Network.IP) != len(p.Network.Mask) {
   687  		return len(msg), &Error{err: "address and mask lengths don't match"}
   688  	}
   689  
   690  	var err error
   691  	prefix, _ := p.Network.Mask.Size()
   692  	addr := p.Network.IP.Mask(p.Network.Mask)[:(prefix+7)/8]
   693  
   694  	switch len(p.Network.IP) {
   695  	case net.IPv4len:
   696  		off, err = packUint16(1, msg, off)
   697  	case net.IPv6len:
   698  		off, err = packUint16(2, msg, off)
   699  	default:
   700  		err = &Error{err: "unrecognized address family"}
   701  	}
   702  	if err != nil {
   703  		return len(msg), err
   704  	}
   705  
   706  	off, err = packUint8(uint8(prefix), msg, off)
   707  	if err != nil {
   708  		return len(msg), err
   709  	}
   710  
   711  	var n uint8
   712  	if p.Negation {
   713  		n = 0x80
   714  	}
   715  
   716  	// trim trailing zero bytes as specified in RFC3123 Sections 4.1 and 4.2.
   717  	i := len(addr) - 1
   718  	for ; i >= 0 && addr[i] == 0; i-- {
   719  	}
   720  	addr = addr[:i+1]
   721  
   722  	adflen := uint8(len(addr)) & 0x7f
   723  	off, err = packUint8(n|adflen, msg, off)
   724  	if err != nil {
   725  		return len(msg), err
   726  	}
   727  
   728  	if off+len(addr) > len(msg) {
   729  		return len(msg), &Error{err: "overflow packing APL prefix"}
   730  	}
   731  	off += copy(msg[off:], addr)
   732  
   733  	return off, nil
   734  }
   735  
   736  func unpackDataApl(msg []byte, off int) ([]APLPrefix, int, error) {
   737  	var result []APLPrefix
   738  	for off < len(msg) {
   739  		prefix, end, err := unpackDataAplPrefix(msg, off)
   740  		if err != nil {
   741  			return nil, len(msg), err
   742  		}
   743  		off = end
   744  		result = append(result, prefix)
   745  	}
   746  	return result, off, nil
   747  }
   748  
   749  func unpackDataAplPrefix(msg []byte, off int) (APLPrefix, int, error) {
   750  	family, off, err := unpackUint16(msg, off)
   751  	if err != nil {
   752  		return APLPrefix{}, len(msg), &Error{err: "overflow unpacking APL prefix"}
   753  	}
   754  	prefix, off, err := unpackUint8(msg, off)
   755  	if err != nil {
   756  		return APLPrefix{}, len(msg), &Error{err: "overflow unpacking APL prefix"}
   757  	}
   758  	nlen, off, err := unpackUint8(msg, off)
   759  	if err != nil {
   760  		return APLPrefix{}, len(msg), &Error{err: "overflow unpacking APL prefix"}
   761  	}
   762  
   763  	var ip []byte
   764  	switch family {
   765  	case 1:
   766  		ip = make([]byte, net.IPv4len)
   767  	case 2:
   768  		ip = make([]byte, net.IPv6len)
   769  	default:
   770  		return APLPrefix{}, len(msg), &Error{err: "unrecognized APL address family"}
   771  	}
   772  	if int(prefix) > 8*len(ip) {
   773  		return APLPrefix{}, len(msg), &Error{err: "APL prefix too long"}
   774  	}
   775  	afdlen := int(nlen & 0x7f)
   776  	if afdlen > len(ip) {
   777  		return APLPrefix{}, len(msg), &Error{err: "APL length too long"}
   778  	}
   779  	if off+afdlen > len(msg) {
   780  		return APLPrefix{}, len(msg), &Error{err: "overflow unpacking APL address"}
   781  	}
   782  
   783  	// Address MUST NOT contain trailing zero bytes per RFC3123 Sections 4.1 and 4.2.
   784  	off += copy(ip, msg[off:off+afdlen])
   785  	if afdlen > 0 {
   786  		last := ip[afdlen-1]
   787  		if last == 0 {
   788  			return APLPrefix{}, len(msg), &Error{err: "extra APL address bits"}
   789  		}
   790  	}
   791  	ipnet := net.IPNet{
   792  		IP:   ip,
   793  		Mask: net.CIDRMask(int(prefix), 8*len(ip)),
   794  	}
   795  
   796  	return APLPrefix{
   797  		Negation: (nlen & 0x80) != 0,
   798  		Network:  ipnet,
   799  	}, off, nil
   800  }
   801  
   802  func unpackIPSECGateway(msg []byte, off int, gatewayType uint8) (net.IP, string, int, error) {
   803  	var retAddr net.IP
   804  	var retString string
   805  	var err error
   806  
   807  	switch gatewayType {
   808  	case IPSECGatewayNone: // do nothing
   809  	case IPSECGatewayIPv4:
   810  		retAddr, off, err = unpackDataA(msg, off)
   811  	case IPSECGatewayIPv6:
   812  		retAddr, off, err = unpackDataAAAA(msg, off)
   813  	case IPSECGatewayHost:
   814  		retString, off, err = UnpackDomainName(msg, off)
   815  	}
   816  
   817  	return retAddr, retString, off, err
   818  }
   819  
   820  func packIPSECGateway(gatewayAddr net.IP, gatewayString string, msg []byte, off int, gatewayType uint8, compression compressionMap, compress bool) (int, error) {
   821  	var err error
   822  
   823  	switch gatewayType {
   824  	case IPSECGatewayNone: // do nothing
   825  	case IPSECGatewayIPv4:
   826  		off, err = packDataA(gatewayAddr, msg, off)
   827  	case IPSECGatewayIPv6:
   828  		off, err = packDataAAAA(gatewayAddr, msg, off)
   829  	case IPSECGatewayHost:
   830  		off, err = packDomainName(gatewayString, msg, off, compression, compress)
   831  	}
   832  
   833  	return off, err
   834  }
   835  

View as plain text