1
2
3
4
5
6 package s2
7
8 import (
9 "encoding/binary"
10 "errors"
11 "fmt"
12 "strconv"
13
14 "github.com/klauspost/compress/internal/race"
15 )
16
17 var (
18
19 ErrCorrupt = errors.New("s2: corrupt input")
20
21 ErrCRC = errors.New("s2: corrupt input, crc mismatch")
22
23 ErrTooLarge = errors.New("s2: decoded block is too large")
24
25 ErrUnsupported = errors.New("s2: unsupported input")
26 )
27
28
29 func DecodedLen(src []byte) (int, error) {
30 v, _, err := decodedLen(src)
31 return v, err
32 }
33
34
35
36 func decodedLen(src []byte) (blockLen, headerLen int, err error) {
37 v, n := binary.Uvarint(src)
38 if n <= 0 || v > 0xffffffff {
39 return 0, 0, ErrCorrupt
40 }
41
42 const wordSize = 32 << (^uint(0) >> 32 & 1)
43 if wordSize == 32 && v > 0x7fffffff {
44 return 0, 0, ErrTooLarge
45 }
46 return int(v), n, nil
47 }
48
49 const (
50 decodeErrCodeCorrupt = 1
51 )
52
53
54
55
56
57
58 func Decode(dst, src []byte) ([]byte, error) {
59 dLen, s, err := decodedLen(src)
60 if err != nil {
61 return nil, err
62 }
63 if dLen <= cap(dst) {
64 dst = dst[:dLen]
65 } else {
66 dst = make([]byte, dLen)
67 }
68
69 race.WriteSlice(dst)
70 race.ReadSlice(src[s:])
71
72 if s2Decode(dst, src[s:]) != 0 {
73 return nil, ErrCorrupt
74 }
75 return dst, nil
76 }
77
78
79
80
81
82
83 func s2DecodeDict(dst, src []byte, dict *Dict) int {
84 if dict == nil {
85 return s2Decode(dst, src)
86 }
87 const debug = false
88 const debugErrs = debug
89
90 if debug {
91 fmt.Println("Starting decode, dst len:", len(dst))
92 }
93 var d, s, length int
94 offset := len(dict.dict) - dict.repeat
95
96
97 for s < len(src)-5 {
98
99
100
101 switch src[s] & 0x03 {
102 case tagLiteral:
103 x := uint32(src[s] >> 2)
104 switch {
105 case x < 60:
106 s++
107 case x == 60:
108 s += 2
109 x = uint32(src[s-1])
110 case x == 61:
111 in := src[s : s+3]
112 x = uint32(in[1]) | uint32(in[2])<<8
113 s += 3
114 case x == 62:
115 in := src[s : s+4]
116
117 x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
118 x >>= 8
119 s += 4
120 case x == 63:
121 in := src[s : s+5]
122 x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24
123 s += 5
124 }
125 length = int(x) + 1
126 if debug {
127 fmt.Println("literals, length:", length, "d-after:", d+length)
128 }
129 if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
130 if debugErrs {
131 fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
132 }
133 return decodeErrCodeCorrupt
134 }
135
136 copy(dst[d:], src[s:s+length])
137 d += length
138 s += length
139 continue
140
141 case tagCopy1:
142 s += 2
143 toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
144 length = int(src[s-2]) >> 2 & 0x7
145 if toffset == 0 {
146 if debug {
147 fmt.Print("(repeat) ")
148 }
149
150 switch length {
151 case 5:
152 length = int(src[s]) + 4
153 s += 1
154 case 6:
155 in := src[s : s+2]
156 length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8)
157 s += 2
158 case 7:
159 in := src[s : s+3]
160 length = int((uint32(in[2])<<16)|(uint32(in[1])<<8)|uint32(in[0])) + (1 << 16)
161 s += 3
162 default:
163 }
164 } else {
165 offset = toffset
166 }
167 length += 4
168 case tagCopy2:
169 in := src[s : s+3]
170 offset = int(uint32(in[1]) | uint32(in[2])<<8)
171 length = 1 + int(in[0])>>2
172 s += 3
173
174 case tagCopy4:
175 in := src[s : s+5]
176 offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24)
177 length = 1 + int(in[0])>>2
178 s += 5
179 }
180
181 if offset <= 0 || length > len(dst)-d {
182 if debugErrs {
183 fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
184 }
185 return decodeErrCodeCorrupt
186 }
187
188
189 if d < offset {
190 if d > MaxDictSrcOffset {
191 if debugErrs {
192 fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
193 }
194 return decodeErrCodeCorrupt
195 }
196 startOff := len(dict.dict) - offset + d
197 if startOff < 0 || startOff+length > len(dict.dict) {
198 if debugErrs {
199 fmt.Printf("offset (%d) + length (%d) bigger than dict (%d)\n", offset, length, len(dict.dict))
200 }
201 return decodeErrCodeCorrupt
202 }
203 if debug {
204 fmt.Println("dict copy, length:", length, "offset:", offset, "d-after:", d+length, "dict start offset:", startOff)
205 }
206 copy(dst[d:d+length], dict.dict[startOff:])
207 d += length
208 continue
209 }
210
211 if debug {
212 fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
213 }
214
215
216
217 if offset > length {
218 copy(dst[d:d+length], dst[d-offset:])
219 d += length
220 continue
221 }
222
223
224
225
226
227
228
229
230 a := dst[d : d+length]
231 b := dst[d-offset:]
232 b = b[:len(a)]
233 for i := range a {
234 a[i] = b[i]
235 }
236 d += length
237 }
238
239
240 for s < len(src) {
241 switch src[s] & 0x03 {
242 case tagLiteral:
243 x := uint32(src[s] >> 2)
244 switch {
245 case x < 60:
246 s++
247 case x == 60:
248 s += 2
249 if uint(s) > uint(len(src)) {
250 if debugErrs {
251 fmt.Println("src went oob")
252 }
253 return decodeErrCodeCorrupt
254 }
255 x = uint32(src[s-1])
256 case x == 61:
257 s += 3
258 if uint(s) > uint(len(src)) {
259 if debugErrs {
260 fmt.Println("src went oob")
261 }
262 return decodeErrCodeCorrupt
263 }
264 x = uint32(src[s-2]) | uint32(src[s-1])<<8
265 case x == 62:
266 s += 4
267 if uint(s) > uint(len(src)) {
268 if debugErrs {
269 fmt.Println("src went oob")
270 }
271 return decodeErrCodeCorrupt
272 }
273 x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
274 case x == 63:
275 s += 5
276 if uint(s) > uint(len(src)) {
277 if debugErrs {
278 fmt.Println("src went oob")
279 }
280 return decodeErrCodeCorrupt
281 }
282 x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
283 }
284 length = int(x) + 1
285 if length > len(dst)-d || length > len(src)-s || (strconv.IntSize == 32 && length <= 0) {
286 if debugErrs {
287 fmt.Println("corrupt literal: length:", length, "d-left:", len(dst)-d, "src-left:", len(src)-s)
288 }
289 return decodeErrCodeCorrupt
290 }
291 if debug {
292 fmt.Println("literals, length:", length, "d-after:", d+length)
293 }
294
295 copy(dst[d:], src[s:s+length])
296 d += length
297 s += length
298 continue
299
300 case tagCopy1:
301 s += 2
302 if uint(s) > uint(len(src)) {
303 if debugErrs {
304 fmt.Println("src went oob")
305 }
306 return decodeErrCodeCorrupt
307 }
308 length = int(src[s-2]) >> 2 & 0x7
309 toffset := int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
310 if toffset == 0 {
311 if debug {
312 fmt.Print("(repeat) ")
313 }
314
315 switch length {
316 case 5:
317 s += 1
318 if uint(s) > uint(len(src)) {
319 if debugErrs {
320 fmt.Println("src went oob")
321 }
322 return decodeErrCodeCorrupt
323 }
324 length = int(uint32(src[s-1])) + 4
325 case 6:
326 s += 2
327 if uint(s) > uint(len(src)) {
328 if debugErrs {
329 fmt.Println("src went oob")
330 }
331 return decodeErrCodeCorrupt
332 }
333 length = int(uint32(src[s-2])|(uint32(src[s-1])<<8)) + (1 << 8)
334 case 7:
335 s += 3
336 if uint(s) > uint(len(src)) {
337 if debugErrs {
338 fmt.Println("src went oob")
339 }
340 return decodeErrCodeCorrupt
341 }
342 length = int(uint32(src[s-3])|(uint32(src[s-2])<<8)|(uint32(src[s-1])<<16)) + (1 << 16)
343 default:
344 }
345 } else {
346 offset = toffset
347 }
348 length += 4
349 case tagCopy2:
350 s += 3
351 if uint(s) > uint(len(src)) {
352 if debugErrs {
353 fmt.Println("src went oob")
354 }
355 return decodeErrCodeCorrupt
356 }
357 length = 1 + int(src[s-3])>>2
358 offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
359
360 case tagCopy4:
361 s += 5
362 if uint(s) > uint(len(src)) {
363 if debugErrs {
364 fmt.Println("src went oob")
365 }
366 return decodeErrCodeCorrupt
367 }
368 length = 1 + int(src[s-5])>>2
369 offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
370 }
371
372 if offset <= 0 || length > len(dst)-d {
373 if debugErrs {
374 fmt.Println("match error; offset:", offset, "length:", length, "dst-left:", len(dst)-d)
375 }
376 return decodeErrCodeCorrupt
377 }
378
379
380 if d < offset {
381 if d > MaxDictSrcOffset {
382 if debugErrs {
383 fmt.Println("dict after", MaxDictSrcOffset, "d:", d, "offset:", offset, "length:", length)
384 }
385 return decodeErrCodeCorrupt
386 }
387 rOff := len(dict.dict) - (offset - d)
388 if debug {
389 fmt.Println("starting dict entry from dict offset", len(dict.dict)-rOff)
390 }
391 if rOff+length > len(dict.dict) {
392 if debugErrs {
393 fmt.Println("err: END offset", rOff+length, "bigger than dict", len(dict.dict), "dict offset:", rOff, "length:", length)
394 }
395 return decodeErrCodeCorrupt
396 }
397 if rOff < 0 {
398 if debugErrs {
399 fmt.Println("err: START offset", rOff, "less than 0", len(dict.dict), "dict offset:", rOff, "length:", length)
400 }
401 return decodeErrCodeCorrupt
402 }
403 copy(dst[d:d+length], dict.dict[rOff:])
404 d += length
405 continue
406 }
407
408 if debug {
409 fmt.Println("copy, length:", length, "offset:", offset, "d-after:", d+length)
410 }
411
412
413
414 if offset > length {
415 copy(dst[d:d+length], dst[d-offset:])
416 d += length
417 continue
418 }
419
420
421
422
423
424
425
426
427 a := dst[d : d+length]
428 b := dst[d-offset:]
429 b = b[:len(a)]
430 for i := range a {
431 a[i] = b[i]
432 }
433 d += length
434 }
435
436 if d != len(dst) {
437 if debugErrs {
438 fmt.Println("wanted length", len(dst), "got", d)
439 }
440 return decodeErrCodeCorrupt
441 }
442 return 0
443 }
444
View as plain text