...

Source file src/github.com/Microsoft/go-winio/wim/lzx/lzx.go

Documentation: github.com/Microsoft/go-winio/wim/lzx

     1  // Package lzx implements a decompressor for the the WIM variant of the
     2  // LZX compression algorithm.
     3  //
     4  // The LZX algorithm is an earlier variant of LZX DELTA, which is documented
     5  // at https://msdn.microsoft.com/en-us/library/cc483133(v=exchg.80).aspx.
     6  package lzx
     7  
     8  import (
     9  	"bytes"
    10  	"encoding/binary"
    11  	"errors"
    12  	"io"
    13  )
    14  
    15  const (
    16  	maincodecount = 496
    17  	maincodesplit = 256
    18  	lencodecount  = 249
    19  	lenshift      = 9
    20  	codemask      = 0x1ff
    21  	tablebits     = 9
    22  	tablesize     = 1 << tablebits
    23  
    24  	maxBlockSize = 32768
    25  	windowSize   = 32768
    26  
    27  	maxTreePathLen = 16
    28  
    29  	e8filesize  = 12000000
    30  	maxe8offset = 0x3fffffff
    31  
    32  	verbatimBlock      = 1
    33  	alignedOffsetBlock = 2
    34  	uncompressedBlock  = 3
    35  )
    36  
    37  var footerBits = [...]byte{
    38  	0, 0, 0, 0, 1, 1, 2, 2,
    39  	3, 3, 4, 4, 5, 5, 6, 6,
    40  	7, 7, 8, 8, 9, 9, 10, 10,
    41  	11, 11, 12, 12, 13, 13, 14,
    42  }
    43  
    44  var basePosition = [...]uint16{
    45  	0, 1, 2, 3, 4, 6, 8, 12,
    46  	16, 24, 32, 48, 64, 96, 128, 192,
    47  	256, 384, 512, 768, 1024, 1536, 2048, 3072,
    48  	4096, 6144, 8192, 12288, 16384, 24576, 32768,
    49  }
    50  
    51  var (
    52  	errCorrupt = errors.New("LZX data corrupt")
    53  )
    54  
    55  // Reader is an interface used by the decompressor to access
    56  // the input stream. If the provided io.Reader does not implement
    57  // Reader, then a bufio.Reader is used.
    58  type Reader interface {
    59  	io.Reader
    60  	io.ByteReader
    61  }
    62  
    63  type decompressor struct {
    64  	r            io.Reader
    65  	err          error
    66  	unaligned    bool
    67  	nbits        byte
    68  	c            uint32
    69  	lru          [3]uint16
    70  	uncompressed int
    71  	windowReader *bytes.Reader
    72  	mainlens     [maincodecount]byte
    73  	lenlens      [lencodecount]byte
    74  	window       [windowSize]byte
    75  	b            []byte
    76  	bv           int
    77  	bo           int
    78  }
    79  
    80  //go:noinline
    81  func (f *decompressor) fail(err error) {
    82  	if f.err == nil {
    83  		f.err = err
    84  	}
    85  	f.bo = 0
    86  	f.bv = 0
    87  }
    88  
    89  func (f *decompressor) ensureAtLeast(n int) error {
    90  	if f.bv-f.bo >= n {
    91  		return nil
    92  	}
    93  
    94  	if f.err != nil {
    95  		return f.err
    96  	}
    97  
    98  	if f.bv != f.bo {
    99  		copy(f.b[:f.bv-f.bo], f.b[f.bo:f.bv])
   100  	}
   101  	n, err := io.ReadAtLeast(f.r, f.b[f.bv-f.bo:], n)
   102  	if err != nil {
   103  		if err == io.EOF { //nolint:errorlint
   104  			err = io.ErrUnexpectedEOF
   105  		} else {
   106  			f.fail(err)
   107  		}
   108  		return err
   109  	}
   110  	f.bv = f.bv - f.bo + n
   111  	f.bo = 0
   112  	return nil
   113  }
   114  
   115  // feed retrieves another 16-bit word from the stream and consumes
   116  // it into f.c. It returns false if there are no more bytes available.
   117  // Otherwise, on error, it sets f.err.
   118  func (f *decompressor) feed() bool {
   119  	err := f.ensureAtLeast(2)
   120  	if err == io.ErrUnexpectedEOF { //nolint:errorlint // returns io.ErrUnexpectedEOF by contract
   121  		return false
   122  	}
   123  	f.c |= (uint32(f.b[f.bo+1])<<8 | uint32(f.b[f.bo])) << (16 - f.nbits)
   124  	f.nbits += 16
   125  	f.bo += 2
   126  	return true
   127  }
   128  
   129  // getBits retrieves the next n bits from the byte stream. n
   130  // must be <= 16. It sets f.err on error.
   131  func (f *decompressor) getBits(n byte) uint16 {
   132  	if f.nbits < n {
   133  		if !f.feed() {
   134  			f.fail(io.ErrUnexpectedEOF)
   135  		}
   136  	}
   137  	c := uint16(f.c >> (32 - n))
   138  	f.c <<= n
   139  	f.nbits -= n
   140  	return c
   141  }
   142  
   143  type huffman struct {
   144  	extra   [][]uint16
   145  	maxbits byte
   146  	table   [tablesize]uint16
   147  }
   148  
   149  // buildTable builds a huffman decoding table from a slice of code lengths,
   150  // one per code, in order. Each code length must be <= maxTreePathLen.
   151  // See https://en.wikipedia.org/wiki/Canonical_Huffman_code.
   152  func buildTable(codelens []byte) *huffman {
   153  	// Determine the number of codes of each length, and the
   154  	// maximum length.
   155  	var count [maxTreePathLen + 1]uint
   156  	var max byte
   157  	for _, cl := range codelens {
   158  		count[cl]++
   159  		if max < cl {
   160  			max = cl
   161  		}
   162  	}
   163  
   164  	if max == 0 {
   165  		return &huffman{}
   166  	}
   167  
   168  	// Determine the first code of each length.
   169  	var first [maxTreePathLen + 1]uint
   170  	code := uint(0)
   171  	for i := byte(1); i <= max; i++ {
   172  		code <<= 1
   173  		first[i] = code
   174  		code += count[i]
   175  	}
   176  
   177  	if code != 1<<max {
   178  		return nil
   179  	}
   180  
   181  	// Build a table for code lookup. For code sizes < max,
   182  	// put all possible suffixes for the code into the table, too.
   183  	// For max > tablebits, split long codes into additional tables
   184  	// of suffixes of max-tablebits length.
   185  	h := &huffman{maxbits: max}
   186  	if max > tablebits {
   187  		core := first[tablebits+1] / 2 // Number of codes that fit without extra tables
   188  		nextra := 1<<tablebits - core  // Number of extra entries
   189  		h.extra = make([][]uint16, nextra)
   190  		for code := core; code < 1<<tablebits; code++ {
   191  			h.table[code] = uint16(code - core)
   192  			h.extra[code-core] = make([]uint16, 1<<(max-tablebits))
   193  		}
   194  	}
   195  
   196  	for i, cl := range codelens {
   197  		if cl != 0 {
   198  			code := first[cl]
   199  			first[cl]++
   200  			v := uint16(cl)<<lenshift | uint16(i)
   201  			if cl <= tablebits {
   202  				extendedCode := code << (tablebits - cl)
   203  				for j := uint(0); j < 1<<(tablebits-cl); j++ {
   204  					h.table[extendedCode+j] = v
   205  				}
   206  			} else {
   207  				prefix := code >> (cl - tablebits)
   208  				suffix := code & (1<<(cl-tablebits) - 1)
   209  				extendedCode := suffix << (max - cl)
   210  				for j := uint(0); j < 1<<(max-cl); j++ {
   211  					h.extra[h.table[prefix]][extendedCode+j] = v
   212  				}
   213  			}
   214  		}
   215  	}
   216  
   217  	return h
   218  }
   219  
   220  // getCode retrieves the next code using the provided
   221  // huffman tree. It sets f.err on error.
   222  func (f *decompressor) getCode(h *huffman) uint16 {
   223  	if h.maxbits > 0 {
   224  		if f.nbits < maxTreePathLen {
   225  			f.feed()
   226  		}
   227  
   228  		// For codes with length < tablebits, it doesn't matter
   229  		// what the remainder of the bits used for table lookup
   230  		// are, since entries with all possible suffixes were
   231  		// added to the table.
   232  		c := h.table[f.c>>(32-tablebits)]
   233  		if !(c >= 1<<lenshift) {
   234  			// The code is not in c.
   235  			c = h.extra[c][f.c<<tablebits>>(32-(h.maxbits-tablebits))]
   236  		}
   237  
   238  		n := byte(c >> lenshift)
   239  		if f.nbits >= n {
   240  			// Only consume the length of the code, not the maximum
   241  			// code length.
   242  			f.c <<= n
   243  			f.nbits -= n
   244  			return c & codemask
   245  		}
   246  
   247  		f.fail(io.ErrUnexpectedEOF)
   248  		return 0
   249  	}
   250  
   251  	// This is an empty tree. It should not be used.
   252  	f.fail(errCorrupt)
   253  	return 0
   254  }
   255  
   256  // readTree updates the huffman tree path lengths in lens by
   257  // reading and decoding lengths from the byte stream. lens
   258  // should be prepopulated with the previous block's tree's path
   259  // lengths. For the first block, lens should be zero.
   260  func (f *decompressor) readTree(lens []byte) error {
   261  	// Get the pre-tree for the main tree.
   262  	var pretreeLen [20]byte
   263  	for i := range pretreeLen {
   264  		pretreeLen[i] = byte(f.getBits(4))
   265  	}
   266  	if f.err != nil {
   267  		return f.err
   268  	}
   269  	h := buildTable(pretreeLen[:])
   270  
   271  	// The lengths are encoded as a series of huffman codes
   272  	// encoded by the pre-tree.
   273  	for i := 0; i < len(lens); {
   274  		c := byte(f.getCode(h))
   275  		if f.err != nil {
   276  			return f.err
   277  		}
   278  		switch {
   279  		case c <= 16: // length is delta from previous length
   280  			lens[i] = (lens[i] + 17 - c) % 17
   281  			i++
   282  		case c == 17: // next n + 4 lengths are zero
   283  			zeroes := int(f.getBits(4)) + 4
   284  			if i+zeroes > len(lens) {
   285  				return errCorrupt
   286  			}
   287  			for j := 0; j < zeroes; j++ {
   288  				lens[i+j] = 0
   289  			}
   290  			i += zeroes
   291  		case c == 18: // next n + 20 lengths are zero
   292  			zeroes := int(f.getBits(5)) + 20
   293  			if i+zeroes > len(lens) {
   294  				return errCorrupt
   295  			}
   296  			for j := 0; j < zeroes; j++ {
   297  				lens[i+j] = 0
   298  			}
   299  			i += zeroes
   300  		case c == 19: // next n + 4 lengths all have the same value
   301  			same := int(f.getBits(1)) + 4
   302  			if i+same > len(lens) {
   303  				return errCorrupt
   304  			}
   305  			c = byte(f.getCode(h))
   306  			if c > 16 {
   307  				return errCorrupt
   308  			}
   309  			l := (lens[i] + 17 - c) % 17
   310  			for j := 0; j < same; j++ {
   311  				lens[i+j] = l
   312  			}
   313  			i += same
   314  		default:
   315  			return errCorrupt
   316  		}
   317  	}
   318  
   319  	if f.err != nil {
   320  		return f.err
   321  	}
   322  	return nil
   323  }
   324  
   325  func (f *decompressor) readBlockHeader() (byte, uint16, error) {
   326  	// If the previous block was an unaligned uncompressed block, restore
   327  	// 2-byte alignment.
   328  	if f.unaligned {
   329  		err := f.ensureAtLeast(1)
   330  		if err != nil {
   331  			return 0, 0, err
   332  		}
   333  		f.bo++
   334  		f.unaligned = false
   335  	}
   336  
   337  	blockType := f.getBits(3)
   338  	full := f.getBits(1)
   339  	var blockSize uint16
   340  	if full != 0 {
   341  		blockSize = maxBlockSize
   342  	} else {
   343  		blockSize = f.getBits(16)
   344  		if blockSize > maxBlockSize {
   345  			return 0, 0, errCorrupt
   346  		}
   347  	}
   348  
   349  	if f.err != nil {
   350  		return 0, 0, f.err
   351  	}
   352  
   353  	switch blockType {
   354  	case verbatimBlock, alignedOffsetBlock:
   355  		// The caller will read the huffman trees.
   356  	case uncompressedBlock:
   357  		if f.nbits > 16 {
   358  			panic("impossible: more than one 16-bit word remains")
   359  		}
   360  
   361  		// Drop the remaining bits in the current 16-bit word
   362  		// If there are no bits left, discard a full 16-bit word.
   363  		n := f.nbits
   364  		if n == 0 {
   365  			n = 16
   366  		}
   367  
   368  		f.getBits(n)
   369  
   370  		// Read the LRU values for the next block.
   371  		err := f.ensureAtLeast(12)
   372  		if err != nil {
   373  			return 0, 0, err
   374  		}
   375  
   376  		f.lru[0] = uint16(binary.LittleEndian.Uint32(f.b[f.bo : f.bo+4]))
   377  		f.lru[1] = uint16(binary.LittleEndian.Uint32(f.b[f.bo+4 : f.bo+8]))
   378  		f.lru[2] = uint16(binary.LittleEndian.Uint32(f.b[f.bo+8 : f.bo+12]))
   379  		f.bo += 12
   380  
   381  	default:
   382  		return 0, 0, errCorrupt
   383  	}
   384  
   385  	return byte(blockType), blockSize, nil
   386  }
   387  
   388  // readTrees reads the two or three huffman trees for the current block.
   389  // readAligned specifies whether to read the aligned offset tree.
   390  func (f *decompressor) readTrees(readAligned bool) (main *huffman, length *huffman, aligned *huffman, err error) {
   391  	// Aligned offset blocks start with a small aligned offset tree.
   392  	if readAligned {
   393  		var alignedLen [8]byte
   394  		for i := range alignedLen {
   395  			alignedLen[i] = byte(f.getBits(3))
   396  		}
   397  		aligned = buildTable(alignedLen[:])
   398  		if aligned == nil {
   399  			return main, length, aligned, errors.New("corrupt")
   400  		}
   401  	}
   402  
   403  	// The main tree is encoded in two parts.
   404  	err = f.readTree(f.mainlens[:maincodesplit])
   405  	if err != nil {
   406  		return main, length, aligned, err
   407  	}
   408  	err = f.readTree(f.mainlens[maincodesplit:])
   409  	if err != nil {
   410  		return main, length, aligned, err
   411  	}
   412  
   413  	main = buildTable(f.mainlens[:])
   414  	if main == nil {
   415  		return main, length, aligned, errors.New("corrupt")
   416  	}
   417  
   418  	// The length tree is encoding in a single part.
   419  	err = f.readTree(f.lenlens[:])
   420  	if err != nil {
   421  		return main, length, aligned, err
   422  	}
   423  
   424  	length = buildTable(f.lenlens[:])
   425  	if length == nil {
   426  		return main, length, aligned, errors.New("corrupt")
   427  	}
   428  
   429  	return main, length, aligned, f.err
   430  }
   431  
   432  // readCompressedBlock decodes a compressed block, writing into the window
   433  // starting at start and ending at end, and using the provided huffman trees.
   434  func (f *decompressor) readCompressedBlock(start, end uint16, hmain, hlength, haligned *huffman) (int, error) {
   435  	i := start
   436  	for i < end {
   437  		main := f.getCode(hmain)
   438  		if f.err != nil {
   439  			break
   440  		}
   441  		if main < 256 {
   442  			// Literal byte.
   443  			f.window[i] = byte(main)
   444  			i++
   445  			continue
   446  		}
   447  
   448  		// This is a match backward in the window. Determine
   449  		// the offset and dlength.
   450  		matchlen := (main - 256) % 8
   451  		slot := (main - 256) / 8
   452  
   453  		// The length is either the low bits of the code,
   454  		// or if this is 7, is encoded with the length tree.
   455  		if matchlen == 7 {
   456  			matchlen += f.getCode(hlength)
   457  		}
   458  		matchlen += 2
   459  
   460  		var matchoffset uint16
   461  		if slot < 3 { //nolint:nestif // todo: simplify nested complexity
   462  			// The offset is one of the LRU values.
   463  			matchoffset = f.lru[slot]
   464  			f.lru[slot] = f.lru[0]
   465  			f.lru[0] = matchoffset
   466  		} else {
   467  			// The offset is encoded as a combination of the
   468  			// slot and more bits from the bit stream.
   469  			offsetbits := footerBits[slot]
   470  			var verbatimbits, alignedbits uint16
   471  			if offsetbits > 0 {
   472  				if haligned != nil && offsetbits >= 3 {
   473  					// This is an aligned offset block. Combine
   474  					// the bits written verbatim with the aligned
   475  					// offset tree code.
   476  					verbatimbits = f.getBits(offsetbits-3) * 8
   477  					alignedbits = f.getCode(haligned)
   478  				} else {
   479  					// There are no aligned offset bits to read,
   480  					// only verbatim bits.
   481  					verbatimbits = f.getBits(offsetbits)
   482  					alignedbits = 0
   483  				}
   484  			}
   485  			matchoffset = basePosition[slot] + verbatimbits + alignedbits - 2
   486  			// Update the LRU cache.
   487  			f.lru[2] = f.lru[1]
   488  			f.lru[1] = f.lru[0]
   489  			f.lru[0] = matchoffset
   490  		}
   491  
   492  		if !(matchoffset <= i && matchlen <= end-i) {
   493  			f.fail(errCorrupt)
   494  			break
   495  		}
   496  		copyend := i + matchlen
   497  		for ; i < copyend; i++ {
   498  			f.window[i] = f.window[i-matchoffset]
   499  		}
   500  	}
   501  	return int(i - start), f.err
   502  }
   503  
   504  // readBlock decodes the current block and returns the number of uncompressed bytes.
   505  func (f *decompressor) readBlock(start uint16) (int, error) {
   506  	blockType, size, err := f.readBlockHeader()
   507  	if err != nil {
   508  		return 0, err
   509  	}
   510  
   511  	if blockType == uncompressedBlock {
   512  		if size%2 == 1 {
   513  			// Remember to realign the byte stream at the next block.
   514  			f.unaligned = true
   515  		}
   516  		copied := 0
   517  		if f.bo < f.bv {
   518  			copied = int(size)
   519  			s := int(start)
   520  			if copied > f.bv-f.bo {
   521  				copied = f.bv - f.bo
   522  			}
   523  			copy(f.window[s:s+copied], f.b[f.bo:f.bo+copied])
   524  			f.bo += copied
   525  		}
   526  		n, err := io.ReadFull(f.r, f.window[start+uint16(copied):start+size])
   527  		return copied + n, err
   528  	}
   529  
   530  	hmain, hlength, haligned, err := f.readTrees(blockType == alignedOffsetBlock)
   531  	if err != nil {
   532  		return 0, err
   533  	}
   534  
   535  	return f.readCompressedBlock(start, start+size, hmain, hlength, haligned)
   536  }
   537  
   538  // decodeE8 reverses the 0xe8 x86 instruction encoding that was performed
   539  // to the uncompressed data before it was compressed.
   540  func decodeE8(b []byte, off int64) {
   541  	if off > maxe8offset || len(b) < 10 {
   542  		return
   543  	}
   544  	for i := 0; i < len(b)-10; i++ {
   545  		if b[i] == 0xe8 {
   546  			currentPtr := int32(off) + int32(i)
   547  			abs := int32(binary.LittleEndian.Uint32(b[i+1 : i+5]))
   548  			if abs >= -currentPtr && abs < e8filesize {
   549  				var rel int32
   550  				if abs >= 0 {
   551  					rel = abs - currentPtr
   552  				} else {
   553  					rel = abs + e8filesize
   554  				}
   555  				binary.LittleEndian.PutUint32(b[i+1:i+5], uint32(rel))
   556  			}
   557  			i += 4
   558  		}
   559  	}
   560  }
   561  
   562  func (f *decompressor) Read(b []byte) (int, error) {
   563  	// Read and uncompress everything.
   564  	if f.windowReader == nil {
   565  		n := 0
   566  		for n < f.uncompressed {
   567  			k, err := f.readBlock(uint16(n))
   568  			if err != nil {
   569  				return 0, err
   570  			}
   571  			n += k
   572  		}
   573  		decodeE8(f.window[:f.uncompressed], 0)
   574  		f.windowReader = bytes.NewReader(f.window[:f.uncompressed])
   575  	}
   576  
   577  	// Just read directly from the window.
   578  	return f.windowReader.Read(b)
   579  }
   580  
   581  func (*decompressor) Close() error {
   582  	return nil
   583  }
   584  
   585  // NewReader returns a new io.ReadCloser that decompresses a
   586  // WIM LZX stream until uncompressedSize bytes have been returned.
   587  func NewReader(r io.Reader, uncompressedSize int) (io.ReadCloser, error) {
   588  	if uncompressedSize > windowSize {
   589  		return nil, errors.New("uncompressed size is limited to 32KB")
   590  	}
   591  	f := &decompressor{
   592  		lru:          [3]uint16{1, 1, 1},
   593  		uncompressed: uncompressedSize,
   594  		b:            make([]byte, 4096),
   595  		r:            r,
   596  	}
   597  	return f, nil
   598  }
   599  

View as plain text