...

Source file src/cloud.google.com/go/bigquery/storage/managedwriter/adapt/protoconversion.go

Documentation: cloud.google.com/go/bigquery/storage/managedwriter/adapt

     1  // Copyright 2021 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package adapt
    16  
    17  import (
    18  	"encoding/base64"
    19  	"fmt"
    20  	"sort"
    21  	"strings"
    22  
    23  	"cloud.google.com/go/bigquery/storage/apiv1/storagepb"
    24  	"google.golang.org/protobuf/proto"
    25  	"google.golang.org/protobuf/reflect/protodesc"
    26  	"google.golang.org/protobuf/reflect/protoreflect"
    27  	"google.golang.org/protobuf/types/descriptorpb"
    28  	"google.golang.org/protobuf/types/known/wrapperspb"
    29  )
    30  
    31  var bqModeToFieldLabelMapProto2 = map[storagepb.TableFieldSchema_Mode]descriptorpb.FieldDescriptorProto_Label{
    32  	storagepb.TableFieldSchema_NULLABLE: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
    33  	storagepb.TableFieldSchema_REPEATED: descriptorpb.FieldDescriptorProto_LABEL_REPEATED,
    34  	storagepb.TableFieldSchema_REQUIRED: descriptorpb.FieldDescriptorProto_LABEL_REQUIRED,
    35  }
    36  
    37  var bqModeToFieldLabelMapProto3 = map[storagepb.TableFieldSchema_Mode]descriptorpb.FieldDescriptorProto_Label{
    38  	storagepb.TableFieldSchema_NULLABLE: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
    39  	storagepb.TableFieldSchema_REPEATED: descriptorpb.FieldDescriptorProto_LABEL_REPEATED,
    40  	storagepb.TableFieldSchema_REQUIRED: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
    41  }
    42  
    43  func convertModeToLabel(mode storagepb.TableFieldSchema_Mode, useProto3 bool) *descriptorpb.FieldDescriptorProto_Label {
    44  	if useProto3 {
    45  		return bqModeToFieldLabelMapProto3[mode].Enum()
    46  	}
    47  	return bqModeToFieldLabelMapProto2[mode].Enum()
    48  }
    49  
    50  // Allows conversion between BQ schema type and FieldDescriptorProto's type.
    51  var bqTypeToFieldTypeMap = map[storagepb.TableFieldSchema_Type]descriptorpb.FieldDescriptorProto_Type{
    52  	storagepb.TableFieldSchema_BIGNUMERIC: descriptorpb.FieldDescriptorProto_TYPE_BYTES,
    53  	storagepb.TableFieldSchema_BOOL:       descriptorpb.FieldDescriptorProto_TYPE_BOOL,
    54  	storagepb.TableFieldSchema_BYTES:      descriptorpb.FieldDescriptorProto_TYPE_BYTES,
    55  	storagepb.TableFieldSchema_DATE:       descriptorpb.FieldDescriptorProto_TYPE_INT32,
    56  	storagepb.TableFieldSchema_DATETIME:   descriptorpb.FieldDescriptorProto_TYPE_INT64,
    57  	storagepb.TableFieldSchema_DOUBLE:     descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
    58  	storagepb.TableFieldSchema_GEOGRAPHY:  descriptorpb.FieldDescriptorProto_TYPE_STRING,
    59  	storagepb.TableFieldSchema_INT64:      descriptorpb.FieldDescriptorProto_TYPE_INT64,
    60  	storagepb.TableFieldSchema_NUMERIC:    descriptorpb.FieldDescriptorProto_TYPE_BYTES,
    61  	storagepb.TableFieldSchema_STRING:     descriptorpb.FieldDescriptorProto_TYPE_STRING,
    62  	storagepb.TableFieldSchema_STRUCT:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
    63  	storagepb.TableFieldSchema_TIME:       descriptorpb.FieldDescriptorProto_TYPE_INT64,
    64  	storagepb.TableFieldSchema_TIMESTAMP:  descriptorpb.FieldDescriptorProto_TYPE_INT64,
    65  	storagepb.TableFieldSchema_RANGE:      descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
    66  }
    67  
    68  var allowedRangeTypes = []storagepb.TableFieldSchema_Type{
    69  	storagepb.TableFieldSchema_DATE,
    70  	storagepb.TableFieldSchema_DATETIME,
    71  	storagepb.TableFieldSchema_TIMESTAMP,
    72  }
    73  
    74  // Primitive types which can leverage packed encoding when repeated/arrays.
    75  //
    76  // Note: many/most of these aren't used when doing schema to proto conversion, but
    77  // are included for completeness.
    78  var packedTypes = []descriptorpb.FieldDescriptorProto_Type{
    79  	descriptorpb.FieldDescriptorProto_TYPE_INT32,
    80  	descriptorpb.FieldDescriptorProto_TYPE_INT64,
    81  	descriptorpb.FieldDescriptorProto_TYPE_UINT32,
    82  	descriptorpb.FieldDescriptorProto_TYPE_UINT64,
    83  	descriptorpb.FieldDescriptorProto_TYPE_SINT32,
    84  	descriptorpb.FieldDescriptorProto_TYPE_SINT64,
    85  	descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
    86  	descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
    87  	descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
    88  	descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
    89  	descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
    90  	descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
    91  	descriptorpb.FieldDescriptorProto_TYPE_BOOL,
    92  	descriptorpb.FieldDescriptorProto_TYPE_ENUM,
    93  }
    94  
    95  // For TableFieldSchema OPTIONAL mode, we use the wrapper types to allow for the
    96  // proper representation of NULL values, as proto3 semantics would just use default value.
    97  var bqTypeToWrapperMap = map[storagepb.TableFieldSchema_Type]string{
    98  	storagepb.TableFieldSchema_BIGNUMERIC: ".google.protobuf.BytesValue",
    99  	storagepb.TableFieldSchema_BOOL:       ".google.protobuf.BoolValue",
   100  	storagepb.TableFieldSchema_BYTES:      ".google.protobuf.BytesValue",
   101  	storagepb.TableFieldSchema_DATE:       ".google.protobuf.Int32Value",
   102  	storagepb.TableFieldSchema_DATETIME:   ".google.protobuf.Int64Value",
   103  	storagepb.TableFieldSchema_DOUBLE:     ".google.protobuf.DoubleValue",
   104  	storagepb.TableFieldSchema_GEOGRAPHY:  ".google.protobuf.StringValue",
   105  	storagepb.TableFieldSchema_INT64:      ".google.protobuf.Int64Value",
   106  	storagepb.TableFieldSchema_NUMERIC:    ".google.protobuf.BytesValue",
   107  	storagepb.TableFieldSchema_STRING:     ".google.protobuf.StringValue",
   108  	storagepb.TableFieldSchema_TIME:       ".google.protobuf.Int64Value",
   109  	storagepb.TableFieldSchema_TIMESTAMP:  ".google.protobuf.Int64Value",
   110  }
   111  
   112  // filename used by well known types proto
   113  var wellKnownTypesWrapperName = "google/protobuf/wrappers.proto"
   114  
   115  var rangeTypesPrefix = "rangemessage_range_"
   116  
   117  // dependencyCache is used to reduce the number of unique messages we generate by caching based on the tableschema.
   118  //
   119  // Keys are based on the base64-encoded serialized tableschema value.
   120  type dependencyCache struct {
   121  	// keyed by element type
   122  	rangeTypes map[storagepb.TableFieldSchema_Type]protoreflect.MessageDescriptor
   123  	// general cache
   124  	msgs map[string]protoreflect.MessageDescriptor
   125  }
   126  
   127  func newDependencyCache() *dependencyCache {
   128  	return &dependencyCache{
   129  		rangeTypes: make(map[storagepb.TableFieldSchema_Type]protoreflect.MessageDescriptor),
   130  		msgs:       make(map[string]protoreflect.MessageDescriptor),
   131  	}
   132  }
   133  
   134  func (dm *dependencyCache) get(schema *storagepb.TableSchema) protoreflect.MessageDescriptor {
   135  	if dm == nil {
   136  		return nil
   137  	}
   138  	b, err := proto.Marshal(schema)
   139  	if err != nil {
   140  		return nil
   141  	}
   142  	encoded := base64.StdEncoding.EncodeToString(b)
   143  	if desc, ok := dm.msgs[encoded]; ok {
   144  		return desc
   145  	}
   146  	return nil
   147  }
   148  
   149  func (dm *dependencyCache) getFileDescriptorProtos() []*descriptorpb.FileDescriptorProto {
   150  	var fdpList []*descriptorpb.FileDescriptorProto
   151  	// emit encountered messages.
   152  	for _, d := range dm.msgs {
   153  		if fd := d.ParentFile(); fd != nil {
   154  			fdp := protodesc.ToFileDescriptorProto(fd)
   155  			fdpList = append(fdpList, fdp)
   156  		}
   157  	}
   158  	// emit any range value types used.
   159  	for _, d := range dm.rangeTypes {
   160  		if fd := d.ParentFile(); fd != nil {
   161  			fdp := protodesc.ToFileDescriptorProto(fd)
   162  			fdpList = append(fdpList, fdp)
   163  		}
   164  	}
   165  	return fdpList
   166  }
   167  
   168  func (dm *dependencyCache) add(schema *storagepb.TableSchema, descriptor protoreflect.MessageDescriptor) error {
   169  	if dm == nil {
   170  		return fmt.Errorf("cache is nil")
   171  	}
   172  	b, err := proto.Marshal(schema)
   173  	if err != nil {
   174  		return fmt.Errorf("failed to serialize tableschema: %w", err)
   175  	}
   176  	encoded := base64.StdEncoding.EncodeToString(b)
   177  	dm.msgs[encoded] = descriptor
   178  	return nil
   179  }
   180  
   181  func (dm *dependencyCache) addRangeByElementType(typ storagepb.TableFieldSchema_Type, useProto3 bool) (protoreflect.MessageDescriptor, error) {
   182  	if md, present := dm.rangeTypes[typ]; present {
   183  		// already added, do nothing.
   184  		return md, nil
   185  	}
   186  	// Not yet present.  Build the message.
   187  	allowed := false
   188  	for _, a := range allowedRangeTypes {
   189  		if typ == a {
   190  			allowed = true
   191  		}
   192  	}
   193  	if !allowed {
   194  		return nil, fmt.Errorf("range does not support %q as a valid element type", typ.String())
   195  	}
   196  	ts := &storagepb.TableSchema{
   197  		Fields: []*storagepb.TableFieldSchema{
   198  			{
   199  				Name: "start",
   200  				Type: typ,
   201  				Mode: storagepb.TableFieldSchema_NULLABLE,
   202  			},
   203  			{
   204  				Name: "end",
   205  				Type: typ,
   206  				Mode: storagepb.TableFieldSchema_NULLABLE,
   207  			},
   208  		},
   209  	}
   210  	// we put the range types outside the hierarchical namespace as they're effectively BQ-specific well-known types.
   211  	msgTypeName := fmt.Sprintf("%s%s", rangeTypesPrefix, strings.ToLower(typ.String()))
   212  	// use a new dependency cache, as we don't want to taint the main one due to matching schema
   213  	md, err := storageSchemaToDescriptorInternal(ts, msgTypeName, newDependencyCache(), useProto3)
   214  	if err != nil {
   215  		return nil, fmt.Errorf("failed to generate range descriptor %q: %v", msgTypeName, err)
   216  	}
   217  	dm.rangeTypes[typ] = md
   218  	return md, nil
   219  }
   220  
   221  func (dm *dependencyCache) getRange(typ storagepb.TableFieldSchema_Type) protoreflect.MessageDescriptor {
   222  	md, ok := dm.rangeTypes[typ]
   223  	if !ok {
   224  		return nil
   225  	}
   226  	return md
   227  }
   228  
   229  // StorageSchemaToProto2Descriptor builds a protoreflect.Descriptor for a given table schema using proto2 syntax.
   230  func StorageSchemaToProto2Descriptor(inSchema *storagepb.TableSchema, scope string) (protoreflect.Descriptor, error) {
   231  	dc := newDependencyCache()
   232  	// TODO: b/193064992 tracks support for wrapper types.  In the interim, disable wrapper usage.
   233  	return storageSchemaToDescriptorInternal(inSchema, scope, dc, false)
   234  }
   235  
   236  // StorageSchemaToProto3Descriptor builds a protoreflect.Descriptor for a given table schema using proto3 syntax.
   237  //
   238  // NOTE: Currently the write API doesn't yet support proto3 behaviors (default value, wrapper types, etc), but this is provided for
   239  // completeness.
   240  func StorageSchemaToProto3Descriptor(inSchema *storagepb.TableSchema, scope string) (protoreflect.Descriptor, error) {
   241  	dc := newDependencyCache()
   242  	return storageSchemaToDescriptorInternal(inSchema, scope, dc, true)
   243  }
   244  
   245  // Internal implementation of the conversion code.
   246  func storageSchemaToDescriptorInternal(inSchema *storagepb.TableSchema, scope string, cache *dependencyCache, useProto3 bool) (protoreflect.MessageDescriptor, error) {
   247  	if inSchema == nil {
   248  		return nil, newConversionError(scope, fmt.Errorf("no input schema was provided"))
   249  	}
   250  
   251  	var fields []*descriptorpb.FieldDescriptorProto
   252  	var deps []protoreflect.FileDescriptor
   253  	var fNumber int32
   254  
   255  	for _, f := range inSchema.GetFields() {
   256  		fNumber = fNumber + 1
   257  		currentScope := fmt.Sprintf("%s__%s", scope, f.GetName())
   258  
   259  		if f.Type == storagepb.TableFieldSchema_STRUCT {
   260  			// If we're dealing with a STRUCT type, we must deal with sub messages.
   261  			// As multiple submessages may share the same type definition, we use a dependency cache
   262  			// and interrogate it / populate it as we're going.
   263  			foundDesc := cache.get(&storagepb.TableSchema{Fields: f.GetFields()})
   264  			if foundDesc != nil {
   265  				// check to see if we already have this in current dependency list
   266  				haveDep := false
   267  				for _, dep := range deps {
   268  					if messageDependsOnFile(foundDesc, dep) {
   269  						haveDep = true
   270  						break
   271  					}
   272  				}
   273  				// If dep is missing, add to current dependencies.
   274  				if !haveDep {
   275  					deps = append(deps, foundDesc.ParentFile())
   276  				}
   277  				// Construct field descriptor for the message.
   278  				fdp, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, string(foundDesc.FullName()), useProto3)
   279  				if err != nil {
   280  					return nil, newConversionError(scope, fmt.Errorf("couldn't convert field to FieldDescriptorProto: %w", err))
   281  				}
   282  				fields = append(fields, fdp)
   283  			} else {
   284  				// Wrap the current struct's fields in a TableSchema outer message, and then build the submessage.
   285  				ts := &storagepb.TableSchema{
   286  					Fields: f.GetFields(),
   287  				}
   288  				desc, err := storageSchemaToDescriptorInternal(ts, currentScope, cache, useProto3)
   289  				if err != nil {
   290  					return nil, newConversionError(currentScope, fmt.Errorf("couldn't convert message: %w", err))
   291  				}
   292  				// Now that we have the submessage definition, we append it both to the local dependencies, as well
   293  				// as inserting it into the cache for possible reuse elsewhere.
   294  				deps = append(deps, desc.ParentFile())
   295  				err = cache.add(ts, desc)
   296  				if err != nil {
   297  					return nil, newConversionError(currentScope, fmt.Errorf("failed to add descriptor to dependency cache: %w", err))
   298  				}
   299  				fdp, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, currentScope, useProto3)
   300  				if err != nil {
   301  					return nil, newConversionError(currentScope, fmt.Errorf("couldn't compute field schema : %w", err))
   302  				}
   303  				fields = append(fields, fdp)
   304  			}
   305  		} else {
   306  			if f.Type == storagepb.TableFieldSchema_RANGE {
   307  				// Range handling is a special case of general struct handling.
   308  				ret := f.GetRangeElementType()
   309  				if ret == nil {
   310  					return nil, fmt.Errorf("field %q is a RANGE, but doesn't include RangeElementType info", f.GetName())
   311  				}
   312  				foundDesc, err := cache.addRangeByElementType(ret.GetType(), useProto3)
   313  				if err != nil {
   314  					return nil, err
   315  				}
   316  				if foundDesc != nil {
   317  					haveDep := false
   318  					for _, dep := range deps {
   319  						if messageDependsOnFile(foundDesc, dep) {
   320  							haveDep = true
   321  							break
   322  						}
   323  					}
   324  					// If dep is missing, add to current dependencies.
   325  					if !haveDep {
   326  						deps = append(deps, foundDesc.ParentFile())
   327  					}
   328  				}
   329  			}
   330  			fd, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, currentScope, useProto3)
   331  			if err != nil {
   332  				return nil, newConversionError(currentScope, err)
   333  			}
   334  			fields = append(fields, fd)
   335  		}
   336  	}
   337  	// Start constructing a DescriptorProto.
   338  	dp := &descriptorpb.DescriptorProto{
   339  		Name:  proto.String(scope),
   340  		Field: fields,
   341  	}
   342  
   343  	// Use the local dependencies to generate a list of filenames.
   344  	depNames := []string{wellKnownTypesWrapperName}
   345  	for _, d := range deps {
   346  		depNames = append(depNames, d.ParentFile().Path())
   347  	}
   348  
   349  	// Now, construct a FileDescriptorProto.
   350  	fdp := &descriptorpb.FileDescriptorProto{
   351  		MessageType: []*descriptorpb.DescriptorProto{dp},
   352  		Name:        proto.String(fmt.Sprintf("%s.proto", scope)),
   353  		Syntax:      proto.String("proto3"),
   354  		Dependency:  depNames,
   355  	}
   356  	if !useProto3 {
   357  		fdp.Syntax = proto.String("proto2")
   358  	}
   359  
   360  	// We'll need a FileDescriptorSet as we have a FileDescriptorProto for the current
   361  	// descriptor we're building, but we need to include all the referenced dependencies.
   362  
   363  	fdpList := []*descriptorpb.FileDescriptorProto{
   364  		fdp,
   365  		protodesc.ToFileDescriptorProto(wrapperspb.File_google_protobuf_wrappers_proto),
   366  	}
   367  	fdpList = append(fdpList, cache.getFileDescriptorProtos()...)
   368  
   369  	fds := &descriptorpb.FileDescriptorSet{
   370  		File: fdpList,
   371  	}
   372  
   373  	// Load the set into a registry, then interrogate it for the descriptor corresponding to the top level message.
   374  	files, err := protodesc.NewFiles(fds)
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  	found, err := files.FindDescriptorByName(protoreflect.FullName(scope))
   379  	if err != nil {
   380  		return nil, err
   381  	}
   382  	return found.(protoreflect.MessageDescriptor), nil
   383  }
   384  
   385  // messageDependsOnFile checks if the given message descriptor already belongs to the file descriptor.
   386  // To check for that, first we check if the message descriptor parent file is the same as the file descriptor.
   387  // If not, check if the message descriptor belongs is contained as a child of the file descriptor.
   388  func messageDependsOnFile(msg protoreflect.MessageDescriptor, file protoreflect.FileDescriptor) bool {
   389  	parentFile := msg.ParentFile()
   390  	parentFileName := parentFile.FullName()
   391  	if parentFileName != "" {
   392  		if parentFileName == file.FullName() {
   393  			return true
   394  		}
   395  	}
   396  	fileMessages := file.Messages()
   397  	for i := 0; i < fileMessages.Len(); i++ {
   398  		childMsg := fileMessages.Get(i)
   399  		if msg.FullName() == childMsg.FullName() {
   400  			return true
   401  		}
   402  	}
   403  	return false
   404  }
   405  
   406  // tableFieldSchemaToFieldDescriptorProto builds individual field descriptors for a proto message.
   407  //
   408  // For proto3, in cases where the mode is nullable we use the well known wrapper types.
   409  // For proto2, we propagate the mode->label annotation as expected.
   410  //
   411  // Messages are always nullable, and repeated fields are as well.
   412  func tableFieldSchemaToFieldDescriptorProto(field *storagepb.TableFieldSchema, idx int32, scope string, useProto3 bool) (*descriptorpb.FieldDescriptorProto, error) {
   413  	name := field.GetName()
   414  	var fdp *descriptorpb.FieldDescriptorProto
   415  
   416  	if field.GetType() == storagepb.TableFieldSchema_STRUCT {
   417  		fdp = &descriptorpb.FieldDescriptorProto{
   418  			Name:     proto.String(name),
   419  			Number:   proto.Int32(idx),
   420  			TypeName: proto.String(scope),
   421  			Label:    convertModeToLabel(field.GetMode(), useProto3),
   422  		}
   423  	} else if field.GetType() == storagepb.TableFieldSchema_RANGE {
   424  		fdp = &descriptorpb.FieldDescriptorProto{
   425  			Name:     proto.String(name),
   426  			Number:   proto.Int32(idx),
   427  			TypeName: proto.String(fmt.Sprintf("%s%s", rangeTypesPrefix, strings.ToLower(field.GetRangeElementType().GetType().String()))),
   428  			Label:    convertModeToLabel(field.GetMode(), useProto3),
   429  		}
   430  	} else {
   431  		// For (REQUIRED||REPEATED) fields for proto3, or all cases for proto2, we can use the expected scalar types.
   432  		if field.GetMode() != storagepb.TableFieldSchema_NULLABLE || !useProto3 {
   433  			outType := bqTypeToFieldTypeMap[field.GetType()]
   434  			fdp = &descriptorpb.FieldDescriptorProto{
   435  				Name:   proto.String(name),
   436  				Number: proto.Int32(idx),
   437  				Type:   outType.Enum(),
   438  				Label:  convertModeToLabel(field.GetMode(), useProto3),
   439  			}
   440  
   441  			// Special case: proto2 repeated fields may benefit from using packed annotation.
   442  			if field.GetMode() == storagepb.TableFieldSchema_REPEATED && !useProto3 {
   443  				for _, v := range packedTypes {
   444  					if outType == v {
   445  						fdp.Options = &descriptorpb.FieldOptions{
   446  							Packed: proto.Bool(true),
   447  						}
   448  						break
   449  					}
   450  				}
   451  			}
   452  		} else {
   453  			// For NULLABLE proto3 fields, use a wrapper type.
   454  			fdp = &descriptorpb.FieldDescriptorProto{
   455  				Name:     proto.String(name),
   456  				Number:   proto.Int32(idx),
   457  				Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
   458  				TypeName: proto.String(bqTypeToWrapperMap[field.GetType()]),
   459  				Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
   460  			}
   461  		}
   462  	}
   463  	if nameRequiresAnnotation(name) {
   464  		// Use a prefix + base64 encoded name when annotations bear the actual name.
   465  		// Base 64 standard encoding may also contain certain characters (+,/,=) which
   466  		// we remove from the generated name.
   467  		encoded := strings.Trim(base64.StdEncoding.EncodeToString([]byte(name)), "+/=")
   468  		fdp.Name = proto.String(fmt.Sprintf("col_%s", encoded))
   469  		opts := fdp.GetOptions()
   470  		if opts == nil {
   471  			fdp.Options = &descriptorpb.FieldOptions{}
   472  		}
   473  		proto.SetExtension(fdp.Options, storagepb.E_ColumnName, name)
   474  	}
   475  	return fdp, nil
   476  }
   477  
   478  // nameRequiresAnnotation determines whether a field name requires unicode-annotation.
   479  func nameRequiresAnnotation(in string) bool {
   480  	return !protoreflect.Name(in).IsValid()
   481  }
   482  
   483  // NormalizeDescriptor builds a self-contained DescriptorProto suitable for communicating schema
   484  // information with the BigQuery Storage write API.  It's primarily used for cases where users are
   485  // interested in sending data using a predefined protocol buffer message.
   486  //
   487  // The storage API accepts a single DescriptorProto for decoding message data.  In many cases, a message
   488  // is comprised of multiple independent messages, from the same .proto file or from multiple sources.  Rather
   489  // than being forced to communicate all these messages independently, what this method does is rewrite the
   490  // DescriptorProto to inline all messages as nested submessages.  As the backend only cares about the types
   491  // and not the namespaces when decoding, this is sufficient for the needs of the API's representation.
   492  //
   493  // In addition to nesting messages, this method also handles some encapsulation of enum types to avoid possible
   494  // conflicts due to ambiguities, and clears oneof indices as oneof isn't a concept that maps into BigQuery
   495  // schemas.
   496  //
   497  // To enable proto3 usage, this function will also rewrite proto3 descriptors into equivalent proto2 form.
   498  // Such rewrites include setting the appropriate default values for proto3 fields.
   499  func NormalizeDescriptor(in protoreflect.MessageDescriptor) (*descriptorpb.DescriptorProto, error) {
   500  	return normalizeDescriptorInternal(in, newStringSet(), newStringSet(), newStringSet(), nil)
   501  }
   502  
   503  func normalizeDescriptorInternal(in protoreflect.MessageDescriptor, visitedTypes, enumTypes, structTypes *stringSet, root *descriptorpb.DescriptorProto) (*descriptorpb.DescriptorProto, error) {
   504  	if in == nil {
   505  		return nil, fmt.Errorf("no messagedescriptor provided")
   506  	}
   507  	resultDP := &descriptorpb.DescriptorProto{}
   508  	if root == nil {
   509  		root = resultDP
   510  	}
   511  	fullProtoName := string(in.FullName())
   512  	resultDP.Name = proto.String(normalizeName(fullProtoName))
   513  	visitedTypes.add(fullProtoName)
   514  	for i := 0; i < in.Fields().Len(); i++ {
   515  		inField := in.Fields().Get(i)
   516  		resultFDP := protodesc.ToFieldDescriptorProto(inField)
   517  		// For messages without explicit presence, use default values to match implicit presence behavior.
   518  		if !inField.HasPresence() && inField.Cardinality() != protoreflect.Repeated {
   519  			switch resultFDP.GetType() {
   520  			case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
   521  				resultFDP.DefaultValue = proto.String("false")
   522  			case descriptorpb.FieldDescriptorProto_TYPE_BYTES, descriptorpb.FieldDescriptorProto_TYPE_STRING:
   523  				resultFDP.DefaultValue = proto.String("")
   524  			case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
   525  				// Resolve the proto3 default value.  The default value should be the value name.
   526  				defValue := inField.Enum().Values().ByNumber(inField.Default().Enum())
   527  				resultFDP.DefaultValue = proto.String(string(defValue.Name()))
   528  			case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
   529  				descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
   530  				descriptorpb.FieldDescriptorProto_TYPE_INT64,
   531  				descriptorpb.FieldDescriptorProto_TYPE_UINT64,
   532  				descriptorpb.FieldDescriptorProto_TYPE_INT32,
   533  				descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
   534  				descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
   535  				descriptorpb.FieldDescriptorProto_TYPE_UINT32,
   536  				descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
   537  				descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
   538  				descriptorpb.FieldDescriptorProto_TYPE_SINT32,
   539  				descriptorpb.FieldDescriptorProto_TYPE_SINT64:
   540  				resultFDP.DefaultValue = proto.String("0")
   541  			}
   542  		}
   543  		// Clear proto3 optional annotation, as the backend converter can
   544  		// treat this as a proto2 optional.
   545  		if resultFDP.Proto3Optional != nil {
   546  			resultFDP.Proto3Optional = nil
   547  		}
   548  		if resultFDP.OneofIndex != nil {
   549  			resultFDP.OneofIndex = nil
   550  		}
   551  		if inField.Kind() == protoreflect.MessageKind || inField.Kind() == protoreflect.GroupKind {
   552  			// Handle fields that reference messages.
   553  			// Groups are a proto2-ism which predated nested messages.
   554  			msgFullName := string(inField.Message().FullName())
   555  			if !skipNormalization(msgFullName) {
   556  				// for everything but well known types, normalize.
   557  				normName := normalizeName(string(msgFullName))
   558  				if structTypes.contains(msgFullName) {
   559  					resultFDP.TypeName = proto.String(normName)
   560  				} else {
   561  					if visitedTypes.contains(msgFullName) {
   562  						return nil, fmt.Errorf("recursive type not supported: %s", inField.FullName())
   563  					}
   564  					visitedTypes.add(msgFullName)
   565  					dp, err := normalizeDescriptorInternal(inField.Message(), visitedTypes, enumTypes, structTypes, root)
   566  					if err != nil {
   567  						return nil, fmt.Errorf("error converting message %s: %v", inField.FullName(), err)
   568  					}
   569  					root.NestedType = append(root.NestedType, dp)
   570  					visitedTypes.delete(msgFullName)
   571  					lastNested := root.GetNestedType()[len(root.GetNestedType())-1].GetName()
   572  					resultFDP.TypeName = proto.String(lastNested)
   573  				}
   574  			}
   575  		}
   576  		if inField.Kind() == protoreflect.EnumKind {
   577  			// For enums, in order to avoid value conflict, we will always define
   578  			// a enclosing struct called enum_full_name_E that includes the actual
   579  			// enum.
   580  			enumFullName := string(inField.Enum().FullName())
   581  			enclosingTypeName := normalizeName(enumFullName) + "_E"
   582  			enumName := string(inField.Enum().Name())
   583  			actualFullName := fmt.Sprintf("%s.%s", enclosingTypeName, enumName)
   584  			if enumTypes.contains(enumFullName) {
   585  				resultFDP.TypeName = proto.String(actualFullName)
   586  			} else {
   587  				enumDP := protodesc.ToEnumDescriptorProto(inField.Enum())
   588  				enumDP.Name = proto.String(enumName)
   589  				// Ensure values in enum are sorted.
   590  				vals := enumDP.GetValue()
   591  				sort.SliceStable(vals, func(i, j int) bool {
   592  					return vals[i].GetNumber() < vals[j].GetNumber()
   593  				})
   594  				// Append wrapped enum to nested types.
   595  				root.NestedType = append(root.NestedType, &descriptorpb.DescriptorProto{
   596  					Name:     proto.String(enclosingTypeName),
   597  					EnumType: []*descriptorpb.EnumDescriptorProto{enumDP},
   598  				})
   599  				resultFDP.TypeName = proto.String(actualFullName)
   600  				enumTypes.add(enumFullName)
   601  			}
   602  		}
   603  		resultDP.Field = append(resultDP.Field, resultFDP)
   604  	}
   605  	// To reduce comparison jitter, order the common slices fields where possible.
   606  	//
   607  	// First, fields are sorted by ID number.
   608  	fields := resultDP.GetField()
   609  	sort.SliceStable(fields, func(i, j int) bool {
   610  		return fields[i].GetNumber() < fields[j].GetNumber()
   611  	})
   612  	// Then, sort nested messages in NestedType by name.
   613  	nested := resultDP.GetNestedType()
   614  	sort.SliceStable(nested, func(i, j int) bool {
   615  		return nested[i].GetName() < nested[j].GetName()
   616  	})
   617  	structTypes.add(fullProtoName)
   618  	return resultDP, nil
   619  }
   620  
   621  type stringSet struct {
   622  	m map[string]struct{}
   623  }
   624  
   625  func (s *stringSet) contains(k string) bool {
   626  	_, ok := s.m[k]
   627  	return ok
   628  }
   629  
   630  func (s *stringSet) add(k string) {
   631  	s.m[k] = struct{}{}
   632  }
   633  
   634  func (s *stringSet) delete(k string) {
   635  	delete(s.m, k)
   636  }
   637  
   638  func newStringSet() *stringSet {
   639  	return &stringSet{
   640  		m: make(map[string]struct{}),
   641  	}
   642  }
   643  
   644  func normalizeName(in string) string {
   645  	return strings.Replace(in, ".", "_", -1)
   646  }
   647  
   648  // These types don't get normalized into the fully-contained structure.
   649  var normalizationSkipList = []string{
   650  	/*
   651  		TODO: when backend supports resolving well known types, this list should be enabled.
   652  		"google.protobuf.DoubleValue",
   653  		"google.protobuf.FloatValue",
   654  		"google.protobuf.Int64Value",
   655  		"google.protobuf.UInt64Value",
   656  		"google.protobuf.Int32Value",
   657  		"google.protobuf.Uint32Value",
   658  		"google.protobuf.BoolValue",
   659  		"google.protobuf.StringValue",
   660  		"google.protobuf.BytesValue",
   661  	*/
   662  }
   663  
   664  func skipNormalization(fullName string) bool {
   665  	for _, v := range normalizationSkipList {
   666  		if v == fullName {
   667  			return true
   668  		}
   669  	}
   670  	return false
   671  }
   672  

View as plain text