1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package schema
18
19 import (
20 "fmt"
21
22 "github.com/apache/arrow/go/v15/parquet"
23 format "github.com/apache/arrow/go/v15/parquet/internal/gen-go/parquet"
24 "github.com/apache/thrift/lib/go/thrift"
25 "golang.org/x/xerrors"
26 )
27
28
29 type NodeType int
30
31
32 const (
33 Primitive NodeType = iota
34 Group
35 )
36
37
38
39
40
41 type Node interface {
42 Name() string
43 Type() NodeType
44 RepetitionType() parquet.Repetition
45 ConvertedType() ConvertedType
46 LogicalType() LogicalType
47 FieldID() int32
48 Parent() Node
49 SetParent(Node)
50 Path() string
51 Equals(Node) bool
52 Visit(v Visitor)
53 toThrift() *format.SchemaElement
54 }
55
56
57
58
59
60
61
62
63
64 type Visitor interface {
65 VisitPre(Node) bool
66 VisitPost(Node)
67 }
68
69
70
71 func ColumnPathFromNode(n Node) parquet.ColumnPath {
72 if n == nil {
73 return nil
74 }
75
76 c := make([]string, 0)
77
78
79 cursor := n
80 for cursor.Parent() != nil {
81 c = append(c, cursor.Name())
82 cursor = cursor.Parent()
83 }
84
85
86
87 for i := len(c)/2 - 1; i >= 0; i-- {
88 opp := len(c) - 1 - i
89 c[i], c[opp] = c[opp], c[i]
90 }
91
92 return c
93 }
94
95
96 type node struct {
97 typ NodeType
98 parent Node
99
100 name string
101 repetition parquet.Repetition
102 fieldID int32
103 logicalType LogicalType
104 convertedType ConvertedType
105 colPath parquet.ColumnPath
106 }
107
108 func (n *node) toThrift() *format.SchemaElement { return nil }
109 func (n *node) Name() string { return n.name }
110 func (n *node) Type() NodeType { return n.typ }
111 func (n *node) RepetitionType() parquet.Repetition { return n.repetition }
112 func (n *node) ConvertedType() ConvertedType { return n.convertedType }
113 func (n *node) LogicalType() LogicalType { return n.logicalType }
114 func (n *node) FieldID() int32 { return n.fieldID }
115 func (n *node) Parent() Node { return n.parent }
116 func (n *node) SetParent(p Node) { n.parent = p }
117 func (n *node) Path() string {
118 return n.columnPath().String()
119 }
120 func (n *node) columnPath() parquet.ColumnPath {
121 if n.colPath == nil {
122 n.colPath = ColumnPathFromNode(n)
123 }
124 return n.colPath
125 }
126
127 func (n *node) Equals(rhs Node) bool {
128 return n.typ == rhs.Type() &&
129 n.Name() == rhs.Name() &&
130 n.RepetitionType() == rhs.RepetitionType() &&
131 n.ConvertedType() == rhs.ConvertedType() &&
132 n.FieldID() == rhs.FieldID() &&
133 n.LogicalType().Equals(rhs.LogicalType())
134 }
135
136 func (n *node) Visit(v Visitor) {}
137
138
139
140
141
142 type PrimitiveNode struct {
143 node
144
145 ColumnOrder parquet.ColumnOrder
146 physicalType parquet.Type
147 typeLen int
148 decimalMetaData DecimalMetadata
149 }
150
151
152
153 func NewPrimitiveNodeLogical(name string, repetition parquet.Repetition, logicalType LogicalType, physicalType parquet.Type, typeLen int, id int32) (*PrimitiveNode, error) {
154 n := &PrimitiveNode{
155 node: node{typ: Primitive, name: name, repetition: repetition, logicalType: logicalType, fieldID: id},
156 physicalType: physicalType,
157 typeLen: typeLen,
158 }
159
160 if logicalType != nil {
161 if !logicalType.IsNested() {
162 if logicalType.IsApplicable(physicalType, int32(typeLen)) {
163 n.convertedType, n.decimalMetaData = n.logicalType.ToConvertedType()
164 } else {
165 return nil, fmt.Errorf("%s cannot be applied to primitive type %s", logicalType, physicalType)
166 }
167 } else {
168 return nil, fmt.Errorf("nested logical type %s cannot be applied to a non-group node", logicalType)
169 }
170 } else {
171 n.logicalType = NoLogicalType{}
172 n.convertedType, n.decimalMetaData = n.logicalType.ToConvertedType()
173 }
174
175 if !(n.logicalType != nil && !n.logicalType.IsNested() && n.logicalType.IsCompatible(n.convertedType, n.decimalMetaData)) {
176 return nil, fmt.Errorf("invalid logical type %s", n.logicalType)
177 }
178
179 if n.physicalType == parquet.Types.FixedLenByteArray && n.typeLen <= 0 {
180 return nil, xerrors.New("invalid fixed length byte array length")
181 }
182 return n, nil
183 }
184
185
186
187 func NewPrimitiveNodeConverted(name string, repetition parquet.Repetition, typ parquet.Type, converted ConvertedType, typeLen, precision, scale int, id int32) (*PrimitiveNode, error) {
188 n := &PrimitiveNode{
189 node: node{typ: Primitive, name: name, repetition: repetition, convertedType: converted, fieldID: id},
190 physicalType: typ,
191 typeLen: -1,
192 }
193
194 switch converted {
195 case ConvertedTypes.None:
196 case ConvertedTypes.UTF8, ConvertedTypes.JSON, ConvertedTypes.BSON:
197 if typ != parquet.Types.ByteArray {
198 return nil, fmt.Errorf("parquet: %s can only annotate BYTE_LEN fields", typ)
199 }
200 case ConvertedTypes.Decimal:
201 switch typ {
202 case parquet.Types.Int32, parquet.Types.Int64, parquet.Types.ByteArray, parquet.Types.FixedLenByteArray:
203 default:
204 return nil, xerrors.New("parquet: DECIMAL can only annotate INT32, INT64, BYTE_ARRAY and FIXED")
205 }
206
207 switch {
208 case precision <= 0:
209 return nil, fmt.Errorf("parquet: invalid decimal precision: %d, must be between 1 and 38 inclusive", precision)
210 case scale < 0:
211 return nil, fmt.Errorf("parquet: invalid decimal scale: %d, must be a number between 0 and precision inclusive", scale)
212 case scale > precision:
213 return nil, fmt.Errorf("parquet: invalid decimal scale %d, cannot be greater than precision: %d", scale, precision)
214 }
215 n.decimalMetaData.IsSet = true
216 n.decimalMetaData.Precision = int32(precision)
217 n.decimalMetaData.Scale = int32(scale)
218 case ConvertedTypes.Date,
219 ConvertedTypes.TimeMillis,
220 ConvertedTypes.Int8,
221 ConvertedTypes.Int16,
222 ConvertedTypes.Int32,
223 ConvertedTypes.Uint8,
224 ConvertedTypes.Uint16,
225 ConvertedTypes.Uint32:
226 if typ != parquet.Types.Int32 {
227 return nil, fmt.Errorf("parquet: %s can only annotate INT32", converted)
228 }
229 case ConvertedTypes.TimeMicros,
230 ConvertedTypes.TimestampMicros,
231 ConvertedTypes.TimestampMillis,
232 ConvertedTypes.Int64,
233 ConvertedTypes.Uint64:
234 if typ != parquet.Types.Int64 {
235 return nil, fmt.Errorf("parquet: %s can only annotate INT64", converted)
236 }
237 case ConvertedTypes.Interval:
238 if typ != parquet.Types.FixedLenByteArray || typeLen != 12 {
239 return nil, xerrors.New("parquet: INTERVAL can only annotate FIXED_LEN_BYTE_ARRAY(12)")
240 }
241 case ConvertedTypes.Enum:
242 if typ != parquet.Types.ByteArray {
243 return nil, xerrors.New("parquet: ENUM can only annotate BYTE_ARRAY fields")
244 }
245 case ConvertedTypes.NA:
246 default:
247 return nil, fmt.Errorf("parquet: %s cannot be applied to a primitive type", converted.String())
248 }
249
250 n.logicalType = n.convertedType.ToLogicalType(n.decimalMetaData)
251 if !(n.logicalType != nil && !n.logicalType.IsNested() && n.logicalType.IsCompatible(n.convertedType, n.decimalMetaData)) {
252 return nil, fmt.Errorf("invalid logical type %s", n.logicalType)
253 }
254
255 if n.physicalType == parquet.Types.FixedLenByteArray {
256 if typeLen <= 0 {
257 return nil, xerrors.New("invalid fixed len byte array length")
258 }
259 n.typeLen = typeLen
260 }
261
262 return n, nil
263 }
264
265 func PrimitiveNodeFromThrift(elem *format.SchemaElement) (*PrimitiveNode, error) {
266 fieldID := int32(-1)
267 if elem.IsSetFieldID() {
268 fieldID = elem.GetFieldID()
269 }
270
271 if elem.IsSetLogicalType() {
272 return NewPrimitiveNodeLogical(elem.GetName(), parquet.Repetition(elem.GetRepetitionType()),
273 getLogicalType(elem.GetLogicalType()), parquet.Type(elem.GetType()), int(elem.GetTypeLength()),
274 fieldID)
275 } else if elem.IsSetConvertedType() {
276 return NewPrimitiveNodeConverted(elem.GetName(), parquet.Repetition(elem.GetRepetitionType()),
277 parquet.Type(elem.GetType()), ConvertedType(elem.GetConvertedType()),
278 int(elem.GetTypeLength()), int(elem.GetPrecision()), int(elem.GetScale()), fieldID)
279 }
280 return NewPrimitiveNodeLogical(elem.GetName(), parquet.Repetition(elem.GetRepetitionType()), NoLogicalType{}, parquet.Type(elem.GetType()), int(elem.GetTypeLength()), fieldID)
281 }
282
283
284
285
286 func NewPrimitiveNode(name string, repetition parquet.Repetition, typ parquet.Type, fieldID, typeLength int32) (*PrimitiveNode, error) {
287 return NewPrimitiveNodeLogical(name, repetition, nil, typ, int(typeLength), fieldID)
288 }
289
290
291
292 func (p *PrimitiveNode) Equals(rhs Node) bool {
293 if !p.node.Equals(rhs) {
294 return false
295 }
296
297 other := rhs.(*PrimitiveNode)
298 if p == other {
299 return true
300 }
301
302 if p.PhysicalType() != other.PhysicalType() {
303 return false
304 }
305
306 equal := true
307 if p.ConvertedType() == ConvertedTypes.Decimal {
308 equal = equal &&
309 (p.decimalMetaData.Precision == other.decimalMetaData.Precision &&
310 p.decimalMetaData.Scale == other.decimalMetaData.Scale)
311 }
312 if p.PhysicalType() == parquet.Types.FixedLenByteArray {
313 equal = equal && p.TypeLength() == other.TypeLength()
314 }
315 return equal
316 }
317
318
319
320 func (p *PrimitiveNode) PhysicalType() parquet.Type { return p.physicalType }
321
322
323
324 func (p *PrimitiveNode) SetTypeLength(length int) {
325 if p.PhysicalType() == parquet.Types.FixedLenByteArray {
326 p.typeLen = length
327 }
328 }
329
330
331
332 func (p *PrimitiveNode) TypeLength() int { return p.typeLen }
333
334
335
336 func (p *PrimitiveNode) DecimalMetadata() DecimalMetadata { return p.decimalMetaData }
337
338
339
340 func (p *PrimitiveNode) Visit(v Visitor) {
341 v.VisitPre(p)
342 v.VisitPost(p)
343 }
344
345 func (p *PrimitiveNode) toThrift() *format.SchemaElement {
346 elem := &format.SchemaElement{
347 Name: p.Name(),
348 RepetitionType: format.FieldRepetitionTypePtr(format.FieldRepetitionType(p.RepetitionType())),
349 Type: format.TypePtr(format.Type(p.PhysicalType())),
350 }
351 if p.ConvertedType() != ConvertedTypes.None {
352 elem.ConvertedType = format.ConvertedTypePtr(format.ConvertedType(p.ConvertedType()))
353 }
354 if p.FieldID() >= 0 {
355 elem.FieldID = thrift.Int32Ptr(p.FieldID())
356 }
357 if p.logicalType != nil && p.logicalType.IsSerialized() && !p.logicalType.Equals(IntervalLogicalType{}) {
358 elem.LogicalType = p.logicalType.toThrift()
359 }
360 if p.physicalType == parquet.Types.FixedLenByteArray {
361 elem.TypeLength = thrift.Int32Ptr(int32(p.typeLen))
362 }
363 if p.decimalMetaData.IsSet {
364 elem.Precision = &p.decimalMetaData.Precision
365 elem.Scale = &p.decimalMetaData.Scale
366 }
367 return elem
368 }
369
370
371 type FieldList []Node
372
373
374 func (f FieldList) Len() int { return len(f) }
375
376
377 type GroupNode struct {
378 node
379 fields FieldList
380 nameToIdx strIntMultimap
381 }
382
383
384
385 func NewGroupNodeConverted(name string, repetition parquet.Repetition, fields FieldList, converted ConvertedType, id int32) (n *GroupNode, err error) {
386 n = &GroupNode{
387 node: node{typ: Group, name: name, repetition: repetition, convertedType: converted, fieldID: id},
388 fields: fields,
389 }
390 n.logicalType = n.convertedType.ToLogicalType(DecimalMetadata{})
391 if !(n.logicalType != nil && (n.logicalType.IsNested() || n.logicalType.IsNone()) && n.logicalType.IsCompatible(n.convertedType, DecimalMetadata{})) {
392 err = fmt.Errorf("invalid logical type %s", n.logicalType.String())
393 return
394 }
395
396 n.nameToIdx = make(strIntMultimap)
397 for idx, f := range n.fields {
398 f.SetParent(n)
399 n.nameToIdx.Add(f.Name(), idx)
400 }
401 return
402 }
403
404
405
406 func NewGroupNodeLogical(name string, repetition parquet.Repetition, fields FieldList, logical LogicalType, id int32) (n *GroupNode, err error) {
407 n = &GroupNode{
408 node: node{typ: Group, name: name, repetition: repetition, logicalType: logical, fieldID: id},
409 fields: fields,
410 }
411
412 if logical != nil {
413 if logical.IsNested() {
414 n.convertedType, _ = logical.ToConvertedType()
415 } else {
416 err = fmt.Errorf("logical type %s cannot be applied to group node", logical)
417 return
418 }
419 } else {
420 n.logicalType = NoLogicalType{}
421 n.convertedType, _ = n.logicalType.ToConvertedType()
422 }
423
424 if !(n.logicalType != nil && (n.logicalType.IsNested() || n.logicalType.IsNone()) && n.logicalType.IsCompatible(n.convertedType, DecimalMetadata{})) {
425 err = fmt.Errorf("invalid logical type %s", n.logicalType)
426 return
427 }
428
429 n.nameToIdx = make(strIntMultimap)
430 for idx, f := range n.fields {
431 f.SetParent(n)
432 n.nameToIdx.Add(f.Name(), idx)
433 }
434 return
435 }
436
437
438
439 func NewGroupNode(name string, repetition parquet.Repetition, fields FieldList, fieldID int32) (*GroupNode, error) {
440 return NewGroupNodeConverted(name, repetition, fields, ConvertedTypes.None, fieldID)
441 }
442
443
444
445 func Must(n Node, err error) Node {
446 if err != nil {
447 panic(err)
448 }
449 return n
450 }
451
452
453
454 func MustGroup(n Node, err error) *GroupNode {
455 if err != nil {
456 panic(err)
457 }
458 return n.(*GroupNode)
459 }
460
461
462
463 func MustPrimitive(n Node, err error) *PrimitiveNode {
464 if err != nil {
465 panic(err)
466 }
467 return n.(*PrimitiveNode)
468 }
469
470 func GroupNodeFromThrift(elem *format.SchemaElement, fields FieldList) (*GroupNode, error) {
471 id := int32(-1)
472 if elem.IsSetFieldID() {
473 id = elem.GetFieldID()
474 }
475
476 if elem.IsSetLogicalType() {
477 return NewGroupNodeLogical(elem.GetName(), parquet.Repetition(elem.GetRepetitionType()), fields, getLogicalType(elem.GetLogicalType()), id)
478 }
479
480 converted := ConvertedTypes.None
481 if elem.IsSetConvertedType() {
482 converted = ConvertedType(elem.GetConvertedType())
483 }
484 return NewGroupNodeConverted(elem.GetName(), parquet.Repetition(elem.GetRepetitionType()), fields, converted, id)
485 }
486
487 func (g *GroupNode) toThrift() *format.SchemaElement {
488 elem := &format.SchemaElement{
489 Name: g.name,
490 NumChildren: thrift.Int32Ptr(int32(len(g.fields))),
491 RepetitionType: format.FieldRepetitionTypePtr(format.FieldRepetitionType(g.RepetitionType())),
492 }
493 if g.convertedType != ConvertedTypes.None {
494 elem.ConvertedType = format.ConvertedTypePtr(format.ConvertedType(g.convertedType))
495 }
496 if g.fieldID >= 0 {
497 elem.FieldID = &g.fieldID
498 }
499 if g.logicalType != nil && g.logicalType.IsSerialized() {
500 elem.LogicalType = g.logicalType.toThrift()
501 }
502 return elem
503 }
504
505
506
507
508 func (g *GroupNode) Equals(rhs Node) bool {
509 if !g.node.Equals(rhs) {
510 return false
511 }
512
513 other := rhs.(*GroupNode)
514 if g == other {
515 return true
516 }
517 if len(g.fields) != len(other.fields) {
518 return false
519 }
520
521 for idx, field := range g.fields {
522 if !field.Equals(other.fields[idx]) {
523 return false
524 }
525 }
526 return true
527 }
528
529
530 func (g *GroupNode) NumFields() int {
531 return len(g.fields)
532 }
533
534
535 func (g *GroupNode) Field(i int) Node {
536 return g.fields[i]
537 }
538
539
540
541
542
543 func (g *GroupNode) FieldIndexByName(name string) int {
544 if idx, ok := g.nameToIdx[name]; ok {
545 return idx[0]
546 }
547 return -1
548 }
549
550
551
552 func (g *GroupNode) FieldIndexByField(n Node) int {
553 if search, ok := g.nameToIdx[n.Name()]; ok {
554 for _, idx := range search {
555 if n == g.fields[idx] {
556 return idx
557 }
558 }
559 }
560 return -1
561 }
562
563
564
565 func (g *GroupNode) Visit(v Visitor) {
566 if v.VisitPre(g) {
567 for _, field := range g.fields {
568 field.Visit(v)
569 }
570 }
571 v.VisitPost(g)
572 }
573
574
575
576
577
578 func (g *GroupNode) HasRepeatedFields() bool {
579 for _, field := range g.fields {
580 if field.RepetitionType() == parquet.Repetitions.Repeated {
581 return true
582 }
583 if field.Type() == Group {
584 return field.(*GroupNode).HasRepeatedFields()
585 }
586 }
587 return false
588 }
589
590
591 func NewInt32Node(name string, rep parquet.Repetition, fieldID int32) *PrimitiveNode {
592 return MustPrimitive(NewPrimitiveNode(name, rep, parquet.Types.Int32, fieldID, -1))
593 }
594
595
596 func NewInt64Node(name string, rep parquet.Repetition, fieldID int32) *PrimitiveNode {
597 return MustPrimitive(NewPrimitiveNode(name, rep, parquet.Types.Int64, fieldID, -1))
598 }
599
600
601 func NewInt96Node(name string, rep parquet.Repetition, fieldID int32) *PrimitiveNode {
602 return MustPrimitive(NewPrimitiveNode(name, rep, parquet.Types.Int96, fieldID, -1))
603 }
604
605
606 func NewFloat32Node(name string, rep parquet.Repetition, fieldID int32) *PrimitiveNode {
607 return MustPrimitive(NewPrimitiveNode(name, rep, parquet.Types.Float, fieldID, -1))
608 }
609
610
611 func NewFloat64Node(name string, rep parquet.Repetition, fieldID int32) *PrimitiveNode {
612 return MustPrimitive(NewPrimitiveNode(name, rep, parquet.Types.Double, fieldID, -1))
613 }
614
615
616 func NewBooleanNode(name string, rep parquet.Repetition, fieldID int32) *PrimitiveNode {
617 return MustPrimitive(NewPrimitiveNode(name, rep, parquet.Types.Boolean, fieldID, -1))
618 }
619
620
621 func NewByteArrayNode(name string, rep parquet.Repetition, fieldID int32) *PrimitiveNode {
622 return MustPrimitive(NewPrimitiveNode(name, rep, parquet.Types.ByteArray, fieldID, -1))
623 }
624
625
626
627 func NewFixedLenByteArrayNode(name string, rep parquet.Repetition, length int32, fieldID int32) *PrimitiveNode {
628 return MustPrimitive(NewPrimitiveNode(name, rep, parquet.Types.FixedLenByteArray, fieldID, length))
629 }
630
View as plain text