1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package encoding
18
19 import (
20 "fmt"
21 "math/bits"
22 "reflect"
23
24 "github.com/apache/arrow/go/v15/arrow"
25 "github.com/apache/arrow/go/v15/arrow/bitutil"
26 "github.com/apache/arrow/go/v15/arrow/memory"
27 "github.com/apache/arrow/go/v15/internal/bitutils"
28 "github.com/apache/arrow/go/v15/parquet"
29 format "github.com/apache/arrow/go/v15/parquet/internal/gen-go/parquet"
30 "github.com/apache/arrow/go/v15/parquet/internal/utils"
31 "github.com/apache/arrow/go/v15/parquet/schema"
32 )
33
34
35
36
37
38 type EncoderTraits interface {
39 Encoder(format.Encoding, bool, *schema.Column, memory.Allocator) TypedEncoder
40 }
41
42
43
44
45
46 func NewEncoder(t parquet.Type, e parquet.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder {
47 traits := getEncodingTraits(t)
48 if traits == nil {
49 return nil
50 }
51
52 if mem == nil {
53 mem = memory.DefaultAllocator
54 }
55 return traits.Encoder(format.Encoding(e), useDict, descr, mem)
56 }
57
58 type encoder struct {
59 descr *schema.Column
60 encoding format.Encoding
61 typeLen int
62 mem memory.Allocator
63
64 sink *PooledBufferWriter
65 }
66
67
68
69 func newEncoderBase(e format.Encoding, descr *schema.Column, mem memory.Allocator) encoder {
70 typelen := -1
71 if descr != nil && descr.PhysicalType() == parquet.Types.FixedLenByteArray {
72 typelen = int(descr.TypeLength())
73 }
74 return encoder{
75 descr: descr,
76 encoding: e,
77 mem: mem,
78 typeLen: typelen,
79 sink: NewPooledBufferWriter(1024),
80 }
81 }
82
83 func (e *encoder) Release() {
84 poolbuf := e.sink.buf
85 memory.Set(poolbuf.Buf(), 0)
86 poolbuf.ResizeNoShrink(0)
87 bufferPool.Put(poolbuf)
88 e.sink = nil
89 }
90
91
92 func (e *encoder) ReserveForWrite(n int) { e.sink.Reserve(n) }
93 func (e *encoder) EstimatedDataEncodedSize() int64 { return int64(e.sink.Len()) }
94 func (e *encoder) Encoding() parquet.Encoding { return parquet.Encoding(e.encoding) }
95 func (e *encoder) Allocator() memory.Allocator { return e.mem }
96 func (e *encoder) append(data []byte) { e.sink.Write(data) }
97
98
99
100
101 func (e *encoder) FlushValues() (Buffer, error) { return e.sink.Finish(), nil }
102
103
104 func (e *encoder) Bytes() []byte { return e.sink.Bytes() }
105
106
107 func (e *encoder) Reset() { e.sink.Reset(0) }
108
109 type dictEncoder struct {
110 encoder
111
112 dictEncodedSize int
113 idxBuffer *memory.Buffer
114 idxValues []int32
115 memo MemoTable
116
117 preservedDict arrow.Array
118 }
119
120
121
122 func newDictEncoderBase(descr *schema.Column, memo MemoTable, mem memory.Allocator) dictEncoder {
123 return dictEncoder{
124 encoder: newEncoderBase(format.Encoding_PLAIN_DICTIONARY, descr, mem),
125 idxBuffer: memory.NewResizableBuffer(mem),
126 memo: memo,
127 }
128 }
129
130
131
132 func (d *dictEncoder) Reset() {
133 d.encoder.Reset()
134 d.dictEncodedSize = 0
135 d.idxValues = d.idxValues[:0]
136 d.idxBuffer.ResizeNoShrink(0)
137 d.memo.Reset()
138 if d.preservedDict != nil {
139 d.preservedDict.Release()
140 d.preservedDict = nil
141 }
142 }
143
144 func (d *dictEncoder) Release() {
145 d.encoder.Release()
146 d.idxBuffer.Release()
147 if m, ok := d.memo.(BinaryMemoTable); ok {
148 m.Release()
149 } else {
150 d.memo.Reset()
151 }
152 if d.preservedDict != nil {
153 d.preservedDict.Release()
154 d.preservedDict = nil
155 }
156 }
157
158 func (d *dictEncoder) expandBuffer(newCap int) {
159 if cap(d.idxValues) >= newCap {
160 return
161 }
162
163 curLen := len(d.idxValues)
164 d.idxBuffer.ResizeNoShrink(arrow.Int32Traits.BytesRequired(bitutil.NextPowerOf2(newCap)))
165 d.idxValues = arrow.Int32Traits.CastFromBytes(d.idxBuffer.Buf())[: curLen : d.idxBuffer.Len()/arrow.Int32SizeBytes]
166 }
167
168 func (d *dictEncoder) PutIndices(data arrow.Array) error {
169 newValues := data.Len() - data.NullN()
170 curPos := len(d.idxValues)
171 newLen := newValues + curPos
172 d.expandBuffer(newLen)
173 d.idxValues = d.idxValues[:newLen:cap(d.idxValues)]
174
175 switch data.DataType().ID() {
176 case arrow.UINT8, arrow.INT8:
177 values := arrow.Uint8Traits.CastFromBytes(data.Data().Buffers()[1].Bytes())[data.Data().Offset():]
178 bitutils.VisitSetBitRunsNoErr(data.NullBitmapBytes(),
179 int64(data.Data().Offset()), int64(data.Len()),
180 func(pos, length int64) {
181 for i := int64(0); i < length; i++ {
182 d.idxValues[curPos] = int32(values[i+pos])
183 curPos++
184 }
185 })
186 case arrow.UINT16, arrow.INT16:
187 values := arrow.Uint16Traits.CastFromBytes(data.Data().Buffers()[1].Bytes())[data.Data().Offset():]
188 bitutils.VisitSetBitRunsNoErr(data.NullBitmapBytes(),
189 int64(data.Data().Offset()), int64(data.Len()),
190 func(pos, length int64) {
191 for i := int64(0); i < length; i++ {
192 d.idxValues[curPos] = int32(values[i+pos])
193 curPos++
194 }
195 })
196 case arrow.UINT32, arrow.INT32:
197 values := arrow.Uint32Traits.CastFromBytes(data.Data().Buffers()[1].Bytes())[data.Data().Offset():]
198 bitutils.VisitSetBitRunsNoErr(data.NullBitmapBytes(),
199 int64(data.Data().Offset()), int64(data.Len()),
200 func(pos, length int64) {
201 for i := int64(0); i < length; i++ {
202 d.idxValues[curPos] = int32(values[i+pos])
203 curPos++
204 }
205 })
206 case arrow.UINT64, arrow.INT64:
207 values := arrow.Uint64Traits.CastFromBytes(data.Data().Buffers()[1].Bytes())[data.Data().Offset():]
208 bitutils.VisitSetBitRunsNoErr(data.NullBitmapBytes(),
209 int64(data.Data().Offset()), int64(data.Len()),
210 func(pos, length int64) {
211 for i := int64(0); i < length; i++ {
212 d.idxValues[curPos] = int32(values[i+pos])
213 curPos++
214 }
215 })
216 default:
217 return fmt.Errorf("%w: passed non-integer array to PutIndices", arrow.ErrInvalid)
218 }
219
220 return nil
221 }
222
223
224 func (d *dictEncoder) addIndex(idx int) {
225 curLen := len(d.idxValues)
226 d.expandBuffer(curLen + 1)
227 d.idxValues = append(d.idxValues, int32(idx))
228 }
229
230
231
232 func (d *dictEncoder) FlushValues() (Buffer, error) {
233 buf := bufferPool.Get().(*memory.Buffer)
234 buf.Reserve(int(d.EstimatedDataEncodedSize()))
235 size, err := d.WriteIndices(buf.Buf())
236 if err != nil {
237 poolBuffer{buf}.Release()
238 return nil, err
239 }
240 buf.ResizeNoShrink(size)
241 return poolBuffer{buf}, nil
242 }
243
244
245
246 func (d *dictEncoder) EstimatedDataEncodedSize() int64 {
247 return 1 + int64(utils.MaxRLEBufferSize(d.BitWidth(), len(d.idxValues))+utils.MinRLEBufferSize(d.BitWidth()))
248 }
249
250
251 func (d *dictEncoder) NumEntries() int {
252 return d.memo.Size()
253 }
254
255
256
257 func (d *dictEncoder) BitWidth() int {
258 switch d.NumEntries() {
259 case 0:
260 return 0
261 case 1:
262 return 1
263 default:
264 return bits.Len32(uint32(d.NumEntries() - 1))
265 }
266 }
267
268
269 func (d *dictEncoder) WriteDict(out []byte) {
270 d.memo.WriteOut(out)
271 }
272
273
274
275
276 func (d *dictEncoder) WriteIndices(out []byte) (int, error) {
277 out[0] = byte(d.BitWidth())
278
279 enc := utils.NewRleEncoder(utils.NewWriterAtBuffer(out[1:]), d.BitWidth())
280 for _, idx := range d.idxValues {
281 if err := enc.Put(uint64(idx)); err != nil {
282 return -1, err
283 }
284 }
285 nbytes := enc.Flush()
286
287 d.idxValues = d.idxValues[:0]
288 return nbytes + 1, nil
289 }
290
291
292
293 func (d *dictEncoder) Put(v interface{}) {
294 memoIdx, found, err := d.memo.GetOrInsert(v)
295 if err != nil {
296 panic(err)
297 }
298 if !found {
299 d.dictEncodedSize += int(reflect.TypeOf(v).Size())
300 }
301 d.addIndex(memoIdx)
302 }
303
304
305 func (d *dictEncoder) DictEncodedSize() int {
306 return d.dictEncodedSize
307 }
308
309 func (d *dictEncoder) canPutDictionary(values arrow.Array) error {
310 switch {
311 case values.NullN() > 0:
312 return fmt.Errorf("%w: inserted dictionary cannot contain nulls",
313 arrow.ErrInvalid)
314 case d.NumEntries() > 0:
315 return fmt.Errorf("%w: can only call PutDictionary on an empty DictEncoder",
316 arrow.ErrInvalid)
317 }
318
319 return nil
320 }
321
322 func (d *dictEncoder) PreservedDictionary() arrow.Array { return d.preservedDict }
323
324
325
326 func spacedCompress(src, out interface{}, validBits []byte, validBitsOffset int64) int {
327 nvalid := 0
328
329
330
331 switch s := src.(type) {
332 case []int32:
333 o := out.([]int32)
334 reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
335 for {
336 run := reader.NextRun()
337 if run.Length == 0 {
338 break
339 }
340 copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
341 nvalid += int(run.Length)
342 }
343 case []int64:
344 o := out.([]int64)
345 reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
346 for {
347 run := reader.NextRun()
348 if run.Length == 0 {
349 break
350 }
351 copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
352 nvalid += int(run.Length)
353 }
354 case []float32:
355 o := out.([]float32)
356 reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
357 for {
358 run := reader.NextRun()
359 if run.Length == 0 {
360 break
361 }
362 copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
363 nvalid += int(run.Length)
364 }
365 case []float64:
366 o := out.([]float64)
367 reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
368 for {
369 run := reader.NextRun()
370 if run.Length == 0 {
371 break
372 }
373 copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
374 nvalid += int(run.Length)
375 }
376 case []parquet.ByteArray:
377 o := out.([]parquet.ByteArray)
378 reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
379 for {
380 run := reader.NextRun()
381 if run.Length == 0 {
382 break
383 }
384 copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
385 nvalid += int(run.Length)
386 }
387 case []parquet.FixedLenByteArray:
388 o := out.([]parquet.FixedLenByteArray)
389 reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
390 for {
391 run := reader.NextRun()
392 if run.Length == 0 {
393 break
394 }
395 copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
396 nvalid += int(run.Length)
397 }
398 case []bool:
399 o := out.([]bool)
400 reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
401 for {
402 run := reader.NextRun()
403 if run.Length == 0 {
404 break
405 }
406 copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
407 nvalid += int(run.Length)
408 }
409 }
410
411 return nvalid
412 }
413
View as plain text