...

Source file src/github.com/klauspost/compress/zstd/snappy.go

Documentation: github.com/klauspost/compress/zstd

     1  // Copyright 2019+ Klaus Post. All rights reserved.
     2  // License information can be found in the LICENSE file.
     3  // Based on work by Yann Collet, released under BSD License.
     4  
     5  package zstd
     6  
     7  import (
     8  	"encoding/binary"
     9  	"errors"
    10  	"hash/crc32"
    11  	"io"
    12  
    13  	"github.com/klauspost/compress/huff0"
    14  	snappy "github.com/klauspost/compress/internal/snapref"
    15  )
    16  
    17  const (
    18  	snappyTagLiteral = 0x00
    19  	snappyTagCopy1   = 0x01
    20  	snappyTagCopy2   = 0x02
    21  	snappyTagCopy4   = 0x03
    22  )
    23  
    24  const (
    25  	snappyChecksumSize = 4
    26  	snappyMagicBody    = "sNaPpY"
    27  
    28  	// snappyMaxBlockSize is the maximum size of the input to encodeBlock. It is not
    29  	// part of the wire format per se, but some parts of the encoder assume
    30  	// that an offset fits into a uint16.
    31  	//
    32  	// Also, for the framing format (Writer type instead of Encode function),
    33  	// https://github.com/google/snappy/blob/master/framing_format.txt says
    34  	// that "the uncompressed data in a chunk must be no longer than 65536
    35  	// bytes".
    36  	snappyMaxBlockSize = 65536
    37  
    38  	// snappyMaxEncodedLenOfMaxBlockSize equals MaxEncodedLen(snappyMaxBlockSize), but is
    39  	// hard coded to be a const instead of a variable, so that obufLen can also
    40  	// be a const. Their equivalence is confirmed by
    41  	// TestMaxEncodedLenOfMaxBlockSize.
    42  	snappyMaxEncodedLenOfMaxBlockSize = 76490
    43  )
    44  
    45  const (
    46  	chunkTypeCompressedData   = 0x00
    47  	chunkTypeUncompressedData = 0x01
    48  	chunkTypePadding          = 0xfe
    49  	chunkTypeStreamIdentifier = 0xff
    50  )
    51  
    52  var (
    53  	// ErrSnappyCorrupt reports that the input is invalid.
    54  	ErrSnappyCorrupt = errors.New("snappy: corrupt input")
    55  	// ErrSnappyTooLarge reports that the uncompressed length is too large.
    56  	ErrSnappyTooLarge = errors.New("snappy: decoded block is too large")
    57  	// ErrSnappyUnsupported reports that the input isn't supported.
    58  	ErrSnappyUnsupported = errors.New("snappy: unsupported input")
    59  
    60  	errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length")
    61  )
    62  
    63  // SnappyConverter can read SnappyConverter-compressed streams and convert them to zstd.
    64  // Conversion is done by converting the stream directly from Snappy without intermediate
    65  // full decoding.
    66  // Therefore the compression ratio is much less than what can be done by a full decompression
    67  // and compression, and a faulty Snappy stream may lead to a faulty Zstandard stream without
    68  // any errors being generated.
    69  // No CRC value is being generated and not all CRC values of the Snappy stream are checked.
    70  // However, it provides really fast recompression of Snappy streams.
    71  // The converter can be reused to avoid allocations, even after errors.
    72  type SnappyConverter struct {
    73  	r     io.Reader
    74  	err   error
    75  	buf   []byte
    76  	block *blockEnc
    77  }
    78  
    79  // Convert the Snappy stream supplied in 'in' and write the zStandard stream to 'w'.
    80  // If any error is detected on the Snappy stream it is returned.
    81  // The number of bytes written is returned.
    82  func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) {
    83  	initPredefined()
    84  	r.err = nil
    85  	r.r = in
    86  	if r.block == nil {
    87  		r.block = &blockEnc{}
    88  		r.block.init()
    89  	}
    90  	r.block.initNewEncode()
    91  	if len(r.buf) != snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize {
    92  		r.buf = make([]byte, snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize)
    93  	}
    94  	r.block.litEnc.Reuse = huff0.ReusePolicyNone
    95  	var written int64
    96  	var readHeader bool
    97  	{
    98  		header := frameHeader{WindowSize: snappyMaxBlockSize}.appendTo(r.buf[:0])
    99  
   100  		var n int
   101  		n, r.err = w.Write(header)
   102  		if r.err != nil {
   103  			return written, r.err
   104  		}
   105  		written += int64(n)
   106  	}
   107  
   108  	for {
   109  		if !r.readFull(r.buf[:4], true) {
   110  			// Add empty last block
   111  			r.block.reset(nil)
   112  			r.block.last = true
   113  			err := r.block.encodeLits(r.block.literals, false)
   114  			if err != nil {
   115  				return written, err
   116  			}
   117  			n, err := w.Write(r.block.output)
   118  			if err != nil {
   119  				return written, err
   120  			}
   121  			written += int64(n)
   122  
   123  			return written, r.err
   124  		}
   125  		chunkType := r.buf[0]
   126  		if !readHeader {
   127  			if chunkType != chunkTypeStreamIdentifier {
   128  				println("chunkType != chunkTypeStreamIdentifier", chunkType)
   129  				r.err = ErrSnappyCorrupt
   130  				return written, r.err
   131  			}
   132  			readHeader = true
   133  		}
   134  		chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
   135  		if chunkLen > len(r.buf) {
   136  			println("chunkLen > len(r.buf)", chunkType)
   137  			r.err = ErrSnappyUnsupported
   138  			return written, r.err
   139  		}
   140  
   141  		// The chunk types are specified at
   142  		// https://github.com/google/snappy/blob/master/framing_format.txt
   143  		switch chunkType {
   144  		case chunkTypeCompressedData:
   145  			// Section 4.2. Compressed data (chunk type 0x00).
   146  			if chunkLen < snappyChecksumSize {
   147  				println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize)
   148  				r.err = ErrSnappyCorrupt
   149  				return written, r.err
   150  			}
   151  			buf := r.buf[:chunkLen]
   152  			if !r.readFull(buf, false) {
   153  				return written, r.err
   154  			}
   155  			//checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
   156  			buf = buf[snappyChecksumSize:]
   157  
   158  			n, hdr, err := snappyDecodedLen(buf)
   159  			if err != nil {
   160  				r.err = err
   161  				return written, r.err
   162  			}
   163  			buf = buf[hdr:]
   164  			if n > snappyMaxBlockSize {
   165  				println("n > snappyMaxBlockSize", n, snappyMaxBlockSize)
   166  				r.err = ErrSnappyCorrupt
   167  				return written, r.err
   168  			}
   169  			r.block.reset(nil)
   170  			r.block.pushOffsets()
   171  			if err := decodeSnappy(r.block, buf); err != nil {
   172  				r.err = err
   173  				return written, r.err
   174  			}
   175  			if r.block.size+r.block.extraLits != n {
   176  				printf("invalid size, want %d, got %d\n", n, r.block.size+r.block.extraLits)
   177  				r.err = ErrSnappyCorrupt
   178  				return written, r.err
   179  			}
   180  			err = r.block.encode(nil, false, false)
   181  			switch err {
   182  			case errIncompressible:
   183  				r.block.popOffsets()
   184  				r.block.reset(nil)
   185  				r.block.literals, err = snappy.Decode(r.block.literals[:n], r.buf[snappyChecksumSize:chunkLen])
   186  				if err != nil {
   187  					return written, err
   188  				}
   189  				err = r.block.encodeLits(r.block.literals, false)
   190  				if err != nil {
   191  					return written, err
   192  				}
   193  			case nil:
   194  			default:
   195  				return written, err
   196  			}
   197  
   198  			n, r.err = w.Write(r.block.output)
   199  			if r.err != nil {
   200  				return written, err
   201  			}
   202  			written += int64(n)
   203  			continue
   204  		case chunkTypeUncompressedData:
   205  			if debugEncoder {
   206  				println("Uncompressed, chunklen", chunkLen)
   207  			}
   208  			// Section 4.3. Uncompressed data (chunk type 0x01).
   209  			if chunkLen < snappyChecksumSize {
   210  				println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize)
   211  				r.err = ErrSnappyCorrupt
   212  				return written, r.err
   213  			}
   214  			r.block.reset(nil)
   215  			buf := r.buf[:snappyChecksumSize]
   216  			if !r.readFull(buf, false) {
   217  				return written, r.err
   218  			}
   219  			checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
   220  			// Read directly into r.decoded instead of via r.buf.
   221  			n := chunkLen - snappyChecksumSize
   222  			if n > snappyMaxBlockSize {
   223  				println("n > snappyMaxBlockSize", n, snappyMaxBlockSize)
   224  				r.err = ErrSnappyCorrupt
   225  				return written, r.err
   226  			}
   227  			r.block.literals = r.block.literals[:n]
   228  			if !r.readFull(r.block.literals, false) {
   229  				return written, r.err
   230  			}
   231  			if snappyCRC(r.block.literals) != checksum {
   232  				println("literals crc mismatch")
   233  				r.err = ErrSnappyCorrupt
   234  				return written, r.err
   235  			}
   236  			err := r.block.encodeLits(r.block.literals, false)
   237  			if err != nil {
   238  				return written, err
   239  			}
   240  			n, r.err = w.Write(r.block.output)
   241  			if r.err != nil {
   242  				return written, err
   243  			}
   244  			written += int64(n)
   245  			continue
   246  
   247  		case chunkTypeStreamIdentifier:
   248  			if debugEncoder {
   249  				println("stream id", chunkLen, len(snappyMagicBody))
   250  			}
   251  			// Section 4.1. Stream identifier (chunk type 0xff).
   252  			if chunkLen != len(snappyMagicBody) {
   253  				println("chunkLen != len(snappyMagicBody)", chunkLen, len(snappyMagicBody))
   254  				r.err = ErrSnappyCorrupt
   255  				return written, r.err
   256  			}
   257  			if !r.readFull(r.buf[:len(snappyMagicBody)], false) {
   258  				return written, r.err
   259  			}
   260  			for i := 0; i < len(snappyMagicBody); i++ {
   261  				if r.buf[i] != snappyMagicBody[i] {
   262  					println("r.buf[i] != snappyMagicBody[i]", r.buf[i], snappyMagicBody[i], i)
   263  					r.err = ErrSnappyCorrupt
   264  					return written, r.err
   265  				}
   266  			}
   267  			continue
   268  		}
   269  
   270  		if chunkType <= 0x7f {
   271  			// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
   272  			println("chunkType <= 0x7f")
   273  			r.err = ErrSnappyUnsupported
   274  			return written, r.err
   275  		}
   276  		// Section 4.4 Padding (chunk type 0xfe).
   277  		// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
   278  		if !r.readFull(r.buf[:chunkLen], false) {
   279  			return written, r.err
   280  		}
   281  	}
   282  }
   283  
   284  // decodeSnappy writes the decoding of src to dst. It assumes that the varint-encoded
   285  // length of the decompressed bytes has already been read.
   286  func decodeSnappy(blk *blockEnc, src []byte) error {
   287  	//decodeRef(make([]byte, snappyMaxBlockSize), src)
   288  	var s, length int
   289  	lits := blk.extraLits
   290  	var offset uint32
   291  	for s < len(src) {
   292  		switch src[s] & 0x03 {
   293  		case snappyTagLiteral:
   294  			x := uint32(src[s] >> 2)
   295  			switch {
   296  			case x < 60:
   297  				s++
   298  			case x == 60:
   299  				s += 2
   300  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   301  					println("uint(s) > uint(len(src)", s, src)
   302  					return ErrSnappyCorrupt
   303  				}
   304  				x = uint32(src[s-1])
   305  			case x == 61:
   306  				s += 3
   307  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   308  					println("uint(s) > uint(len(src)", s, src)
   309  					return ErrSnappyCorrupt
   310  				}
   311  				x = uint32(src[s-2]) | uint32(src[s-1])<<8
   312  			case x == 62:
   313  				s += 4
   314  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   315  					println("uint(s) > uint(len(src)", s, src)
   316  					return ErrSnappyCorrupt
   317  				}
   318  				x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
   319  			case x == 63:
   320  				s += 5
   321  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   322  					println("uint(s) > uint(len(src)", s, src)
   323  					return ErrSnappyCorrupt
   324  				}
   325  				x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
   326  			}
   327  			if x > snappyMaxBlockSize {
   328  				println("x > snappyMaxBlockSize", x, snappyMaxBlockSize)
   329  				return ErrSnappyCorrupt
   330  			}
   331  			length = int(x) + 1
   332  			if length <= 0 {
   333  				println("length <= 0 ", length)
   334  
   335  				return errUnsupportedLiteralLength
   336  			}
   337  			//if length > snappyMaxBlockSize-d || uint32(length) > len(src)-s {
   338  			//	return ErrSnappyCorrupt
   339  			//}
   340  
   341  			blk.literals = append(blk.literals, src[s:s+length]...)
   342  			//println(length, "litLen")
   343  			lits += length
   344  			s += length
   345  			continue
   346  
   347  		case snappyTagCopy1:
   348  			s += 2
   349  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   350  				println("uint(s) > uint(len(src)", s, len(src))
   351  				return ErrSnappyCorrupt
   352  			}
   353  			length = 4 + int(src[s-2])>>2&0x7
   354  			offset = uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])
   355  
   356  		case snappyTagCopy2:
   357  			s += 3
   358  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   359  				println("uint(s) > uint(len(src)", s, len(src))
   360  				return ErrSnappyCorrupt
   361  			}
   362  			length = 1 + int(src[s-3])>>2
   363  			offset = uint32(src[s-2]) | uint32(src[s-1])<<8
   364  
   365  		case snappyTagCopy4:
   366  			s += 5
   367  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   368  				println("uint(s) > uint(len(src)", s, len(src))
   369  				return ErrSnappyCorrupt
   370  			}
   371  			length = 1 + int(src[s-5])>>2
   372  			offset = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
   373  		}
   374  
   375  		if offset <= 0 || blk.size+lits < int(offset) /*|| length > len(blk)-d */ {
   376  			println("offset <= 0 || blk.size+lits < int(offset)", offset, blk.size+lits, int(offset), blk.size, lits)
   377  
   378  			return ErrSnappyCorrupt
   379  		}
   380  
   381  		// Check if offset is one of the recent offsets.
   382  		// Adjusts the output offset accordingly.
   383  		// Gives a tiny bit of compression, typically around 1%.
   384  		if false {
   385  			offset = blk.matchOffset(offset, uint32(lits))
   386  		} else {
   387  			offset += 3
   388  		}
   389  
   390  		blk.sequences = append(blk.sequences, seq{
   391  			litLen:   uint32(lits),
   392  			offset:   offset,
   393  			matchLen: uint32(length) - zstdMinMatch,
   394  		})
   395  		blk.size += length + lits
   396  		lits = 0
   397  	}
   398  	blk.extraLits = lits
   399  	return nil
   400  }
   401  
   402  func (r *SnappyConverter) readFull(p []byte, allowEOF bool) (ok bool) {
   403  	if _, r.err = io.ReadFull(r.r, p); r.err != nil {
   404  		if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
   405  			r.err = ErrSnappyCorrupt
   406  		}
   407  		return false
   408  	}
   409  	return true
   410  }
   411  
   412  var crcTable = crc32.MakeTable(crc32.Castagnoli)
   413  
   414  // crc implements the checksum specified in section 3 of
   415  // https://github.com/google/snappy/blob/master/framing_format.txt
   416  func snappyCRC(b []byte) uint32 {
   417  	c := crc32.Update(0, crcTable, b)
   418  	return c>>15 | c<<17 + 0xa282ead8
   419  }
   420  
   421  // snappyDecodedLen returns the length of the decoded block and the number of bytes
   422  // that the length header occupied.
   423  func snappyDecodedLen(src []byte) (blockLen, headerLen int, err error) {
   424  	v, n := binary.Uvarint(src)
   425  	if n <= 0 || v > 0xffffffff {
   426  		return 0, 0, ErrSnappyCorrupt
   427  	}
   428  
   429  	const wordSize = 32 << (^uint(0) >> 32 & 1)
   430  	if wordSize == 32 && v > 0x7fffffff {
   431  		return 0, 0, ErrSnappyTooLarge
   432  	}
   433  	return int(v), n, nil
   434  }
   435  

View as plain text