1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package pqarrow
18
19 import (
20 "encoding/base64"
21 "fmt"
22 "math"
23 "strconv"
24
25 "github.com/apache/arrow/go/v15/arrow"
26 "github.com/apache/arrow/go/v15/arrow/decimal128"
27 "github.com/apache/arrow/go/v15/arrow/flight"
28 "github.com/apache/arrow/go/v15/arrow/ipc"
29 "github.com/apache/arrow/go/v15/arrow/memory"
30 "github.com/apache/arrow/go/v15/parquet"
31 "github.com/apache/arrow/go/v15/parquet/file"
32 "github.com/apache/arrow/go/v15/parquet/metadata"
33 "github.com/apache/arrow/go/v15/parquet/schema"
34 "golang.org/x/xerrors"
35 )
36
37
38
39
40
41
42 type SchemaField struct {
43 Field *arrow.Field
44 Children []SchemaField
45 ColIndex int
46 LevelInfo file.LevelInfo
47 }
48
49
50 func (s *SchemaField) IsLeaf() bool { return s.ColIndex != -1 }
51
52
53
54 type SchemaManifest struct {
55 descr *schema.Schema
56 OriginSchema *arrow.Schema
57 SchemaMeta *arrow.Metadata
58
59 ColIndexToField map[int]*SchemaField
60 ChildToParent map[*SchemaField]*SchemaField
61 Fields []SchemaField
62 }
63
64
65 func (sm *SchemaManifest) GetColumnField(index int) (*SchemaField, error) {
66 if field, ok := sm.ColIndexToField[index]; ok {
67 return field, nil
68 }
69 return nil, fmt.Errorf("Column Index %d not found in schema manifest", index)
70 }
71
72
73
74 func (sm *SchemaManifest) GetParent(field *SchemaField) *SchemaField {
75 if p, ok := sm.ChildToParent[field]; ok {
76 return p
77 }
78 return nil
79 }
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100 func (sm *SchemaManifest) GetFieldIndices(indices []int) ([]int, error) {
101 added := make(map[int]bool)
102 ret := make([]int, 0)
103
104 for _, idx := range indices {
105 if idx < 0 || idx >= sm.descr.NumColumns() {
106 return nil, fmt.Errorf("column index %d is not valid", idx)
107 }
108
109 fieldNode := sm.descr.ColumnRoot(idx)
110 fieldIdx := sm.descr.Root().FieldIndexByField(fieldNode)
111 if fieldIdx == -1 {
112 return nil, fmt.Errorf("column index %d is not valid", idx)
113 }
114
115 if _, ok := added[fieldIdx]; !ok {
116 ret = append(ret, fieldIdx)
117 added[fieldIdx] = true
118 }
119 }
120 return ret, nil
121 }
122
123 func isDictionaryReadSupported(dt arrow.DataType) bool {
124 return arrow.IsBinaryLike(dt.ID())
125 }
126
127 func arrowTimestampToLogical(typ *arrow.TimestampType, unit arrow.TimeUnit) schema.LogicalType {
128 utc := typ.TimeZone == "" || typ.TimeZone == "UTC"
129
130
131
132
133
134
135
136 var scunit schema.TimeUnitType
137 switch unit {
138 case arrow.Millisecond:
139 scunit = schema.TimeUnitMillis
140 case arrow.Microsecond:
141 scunit = schema.TimeUnitMicros
142 case arrow.Nanosecond:
143 scunit = schema.TimeUnitNanos
144 case arrow.Second:
145
146 return schema.NoLogicalType{}
147 }
148
149 return schema.NewTimestampLogicalTypeForce(utc, scunit)
150 }
151
152 func getTimestampMeta(typ *arrow.TimestampType, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (parquet.Type, schema.LogicalType, error) {
153 coerce := arrprops.coerceTimestamps
154 target := typ.Unit
155 if coerce {
156 target = arrprops.coerceTimestampUnit
157 }
158
159
160 if arrprops.timestampAsInt96 && target == arrow.Nanosecond {
161 return parquet.Types.Int96, schema.NoLogicalType{}, nil
162 }
163
164 physical := parquet.Types.Int64
165 logicalType := arrowTimestampToLogical(typ, target)
166
167
168
169 if coerce {
170 if props.Version() == parquet.V1_0 || props.Version() == parquet.V2_4 {
171 switch target {
172 case arrow.Millisecond, arrow.Microsecond:
173 case arrow.Nanosecond, arrow.Second:
174 return physical, nil, fmt.Errorf("parquet version %s files can only coerce arrow timestamps to millis or micros", props.Version())
175 }
176 } else if target == arrow.Second {
177 return physical, nil, fmt.Errorf("parquet version %s files can only coerce arrow timestamps to millis, micros or nanos", props.Version())
178 }
179 return physical, logicalType, nil
180 }
181
182
183
184
185
186 if (props.Version() == parquet.V1_0 || props.Version() == parquet.V2_4) && typ.Unit == arrow.Nanosecond {
187 logicalType = arrowTimestampToLogical(typ, arrow.Microsecond)
188 return physical, logicalType, nil
189 }
190
191
192
193
194 if typ.Unit == arrow.Second {
195 logicalType = arrowTimestampToLogical(typ, arrow.Millisecond)
196 }
197
198 return physical, logicalType, nil
199 }
200
201
202
203
204
205
206
207 func DecimalSize(precision int32) int32 {
208 if precision < 1 {
209 panic("precision must be >= 1")
210 }
211
212
213
214
215 var byteblock = [...]int32{
216 -1, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 6, 6, 6, 7, 7, 8, 8, 9,
217 9, 9, 10, 10, 11, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 16, 16, 16, 17,
218 17, 18, 18, 18, 19, 19, 20, 20, 21, 21, 21, 22, 22, 23, 23, 23, 24, 24, 25, 25,
219 26, 26, 26, 27, 27, 28, 28, 28, 29, 29, 30, 30, 31, 31, 31, 32, 32,
220 }
221
222 if precision <= 76 {
223 return byteblock[precision]
224 }
225 return int32(math.Ceil(float64(precision)/8.0)*math.Log2(10) + 1)
226 }
227
228 func repFromNullable(isnullable bool) parquet.Repetition {
229 if isnullable {
230 return parquet.Repetitions.Optional
231 }
232 return parquet.Repetitions.Required
233 }
234
235 func structToNode(typ *arrow.StructType, name string, nullable bool, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) {
236 if typ.NumFields() == 0 {
237 return nil, fmt.Errorf("cannot write struct type '%s' with no children field to parquet. Consider adding a dummy child", name)
238 }
239
240 children := make(schema.FieldList, 0, typ.NumFields())
241 for _, f := range typ.Fields() {
242 n, err := fieldToNode(f.Name, f, props, arrprops)
243 if err != nil {
244 return nil, err
245 }
246 children = append(children, n)
247 }
248
249 return schema.NewGroupNode(name, repFromNullable(nullable), children, -1)
250 }
251
252 func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) {
253 var (
254 logicalType schema.LogicalType = schema.NoLogicalType{}
255 typ parquet.Type
256 repType = repFromNullable(field.Nullable)
257 length = -1
258 precision = -1
259 scale = -1
260 err error
261 )
262
263 switch field.Type.ID() {
264 case arrow.NULL:
265 typ = parquet.Types.Int32
266 logicalType = &schema.NullLogicalType{}
267 if repType != parquet.Repetitions.Optional {
268 return nil, xerrors.New("nulltype arrow field must be nullable")
269 }
270 case arrow.BOOL:
271 typ = parquet.Types.Boolean
272 case arrow.UINT8:
273 typ = parquet.Types.Int32
274 logicalType = schema.NewIntLogicalType(8, false)
275 case arrow.INT8:
276 typ = parquet.Types.Int32
277 logicalType = schema.NewIntLogicalType(8, true)
278 case arrow.UINT16:
279 typ = parquet.Types.Int32
280 logicalType = schema.NewIntLogicalType(16, false)
281 case arrow.INT16:
282 typ = parquet.Types.Int32
283 logicalType = schema.NewIntLogicalType(16, true)
284 case arrow.UINT32:
285 typ = parquet.Types.Int32
286 logicalType = schema.NewIntLogicalType(32, false)
287 case arrow.INT32:
288 typ = parquet.Types.Int32
289 logicalType = schema.NewIntLogicalType(32, true)
290 case arrow.UINT64:
291 typ = parquet.Types.Int64
292 logicalType = schema.NewIntLogicalType(64, false)
293 case arrow.INT64:
294 typ = parquet.Types.Int64
295 logicalType = schema.NewIntLogicalType(64, true)
296 case arrow.FLOAT32:
297 typ = parquet.Types.Float
298 case arrow.FLOAT64:
299 typ = parquet.Types.Double
300 case arrow.STRING, arrow.LARGE_STRING:
301 logicalType = schema.StringLogicalType{}
302 fallthrough
303 case arrow.BINARY, arrow.LARGE_BINARY:
304 typ = parquet.Types.ByteArray
305 case arrow.FIXED_SIZE_BINARY:
306 typ = parquet.Types.FixedLenByteArray
307 length = field.Type.(*arrow.FixedSizeBinaryType).ByteWidth
308 case arrow.DECIMAL, arrow.DECIMAL256:
309 dectype := field.Type.(arrow.DecimalType)
310 precision = int(dectype.GetPrecision())
311 scale = int(dectype.GetScale())
312
313 if props.StoreDecimalAsInteger() && 1 <= precision && precision <= 18 {
314 if precision <= 9 {
315 typ = parquet.Types.Int32
316 } else {
317 typ = parquet.Types.Int64
318 }
319 } else {
320 typ = parquet.Types.FixedLenByteArray
321 length = int(DecimalSize(int32(precision)))
322 }
323
324 logicalType = schema.NewDecimalLogicalType(int32(precision), int32(scale))
325 case arrow.DATE32:
326 typ = parquet.Types.Int32
327 logicalType = schema.DateLogicalType{}
328 case arrow.DATE64:
329 typ = parquet.Types.Int64
330 logicalType = schema.NewTimestampLogicalType(true, schema.TimeUnitMillis)
331 case arrow.TIMESTAMP:
332 typ, logicalType, err = getTimestampMeta(field.Type.(*arrow.TimestampType), props, arrprops)
333 if err != nil {
334 return nil, err
335 }
336 case arrow.TIME32:
337 typ = parquet.Types.Int32
338 logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMillis)
339 case arrow.TIME64:
340 typ = parquet.Types.Int64
341 timeType := field.Type.(*arrow.Time64Type)
342 if timeType.Unit == arrow.Nanosecond {
343 logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitNanos)
344 } else {
345 logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMicros)
346 }
347 case arrow.FLOAT16:
348 typ = parquet.Types.FixedLenByteArray
349 length = arrow.Float16SizeBytes
350 logicalType = schema.Float16LogicalType{}
351 case arrow.STRUCT:
352 return structToNode(field.Type.(*arrow.StructType), field.Name, field.Nullable, props, arrprops)
353 case arrow.FIXED_SIZE_LIST, arrow.LIST:
354 var elem arrow.DataType
355 if lt, ok := field.Type.(*arrow.ListType); ok {
356 elem = lt.Elem()
357 } else {
358 elem = field.Type.(*arrow.FixedSizeListType).Elem()
359 }
360
361 child, err := fieldToNode(name, arrow.Field{Name: name, Type: elem, Nullable: true}, props, arrprops)
362 if err != nil {
363 return nil, err
364 }
365
366 return schema.ListOf(child, repFromNullable(field.Nullable), -1)
367 case arrow.DICTIONARY:
368
369 dictType := field.Type.(*arrow.DictionaryType)
370 return fieldToNode(name, arrow.Field{Name: name, Type: dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata},
371 props, arrprops)
372 case arrow.EXTENSION:
373 return fieldToNode(name, arrow.Field{
374 Name: name,
375 Type: field.Type.(arrow.ExtensionType).StorageType(),
376 Nullable: field.Nullable,
377 Metadata: arrow.MetadataFrom(map[string]string{
378 ipc.ExtensionTypeKeyName: field.Type.(arrow.ExtensionType).ExtensionName(),
379 ipc.ExtensionMetadataKeyName: field.Type.(arrow.ExtensionType).Serialize(),
380 }),
381 }, props, arrprops)
382 case arrow.MAP:
383 mapType := field.Type.(*arrow.MapType)
384 keyNode, err := fieldToNode("key", mapType.KeyField(), props, arrprops)
385 if err != nil {
386 return nil, err
387 }
388
389 valueNode, err := fieldToNode("value", mapType.ItemField(), props, arrprops)
390 if err != nil {
391 return nil, err
392 }
393
394 if arrprops.noMapLogicalType {
395 keyval := schema.FieldList{keyNode, valueNode}
396 keyvalNode, err := schema.NewGroupNode("key_value", parquet.Repetitions.Repeated, keyval, -1)
397 if err != nil {
398 return nil, err
399 }
400 return schema.NewGroupNode(field.Name, repFromNullable(field.Nullable), schema.FieldList{
401 keyvalNode,
402 }, -1)
403 }
404 return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), -1)
405 default:
406 return nil, fmt.Errorf("%w: support for %s", arrow.ErrNotImplemented, field.Type.ID())
407 }
408
409 return schema.NewPrimitiveNodeLogical(name, repType, logicalType, typ, length, fieldIDFromMeta(field.Metadata))
410 }
411
412 const fieldIDKey = "PARQUET:field_id"
413
414 func fieldIDFromMeta(m arrow.Metadata) int32 {
415 if m.Len() == 0 {
416 return -1
417 }
418
419 key := m.FindKey(fieldIDKey)
420 if key < 0 {
421 return -1
422 }
423
424 id, err := strconv.ParseInt(m.Values()[key], 10, 32)
425 if err != nil {
426 return -1
427 }
428
429 if id < 0 {
430 return -1
431 }
432
433 return int32(id)
434 }
435
436
437
438 func ToParquet(sc *arrow.Schema, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (*schema.Schema, error) {
439 if props == nil {
440 props = parquet.NewWriterProperties()
441 }
442
443 nodes := make(schema.FieldList, 0, sc.NumFields())
444 for _, f := range sc.Fields() {
445 n, err := fieldToNode(f.Name, f, props, arrprops)
446 if err != nil {
447 return nil, err
448 }
449 nodes = append(nodes, n)
450 }
451
452 root, err := schema.NewGroupNode(props.RootName(), props.RootRepetition(), nodes, -1)
453 if err != nil {
454 return nil, err
455 }
456
457 return schema.NewSchema(root), err
458 }
459
460 type schemaTree struct {
461 manifest *SchemaManifest
462
463 schema *schema.Schema
464 props *ArrowReadProperties
465 }
466
467 func (s schemaTree) LinkParent(child, parent *SchemaField) {
468 s.manifest.ChildToParent[child] = parent
469 }
470
471 func (s schemaTree) RecordLeaf(leaf *SchemaField) {
472 s.manifest.ColIndexToField[leaf.ColIndex] = leaf
473 }
474
475 func arrowInt(log *schema.IntLogicalType) (arrow.DataType, error) {
476 switch log.BitWidth() {
477 case 8:
478 if log.IsSigned() {
479 return arrow.PrimitiveTypes.Int8, nil
480 }
481 return arrow.PrimitiveTypes.Uint8, nil
482 case 16:
483 if log.IsSigned() {
484 return arrow.PrimitiveTypes.Int16, nil
485 }
486 return arrow.PrimitiveTypes.Uint16, nil
487 case 32:
488 if log.IsSigned() {
489 return arrow.PrimitiveTypes.Int32, nil
490 }
491 return arrow.PrimitiveTypes.Uint32, nil
492 case 64:
493 if log.IsSigned() {
494 return arrow.PrimitiveTypes.Int64, nil
495 }
496 return arrow.PrimitiveTypes.Uint64, nil
497 default:
498 return nil, xerrors.New("invalid logical type for int32")
499 }
500 }
501
502 func arrowTime32(logical *schema.TimeLogicalType) (arrow.DataType, error) {
503 if logical.TimeUnit() == schema.TimeUnitMillis {
504 return arrow.FixedWidthTypes.Time32ms, nil
505 }
506
507 return nil, xerrors.New(logical.String() + " cannot annotate a time32")
508 }
509
510 func arrowTime64(logical *schema.TimeLogicalType) (arrow.DataType, error) {
511 switch logical.TimeUnit() {
512 case schema.TimeUnitMicros:
513 return arrow.FixedWidthTypes.Time64us, nil
514 case schema.TimeUnitNanos:
515 return arrow.FixedWidthTypes.Time64ns, nil
516 default:
517 return nil, xerrors.New(logical.String() + " cannot annotate int64")
518 }
519 }
520
521 func arrowTimestamp(logical *schema.TimestampLogicalType) (arrow.DataType, error) {
522 tz := "UTC"
523 if logical.IsFromConvertedType() {
524 tz = ""
525 }
526
527 switch logical.TimeUnit() {
528 case schema.TimeUnitMillis:
529 return &arrow.TimestampType{TimeZone: tz, Unit: arrow.Millisecond}, nil
530 case schema.TimeUnitMicros:
531 return &arrow.TimestampType{TimeZone: tz, Unit: arrow.Microsecond}, nil
532 case schema.TimeUnitNanos:
533 return &arrow.TimestampType{TimeZone: tz, Unit: arrow.Nanosecond}, nil
534 default:
535 return nil, xerrors.New("Unrecognized unit in timestamp logical type " + logical.String())
536 }
537 }
538
539 func arrowDecimal(logical *schema.DecimalLogicalType) arrow.DataType {
540 if logical.Precision() <= decimal128.MaxPrecision {
541 return &arrow.Decimal128Type{Precision: logical.Precision(), Scale: logical.Scale()}
542 }
543 return &arrow.Decimal256Type{Precision: logical.Precision(), Scale: logical.Scale()}
544 }
545
546 func arrowFromInt32(logical schema.LogicalType) (arrow.DataType, error) {
547 switch logtype := logical.(type) {
548 case schema.NoLogicalType:
549 return arrow.PrimitiveTypes.Int32, nil
550 case *schema.TimeLogicalType:
551 return arrowTime32(logtype)
552 case *schema.DecimalLogicalType:
553 return arrowDecimal(logtype), nil
554 case *schema.IntLogicalType:
555 return arrowInt(logtype)
556 case schema.DateLogicalType:
557 return arrow.FixedWidthTypes.Date32, nil
558 default:
559 return nil, xerrors.New(logical.String() + " cannot annotate int32")
560 }
561 }
562
563 func arrowFromInt64(logical schema.LogicalType) (arrow.DataType, error) {
564 if logical.IsNone() {
565 return arrow.PrimitiveTypes.Int64, nil
566 }
567
568 switch logtype := logical.(type) {
569 case *schema.IntLogicalType:
570 return arrowInt(logtype)
571 case *schema.DecimalLogicalType:
572 return arrowDecimal(logtype), nil
573 case *schema.TimeLogicalType:
574 return arrowTime64(logtype)
575 case *schema.TimestampLogicalType:
576 return arrowTimestamp(logtype)
577 default:
578 return nil, xerrors.New(logical.String() + " cannot annotate int64")
579 }
580 }
581
582 func arrowFromByteArray(logical schema.LogicalType) (arrow.DataType, error) {
583 switch logtype := logical.(type) {
584 case schema.StringLogicalType:
585 return arrow.BinaryTypes.String, nil
586 case *schema.DecimalLogicalType:
587 return arrowDecimal(logtype), nil
588 case schema.NoLogicalType,
589 schema.EnumLogicalType,
590 schema.JSONLogicalType,
591 schema.BSONLogicalType:
592 return arrow.BinaryTypes.Binary, nil
593 default:
594 return nil, xerrors.New("unhandled logicaltype " + logical.String() + " for byte_array")
595 }
596 }
597
598 func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, error) {
599 switch logtype := logical.(type) {
600 case *schema.DecimalLogicalType:
601 return arrowDecimal(logtype), nil
602 case schema.NoLogicalType, schema.IntervalLogicalType, schema.UUIDLogicalType:
603 return &arrow.FixedSizeBinaryType{ByteWidth: int(length)}, nil
604 case schema.Float16LogicalType:
605 return &arrow.Float16Type{}, nil
606 default:
607 return nil, xerrors.New("unhandled logical type " + logical.String() + " for fixed-length byte array")
608 }
609 }
610
611 func getArrowType(physical parquet.Type, logical schema.LogicalType, typeLen int) (arrow.DataType, error) {
612 if !logical.IsValid() || logical.Equals(schema.NullLogicalType{}) {
613 return arrow.Null, nil
614 }
615
616 switch physical {
617 case parquet.Types.Boolean:
618 return arrow.FixedWidthTypes.Boolean, nil
619 case parquet.Types.Int32:
620 return arrowFromInt32(logical)
621 case parquet.Types.Int64:
622 return arrowFromInt64(logical)
623 case parquet.Types.Int96:
624 return arrow.FixedWidthTypes.Timestamp_ns, nil
625 case parquet.Types.Float:
626 return arrow.PrimitiveTypes.Float32, nil
627 case parquet.Types.Double:
628 return arrow.PrimitiveTypes.Float64, nil
629 case parquet.Types.ByteArray:
630 return arrowFromByteArray(logical)
631 case parquet.Types.FixedLenByteArray:
632 return arrowFromFLBA(logical, typeLen)
633 default:
634 return nil, xerrors.New("invalid physical column type")
635 }
636 }
637
638 func populateLeaf(colIndex int, field *arrow.Field, currentLevels file.LevelInfo, ctx *schemaTree, parent *SchemaField, out *SchemaField) {
639 out.Field = field
640 out.ColIndex = colIndex
641 out.LevelInfo = currentLevels
642 ctx.RecordLeaf(out)
643 ctx.LinkParent(out, parent)
644 }
645
646 func listToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
647 if n.NumFields() != 1 {
648 return xerrors.New("LIST groups must have only 1 child")
649 }
650
651 if n.RepetitionType() == parquet.Repetitions.Repeated {
652 return xerrors.New("LIST groups must not be repeated")
653 }
654
655 currentLevels.Increment(n)
656
657 out.Children = make([]SchemaField, n.NumFields())
658 ctx.LinkParent(out, parent)
659 ctx.LinkParent(&out.Children[0], out)
660
661 listNode := n.Field(0)
662 if listNode.RepetitionType() != parquet.Repetitions.Repeated {
663 return xerrors.New("non-repeated nodes in a list group are not supported")
664 }
665
666 repeatedAncestorDef := currentLevels.IncrementRepeated()
667 if listNode.Type() == schema.Group {
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693 listGroup := listNode.(*schema.GroupNode)
694 if listGroup.NumFields() == 1 && !(listGroup.Name() == "array" || listGroup.Name() == (n.Name()+"_tuple")) {
695
696 if err := nodeToSchemaField(listGroup.Field(0), currentLevels, ctx, out, &out.Children[0]); err != nil {
697 return err
698 }
699 } else {
700 if err := groupToStructField(listGroup, currentLevels, ctx, out, &out.Children[0]); err != nil {
701 return err
702 }
703 }
704 } else {
705
706
707
708
709
710 primitiveNode := listNode.(*schema.PrimitiveNode)
711 colIndex := ctx.schema.ColumnIndexByNode(primitiveNode)
712 arrowType, err := getArrowType(primitiveNode.PhysicalType(), primitiveNode.LogicalType(), primitiveNode.TypeLength())
713 if err != nil {
714 return err
715 }
716
717 if ctx.props.ReadDict(colIndex) && isDictionaryReadSupported(arrowType) {
718 arrowType = &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: arrowType}
719 }
720
721 itemField := arrow.Field{Name: listNode.Name(), Type: arrowType, Nullable: false, Metadata: createFieldMeta(int(listNode.FieldID()))}
722 populateLeaf(colIndex, &itemField, currentLevels, ctx, out, &out.Children[0])
723 }
724
725 out.Field = &arrow.Field{Name: n.Name(), Type: arrow.ListOfField(
726 arrow.Field{Name: listNode.Name(), Type: out.Children[0].Field.Type, Nullable: true}),
727 Nullable: n.RepetitionType() == parquet.Repetitions.Optional, Metadata: createFieldMeta(int(n.FieldID()))}
728
729 out.LevelInfo = currentLevels
730
731
732 out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDef
733 return nil
734 }
735
736 func groupToStructField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
737 arrowFields := make([]arrow.Field, 0, n.NumFields())
738 out.Children = make([]SchemaField, n.NumFields())
739
740 for i := 0; i < n.NumFields(); i++ {
741 if err := nodeToSchemaField(n.Field(i), currentLevels, ctx, out, &out.Children[i]); err != nil {
742 return err
743 }
744 arrowFields = append(arrowFields, *out.Children[i].Field)
745 }
746
747 out.Field = &arrow.Field{Name: n.Name(), Type: arrow.StructOf(arrowFields...),
748 Nullable: n.RepetitionType() == parquet.Repetitions.Optional, Metadata: createFieldMeta(int(n.FieldID()))}
749 out.LevelInfo = currentLevels
750 return nil
751 }
752
753 func mapToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
754 if n.NumFields() != 1 {
755 return xerrors.New("MAP group must have exactly 1 child")
756 }
757 if n.RepetitionType() == parquet.Repetitions.Repeated {
758 return xerrors.New("MAP groups must not be repeated")
759 }
760
761 keyvalueNode := n.Field(0)
762 if keyvalueNode.RepetitionType() != parquet.Repetitions.Repeated {
763 return xerrors.New("Non-repeated keyvalue group in MAP group is not supported")
764 }
765
766 if keyvalueNode.Type() != schema.Group {
767 return xerrors.New("keyvalue node must be a group")
768 }
769
770 kvgroup := keyvalueNode.(*schema.GroupNode)
771 if kvgroup.NumFields() != 1 && kvgroup.NumFields() != 2 {
772 return fmt.Errorf("keyvalue node group must have exactly 1 or 2 child elements, Found %d", kvgroup.NumFields())
773 }
774
775 keyNode := kvgroup.Field(0)
776 if keyNode.RepetitionType() != parquet.Repetitions.Required {
777 return xerrors.New("MAP keys must be required")
778 }
779
780
781
782
783 if kvgroup.NumFields() == 1 {
784 return listToSchemaField(n, currentLevels, ctx, parent, out)
785 }
786
787 currentLevels.Increment(n)
788 repeatedAncestorDef := currentLevels.IncrementRepeated()
789 out.Children = make([]SchemaField, 1)
790
791 kvfield := &out.Children[0]
792 kvfield.Children = make([]SchemaField, 2)
793
794 keyField := &kvfield.Children[0]
795 valueField := &kvfield.Children[1]
796
797 ctx.LinkParent(out, parent)
798 ctx.LinkParent(kvfield, out)
799 ctx.LinkParent(keyField, kvfield)
800 ctx.LinkParent(valueField, kvfield)
801
802
803
804
805
806
807
808
809
810 if err := nodeToSchemaField(keyNode, currentLevels, ctx, kvfield, keyField); err != nil {
811 return err
812 }
813 if err := nodeToSchemaField(kvgroup.Field(1), currentLevels, ctx, kvfield, valueField); err != nil {
814 return err
815 }
816
817 kvfield.Field = &arrow.Field{Name: n.Name(), Type: arrow.StructOf(*keyField.Field, *valueField.Field),
818 Nullable: false, Metadata: createFieldMeta(int(kvgroup.FieldID()))}
819
820 kvfield.LevelInfo = currentLevels
821 out.Field = &arrow.Field{Name: n.Name(), Type: arrow.MapOf(keyField.Field.Type, valueField.Field.Type),
822 Nullable: n.RepetitionType() == parquet.Repetitions.Optional,
823 Metadata: createFieldMeta(int(n.FieldID()))}
824 out.LevelInfo = currentLevels
825
826
827 out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDef
828 return nil
829 }
830
831 func groupToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
832 if n.LogicalType().Equals(schema.NewListLogicalType()) {
833 return listToSchemaField(n, currentLevels, ctx, parent, out)
834 } else if n.LogicalType().Equals(schema.MapLogicalType{}) {
835 return mapToSchemaField(n, currentLevels, ctx, parent, out)
836 }
837
838 if n.RepetitionType() == parquet.Repetitions.Repeated {
839
840
841
842
843
844
845 out.Children = make([]SchemaField, 1)
846 repeatedAncestorDef := currentLevels.IncrementRepeated()
847 if err := groupToStructField(n, currentLevels, ctx, out, &out.Children[0]); err != nil {
848 return err
849 }
850
851 out.Field = &arrow.Field{Name: n.Name(), Type: arrow.ListOf(out.Children[0].Field.Type), Nullable: false,
852 Metadata: createFieldMeta(int(n.FieldID()))}
853 ctx.LinkParent(&out.Children[0], out)
854 out.LevelInfo = currentLevels
855 out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDef
856 return nil
857 }
858
859 currentLevels.Increment(n)
860 return groupToStructField(n, currentLevels, ctx, parent, out)
861 }
862
863 func createFieldMeta(fieldID int) arrow.Metadata {
864 return arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{strconv.Itoa(fieldID)})
865 }
866
867 func nodeToSchemaField(n schema.Node, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
868 ctx.LinkParent(out, parent)
869
870 if n.Type() == schema.Group {
871 return groupToSchemaField(n.(*schema.GroupNode), currentLevels, ctx, parent, out)
872 }
873
874
875
876
877
878
879
880
881
882
883
884 primitive := n.(*schema.PrimitiveNode)
885 colIndex := ctx.schema.ColumnIndexByNode(primitive)
886 arrowType, err := getArrowType(primitive.PhysicalType(), primitive.LogicalType(), primitive.TypeLength())
887 if err != nil {
888 return err
889 }
890
891 if ctx.props.ReadDict(colIndex) && isDictionaryReadSupported(arrowType) {
892 arrowType = &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: arrowType}
893 }
894
895 if primitive.RepetitionType() == parquet.Repetitions.Repeated {
896
897 repeatedAncestorDefLevel := currentLevels.IncrementRepeated()
898 out.Children = make([]SchemaField, 1)
899 child := arrow.Field{Name: primitive.Name(), Type: arrowType, Nullable: false}
900 populateLeaf(colIndex, &child, currentLevels, ctx, out, &out.Children[0])
901 out.Field = &arrow.Field{Name: primitive.Name(), Type: arrow.ListOf(child.Type), Nullable: false,
902 Metadata: createFieldMeta(int(primitive.FieldID()))}
903 out.LevelInfo = currentLevels
904 out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDefLevel
905 return nil
906 }
907
908 currentLevels.Increment(n)
909 populateLeaf(colIndex, &arrow.Field{Name: n.Name(), Type: arrowType,
910 Nullable: n.RepetitionType() == parquet.Repetitions.Optional,
911 Metadata: createFieldMeta(int(n.FieldID()))},
912 currentLevels, ctx, parent, out)
913 return nil
914 }
915
916 func getOriginSchema(meta metadata.KeyValueMetadata, mem memory.Allocator) (*arrow.Schema, error) {
917 if meta == nil {
918 return nil, nil
919 }
920
921 const arrowSchemaKey = "ARROW:schema"
922 serialized := meta.FindValue(arrowSchemaKey)
923 if serialized == nil {
924 return nil, nil
925 }
926
927 var (
928 decoded []byte
929 err error
930 )
931
932
933
934 if len(*serialized)%4 == 0 {
935 decoded, err = base64.StdEncoding.DecodeString(*serialized)
936 }
937
938
939 if len(decoded) == 0 || err != nil {
940 decoded, err = base64.RawStdEncoding.DecodeString(*serialized)
941 }
942
943 if err != nil {
944 return nil, err
945 }
946
947 return flight.DeserializeSchema(decoded, mem)
948 }
949
950 func getNestedFactory(origin, inferred arrow.DataType) func(fieldList []arrow.Field) arrow.DataType {
951 switch inferred.ID() {
952 case arrow.STRUCT:
953 if origin.ID() == arrow.STRUCT {
954 return func(list []arrow.Field) arrow.DataType {
955 return arrow.StructOf(list...)
956 }
957 }
958 case arrow.LIST:
959 switch origin.ID() {
960 case arrow.LIST:
961 return func(list []arrow.Field) arrow.DataType {
962 return arrow.ListOf(list[0].Type)
963 }
964 case arrow.FIXED_SIZE_LIST:
965 sz := origin.(*arrow.FixedSizeListType).Len()
966 return func(list []arrow.Field) arrow.DataType {
967 return arrow.FixedSizeListOf(sz, list[0].Type)
968 }
969 }
970 case arrow.MAP:
971 if origin.ID() == arrow.MAP {
972 return func(list []arrow.Field) arrow.DataType {
973 valType := list[0].Type.(*arrow.StructType)
974 return arrow.MapOf(valType.Field(0).Type, valType.Field(1).Type)
975 }
976 }
977 }
978 return nil
979 }
980
981 func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField) (modified bool, err error) {
982 nchildren := len(inferred.Children)
983 switch origin.Type.ID() {
984 case arrow.EXTENSION:
985 extType := origin.Type.(arrow.ExtensionType)
986 modified, err = applyOriginalStorageMetadata(arrow.Field{
987 Type: extType.StorageType(),
988 Metadata: origin.Metadata,
989 }, inferred)
990 if err != nil {
991 return
992 }
993
994 if !arrow.TypeEqual(extType.StorageType(), inferred.Field.Type) {
995 return modified, fmt.Errorf("%w: mismatch storage type '%s' for extension type '%s'",
996 arrow.ErrInvalid, inferred.Field.Type, extType)
997 }
998
999 inferred.Field.Type = extType
1000 modified = true
1001 case arrow.SPARSE_UNION, arrow.DENSE_UNION:
1002 err = xerrors.New("unimplemented type")
1003 case arrow.STRUCT:
1004 typ := origin.Type.(*arrow.StructType)
1005 if nchildren != typ.NumFields() {
1006 return
1007 }
1008
1009 factory := getNestedFactory(typ, inferred.Field.Type)
1010 if factory == nil {
1011 return
1012 }
1013
1014 modified = typ.ID() != inferred.Field.Type.ID()
1015 for idx := range inferred.Children {
1016 childMod, err := applyOriginalMetadata(typ.Field(idx), &inferred.Children[idx])
1017 if err != nil {
1018 return false, err
1019 }
1020 modified = modified || childMod
1021 }
1022 if modified {
1023 modifiedChildren := make([]arrow.Field, len(inferred.Children))
1024 for idx, child := range inferred.Children {
1025 modifiedChildren[idx] = *child.Field
1026 }
1027 inferred.Field.Type = factory(modifiedChildren)
1028 }
1029 case arrow.FIXED_SIZE_LIST, arrow.LIST, arrow.LARGE_LIST, arrow.MAP:
1030 if nchildren != 1 {
1031 return
1032 }
1033 factory := getNestedFactory(origin.Type, inferred.Field.Type)
1034 if factory == nil {
1035 return
1036 }
1037
1038 modified = origin.Type.ID() != inferred.Field.Type.ID()
1039 childModified, err := applyOriginalMetadata(arrow.Field{Type: origin.Type.(arrow.ListLikeType).Elem()}, &inferred.Children[0])
1040 if err != nil {
1041 return modified, err
1042 }
1043 modified = modified || childModified
1044 if modified {
1045 inferred.Field.Type = factory([]arrow.Field{*inferred.Children[0].Field})
1046 }
1047 case arrow.TIMESTAMP:
1048 if inferred.Field.Type.ID() != arrow.TIMESTAMP {
1049 return
1050 }
1051
1052 tsOtype := origin.Type.(*arrow.TimestampType)
1053 tsInfType := inferred.Field.Type.(*arrow.TimestampType)
1054
1055
1056
1057 if tsOtype.Unit == tsInfType.Unit && tsInfType.TimeZone == "UTC" && tsOtype.TimeZone != "" {
1058 inferred.Field.Type = origin.Type
1059 }
1060 modified = true
1061 case arrow.LARGE_STRING, arrow.LARGE_BINARY:
1062 inferred.Field.Type = origin.Type
1063 modified = true
1064 case arrow.DICTIONARY:
1065 if origin.Type.ID() != arrow.DICTIONARY || (inferred.Field.Type.ID() == arrow.DICTIONARY || !isDictionaryReadSupported(inferred.Field.Type)) {
1066 return
1067 }
1068
1069
1070
1071 dictOriginType := origin.Type.(*arrow.DictionaryType)
1072 inferred.Field.Type = &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32,
1073 ValueType: inferred.Field.Type, Ordered: dictOriginType.Ordered}
1074 modified = true
1075 case arrow.DECIMAL256:
1076 if inferred.Field.Type.ID() == arrow.DECIMAL128 {
1077 inferred.Field.Type = origin.Type
1078 modified = true
1079 }
1080 }
1081
1082 if origin.HasMetadata() {
1083 meta := origin.Metadata
1084 if inferred.Field.HasMetadata() {
1085 final := make(map[string]string)
1086 for idx, k := range meta.Keys() {
1087 final[k] = meta.Values()[idx]
1088 }
1089 for idx, k := range inferred.Field.Metadata.Keys() {
1090 final[k] = inferred.Field.Metadata.Values()[idx]
1091 }
1092 inferred.Field.Metadata = arrow.MetadataFrom(final)
1093 } else {
1094 inferred.Field.Metadata = meta
1095 }
1096 modified = true
1097 }
1098
1099 return
1100 }
1101
1102 func applyOriginalMetadata(origin arrow.Field, inferred *SchemaField) (bool, error) {
1103 return applyOriginalStorageMetadata(origin, inferred)
1104 }
1105
1106
1107
1108
1109
1110 func NewSchemaManifest(sc *schema.Schema, meta metadata.KeyValueMetadata, props *ArrowReadProperties) (*SchemaManifest, error) {
1111 var ctx schemaTree
1112 ctx.manifest = &SchemaManifest{
1113 ColIndexToField: make(map[int]*SchemaField),
1114 ChildToParent: make(map[*SchemaField]*SchemaField),
1115 descr: sc,
1116 Fields: make([]SchemaField, sc.Root().NumFields()),
1117 }
1118 ctx.props = props
1119 if ctx.props == nil {
1120 ctx.props = &ArrowReadProperties{}
1121 }
1122 ctx.schema = sc
1123
1124 var err error
1125 ctx.manifest.OriginSchema, err = getOriginSchema(meta, memory.DefaultAllocator)
1126 if err != nil {
1127 return nil, err
1128 }
1129
1130
1131 if ctx.manifest.OriginSchema != nil && len(ctx.manifest.OriginSchema.Fields()) != sc.Root().NumFields() {
1132 ctx.manifest.OriginSchema = nil
1133 }
1134
1135 for idx := range ctx.manifest.Fields {
1136 field := &ctx.manifest.Fields[idx]
1137 if err := nodeToSchemaField(sc.Root().Field(idx), file.LevelInfo{NullSlotUsage: 1}, &ctx, nil, field); err != nil {
1138 return nil, err
1139 }
1140
1141 if ctx.manifest.OriginSchema != nil {
1142 if _, err := applyOriginalMetadata(ctx.manifest.OriginSchema.Field(idx), field); err != nil {
1143 return nil, err
1144 }
1145 }
1146 }
1147 return ctx.manifest, nil
1148 }
1149
1150
1151 func FromParquet(sc *schema.Schema, props *ArrowReadProperties, kv metadata.KeyValueMetadata) (*arrow.Schema, error) {
1152 manifest, err := NewSchemaManifest(sc, kv, props)
1153 if err != nil {
1154 return nil, err
1155 }
1156
1157 fields := make([]arrow.Field, len(manifest.Fields))
1158 for idx, field := range manifest.Fields {
1159 fields[idx] = *field.Field
1160 }
1161
1162 if manifest.OriginSchema != nil {
1163 meta := manifest.OriginSchema.Metadata()
1164 return arrow.NewSchema(fields, &meta), nil
1165 }
1166 return arrow.NewSchema(fields, manifest.SchemaMeta), nil
1167 }
1168
View as plain text