...

Source file src/cloud.google.com/go/bigquery/arrow.go

Documentation: cloud.google.com/go/bigquery

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package bigquery
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/base64"
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"math/big"
    24  
    25  	"cloud.google.com/go/civil"
    26  	"github.com/apache/arrow/go/v15/arrow"
    27  	"github.com/apache/arrow/go/v15/arrow/array"
    28  	"github.com/apache/arrow/go/v15/arrow/ipc"
    29  	"github.com/apache/arrow/go/v15/arrow/memory"
    30  	"google.golang.org/api/iterator"
    31  )
    32  
    33  // ArrowRecordBatch represents an Arrow RecordBatch with the source PartitionID
    34  type ArrowRecordBatch struct {
    35  	reader io.Reader
    36  	// Serialized Arrow Record Batch.
    37  	Data []byte
    38  	// Serialized Arrow Schema.
    39  	Schema []byte
    40  	// Source partition ID. In the Storage API world, it represents the ReadStream.
    41  	PartitionID string
    42  }
    43  
    44  // Read makes ArrowRecordBatch implements io.Reader
    45  func (r *ArrowRecordBatch) Read(p []byte) (int, error) {
    46  	if r.reader == nil {
    47  		buf := bytes.NewBuffer(r.Schema)
    48  		buf.Write(r.Data)
    49  		r.reader = buf
    50  	}
    51  	return r.reader.Read(p)
    52  }
    53  
    54  // ArrowIterator represents a way to iterate through a stream of arrow records.
    55  // Experimental: this interface is experimental and may be modified or removed in future versions,
    56  // regardless of any other documented package stability guarantees.
    57  type ArrowIterator interface {
    58  	Next() (*ArrowRecordBatch, error)
    59  	Schema() Schema
    60  	SerializedArrowSchema() []byte
    61  }
    62  
    63  // NewArrowIteratorReader allows to consume an ArrowIterator as an io.Reader.
    64  // Experimental: this interface is experimental and may be modified or removed in future versions,
    65  // regardless of any other documented package stability guarantees.
    66  func NewArrowIteratorReader(it ArrowIterator) io.Reader {
    67  	return &arrowIteratorReader{
    68  		it: it,
    69  	}
    70  }
    71  
    72  type arrowIteratorReader struct {
    73  	buf *bytes.Buffer
    74  	it  ArrowIterator
    75  }
    76  
    77  // Read makes ArrowIteratorReader implement io.Reader
    78  func (r *arrowIteratorReader) Read(p []byte) (int, error) {
    79  	if r.it == nil {
    80  		return -1, errors.New("bigquery: nil ArrowIterator")
    81  	}
    82  	if r.buf == nil { // init with schema
    83  		buf := bytes.NewBuffer(r.it.SerializedArrowSchema())
    84  		r.buf = buf
    85  	}
    86  	n, err := r.buf.Read(p)
    87  	if err == io.EOF {
    88  		batch, err := r.it.Next()
    89  		if err == iterator.Done {
    90  			return 0, io.EOF
    91  		}
    92  		r.buf.Write(batch.Data)
    93  		return r.Read(p)
    94  	}
    95  	return n, err
    96  }
    97  
    98  type arrowDecoder struct {
    99  	allocator   memory.Allocator
   100  	tableSchema Schema
   101  	arrowSchema *arrow.Schema
   102  }
   103  
   104  func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) {
   105  	buf := bytes.NewBuffer(arrowSerializedSchema)
   106  	r, err := ipc.NewReader(buf)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	defer r.Release()
   111  	p := &arrowDecoder{
   112  		tableSchema: schema,
   113  		arrowSchema: r.Schema(),
   114  		allocator:   memory.DefaultAllocator,
   115  	}
   116  	return p, nil
   117  }
   118  
   119  func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) {
   120  	return ipc.NewReader(
   121  		arrowRecordBatch,
   122  		ipc.WithSchema(ap.arrowSchema),
   123  		ipc.WithAllocator(ap.allocator),
   124  	)
   125  }
   126  
   127  // decodeArrowRecords decodes BQ ArrowRecordBatch into rows of []Value.
   128  func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) {
   129  	r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	defer r.Release()
   134  	rs := make([][]Value, 0)
   135  	for r.Next() {
   136  		rec := r.Record()
   137  		values, err := ap.convertArrowRecordValue(rec)
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  		rs = append(rs, values...)
   142  	}
   143  	return rs, nil
   144  }
   145  
   146  // decodeRetainedArrowRecords decodes BQ ArrowRecordBatch into a list of retained arrow.Record.
   147  func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) {
   148  	r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	defer r.Release()
   153  	records := []arrow.Record{}
   154  	for r.Next() {
   155  		rec := r.Record()
   156  		rec.Retain()
   157  		records = append(records, rec)
   158  	}
   159  	return records, nil
   160  }
   161  
   162  // convertArrowRows converts an arrow.Record into a series of Value slices.
   163  func (ap *arrowDecoder) convertArrowRecordValue(record arrow.Record) ([][]Value, error) {
   164  	rs := make([][]Value, record.NumRows())
   165  	for i := range rs {
   166  		rs[i] = make([]Value, record.NumCols())
   167  	}
   168  	for j, col := range record.Columns() {
   169  		fs := ap.tableSchema[j]
   170  		ft := ap.arrowSchema.Field(j).Type
   171  		for i := 0; i < col.Len(); i++ {
   172  			v, err := convertArrowValue(col, i, ft, fs)
   173  			if err != nil {
   174  				return nil, fmt.Errorf("found arrow type %s, but could not convert value: %v", ap.arrowSchema.Field(j).Type, err)
   175  			}
   176  			rs[i][j] = v
   177  		}
   178  	}
   179  	return rs, nil
   180  }
   181  
   182  // convertArrow gets row value in the given column and converts to a Value.
   183  // Arrow is a colunar storage, so we navigate first by column and get the row value.
   184  // More details on conversions can be seen here: https://cloud.google.com/bigquery/docs/reference/storage#arrow_schema_details
   185  func convertArrowValue(col arrow.Array, i int, ft arrow.DataType, fs *FieldSchema) (Value, error) {
   186  	if !col.IsValid(i) {
   187  		return nil, nil
   188  	}
   189  	switch ft.(type) {
   190  	case *arrow.BooleanType:
   191  		v := col.(*array.Boolean).Value(i)
   192  		return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
   193  	case *arrow.Int8Type:
   194  		v := col.(*array.Int8).Value(i)
   195  		return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
   196  	case *arrow.Int16Type:
   197  		v := col.(*array.Int16).Value(i)
   198  		return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
   199  	case *arrow.Int32Type:
   200  		v := col.(*array.Int32).Value(i)
   201  		return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
   202  	case *arrow.Int64Type:
   203  		v := col.(*array.Int64).Value(i)
   204  		return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
   205  	case *arrow.Float16Type:
   206  		v := col.(*array.Float16).Value(i)
   207  		return convertBasicType(fmt.Sprintf("%v", v.Float32()), fs.Type)
   208  	case *arrow.Float32Type:
   209  		v := col.(*array.Float32).Value(i)
   210  		return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
   211  	case *arrow.Float64Type:
   212  		v := col.(*array.Float64).Value(i)
   213  		return convertBasicType(fmt.Sprintf("%v", v), fs.Type)
   214  	case *arrow.BinaryType:
   215  		v := col.(*array.Binary).Value(i)
   216  		encoded := base64.StdEncoding.EncodeToString(v)
   217  		return convertBasicType(encoded, fs.Type)
   218  	case *arrow.StringType:
   219  		v := col.(*array.String).Value(i)
   220  		return convertBasicType(v, fs.Type)
   221  	case *arrow.Date32Type:
   222  		v := col.(*array.Date32).Value(i)
   223  		return convertBasicType(v.FormattedString(), fs.Type)
   224  	case *arrow.Date64Type:
   225  		v := col.(*array.Date64).Value(i)
   226  		return convertBasicType(v.FormattedString(), fs.Type)
   227  	case *arrow.TimestampType:
   228  		v := col.(*array.Timestamp).Value(i)
   229  		dft := ft.(*arrow.TimestampType)
   230  		t := v.ToTime(dft.Unit)
   231  		if dft.TimeZone == "" { // Datetime
   232  			return Value(civil.DateTimeOf(t)), nil
   233  		}
   234  		return Value(t.UTC()), nil // Timestamp
   235  	case *arrow.Time32Type:
   236  		v := col.(*array.Time32).Value(i)
   237  		return convertBasicType(v.FormattedString(arrow.Microsecond), fs.Type)
   238  	case *arrow.Time64Type:
   239  		v := col.(*array.Time64).Value(i)
   240  		return convertBasicType(v.FormattedString(arrow.Microsecond), fs.Type)
   241  	case *arrow.Decimal128Type:
   242  		dft := ft.(*arrow.Decimal128Type)
   243  		v := col.(*array.Decimal128).Value(i)
   244  		rat := big.NewRat(1, 1)
   245  		rat.Num().SetBytes(v.BigInt().Bytes())
   246  		d := rat.Denom()
   247  		d.Exp(big.NewInt(10), big.NewInt(int64(dft.Scale)), nil)
   248  		return Value(rat), nil
   249  	case *arrow.Decimal256Type:
   250  		dft := ft.(*arrow.Decimal256Type)
   251  		v := col.(*array.Decimal256).Value(i)
   252  		rat := big.NewRat(1, 1)
   253  		rat.Num().SetBytes(v.BigInt().Bytes())
   254  		d := rat.Denom()
   255  		d.Exp(big.NewInt(10), big.NewInt(int64(dft.Scale)), nil)
   256  		return Value(rat), nil
   257  	case *arrow.ListType:
   258  		arr := col.(*array.List)
   259  		dft := ft.(*arrow.ListType)
   260  		values := []Value{}
   261  		start, end := arr.ValueOffsets(i)
   262  		slice := array.NewSlice(arr.ListValues(), start, end)
   263  		for j := 0; j < slice.Len(); j++ {
   264  			v, err := convertArrowValue(slice, j, dft.Elem(), fs)
   265  			if err != nil {
   266  				return nil, err
   267  			}
   268  			values = append(values, v)
   269  		}
   270  		return values, nil
   271  	case *arrow.StructType:
   272  		arr := col.(*array.Struct)
   273  		nestedValues := []Value{}
   274  		fields := ft.(*arrow.StructType).Fields()
   275  		if fs.Type == RangeFieldType {
   276  			rangeFieldSchema := &FieldSchema{
   277  				Type: fs.RangeElementType.Type,
   278  			}
   279  			start, err := convertArrowValue(arr.Field(0), i, fields[0].Type, rangeFieldSchema)
   280  			if err != nil {
   281  				return nil, err
   282  			}
   283  			end, err := convertArrowValue(arr.Field(1), i, fields[1].Type, rangeFieldSchema)
   284  			if err != nil {
   285  				return nil, err
   286  			}
   287  			rangeValue := &RangeValue{Start: start, End: end}
   288  			return Value(rangeValue), nil
   289  		}
   290  		for fIndex, f := range fields {
   291  			v, err := convertArrowValue(arr.Field(fIndex), i, f.Type, fs.Schema[fIndex])
   292  			if err != nil {
   293  				return nil, err
   294  			}
   295  			nestedValues = append(nestedValues, v)
   296  		}
   297  		return nestedValues, nil
   298  	default:
   299  		return nil, fmt.Errorf("unknown arrow type: %v", ft)
   300  	}
   301  }
   302  

View as plain text