...
1
2
3
4
5 package zstd
6
7 import (
8 "encoding/binary"
9 "errors"
10 "fmt"
11 "io"
12 "math/bits"
13 )
14
15
16
17
18 type bitReader struct {
19 in []byte
20 value uint64
21 bitsRead uint8
22 }
23
24
25 func (b *bitReader) init(in []byte) error {
26 if len(in) < 1 {
27 return errors.New("corrupt stream: too short")
28 }
29 b.in = in
30
31 v := in[len(in)-1]
32 if v == 0 {
33 return errors.New("corrupt stream, did not find end of stream")
34 }
35 b.bitsRead = 64
36 b.value = 0
37 if len(in) >= 8 {
38 b.fillFastStart()
39 } else {
40 b.fill()
41 b.fill()
42 }
43 b.bitsRead += 8 - uint8(highBits(uint32(v)))
44 return nil
45 }
46
47
48 func (b *bitReader) getBits(n uint8) int {
49 if n == 0 {
50 return 0
51 }
52 return int(b.get32BitsFast(n))
53 }
54
55
56
57 func (b *bitReader) get32BitsFast(n uint8) uint32 {
58 const regMask = 64 - 1
59 v := uint32((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask))
60 b.bitsRead += n
61 return v
62 }
63
64
65
66 func (b *bitReader) fillFast() {
67 if b.bitsRead < 32 {
68 return
69 }
70 v := b.in[len(b.in)-4:]
71 b.in = b.in[:len(b.in)-4]
72 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
73 b.value = (b.value << 32) | uint64(low)
74 b.bitsRead -= 32
75 }
76
77
78 func (b *bitReader) fillFastStart() {
79 v := b.in[len(b.in)-8:]
80 b.in = b.in[:len(b.in)-8]
81 b.value = binary.LittleEndian.Uint64(v)
82 b.bitsRead = 0
83 }
84
85
86 func (b *bitReader) fill() {
87 if b.bitsRead < 32 {
88 return
89 }
90 if len(b.in) >= 4 {
91 v := b.in[len(b.in)-4:]
92 b.in = b.in[:len(b.in)-4]
93 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
94 b.value = (b.value << 32) | uint64(low)
95 b.bitsRead -= 32
96 return
97 }
98
99 b.bitsRead -= uint8(8 * len(b.in))
100 for len(b.in) > 0 {
101 b.value = (b.value << 8) | uint64(b.in[len(b.in)-1])
102 b.in = b.in[:len(b.in)-1]
103 }
104 }
105
106
107 func (b *bitReader) finished() bool {
108 return len(b.in) == 0 && b.bitsRead >= 64
109 }
110
111
112 func (b *bitReader) overread() bool {
113 return b.bitsRead > 64
114 }
115
116
117 func (b *bitReader) remain() uint {
118 return 8*uint(len(b.in)) + 64 - uint(b.bitsRead)
119 }
120
121
122 func (b *bitReader) close() error {
123
124 b.in = nil
125 if !b.finished() {
126 return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
127 }
128 if b.bitsRead > 64 {
129 return io.ErrUnexpectedEOF
130 }
131 return nil
132 }
133
134 func highBits(val uint32) (n uint32) {
135 return uint32(bits.Len32(val) - 1)
136 }
137
View as plain text