1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package schema
18
19 import (
20 "testing"
21
22 "github.com/apache/arrow/go/v15/parquet"
23 format "github.com/apache/arrow/go/v15/parquet/internal/gen-go/parquet"
24 "github.com/stretchr/testify/assert"
25 "github.com/stretchr/testify/suite"
26 )
27
28 func NewPrimitive(name string, repetition format.FieldRepetitionType, typ format.Type, fieldID int32) *format.SchemaElement {
29 ret := &format.SchemaElement{
30 Name: name,
31 RepetitionType: format.FieldRepetitionTypePtr(repetition),
32 Type: format.TypePtr(typ),
33 }
34 if fieldID >= 0 {
35 ret.FieldID = &fieldID
36 }
37 return ret
38 }
39
40 func NewGroup(name string, repetition format.FieldRepetitionType, numChildren, fieldID int32) *format.SchemaElement {
41 ret := &format.SchemaElement{
42 Name: name,
43 RepetitionType: format.FieldRepetitionTypePtr(repetition),
44 NumChildren: &numChildren,
45 }
46 if fieldID >= 0 {
47 ret.FieldID = &fieldID
48 }
49 return ret
50 }
51
52 type SchemaFlattenSuite struct {
53 suite.Suite
54
55 name string
56 }
57
58 func (s *SchemaFlattenSuite) SetupSuite() {
59 s.name = "parquet_schema"
60 }
61
62 func (s *SchemaFlattenSuite) TestDecimalMetadata() {
63 group := MustGroup(NewGroupNodeConverted("group" , parquet.Repetitions.Repeated, FieldList{
64 MustPrimitive(NewPrimitiveNodeConverted("decimal" , parquet.Repetitions.Required, parquet.Types.Int64,
65 ConvertedTypes.Decimal, 0 , 8 , 4 , -1 )),
66 }, ConvertedTypes.List, -1 ))
67 elements := ToThrift(group)
68
69 s.Len(elements, 2)
70 s.Equal("decimal", elements[1].GetName())
71 s.True(elements[1].IsSetPrecision())
72 s.True(elements[1].IsSetScale())
73
74 group = MustGroup(NewGroupNodeLogical("group" , parquet.Repetitions.Repeated, FieldList{
75 MustPrimitive(NewPrimitiveNodeLogical("decimal" , parquet.Repetitions.Required, NewDecimalLogicalType(10 , 5 ),
76 parquet.Types.Int64, 0 , -1 )),
77 }, NewListLogicalType(), -1 ))
78 elements = ToThrift(group)
79 s.Equal("decimal", elements[1].Name)
80 s.True(elements[1].IsSetPrecision())
81 s.True(elements[1].IsSetScale())
82
83 group = MustGroup(NewGroupNodeConverted("group" , parquet.Repetitions.Repeated, FieldList{
84 NewInt64Node("int64" , parquet.Repetitions.Required, -1 )}, ConvertedTypes.List, -1 ))
85 elements = ToThrift(group)
86 s.Equal("int64", elements[1].Name)
87 s.False(elements[0].IsSetPrecision())
88 s.False(elements[1].IsSetPrecision())
89 s.False(elements[0].IsSetScale())
90 s.False(elements[1].IsSetScale())
91 }
92
93 func (s *SchemaFlattenSuite) TestNestedExample() {
94 elements := make([]*format.SchemaElement, 0)
95 elements = append(elements,
96 NewGroup(s.name, format.FieldRepetitionType_REPEATED, 2 , 0 ),
97 NewPrimitive("a" , format.FieldRepetitionType_REQUIRED, format.Type_INT32, 1 ),
98 NewGroup("bag" , format.FieldRepetitionType_OPTIONAL, 1 , 2 ))
99
100 elt := NewGroup("b" , format.FieldRepetitionType_REPEATED, 1 , 3 )
101 elt.ConvertedType = format.ConvertedTypePtr(format.ConvertedType_LIST)
102 elt.LogicalType = &format.LogicalType{LIST: format.NewListType()}
103 elements = append(elements, elt, NewPrimitive("item" , format.FieldRepetitionType_OPTIONAL, format.Type_INT64, 4 ))
104
105 fields := FieldList{NewInt32Node("a" , parquet.Repetitions.Required, 1 )}
106 list := MustGroup(NewGroupNodeConverted("b" , parquet.Repetitions.Repeated, FieldList{
107 NewInt64Node("item" , parquet.Repetitions.Optional, 4 )}, ConvertedTypes.List, 3 ))
108 fields = append(fields, MustGroup(NewGroupNode("bag" , parquet.Repetitions.Optional, FieldList{list}, 2 )))
109
110 sc := MustGroup(NewGroupNode(s.name, parquet.Repetitions.Repeated, fields, 0 ))
111
112 flattened := ToThrift(sc)
113 s.Len(flattened, len(elements))
114 for idx, elem := range flattened {
115 s.Equal(elements[idx], elem)
116 }
117 }
118
119 func TestSchemaFlatten(t *testing.T) {
120 suite.Run(t, new(SchemaFlattenSuite))
121 }
122
123 func TestInvalidConvertedTypeInDeserialize(t *testing.T) {
124 n := MustPrimitive(NewPrimitiveNodeLogical("string" , parquet.Repetitions.Required, StringLogicalType{},
125 parquet.Types.ByteArray, -1 , -1 ))
126 assert.True(t, n.LogicalType().Equals(StringLogicalType{}))
127 assert.True(t, n.LogicalType().IsValid())
128 assert.True(t, n.LogicalType().IsSerialized())
129 intermediary := n.toThrift()
130
131 intermediary.LogicalType.STRING = nil
132 assert.Panics(t, func() {
133 PrimitiveNodeFromThrift(intermediary)
134 })
135 }
136
137 func TestInvalidTimeUnitInTimeLogical(t *testing.T) {
138 n := MustPrimitive(NewPrimitiveNodeLogical("time" , parquet.Repetitions.Required,
139 NewTimeLogicalType(true , TimeUnitNanos), parquet.Types.Int64, -1 , -1 ))
140 intermediary := n.toThrift()
141
142 intermediary.LogicalType.TIME.Unit.NANOS = nil
143 assert.Panics(t, func() {
144 PrimitiveNodeFromThrift(intermediary)
145 })
146 }
147
148 func TestInvalidTimeUnitInTimestampLogical(t *testing.T) {
149 n := MustPrimitive(NewPrimitiveNodeLogical("time" , parquet.Repetitions.Required,
150 NewTimestampLogicalType(true , TimeUnitNanos), parquet.Types.Int64, -1 , -1 ))
151 intermediary := n.toThrift()
152
153 intermediary.LogicalType.TIMESTAMP.Unit.NANOS = nil
154 assert.Panics(t, func() {
155 PrimitiveNodeFromThrift(intermediary)
156 })
157 }
158
View as plain text