...

Source file src/github.com/klauspost/compress/s2/decode_other.go

Documentation: github.com/klauspost/compress/s2

     1  // Copyright 2016 The Snappy-Go Authors. All rights reserved.
     2  // Copyright (c) 2019 Klaus Post. All rights reserved.
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  //go:build (!amd64 && !arm64) || appengine || !gc || noasm
     7  // +build !amd64,!arm64 appengine !gc noasm
     8  
     9  package s2
    10  
    11  import (
    12  	"fmt"
    13  	"strconv"
    14  )
    15  
    16  // decode writes the decoding of src to dst. It assumes that the varint-encoded
    17  // length of the decompressed bytes has already been read, and that len(dst)
    18  // equals that length.
    19  //
    20  // It returns 0 on success or a decodeErrCodeXxx error code on failure.
    21  func s2Decode(dst, src []byte) int {
    22  	const debug = false
    23  	if debug {
    24  		fmt.Println("Starting decode, dst len:", len(dst))
    25  	}
    26  	var d, s, length int
    27  	offset := 0
    28  
    29  	// As long as we can read at least 5 bytes...
    30  	for s < len(src)-5 {
    31  		// Removing bounds checks is SLOWER, when if doing
    32  		// in := src[s:s+5]
    33  		// Checked on Go 1.18
    34  		switch src[s] & 0x03 {
    35  		case tagLiteral:
    36  			x := uint32(src[s] >> 2)
    37  			switch {
    38  			case x < 60:
    39  				s++
    40  			case x == 60:
    41  				s += 2
    42  				x = uint32(src[s-1])
    43  			case x == 61:
    44  				in := src[s : s+3]
    45  				x = uint32(in[1]) | uint32(in[2])<<8
    46  				s += 3
    47  			case x == 62:
    48  				in := src[s : s+4]
    49  				// Load as 32 bit and shift down.
    50  				x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
    51  				x >>= 8
    52  				s += 4
    53  			case x == 63:
    54  				in := src[s : s+5]
    55  				x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24
    56  				s += 5
    57  			}
    58  			length = int(x) + 1
    59  			if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
    60  				if debug {
    61  					fmt.Println("corrupt: lit size", length)
    62  				}
    63  				return decodeErrCodeCorrupt
    64  			}
    65  			if debug {
    66  				fmt.Println("literals, length:", length, "d-after:", d+length)
    67  			}
    68  
    69  			copy(dst[d:], src[s:s+length])
    70  			d += length
    71  			s += length
    72  			continue
    73  
    74  		case tagCopy1:
    75  			s += 2
    76  			toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
    77  			length = int(src[s-2]) >> 2 & 0x7
    78  			if toffset == 0 {
    79  				if debug {
    80  					fmt.Print("(repeat) ")
    81  				}
    82  				// keep last offset
    83  				switch length {
    84  				case 5:
    85  					length = int(src[s]) + 4
    86  					s += 1
    87  				case 6:
    88  					in := src[s : s+2]
    89  					length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8)
    90  					s += 2
    91  				case 7:
    92  					in := src[s : s+3]
    93  					length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16)
    94  					s += 3
    95  				default: // 0-> 4
    96  				}
    97  			} else {
    98  				offset = toffset
    99  			}
   100  			length += 4
   101  		case tagCopy2:
   102  			in := src[s : s+3]
   103  			offset = int(uint32(in[1]) | uint32(in[2])<<8)
   104  			length = 1 + int(in[0])>>2
   105  			s += 3
   106  
   107  		case tagCopy4:
   108  			in := src[s : s+5]
   109  			offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24)
   110  			length = 1 + int(in[0])>>2
   111  			s += 5
   112  		}
   113  
   114  		if offset <= 0 || d < offset || length > len(dst)-d {
   115  			if debug {
   116  				fmt.Println("corrupt: match, length", length, "offset:", offset, "dst avail:", len(dst)-d, "dst pos:", d)
   117  			}
   118  
   119  			return decodeErrCodeCorrupt
   120  		}
   121  
   122  		if debug {
   123  			fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
   124  		}
   125  
   126  		// Copy from an earlier sub-slice of dst to a later sub-slice.
   127  		// If no overlap, use the built-in copy:
   128  		if offset > length {
   129  			copy(dst[d:d+length], dst[d-offset:])
   130  			d += length
   131  			continue
   132  		}
   133  
   134  		// Unlike the built-in copy function, this byte-by-byte copy always runs
   135  		// forwards, even if the slices overlap. Conceptually, this is:
   136  		//
   137  		// d += forwardCopy(dst[d:d+length], dst[d-offset:])
   138  		//
   139  		// We align the slices into a and b and show the compiler they are the same size.
   140  		// This allows the loop to run without bounds checks.
   141  		a := dst[d : d+length]
   142  		b := dst[d-offset:]
   143  		b = b[:len(a)]
   144  		for i := range a {
   145  			a[i] = b[i]
   146  		}
   147  		d += length
   148  	}
   149  
   150  	// Remaining with extra checks...
   151  	for s < len(src) {
   152  		switch src[s] & 0x03 {
   153  		case tagLiteral:
   154  			x := uint32(src[s] >> 2)
   155  			switch {
   156  			case x < 60:
   157  				s++
   158  			case x == 60:
   159  				s += 2
   160  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   161  					return decodeErrCodeCorrupt
   162  				}
   163  				x = uint32(src[s-1])
   164  			case x == 61:
   165  				s += 3
   166  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   167  					return decodeErrCodeCorrupt
   168  				}
   169  				x = uint32(src[s-2]) | uint32(src[s-1])<<8
   170  			case x == 62:
   171  				s += 4
   172  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   173  					return decodeErrCodeCorrupt
   174  				}
   175  				x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
   176  			case x == 63:
   177  				s += 5
   178  				if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   179  					return decodeErrCodeCorrupt
   180  				}
   181  				x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
   182  			}
   183  			length = int(x) + 1
   184  			if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
   185  				if debug {
   186  					fmt.Println("corrupt: lit size", length)
   187  				}
   188  				return decodeErrCodeCorrupt
   189  			}
   190  			if debug {
   191  				fmt.Println("literals, length:", length, "d-after:", d+length)
   192  			}
   193  
   194  			copy(dst[d:], src[s:s+length])
   195  			d += length
   196  			s += length
   197  			continue
   198  
   199  		case tagCopy1:
   200  			s += 2
   201  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   202  				return decodeErrCodeCorrupt
   203  			}
   204  			length = int(src[s-2]) >> 2 & 0x7
   205  			toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
   206  			if toffset == 0 {
   207  				if debug {
   208  					fmt.Print("(repeat) ")
   209  				}
   210  				// keep last offset
   211  				switch length {
   212  				case 5:
   213  					s += 1
   214  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   215  						return decodeErrCodeCorrupt
   216  					}
   217  					length = int(uint32(src[s-1])) + 4
   218  				case 6:
   219  					s += 2
   220  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   221  						return decodeErrCodeCorrupt
   222  					}
   223  					length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8)
   224  				case 7:
   225  					s += 3
   226  					if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   227  						return decodeErrCodeCorrupt
   228  					}
   229  					length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16)
   230  				default: // 0-> 4
   231  				}
   232  			} else {
   233  				offset = toffset
   234  			}
   235  			length += 4
   236  		case tagCopy2:
   237  			s += 3
   238  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   239  				return decodeErrCodeCorrupt
   240  			}
   241  			length = 1 + int(src[s-3])>>2
   242  			offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
   243  
   244  		case tagCopy4:
   245  			s += 5
   246  			if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
   247  				return decodeErrCodeCorrupt
   248  			}
   249  			length = 1 + int(src[s-5])>>2
   250  			offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
   251  		}
   252  
   253  		if offset <= 0 || d < offset || length > len(dst)-d {
   254  			if debug {
   255  				fmt.Println("corrupt: match, length", length, "offset:", offset, "dst avail:", len(dst)-d, "dst pos:", d)
   256  			}
   257  			return decodeErrCodeCorrupt
   258  		}
   259  
   260  		if debug {
   261  			fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
   262  		}
   263  
   264  		// Copy from an earlier sub-slice of dst to a later sub-slice.
   265  		// If no overlap, use the built-in copy:
   266  		if offset > length {
   267  			copy(dst[d:d+length], dst[d-offset:])
   268  			d += length
   269  			continue
   270  		}
   271  
   272  		// Unlike the built-in copy function, this byte-by-byte copy always runs
   273  		// forwards, even if the slices overlap. Conceptually, this is:
   274  		//
   275  		// d += forwardCopy(dst[d:d+length], dst[d-offset:])
   276  		//
   277  		// We align the slices into a and b and show the compiler they are the same size.
   278  		// This allows the loop to run without bounds checks.
   279  		a := dst[d : d+length]
   280  		b := dst[d-offset:]
   281  		b = b[:len(a)]
   282  		for i := range a {
   283  			a[i] = b[i]
   284  		}
   285  		d += length
   286  	}
   287  
   288  	if d != len(dst) {
   289  		return decodeErrCodeCorrupt
   290  	}
   291  	return 0
   292  }
   293  

View as plain text