1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package bigquery
16
17 import (
18 "context"
19 "fmt"
20 "strings"
21 "time"
22
23 "cloud.google.com/go/internal/optional"
24 "cloud.google.com/go/internal/trace"
25 bq "google.golang.org/api/bigquery/v2"
26 )
27
28
29
30
31
32
33
34
35 type Model struct {
36 ProjectID string
37 DatasetID string
38
39
40 ModelID string
41
42 c *Client
43 }
44
45
46
47
48
49 func (m *Model) Identifier(f IdentifierFormat) (string, error) {
50 switch f {
51 case LegacySQLID:
52 return fmt.Sprintf("%s:%s.%s", m.ProjectID, m.DatasetID, m.ModelID), nil
53 case StandardSQLID:
54
55
56 out := fmt.Sprintf("%s.%s.%s", m.ProjectID, m.DatasetID, m.ModelID)
57 if strings.Contains(out, "-") {
58 out = fmt.Sprintf("`%s`", out)
59 }
60 return out, nil
61 default:
62 return "", ErrUnknownIdentifierFormat
63 }
64 }
65
66
67 func (m *Model) FullyQualifiedName() string {
68 s, _ := m.Identifier(LegacySQLID)
69 return s
70 }
71
72 func (m *Model) toBQ() *bq.ModelReference {
73 return &bq.ModelReference{
74 ProjectId: m.ProjectID,
75 DatasetId: m.DatasetID,
76 ModelId: m.ModelID,
77 }
78 }
79
80
81 func (m *Model) Metadata(ctx context.Context) (mm *ModelMetadata, err error) {
82 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Metadata")
83 defer func() { trace.EndSpan(ctx, err) }()
84
85 req := m.c.bqs.Models.Get(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
86 setClientHeader(req.Header())
87 var model *bq.Model
88 err = runWithRetry(ctx, func() (err error) {
89 ctx = trace.StartSpan(ctx, "bigquery.models.get")
90 model, err = req.Do()
91 trace.EndSpan(ctx, err)
92 return err
93 })
94 if err != nil {
95 return nil, err
96 }
97 return bqToModelMetadata(model)
98 }
99
100
101 func (m *Model) Update(ctx context.Context, mm ModelMetadataToUpdate, etag string) (md *ModelMetadata, err error) {
102 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Update")
103 defer func() { trace.EndSpan(ctx, err) }()
104
105 bqm, err := mm.toBQ()
106 if err != nil {
107 return nil, err
108 }
109 call := m.c.bqs.Models.Patch(m.ProjectID, m.DatasetID, m.ModelID, bqm).Context(ctx)
110 setClientHeader(call.Header())
111 if etag != "" {
112 call.Header().Set("If-Match", etag)
113 }
114 var res *bq.Model
115 if err := runWithRetry(ctx, func() (err error) {
116 ctx = trace.StartSpan(ctx, "bigquery.models.patch")
117 res, err = call.Do()
118 trace.EndSpan(ctx, err)
119 return err
120 }); err != nil {
121 return nil, err
122 }
123 return bqToModelMetadata(res)
124 }
125
126
127 func (m *Model) Delete(ctx context.Context) (err error) {
128 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Delete")
129 defer func() { trace.EndSpan(ctx, err) }()
130
131 req := m.c.bqs.Models.Delete(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
132 setClientHeader(req.Header())
133 return req.Do()
134 }
135
136
137 type ModelMetadata struct {
138
139 Description string
140
141
142 Name string
143
144
145
146
147
148 Type string
149
150
151 CreationTime time.Time
152
153
154 LastModifiedTime time.Time
155
156
157 ExpirationTime time.Time
158
159
160
161 Location string
162
163
164 EncryptionConfig *EncryptionConfig
165
166
167 featureColumns []*bq.StandardSqlField
168
169
170
171 labelColumns []*bq.StandardSqlField
172
173
174 trainingRuns []*bq.TrainingRun
175
176 Labels map[string]string
177
178
179
180 ETag string
181 }
182
183
184
185 type TrainingRun bq.TrainingRun
186
187
188
189
190 func (mm *ModelMetadata) RawTrainingRuns() []*TrainingRun {
191 if mm.trainingRuns == nil {
192 return nil
193 }
194 var runs []*TrainingRun
195
196 for _, v := range mm.trainingRuns {
197 r := TrainingRun(*v)
198 runs = append(runs, &r)
199 }
200 return runs
201 }
202
203
204
205
206 func (mm *ModelMetadata) RawLabelColumns() ([]*StandardSQLField, error) {
207 return bqToModelCols(mm.labelColumns)
208 }
209
210
211
212
213 func (mm *ModelMetadata) RawFeatureColumns() ([]*StandardSQLField, error) {
214 return bqToModelCols(mm.featureColumns)
215 }
216
217 func bqToModelCols(s []*bq.StandardSqlField) ([]*StandardSQLField, error) {
218 if s == nil {
219 return nil, nil
220 }
221 var cols []*StandardSQLField
222 for _, v := range s {
223 c, err := bqToStandardSQLField(v)
224 if err != nil {
225 return nil, err
226 }
227 cols = append(cols, c)
228 }
229 return cols, nil
230 }
231
232 func bqToModelMetadata(m *bq.Model) (*ModelMetadata, error) {
233 md := &ModelMetadata{
234 Description: m.Description,
235 Name: m.FriendlyName,
236 Type: m.ModelType,
237 Location: m.Location,
238 Labels: m.Labels,
239 ExpirationTime: unixMillisToTime(m.ExpirationTime),
240 CreationTime: unixMillisToTime(m.CreationTime),
241 LastModifiedTime: unixMillisToTime(m.LastModifiedTime),
242 EncryptionConfig: bqToEncryptionConfig(m.EncryptionConfiguration),
243 featureColumns: m.FeatureColumns,
244 labelColumns: m.LabelColumns,
245 trainingRuns: m.TrainingRuns,
246 ETag: m.Etag,
247 }
248 return md, nil
249 }
250
251
252
253 type ModelMetadataToUpdate struct {
254
255 Description optional.String
256
257
258 Name optional.String
259
260
261
262 ExpirationTime time.Time
263
264
265 EncryptionConfig *EncryptionConfig
266
267 labelUpdater
268 }
269
270 func (mm *ModelMetadataToUpdate) toBQ() (*bq.Model, error) {
271 m := &bq.Model{}
272 forceSend := func(field string) {
273 m.ForceSendFields = append(m.ForceSendFields, field)
274 }
275
276 if mm.Description != nil {
277 m.Description = optional.ToString(mm.Description)
278 forceSend("Description")
279 }
280
281 if mm.Name != nil {
282 m.FriendlyName = optional.ToString(mm.Name)
283 forceSend("FriendlyName")
284 }
285
286 if mm.EncryptionConfig != nil {
287 m.EncryptionConfiguration = mm.EncryptionConfig.toBQ()
288 }
289
290 if !validExpiration(mm.ExpirationTime) {
291 return nil, invalidTimeError(mm.ExpirationTime)
292 }
293 if mm.ExpirationTime == NeverExpire {
294 m.NullFields = append(m.NullFields, "ExpirationTime")
295 } else if !mm.ExpirationTime.IsZero() {
296 m.ExpirationTime = mm.ExpirationTime.UnixNano() / 1e6
297 forceSend("ExpirationTime")
298 }
299 labels, forces, nulls := mm.update()
300 m.Labels = labels
301 m.ForceSendFields = append(m.ForceSendFields, forces...)
302 m.NullFields = append(m.NullFields, nulls...)
303 return m, nil
304 }
305
View as plain text