...

Source file src/github.com/klauspost/compress/s2/decode.go

Documentation: github.com/klauspost/compress/s2

     1  // Copyright 2011 The Snappy-Go Authors. All rights reserved.
     2  // Copyright (c) 2019 Klaus Post. All rights reserved.
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  package s2
     7  
     8  import (
     9  	"encoding/binary"
    10  	"errors"
    11  	"fmt"
    12  	"strconv"
    13  
    14  	"github.com/klauspost/compress/internal/race"
    15  )
    16  
    17  var (
    18  	// ErrCorrupt reports that the input is invalid.
    19  	ErrCorrupt = errors.New("s2: corrupt input")
    20  	// ErrCRC reports that the input failed CRC validation (streams only)
    21  	ErrCRC = errors.New("s2: corrupt input, crc mismatch")
    22  	// ErrTooLarge reports that the uncompressed length is too large.
    23  	ErrTooLarge = errors.New("s2: decoded block is too large")
    24  	// ErrUnsupported reports that the input isn't supported.
    25  	ErrUnsupported = errors.New("s2: unsupported input")
    26  )
    27  
    28  // DecodedLen returns the length of the decoded block.
    29  func DecodedLen(src []byte) (int, error) {
    30  	v, _, err := decodedLen(src)
    31  	return v, err
    32  }
    33  
    34  // decodedLen returns the length of the decoded block and the number of bytes
    35  // that the length header occupied.
    36  func decodedLen(src []byte) (blockLen, headerLen int, err error) {
    37  	v, n := binary.Uvarint(src)
    38  	if n <= 0 || v > 0xffffffff {
    39  		return 0, 0, ErrCorrupt
    40  	}
    41  
    42  	const wordSize = 32 << (^uint(0) >> 32 & 1)
    43  	if wordSize == 32 && v > 0x7fffffff {
    44  		return 0, 0, ErrTooLarge
    45  	}
    46  	return int(v), n, nil
    47  }
    48  
    49  const (
    50  	decodeErrCodeCorrupt = 1
    51  )
    52  
    53  // Decode returns the decoded form of src. The returned slice may be a sub-
    54  // slice of dst if dst was large enough to hold the entire decoded block.
    55  // Otherwise, a newly allocated slice will be returned.
    56  //
    57  // The dst and src must not overlap. It is valid to pass a nil dst.
    58  func Decode(dst, src []byte) ([]byte, error) {
    59  	dLen, s, err := decodedLen(src)
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	if dLen <= cap(dst) {
    64  		dst = dst[:dLen]
    65  	} else {
    66  		dst = make([]byte, dLen)
    67  	}
    68  
    69  	race.WriteSlice(dst)
    70  	race.ReadSlice(src[s:])
    71  
    72  	if s2Decode(dst, src[s:]) != 0 {
    73  		return nil, ErrCorrupt
    74  	}
    75  	return dst, nil
    76  }
    77  
    78  // s2DecodeDict writes the decoding of src to dst. It assumes that the varint-encoded
    79  // length of the decompressed bytes has already been read, and that len(dst)
    80  // equals that length.
    81  //
    82  // It returns 0 on success or a decodeErrCodeXxx error code on failure.
    83  func s2DecodeDict(dst, src []byte, dict *Dict) int {
    84  	if dict == nil {
    85  		return s2Decode(dst, src)
    86  	}
    87  	const debug = false
    88  	const debugErrs = debug
    89  
    90  	if debug {
    91  		fmt.Println("Starting decode, dst len:", len(dst))
    92  	}
    93  	var d, s, length int
    94  	offset := len(dict.dict) - dict.repeat
    95  
    96  	// As long as we can read at least 5 bytes...
    97  	for s < len(src)-5 {
    98  		// Removing bounds checks is SLOWER, when if doing
    99  		// in := src[s:s+5]
   100  		// Checked on Go 1.18
   101  		switch src[s] & 0x03 {
   102  		case tagLiteral:
   103  			x := uint32(src[s] >> 2)
   104  			switch {
   105  			case x < 60:
   106  				s++
   107  			case x == 60:
   108  				s += 2
   109  				x = uint32(src[s-1])
   110  			case x == 61:
   111  				in := src[s : s+3]
   112  				x = uint32(in[1]) | uint32(in[2])<<8
   113  				s += 3
   114  			case x == 62:
   115  				in := src[s : s+4]
   116  				// Load as 32 bit and shift down.
   117  				x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
   118  				x >>= 8
   119  				s += 4
   120  			case x == 63:
   121  				in := src[s : s+5]
   122  				x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24
   123  				s += 5
   124  			}
   125  			length = int(x) + 1
   126  			if debug {
   127  				fmt.Println("literals, length:", length, "d-after:", d+length)
   128  			}
   129  			if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
   130  				if debugErrs {
   131  					fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
   132  				}
   133  				return decodeErrCodeCorrupt
   134  			}
   135  
   136  			copy(dst[d:], src[s:s+length])
   137  			d += length
   138  			s += length
   139  			continue
   140  
   141  		case tagCopy1:
   142  			s += 2
   143  			toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
   144  			length = int(src[s-2]) >> 2 & 0x7
   145  			if toffset == 0 {
   146  				if debug {
   147  					fmt.Print("(repeat) ")
   148  				}
   149  				// keep last offset
   150  				switch length {
   151  				case 5:
   152  					length = int(src[s]) + 4
   153  					s += 1
   154  				case 6:
   155  					in := src[s : s+2]
   156  					length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8)
   157  					s += 2
   158  				case 7:
   159  					in := src[s : s+3]
   160  					length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16)
   161  					s += 3
   162  				default: // 0-> 4
   163  				}
   164  			} else {
   165  				offset = toffset
   166  			}
   167  			length += 4
   168  		case tagCopy2:
   169  			in := src[s : s+3]
   170  			offset = int(uint32(in[1]) | uint32(in[2])<<8)
   171  			length = 1 + int(in[0])>>2
   172  			s += 3
   173  
   174  		case tagCopy4:
   175  			in := src[s : s+5]
   176  			offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24)
   177  			length = 1 + int(in[0])>>2
   178  			s += 5
   179  		}
   180  
   181  		if offset <= 0 || length > len(dst)-d {
   182  			if debugErrs {
   183  				fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
   184  			}
   185  			return decodeErrCodeCorrupt
   186  		}
   187  
   188  		// copy from dict
   189  		if d < offset {
   190  			if d > MaxDictSrcOffset {
   191  				if debugErrs {
   192  					fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
   193  				}
   194  				return decodeErrCodeCorrupt
   195  			}
   196  			startOff := len(dict.dict) - offset + d
   197  			if startOff < 0 || startOff+length > len(dict.dict) {
   198  				if debugErrs {
   199  					fmt.Printf("offset (%d) + length (%d) bigger than dict (%d)\n", offset, length, len(dict.dict))
   200  				}
   201  				return decodeErrCodeCorrupt
   202  			}
   203  			if debug {
   204  				fmt.Println("dict copy, length:", length, "offset:", offset, "d-after:", d+length, "dict start offset:", startOff)
   205  			}
   206  			copy(dst[d:d+length], dict.dict[startOff:])
   207  			d += length
   208  			continue
   209  		}
   210  
   211  		if debug {
   212  			fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
   213  		}
   214  
   215  		// Copy from an earlier sub-slice of dst to a later sub-slice.
   216  		// If no overlap, use the built-in copy:
   217  		if offset > length {
   218  			copy(dst[d:d+length], dst[d-offset:])
   219  			d += length
   220  			continue
   221  		}
   222  
   223  		// Unlike the built-in copy function, this byte-by-byte copy always runs
   224  		// forwards, even if the slices overlap. Conceptually, this is:
   225  		//
   226  		// d += forwardCopy(dst[d:d+length], dst[d-offset:])
   227  		//
   228  		// We align the slices into a and b and show the compiler they are the same size.
   229  		// This allows the loop to run without bounds checks.
   230  		a := dst[d : d+length]
   231  		b := dst[d-offset:]
   232  		b = b[:len(a)]
   233  		for i := range a {
   234  			a[i] = b[i]
   235  		}
   236  		d += length
   237  	}
   238  
   239  	// Remaining with extra checks...
   240  	for s < len(src) {
   241  		switch src[s] & 0x03 {
   242  		case tagLiteral:
   243  			x := uint32(src[s] >> 2)
   244  			switch {
   245  			case x < 60:
   246  				s++
   247  			case x == 60:
   248  				s += 2
   249  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   250  					if debugErrs {
   251  						fmt.Println("src went oob")
   252  					}
   253  					return decodeErrCodeCorrupt
   254  				}
   255  				x = uint32(src[s-1])
   256  			case x == 61:
   257  				s += 3
   258  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   259  					if debugErrs {
   260  						fmt.Println("src went oob")
   261  					}
   262  					return decodeErrCodeCorrupt
   263  				}
   264  				x = uint32(src[s-2]) | uint32(src[s-1])<<8
   265  			case x == 62:
   266  				s += 4
   267  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   268  					if debugErrs {
   269  						fmt.Println("src went oob")
   270  					}
   271  					return decodeErrCodeCorrupt
   272  				}
   273  				x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
   274  			case x == 63:
   275  				s += 5
   276  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   277  					if debugErrs {
   278  						fmt.Println("src went oob")
   279  					}
   280  					return decodeErrCodeCorrupt
   281  				}
   282  				x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
   283  			}
   284  			length = int(x) + 1
   285  			if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
   286  				if debugErrs {
   287  					fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
   288  				}
   289  				return decodeErrCodeCorrupt
   290  			}
   291  			if debug {
   292  				fmt.Println("literals, length:", length, "d-after:", d+length)
   293  			}
   294  
   295  			copy(dst[d:], src[s:s+length])
   296  			d += length
   297  			s += length
   298  			continue
   299  
   300  		case tagCopy1:
   301  			s += 2
   302  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   303  				if debugErrs {
   304  					fmt.Println("src went oob")
   305  				}
   306  				return decodeErrCodeCorrupt
   307  			}
   308  			length = int(src[s-2]) >> 2 & 0x7
   309  			toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
   310  			if toffset == 0 {
   311  				if debug {
   312  					fmt.Print("(repeat) ")
   313  				}
   314  				// keep last offset
   315  				switch length {
   316  				case 5:
   317  					s += 1
   318  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   319  						if debugErrs {
   320  							fmt.Println("src went oob")
   321  						}
   322  						return decodeErrCodeCorrupt
   323  					}
   324  					length = int(uint32(src[s-1])) + 4
   325  				case 6:
   326  					s += 2
   327  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   328  						if debugErrs {
   329  							fmt.Println("src went oob")
   330  						}
   331  						return decodeErrCodeCorrupt
   332  					}
   333  					length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8)
   334  				case 7:
   335  					s += 3
   336  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   337  						if debugErrs {
   338  							fmt.Println("src went oob")
   339  						}
   340  						return decodeErrCodeCorrupt
   341  					}
   342  					length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16)
   343  				default: // 0-> 4
   344  				}
   345  			} else {
   346  				offset = toffset
   347  			}
   348  			length += 4
   349  		case tagCopy2:
   350  			s += 3
   351  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   352  				if debugErrs {
   353  					fmt.Println("src went oob")
   354  				}
   355  				return decodeErrCodeCorrupt
   356  			}
   357  			length = 1 + int(src[s-3])>>2
   358  			offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
   359  
   360  		case tagCopy4:
   361  			s += 5
   362  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   363  				if debugErrs {
   364  					fmt.Println("src went oob")
   365  				}
   366  				return decodeErrCodeCorrupt
   367  			}
   368  			length = 1 + int(src[s-5])>>2
   369  			offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
   370  		}
   371  
   372  		if offset <= 0 || length > len(dst)-d {
   373  			if debugErrs {
   374  				fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
   375  			}
   376  			return decodeErrCodeCorrupt
   377  		}
   378  
   379  		// copy from dict
   380  		if d < offset {
   381  			if d > MaxDictSrcOffset {
   382  				if debugErrs {
   383  					fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
   384  				}
   385  				return decodeErrCodeCorrupt
   386  			}
   387  			rOff := len(dict.dict) - (offset - d)
   388  			if debug {
   389  				fmt.Println("starting dict entry from dict offset", len(dict.dict)-rOff)
   390  			}
   391  			if rOff+length > len(dict.dict) {
   392  				if debugErrs {
   393  					fmt.Println("err: END offset", rOff+length, "bigger than dict", len(dict.dict), "dict offset:", rOff, "length:", length)
   394  				}
   395  				return decodeErrCodeCorrupt
   396  			}
   397  			if rOff < 0 {
   398  				if debugErrs {
   399  					fmt.Println("err: START offset", rOff, "less than 0", len(dict.dict), "dict offset:", rOff, "length:", length)
   400  				}
   401  				return decodeErrCodeCorrupt
   402  			}
   403  			copy(dst[d:d+length], dict.dict[rOff:])
   404  			d += length
   405  			continue
   406  		}
   407  
   408  		if debug {
   409  			fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
   410  		}
   411  
   412  		// Copy from an earlier sub-slice of dst to a later sub-slice.
   413  		// If no overlap, use the built-in copy:
   414  		if offset > length {
   415  			copy(dst[d:d+length], dst[d-offset:])
   416  			d += length
   417  			continue
   418  		}
   419  
   420  		// Unlike the built-in copy function, this byte-by-byte copy always runs
   421  		// forwards, even if the slices overlap. Conceptually, this is:
   422  		//
   423  		// d += forwardCopy(dst[d:d+length], dst[d-offset:])
   424  		//
   425  		// We align the slices into a and b and show the compiler they are the same size.
   426  		// This allows the loop to run without bounds checks.
   427  		a := dst[d : d+length]
   428  		b := dst[d-offset:]
   429  		b = b[:len(a)]
   430  		for i := range a {
   431  			a[i] = b[i]
   432  		}
   433  		d += length
   434  	}
   435  
   436  	if d != len(dst) {
   437  		if debugErrs {
   438  			fmt.Println("wanted length", len(dst), "got", d)
   439  		}
   440  		return decodeErrCodeCorrupt
   441  	}
   442  	return 0
   443  }
   444  

View as plain text