1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package bigquery
16
17 import (
18 "bytes"
19 "encoding/base64"
20 "errors"
21 "fmt"
22 "io"
23 "math/big"
24
25 "cloud.google.com/go/civil"
26 "github.com/apache/arrow/go/v15/arrow"
27 "github.com/apache/arrow/go/v15/arrow/array"
28 "github.com/apache/arrow/go/v15/arrow/ipc"
29 "github.com/apache/arrow/go/v15/arrow/memory"
30 "google.golang.org/api/iterator"
31 )
32
33
34 type ArrowRecordBatch struct {
35 reader io.Reader
36
37 Data []byte
38
39 Schema []byte
40
41 PartitionID string
42 }
43
44
45 func (r *ArrowRecordBatch) Read(p []byte) (int, error) {
46 if r.reader == nil {
47 buf := bytes.NewBuffer(r.Schema)
48 buf.Write(r.Data)
49 r.reader = buf
50 }
51 return r.reader.Read(p)
52 }
53
54
55
56
57 type ArrowIterator interface {
58 Next() (*ArrowRecordBatch, error)
59 Schema() Schema
60 SerializedArrowSchema() []byte
61 }
62
63
64
65
66 func NewArrowIteratorReader(it ArrowIterator) io.Reader {
67 return &arrowIteratorReader{
68 it: it,
69 }
70 }
71
72 type arrowIteratorReader struct {
73 buf *bytes.Buffer
74 it ArrowIterator
75 }
76
77
78 func (r *arrowIteratorReader) Read(p []byte) (int, error) {
79 if r.it == nil {
80 return -1, errors.New("bigquery: nil ArrowIterator")
81 }
82 if r.buf == nil {
83 buf := bytes.NewBuffer(r.it.SerializedArrowSchema())
84 r.buf = buf
85 }
86 n, err := r.buf.Read(p)
87 if err == io.EOF {
88 batch, err := r.it.Next()
89 if err == iterator.Done {
90 return 0, io.EOF
91 }
92 r.buf.Write(batch.Data)
93 return r.Read(p)
94 }
95 return n, err
96 }
97
98 type arrowDecoder struct {
99 allocator memory.Allocator
100 tableSchema Schema
101 arrowSchema *arrow.Schema
102 }
103
104 func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) {
105 buf := bytes.NewBuffer(arrowSerializedSchema)
106 r, err := ipc.NewReader(buf)
107 if err != nil {
108 return nil, err
109 }
110 defer r.Release()
111 p := &arrowDecoder{
112 tableSchema: schema,
113 arrowSchema: r.Schema(),
114 allocator: memory.DefaultAllocator,
115 }
116 return p, nil
117 }
118
119 func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) {
120 return ipc.NewReader(
121 arrowRecordBatch,
122 ipc.WithSchema(ap.arrowSchema),
123 ipc.WithAllocator(ap.allocator),
124 )
125 }
126
127
128 func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) {
129 r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
130 if err != nil {
131 return nil, err
132 }
133 defer r.Release()
134 rs := make([][]Value, 0)
135 for r.Next() {
136 rec := r.Record()
137 values, err := ap.convertArrowRecordValue(rec)
138 if err != nil {
139 return nil, err
140 }
141 rs = append(rs, values...)
142 }
143 return rs, nil
144 }
145
146
147 func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) {
148 r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
149 if err != nil {
150 return nil, err
151 }
152 defer r.Release()
153 records := []arrow.Record{}
154 for r.Next() {
155 rec := r.Record()
156 rec.Retain()
157 records = append(records, rec)
158 }
159 return records, nil
160 }
161
162
163 func (ap *arrowDecoder) convertArrowRecordValue(record arrow.Record) ([][]Value, error) {
164 rs := make([][]Value, record.NumRows())
165 for i := range rs {
166 rs[i] = make([]Value, record.NumCols())
167 }
168 for j, col := range record.Columns() {
169 fs := ap.tableSchema[j]
170 ft := ap.arrowSchema.Field(j).Type
171 for i := 0; i < col.Len(); i++ {
172 v, err := convertArrowValue(col, i, ft, fs)
173 if err != nil {
174 return nil, fmt.Errorf("found arrow type %s, but could not convert value: %v", ap.arrowSchema.Field(j).Type, err)
175 }
176 rs[i][j] = v
177 }
178 }
179 return rs, nil
180 }
181
182
183
184
185 func convertArrowValue(col arrow.Array, i int, ft arrow.DataType, fs *FieldSchema) (Value, error) {
186 if !col.IsValid(i) {
187 return nil, nil
188 }
189 switch ft.(type) {
190 case *arrow.BooleanType:
191 v := col.(*array.Boolean).Value(i)
192 return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
193 case *arrow.Int8Type:
194 v := col.(*array.Int8).Value(i)
195 return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
196 case *arrow.Int16Type:
197 v := col.(*array.Int16).Value(i)
198 return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
199 case *arrow.Int32Type:
200 v := col.(*array.Int32).Value(i)
201 return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
202 case *arrow.Int64Type:
203 v := col.(*array.Int64).Value(i)
204 return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
205 case *arrow.Float16Type:
206 v := col.(*array.Float16).Value(i)
207 return convertBasicType(fmt.Sprintf("%v", v.Float32()), fs.Type)
208 case *arrow.Float32Type:
209 v := col.(*array.Float32).Value(i)
210 return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
211 case *arrow.Float64Type:
212 v := col.(*array.Float64).Value(i)
213 return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
214 case *arrow.BinaryType:
215 v := col.(*array.Binary).Value(i)
216 encoded := base64.StdEncoding.EncodeToString(v)
217 return convertBasicType(encoded, fs.Type)
218 case *arrow.StringType:
219 v := col.(*array.String).Value(i)
220 return convertBasicType(v, fs.Type)
221 case *arrow.Date32Type:
222 v := col.(*array.Date32).Value(i)
223 return convertBasicType(v.FormattedString(), fs.Type)
224 case *arrow.Date64Type:
225 v := col.(*array.Date64).Value(i)
226 return convertBasicType(v.FormattedString(), fs.Type)
227 case *arrow.TimestampType:
228 v := col.(*array.Timestamp).Value(i)
229 dft := ft.(*arrow.TimestampType)
230 t := v.ToTime(dft.Unit)
231 if dft.TimeZone == "" {
232 return Value(civil.DateTimeOf(t)), nil
233 }
234 return Value(t.UTC()), nil
235 case *arrow.Time32Type:
236 v := col.(*array.Time32).Value(i)
237 return convertBasicType(v.FormattedString(arrow.Microsecond), fs.Type)
238 case *arrow.Time64Type:
239 v := col.(*array.Time64).Value(i)
240 return convertBasicType(v.FormattedString(arrow.Microsecond), fs.Type)
241 case *arrow.Decimal128Type:
242 dft := ft.(*arrow.Decimal128Type)
243 v := col.(*array.Decimal128).Value(i)
244 rat := big.NewRat(1, 1)
245 rat.Num().SetBytes(v.BigInt().Bytes())
246 d := rat.Denom()
247 d.Exp(big.NewInt(10), big.NewInt(int64(dft.Scale)), nil)
248 return Value(rat), nil
249 case *arrow.Decimal256Type:
250 dft := ft.(*arrow.Decimal256Type)
251 v := col.(*array.Decimal256).Value(i)
252 rat := big.NewRat(1, 1)
253 rat.Num().SetBytes(v.BigInt().Bytes())
254 d := rat.Denom()
255 d.Exp(big.NewInt(10), big.NewInt(int64(dft.Scale)), nil)
256 return Value(rat), nil
257 case *arrow.ListType:
258 arr := col.(*array.List)
259 dft := ft.(*arrow.ListType)
260 values := []Value{}
261 start, end := arr.ValueOffsets(i)
262 slice := array.NewSlice(arr.ListValues(), start, end)
263 for j := 0; j < slice.Len(); j++ {
264 v, err := convertArrowValue(slice, j, dft.Elem(), fs)
265 if err != nil {
266 return nil, err
267 }
268 values = append(values, v)
269 }
270 return values, nil
271 case *arrow.StructType:
272 arr := col.(*array.Struct)
273 nestedValues := []Value{}
274 fields := ft.(*arrow.StructType).Fields()
275 if fs.Type == RangeFieldType {
276 rangeFieldSchema := &FieldSchema{
277 Type: fs.RangeElementType.Type,
278 }
279 start, err := convertArrowValue(arr.Field(0), i, fields[0].Type, rangeFieldSchema)
280 if err != nil {
281 return nil, err
282 }
283 end, err := convertArrowValue(arr.Field(1), i, fields[1].Type, rangeFieldSchema)
284 if err != nil {
285 return nil, err
286 }
287 rangeValue := &RangeValue{Start: start, End: end}
288 return Value(rangeValue), nil
289 }
290 for fIndex, f := range fields {
291 v, err := convertArrowValue(arr.Field(fIndex), i, f.Type, fs.Schema[fIndex])
292 if err != nil {
293 return nil, err
294 }
295 nestedValues = append(nestedValues, v)
296 }
297 return nestedValues, nil
298 default:
299 return nil, fmt.Errorf("unknown arrow type: %v", ft)
300 }
301 }
302
View as plain text