...

Source file src/cloud.google.com/go/bigquery/model.go

Documentation: cloud.google.com/go/bigquery

     1  // Copyright 2019 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package bigquery
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  	"time"
    22  
    23  	"cloud.google.com/go/internal/optional"
    24  	"cloud.google.com/go/internal/trace"
    25  	bq "google.golang.org/api/bigquery/v2"
    26  )
    27  
    28  // Model represent a reference to a BigQuery ML model.
    29  // Within the API, models are used largely for communicating
    30  // statistical information about a given model, as creation of models is only
    31  // supported via BigQuery queries (e.g. CREATE MODEL .. AS ..).
    32  //
    33  // For more info, see documentation for Bigquery ML,
    34  // see: https://cloud.google.com/bigquery/docs/bigqueryml
    35  type Model struct {
    36  	ProjectID string
    37  	DatasetID string
    38  	// ModelID must contain only letters (a-z, A-Z), numbers (0-9), or underscores (_).
    39  	// The maximum length is 1,024 characters.
    40  	ModelID string
    41  
    42  	c *Client
    43  }
    44  
    45  // Identifier returns the ID of the model in the requested format.
    46  //
    47  // For Standard SQL format, the identifier will be quoted if the
    48  // ProjectID contains dash (-) characters.
    49  func (m *Model) Identifier(f IdentifierFormat) (string, error) {
    50  	switch f {
    51  	case LegacySQLID:
    52  		return fmt.Sprintf("%s:%s.%s", m.ProjectID, m.DatasetID, m.ModelID), nil
    53  	case StandardSQLID:
    54  		// Per https://cloud.google.com/bigquery-ml/docs/reference/standard-sql/bigqueryml-syntax-create#model_name
    55  		// we quote the entire identifier.
    56  		out := fmt.Sprintf("%s.%s.%s", m.ProjectID, m.DatasetID, m.ModelID)
    57  		if strings.Contains(out, "-") {
    58  			out = fmt.Sprintf("`%s`", out)
    59  		}
    60  		return out, nil
    61  	default:
    62  		return "", ErrUnknownIdentifierFormat
    63  	}
    64  }
    65  
    66  // FullyQualifiedName returns the ID of the model in projectID:datasetID.modelid format.
    67  func (m *Model) FullyQualifiedName() string {
    68  	s, _ := m.Identifier(LegacySQLID)
    69  	return s
    70  }
    71  
    72  func (m *Model) toBQ() *bq.ModelReference {
    73  	return &bq.ModelReference{
    74  		ProjectId: m.ProjectID,
    75  		DatasetId: m.DatasetID,
    76  		ModelId:   m.ModelID,
    77  	}
    78  }
    79  
    80  // Metadata fetches the metadata for a model, which includes ML training statistics.
    81  func (m *Model) Metadata(ctx context.Context) (mm *ModelMetadata, err error) {
    82  	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Metadata")
    83  	defer func() { trace.EndSpan(ctx, err) }()
    84  
    85  	req := m.c.bqs.Models.Get(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
    86  	setClientHeader(req.Header())
    87  	var model *bq.Model
    88  	err = runWithRetry(ctx, func() (err error) {
    89  		ctx = trace.StartSpan(ctx, "bigquery.models.get")
    90  		model, err = req.Do()
    91  		trace.EndSpan(ctx, err)
    92  		return err
    93  	})
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	return bqToModelMetadata(model)
    98  }
    99  
   100  // Update updates mutable fields in an ML model.
   101  func (m *Model) Update(ctx context.Context, mm ModelMetadataToUpdate, etag string) (md *ModelMetadata, err error) {
   102  	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Update")
   103  	defer func() { trace.EndSpan(ctx, err) }()
   104  
   105  	bqm, err := mm.toBQ()
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	call := m.c.bqs.Models.Patch(m.ProjectID, m.DatasetID, m.ModelID, bqm).Context(ctx)
   110  	setClientHeader(call.Header())
   111  	if etag != "" {
   112  		call.Header().Set("If-Match", etag)
   113  	}
   114  	var res *bq.Model
   115  	if err := runWithRetry(ctx, func() (err error) {
   116  		ctx = trace.StartSpan(ctx, "bigquery.models.patch")
   117  		res, err = call.Do()
   118  		trace.EndSpan(ctx, err)
   119  		return err
   120  	}); err != nil {
   121  		return nil, err
   122  	}
   123  	return bqToModelMetadata(res)
   124  }
   125  
   126  // Delete deletes an ML model.
   127  func (m *Model) Delete(ctx context.Context) (err error) {
   128  	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Delete")
   129  	defer func() { trace.EndSpan(ctx, err) }()
   130  
   131  	req := m.c.bqs.Models.Delete(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
   132  	setClientHeader(req.Header())
   133  	return req.Do()
   134  }
   135  
   136  // ModelMetadata represents information about a BigQuery ML model.
   137  type ModelMetadata struct {
   138  	// The user-friendly description of the model.
   139  	Description string
   140  
   141  	// The user-friendly name of the model.
   142  	Name string
   143  
   144  	// The type of the model.  Possible values include:
   145  	// "LINEAR_REGRESSION" - a linear regression model
   146  	// "LOGISTIC_REGRESSION" - a logistic regression model
   147  	// "KMEANS" - a k-means clustering model
   148  	Type string
   149  
   150  	// The creation time of the model.
   151  	CreationTime time.Time
   152  
   153  	// The last modified time of the model.
   154  	LastModifiedTime time.Time
   155  
   156  	// The expiration time of the model.
   157  	ExpirationTime time.Time
   158  
   159  	// The geographic location where the model resides.  This value is
   160  	// inherited from the encapsulating dataset.
   161  	Location string
   162  
   163  	// Custom encryption configuration (e.g., Cloud KMS keys).
   164  	EncryptionConfig *EncryptionConfig
   165  
   166  	// The input feature columns used to train the model.
   167  	featureColumns []*bq.StandardSqlField
   168  
   169  	// The label columns used to train the model.  Output
   170  	// from the model will have a "predicted_" prefix for these columns.
   171  	labelColumns []*bq.StandardSqlField
   172  
   173  	// Information for all training runs, ordered by increasing start times.
   174  	trainingRuns []*bq.TrainingRun
   175  
   176  	Labels map[string]string
   177  
   178  	// ETag is the ETag obtained when reading metadata. Pass it to Model.Update
   179  	// to ensure that the metadata hasn't changed since it was read.
   180  	ETag string
   181  }
   182  
   183  // TrainingRun represents information about a single training run for a BigQuery ML model.
   184  // Experimental:  This information may be modified or removed in future versions of this package.
   185  type TrainingRun bq.TrainingRun
   186  
   187  // RawTrainingRuns exposes the underlying training run stats for a model using types from
   188  // "google.golang.org/api/bigquery/v2", which are subject to change without warning.
   189  // It is EXPERIMENTAL and subject to change or removal without notice.
   190  func (mm *ModelMetadata) RawTrainingRuns() []*TrainingRun {
   191  	if mm.trainingRuns == nil {
   192  		return nil
   193  	}
   194  	var runs []*TrainingRun
   195  
   196  	for _, v := range mm.trainingRuns {
   197  		r := TrainingRun(*v)
   198  		runs = append(runs, &r)
   199  	}
   200  	return runs
   201  }
   202  
   203  // RawLabelColumns exposes the underlying label columns used to train an ML model and uses types from
   204  // "google.golang.org/api/bigquery/v2", which are subject to change without warning.
   205  // It is EXPERIMENTAL and subject to change or removal without notice.
   206  func (mm *ModelMetadata) RawLabelColumns() ([]*StandardSQLField, error) {
   207  	return bqToModelCols(mm.labelColumns)
   208  }
   209  
   210  // RawFeatureColumns exposes the underlying feature columns used to train an ML model and uses types from
   211  // "google.golang.org/api/bigquery/v2", which are subject to change without warning.
   212  // It is EXPERIMENTAL and subject to change or removal without notice.
   213  func (mm *ModelMetadata) RawFeatureColumns() ([]*StandardSQLField, error) {
   214  	return bqToModelCols(mm.featureColumns)
   215  }
   216  
   217  func bqToModelCols(s []*bq.StandardSqlField) ([]*StandardSQLField, error) {
   218  	if s == nil {
   219  		return nil, nil
   220  	}
   221  	var cols []*StandardSQLField
   222  	for _, v := range s {
   223  		c, err := bqToStandardSQLField(v)
   224  		if err != nil {
   225  			return nil, err
   226  		}
   227  		cols = append(cols, c)
   228  	}
   229  	return cols, nil
   230  }
   231  
   232  func bqToModelMetadata(m *bq.Model) (*ModelMetadata, error) {
   233  	md := &ModelMetadata{
   234  		Description:      m.Description,
   235  		Name:             m.FriendlyName,
   236  		Type:             m.ModelType,
   237  		Location:         m.Location,
   238  		Labels:           m.Labels,
   239  		ExpirationTime:   unixMillisToTime(m.ExpirationTime),
   240  		CreationTime:     unixMillisToTime(m.CreationTime),
   241  		LastModifiedTime: unixMillisToTime(m.LastModifiedTime),
   242  		EncryptionConfig: bqToEncryptionConfig(m.EncryptionConfiguration),
   243  		featureColumns:   m.FeatureColumns,
   244  		labelColumns:     m.LabelColumns,
   245  		trainingRuns:     m.TrainingRuns,
   246  		ETag:             m.Etag,
   247  	}
   248  	return md, nil
   249  }
   250  
   251  // ModelMetadataToUpdate is used when updating an ML model's metadata.
   252  // Only non-nil fields will be updated.
   253  type ModelMetadataToUpdate struct {
   254  	// The user-friendly description of this model.
   255  	Description optional.String
   256  
   257  	// The user-friendly name of this model.
   258  	Name optional.String
   259  
   260  	// The time when this model expires.  To remove a model's expiration,
   261  	// set ExpirationTime to NeverExpire.  The zero value is ignored.
   262  	ExpirationTime time.Time
   263  
   264  	// The model's encryption configuration.
   265  	EncryptionConfig *EncryptionConfig
   266  
   267  	labelUpdater
   268  }
   269  
   270  func (mm *ModelMetadataToUpdate) toBQ() (*bq.Model, error) {
   271  	m := &bq.Model{}
   272  	forceSend := func(field string) {
   273  		m.ForceSendFields = append(m.ForceSendFields, field)
   274  	}
   275  
   276  	if mm.Description != nil {
   277  		m.Description = optional.ToString(mm.Description)
   278  		forceSend("Description")
   279  	}
   280  
   281  	if mm.Name != nil {
   282  		m.FriendlyName = optional.ToString(mm.Name)
   283  		forceSend("FriendlyName")
   284  	}
   285  
   286  	if mm.EncryptionConfig != nil {
   287  		m.EncryptionConfiguration = mm.EncryptionConfig.toBQ()
   288  	}
   289  
   290  	if !validExpiration(mm.ExpirationTime) {
   291  		return nil, invalidTimeError(mm.ExpirationTime)
   292  	}
   293  	if mm.ExpirationTime == NeverExpire {
   294  		m.NullFields = append(m.NullFields, "ExpirationTime")
   295  	} else if !mm.ExpirationTime.IsZero() {
   296  		m.ExpirationTime = mm.ExpirationTime.UnixNano() / 1e6
   297  		forceSend("ExpirationTime")
   298  	}
   299  	labels, forces, nulls := mm.update()
   300  	m.Labels = labels
   301  	m.ForceSendFields = append(m.ForceSendFields, forces...)
   302  	m.NullFields = append(m.NullFields, nulls...)
   303  	return m, nil
   304  }
   305  

View as plain text