1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31 package schema
32
33 import (
34 "fmt"
35 "io"
36 "strings"
37
38 "github.com/apache/arrow/go/v15/parquet"
39 format "github.com/apache/arrow/go/v15/parquet/internal/gen-go/parquet"
40 "golang.org/x/xerrors"
41 )
42
43
44
45
46
47
48
49
50
51
52
53 type Schema struct {
54 root Node
55
56 leaves []*Column
57 nodeToLeaf map[*PrimitiveNode]int
58 leafToBase map[int]Node
59 leafToIndex strIntMultimap
60 }
61
62
63 func FromParquet(elems []*format.SchemaElement) (Node, error) {
64 if len(elems) == 0 {
65 return nil, xerrors.New("parquet: empty schema (no root)")
66 }
67
68 if elems[0].GetNumChildren() == 0 {
69 if len(elems) > 1 {
70 return nil, xerrors.New("parquet: schema had multiple nodes but root had no children")
71 }
72
73 return GroupNodeFromThrift(elems[0], []Node{})
74 }
75
76
77
78 var (
79 pos = 0
80 nextNode func() (Node, error)
81 )
82
83 nextNode = func() (Node, error) {
84 if pos == len(elems) {
85 return nil, xerrors.New("parquet: malformed schema: not enough elements")
86 }
87
88 elem := elems[pos]
89 pos++
90
91 if elem.GetNumChildren() == 0 {
92 return PrimitiveNodeFromThrift(elem)
93 }
94
95 fields := make([]Node, 0, elem.GetNumChildren())
96 for i := 0; i < int(elem.GetNumChildren()); i++ {
97 n, err := nextNode()
98 if err != nil {
99 return nil, err
100 }
101 fields = append(fields, n)
102 }
103
104 return GroupNodeFromThrift(elem, fields)
105 }
106
107 return nextNode()
108 }
109
110
111 func (s *Schema) Root() *GroupNode {
112 return s.root.(*GroupNode)
113 }
114
115
116
117 func (s *Schema) NumColumns() int {
118 return len(s.leaves)
119 }
120
121
122
123
124 func (s *Schema) Equals(rhs *Schema) bool {
125 if s.NumColumns() != rhs.NumColumns() {
126 return false
127 }
128
129 for idx, c := range s.leaves {
130 if !c.Equals(rhs.Column(idx)) {
131 return false
132 }
133 }
134 return true
135 }
136
137 func (s *Schema) buildTree(n Node, maxDefLvl, maxRepLvl int16, base Node) {
138 switch n.RepetitionType() {
139 case parquet.Repetitions.Repeated:
140 maxRepLvl++
141 fallthrough
142 case parquet.Repetitions.Optional:
143 maxDefLvl++
144 }
145
146 switch n := n.(type) {
147 case *GroupNode:
148 for _, f := range n.fields {
149 s.buildTree(f, maxDefLvl, maxRepLvl, base)
150 }
151 case *PrimitiveNode:
152 s.nodeToLeaf[n] = len(s.leaves)
153 s.leaves = append(s.leaves, NewColumn(n, maxDefLvl, maxRepLvl))
154 s.leafToBase[len(s.leaves)-1] = base
155 s.leafToIndex.Add(n.Path(), len(s.leaves)-1)
156 }
157 }
158
159
160 func (s *Schema) Column(i int) *Column {
161 return s.leaves[i]
162 }
163
164
165
166
167
168 func (s *Schema) ColumnIndexByName(nodePath string) int {
169 if search, ok := s.leafToIndex[nodePath]; ok {
170 return search[0]
171 }
172 return -1
173 }
174
175
176
177
178 func (s *Schema) ColumnIndexByNode(n Node) int {
179 if search, ok := s.leafToIndex[n.Path()]; ok {
180 for _, idx := range search {
181 if n == s.Column(idx).SchemaNode() {
182 return idx
183 }
184 }
185 }
186 return -1
187 }
188
189
190
191 func (s *Schema) ColumnRoot(i int) Node {
192 return s.leafToBase[i]
193 }
194
195
196 func (s *Schema) HasRepeatedFields() bool {
197 return s.root.(*GroupNode).HasRepeatedFields()
198 }
199
200
201
202 func (s *Schema) UpdateColumnOrders(orders []parquet.ColumnOrder) error {
203 if len(orders) != s.NumColumns() {
204 return xerrors.New("parquet: malformed schema: not enough ColumnOrder values")
205 }
206
207 visitor := schemaColumnOrderUpdater{orders, 0}
208 s.root.Visit(&visitor)
209 return nil
210 }
211
212 func (s *Schema) String() string {
213 var b strings.Builder
214 PrintSchema(s.root, &b, 2)
215 return b.String()
216 }
217
218
219
220
221 func NewSchema(root *GroupNode) *Schema {
222 s := &Schema{
223 root,
224 make([]*Column, 0),
225 make(map[*PrimitiveNode]int),
226 make(map[int]Node),
227 make(strIntMultimap),
228 }
229
230 for _, f := range root.fields {
231 s.buildTree(f, 0, 0, f)
232 }
233 return s
234 }
235
236 type schemaColumnOrderUpdater struct {
237 colOrders []parquet.ColumnOrder
238 leafCount int
239 }
240
241 func (s *schemaColumnOrderUpdater) VisitPre(n Node) bool {
242 if n.Type() == Primitive {
243 leaf := n.(*PrimitiveNode)
244 leaf.ColumnOrder = s.colOrders[s.leafCount]
245 s.leafCount++
246 }
247 return true
248 }
249
250 func (s *schemaColumnOrderUpdater) VisitPost(Node) {}
251
252 type toThriftVisitor struct {
253 elements []*format.SchemaElement
254 }
255
256 func (t *toThriftVisitor) VisitPre(n Node) bool {
257 t.elements = append(t.elements, n.toThrift())
258 return true
259 }
260
261 func (t *toThriftVisitor) VisitPost(Node) {}
262
263
264
265 func ToThrift(schema *GroupNode) []*format.SchemaElement {
266 t := &toThriftVisitor{make([]*format.SchemaElement, 0)}
267 schema.Visit(t)
268 return t.elements
269 }
270
271 type schemaPrinter struct {
272 w io.Writer
273 indent int
274 indentWidth int
275 }
276
277 func (s *schemaPrinter) VisitPre(n Node) bool {
278 fmt.Fprint(s.w, strings.Repeat(" ", s.indent))
279 if n.Type() == Group {
280 g := n.(*GroupNode)
281 fmt.Fprintf(s.w, "%s group field_id=%d %s", g.RepetitionType(), g.FieldID(), g.Name())
282 _, invalid := g.logicalType.(UnknownLogicalType)
283 _, none := g.logicalType.(NoLogicalType)
284
285 if g.logicalType != nil && !invalid && !none {
286 fmt.Fprintf(s.w, " (%s)", g.logicalType)
287 } else if g.convertedType != ConvertedTypes.None {
288 fmt.Fprintf(s.w, " (%s)", g.convertedType)
289 }
290
291 fmt.Fprintln(s.w, " {")
292 s.indent += s.indentWidth
293 } else {
294 p := n.(*PrimitiveNode)
295 fmt.Fprintf(s.w, "%s %s field_id=%d %s", p.RepetitionType(), strings.ToLower(p.PhysicalType().String()), p.FieldID(), p.Name())
296 _, invalid := p.logicalType.(UnknownLogicalType)
297 _, none := p.logicalType.(NoLogicalType)
298
299 if p.logicalType != nil && !invalid && !none {
300 fmt.Fprintf(s.w, " (%s)", p.logicalType)
301 } else if p.convertedType == ConvertedTypes.Decimal {
302 fmt.Fprintf(s.w, " (%s(%d,%d))", p.convertedType, p.DecimalMetadata().Precision, p.DecimalMetadata().Scale)
303 } else if p.convertedType != ConvertedTypes.None {
304 fmt.Fprintf(s.w, " (%s)", p.convertedType)
305 }
306 fmt.Fprintln(s.w, ";")
307 }
308 return true
309 }
310
311 func (s *schemaPrinter) VisitPost(n Node) {
312 if n.Type() == Group {
313 s.indent -= s.indentWidth
314 fmt.Fprint(s.w, strings.Repeat(" ", s.indent))
315 fmt.Fprintln(s.w, "}")
316 }
317 }
318
319
320
321 func PrintSchema(n Node, w io.Writer, indentWidth int) {
322 n.Visit(&schemaPrinter{w, 0, indentWidth})
323 }
324
325 type strIntMultimap map[string][]int
326
327 func (f strIntMultimap) Add(key string, val int) bool {
328 if _, ok := f[key]; !ok {
329 f[key] = []int{val}
330 return false
331 }
332 f[key] = append(f[key], val)
333 return true
334 }
335
View as plain text