1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package utils
18
19 import (
20 "encoding/binary"
21 "errors"
22 "io"
23 "math"
24 "reflect"
25 "unsafe"
26
27 "github.com/apache/arrow/go/v15/arrow"
28 "github.com/apache/arrow/go/v15/arrow/bitutil"
29 "github.com/apache/arrow/go/v15/arrow/memory"
30 "github.com/apache/arrow/go/v15/internal/utils"
31 )
32
33
34 var trailingMask [64]uint64
35
36 func init() {
37
38 for i := 0; i < 64; i++ {
39 trailingMask[i] = (math.MaxUint64 >> (64 - i))
40 }
41 }
42
43
44
45 func trailingBits(v uint64, bits uint) uint64 {
46 if bits >= 64 {
47 return v
48 }
49 return v & trailingMask[bits]
50 }
51
52
53 type reader interface {
54 io.Reader
55 io.ReaderAt
56 io.Seeker
57 }
58
59
60 const buflen = 1024
61
62
63
64
65
66
67
68 type BitReader struct {
69 reader reader
70 buffer uint64
71 byteoffset int64
72 bitoffset uint
73 raw [8]byte
74
75 unpackBuf [buflen]uint32
76 }
77
78
79
80 func NewBitReader(r reader) *BitReader {
81 return &BitReader{reader: r}
82 }
83
84
85 func (b *BitReader) CurOffset() int64 {
86 return b.byteoffset + bitutil.BytesForBits(int64(b.bitoffset))
87 }
88
89
90
91 func (b *BitReader) Reset(r reader) {
92 b.reader = r
93 b.buffer = 0
94 b.byteoffset = 0
95 b.bitoffset = 0
96 }
97
98
99
100
101
102 func (b *BitReader) GetVlqInt() (uint64, bool) {
103 tmp, err := binary.ReadUvarint(b)
104 if err != nil {
105 return 0, false
106 }
107 return tmp, true
108 }
109
110
111
112 func (b *BitReader) GetZigZagVlqInt() (int64, bool) {
113 u, ok := b.GetVlqInt()
114 if !ok {
115 return 0, false
116 }
117
118 return int64(u>>1) ^ -int64(u&1), true
119 }
120
121
122
123 func (b *BitReader) ReadByte() (byte, error) {
124 var tmp byte
125 if ok := b.GetAligned(1, &tmp); !ok {
126 return 0, errors.New("failed to read byte")
127 }
128
129 return tmp, nil
130 }
131
132
133
134
135
136
137
138 func (b *BitReader) GetAligned(nbytes int, v interface{}) bool {
139
140 typBytes := int(reflect.TypeOf(v).Elem().Size())
141 if nbytes > typBytes {
142 return false
143 }
144
145 bread := bitutil.BytesForBits(int64(b.bitoffset))
146
147 b.byteoffset += bread
148 n, err := b.reader.ReadAt(b.raw[:nbytes], b.byteoffset)
149 if err != nil && err != io.EOF {
150 return false
151 }
152 if n != nbytes {
153 return false
154 }
155
156 memory.Set(b.raw[n:typBytes], 0)
157
158 switch v := v.(type) {
159 case *byte:
160 *v = b.raw[0]
161 case *uint64:
162 *v = binary.LittleEndian.Uint64(b.raw[:typBytes])
163 case *uint32:
164 *v = binary.LittleEndian.Uint32(b.raw[:typBytes])
165 case *uint16:
166 *v = binary.LittleEndian.Uint16(b.raw[:typBytes])
167 default:
168 return false
169 }
170
171 b.byteoffset += int64(nbytes)
172
173 b.bitoffset = 0
174 b.fillbuffer()
175 return true
176 }
177
178
179 func (b *BitReader) fillbuffer() error {
180 n, err := b.reader.ReadAt(b.raw[:], b.byteoffset)
181 if err != nil && n == 0 && err != io.EOF {
182 return err
183 }
184 for i := n; i < 8; i++ {
185 b.raw[i] = 0
186 }
187 b.buffer = binary.LittleEndian.Uint64(b.raw[:])
188 return nil
189 }
190
191
192 func (b *BitReader) next(bits uint) (v uint64, err error) {
193 v = trailingBits(b.buffer, b.bitoffset+bits) >> b.bitoffset
194 b.bitoffset += bits
195
196 if b.bitoffset >= 64 {
197 b.byteoffset += 8
198 b.bitoffset -= 64
199 if err = b.fillbuffer(); err != nil {
200 return 0, err
201 }
202 v |= trailingBits(b.buffer, b.bitoffset) << (bits - b.bitoffset)
203 }
204 return
205 }
206
207
208 func (b *BitReader) GetBatchIndex(bits uint, out []IndexType) (i int, err error) {
209
210
211 if bits > 32 {
212 return 0, errors.New("must be 32 bits or less per read")
213 }
214
215 var val uint64
216
217 length := len(out)
218
219 for ; i < length && b.bitoffset != 0; i++ {
220 val, err = b.next(bits)
221 out[i] = IndexType(val)
222 if err != nil {
223 return
224 }
225 }
226
227 b.reader.Seek(b.byteoffset, io.SeekStart)
228
229 if i < length {
230 numUnpacked := unpack32(b.reader, (*(*[]uint32)(unsafe.Pointer(&out)))[i:], int(bits))
231 i += numUnpacked
232 b.byteoffset += int64(numUnpacked * int(bits) / 8)
233 }
234
235
236 b.fillbuffer()
237
238 for ; i < length; i++ {
239 val, err = b.next(bits)
240 out[i] = IndexType(val)
241 if err != nil {
242 break
243 }
244 }
245 return
246 }
247
248
249 func (b *BitReader) GetBatchBools(out []bool) (int, error) {
250 bits := uint(1)
251 length := len(out)
252
253 i := 0
254
255 for ; i < length && b.bitoffset != 0; i++ {
256 val, err := b.next(bits)
257 out[i] = val != 0
258 if err != nil {
259 return i, err
260 }
261 }
262
263 b.reader.Seek(b.byteoffset, io.SeekStart)
264 buf := arrow.Uint32Traits.CastToBytes(b.unpackBuf[:])
265 blen := buflen * 8
266 for i < length {
267
268
269 unpackSize := utils.Min(blen, length-i) / 8 * 8
270 n, err := b.reader.Read(buf[:bitutil.BytesForBits(int64(unpackSize))])
271 if err != nil {
272 return i, err
273 }
274 BytesToBools(buf[:n], out[i:])
275 i += unpackSize
276 b.byteoffset += int64(n)
277 }
278
279 b.fillbuffer()
280
281 for ; i < length; i++ {
282 val, err := b.next(bits)
283 out[i] = val != 0
284 if err != nil {
285 return i, err
286 }
287 }
288
289 return i, nil
290 }
291
292
293
294
295 func (b *BitReader) GetBatch(bits uint, out []uint64) (int, error) {
296
297
298 if bits > 64 {
299 return 0, errors.New("must be 64 bits or less per read")
300 }
301
302 length := len(out)
303
304 i := 0
305
306 for ; i < length && b.bitoffset != 0; i++ {
307 val, err := b.next(bits)
308 out[i] = val
309 if err != nil {
310 return i, err
311 }
312 }
313
314 b.reader.Seek(b.byteoffset, io.SeekStart)
315 for i < length {
316
317 unpackSize := utils.Min(buflen, length-i)
318 numUnpacked := unpack32(b.reader, b.unpackBuf[:unpackSize], int(bits))
319 if numUnpacked == 0 {
320 break
321 }
322
323 for k := 0; k < numUnpacked; k++ {
324 out[i+k] = uint64(b.unpackBuf[k])
325 }
326 i += numUnpacked
327 b.byteoffset += int64(numUnpacked * int(bits) / 8)
328 }
329
330 b.fillbuffer()
331
332 for ; i < length; i++ {
333 val, err := b.next(bits)
334 out[i] = val
335 if err != nil {
336 return i, err
337 }
338 }
339
340 return i, nil
341 }
342
343
344
345 func (b *BitReader) GetValue(width int) (uint64, bool) {
346 v := make([]uint64, 1)
347 n, _ := b.GetBatch(uint(width), v)
348 return v[0], n == 1
349 }
350
View as plain text