1 package storage
2
3
4
5
6 import (
7 "bytes"
8 "encoding/json"
9 "errors"
10 "fmt"
11 "io"
12 "mime/multipart"
13 "net/http"
14 "net/textproto"
15 "sort"
16 "strings"
17 )
18
19
20 type Operation int
21
22
23 const (
24 InsertOp = Operation(1)
25 DeleteOp = Operation(2)
26 ReplaceOp = Operation(3)
27 MergeOp = Operation(4)
28 InsertOrReplaceOp = Operation(5)
29 InsertOrMergeOp = Operation(6)
30 )
31
32
33
34
35 type BatchEntity struct {
36 *Entity
37 Force bool
38 Op Operation
39 }
40
41
42
43 type TableBatch struct {
44 BatchEntitySlice []BatchEntity
45
46
47 Table *Table
48 }
49
50
51 var defaultChangesetHeaders = map[string]string{
52 "Accept": "application/json;odata=minimalmetadata",
53 "Content-Type": "application/json",
54 "Prefer": "return-no-content",
55 }
56
57
58 func (t *Table) NewBatch() *TableBatch {
59 return &TableBatch{
60 Table: t,
61 }
62 }
63
64
65 func (t *TableBatch) InsertEntity(entity *Entity) {
66 be := BatchEntity{Entity: entity, Force: false, Op: InsertOp}
67 t.BatchEntitySlice = append(t.BatchEntitySlice, be)
68 }
69
70
71 func (t *TableBatch) InsertOrReplaceEntity(entity *Entity, force bool) {
72 be := BatchEntity{Entity: entity, Force: false, Op: InsertOrReplaceOp}
73 t.BatchEntitySlice = append(t.BatchEntitySlice, be)
74 }
75
76
77 func (t *TableBatch) InsertOrReplaceEntityByForce(entity *Entity) {
78 t.InsertOrReplaceEntity(entity, true)
79 }
80
81
82 func (t *TableBatch) InsertOrMergeEntity(entity *Entity, force bool) {
83 be := BatchEntity{Entity: entity, Force: false, Op: InsertOrMergeOp}
84 t.BatchEntitySlice = append(t.BatchEntitySlice, be)
85 }
86
87
88 func (t *TableBatch) InsertOrMergeEntityByForce(entity *Entity) {
89 t.InsertOrMergeEntity(entity, true)
90 }
91
92
93 func (t *TableBatch) ReplaceEntity(entity *Entity) {
94 be := BatchEntity{Entity: entity, Force: false, Op: ReplaceOp}
95 t.BatchEntitySlice = append(t.BatchEntitySlice, be)
96 }
97
98
99 func (t *TableBatch) DeleteEntity(entity *Entity, force bool) {
100 be := BatchEntity{Entity: entity, Force: false, Op: DeleteOp}
101 t.BatchEntitySlice = append(t.BatchEntitySlice, be)
102 }
103
104
105 func (t *TableBatch) DeleteEntityByForce(entity *Entity, force bool) {
106 t.DeleteEntity(entity, true)
107 }
108
109
110 func (t *TableBatch) MergeEntity(entity *Entity) {
111 be := BatchEntity{Entity: entity, Force: false, Op: MergeOp}
112 t.BatchEntitySlice = append(t.BatchEntitySlice, be)
113 }
114
115
116
117
118
119
120 func (t *TableBatch) ExecuteBatch() error {
121
122 id, err := newUUID()
123 if err != nil {
124 return err
125 }
126
127 changesetBoundary := fmt.Sprintf("changeset_%s", id.String())
128 uri := t.Table.tsc.client.getEndpoint(tableServiceName, "$batch", nil)
129 changesetBody, err := t.generateChangesetBody(changesetBoundary)
130 if err != nil {
131 return err
132 }
133
134 id, err = newUUID()
135 if err != nil {
136 return err
137 }
138
139 boundary := fmt.Sprintf("batch_%s", id.String())
140 body, err := generateBody(changesetBody, changesetBoundary, boundary)
141 if err != nil {
142 return err
143 }
144
145 headers := t.Table.tsc.client.getStandardHeaders()
146 headers[headerContentType] = fmt.Sprintf("multipart/mixed; boundary=%s", boundary)
147
148 resp, err := t.Table.tsc.client.execBatchOperationJSON(http.MethodPost, uri, headers, bytes.NewReader(body.Bytes()), t.Table.tsc.auth)
149 if err != nil {
150 return err
151 }
152 defer drainRespBody(resp.resp)
153
154 if err = checkRespCode(resp.resp, []int{http.StatusAccepted}); err != nil {
155
156
157 operationFailedMessage := t.getFailedOperation(resp.odata.Err.Message.Value)
158 requestID, date, version := getDebugHeaders(resp.resp.Header)
159 return AzureStorageServiceError{
160 StatusCode: resp.resp.StatusCode,
161 Code: resp.odata.Err.Code,
162 RequestID: requestID,
163 Date: date,
164 APIVersion: version,
165 Message: operationFailedMessage,
166 }
167 }
168
169 return nil
170 }
171
172
173
174 func (t *TableBatch) getFailedOperation(errorMessage string) string {
175
176 sp := strings.Split(errorMessage, ":")
177 if len(sp) > 1 {
178 msg := fmt.Sprintf("Element %s in the batch returned an unexpected response code.\n%s", sp[0], errorMessage)
179 return msg
180 }
181
182
183 return errorMessage
184 }
185
186
187 func generateBody(changeSetBody *bytes.Buffer, changesetBoundary string, boundary string) (*bytes.Buffer, error) {
188
189 body := new(bytes.Buffer)
190 writer := multipart.NewWriter(body)
191 writer.SetBoundary(boundary)
192 h := make(textproto.MIMEHeader)
193 h.Set(headerContentType, fmt.Sprintf("multipart/mixed; boundary=%s\r\n", changesetBoundary))
194 batchWriter, err := writer.CreatePart(h)
195 if err != nil {
196 return nil, err
197 }
198 batchWriter.Write(changeSetBody.Bytes())
199 writer.Close()
200 return body, nil
201 }
202
203
204
205 func (t *TableBatch) generateChangesetBody(changesetBoundary string) (*bytes.Buffer, error) {
206
207 body := new(bytes.Buffer)
208 writer := multipart.NewWriter(body)
209 writer.SetBoundary(changesetBoundary)
210
211 for _, be := range t.BatchEntitySlice {
212 t.generateEntitySubset(&be, writer)
213 }
214
215 writer.Close()
216 return body, nil
217 }
218
219
220 func generateVerb(op Operation) (string, error) {
221 switch op {
222 case InsertOp:
223 return http.MethodPost, nil
224 case DeleteOp:
225 return http.MethodDelete, nil
226 case ReplaceOp, InsertOrReplaceOp:
227 return http.MethodPut, nil
228 case MergeOp, InsertOrMergeOp:
229 return "MERGE", nil
230 default:
231 return "", errors.New("Unable to detect operation")
232 }
233 }
234
235
236
237
238
239 func (t *TableBatch) generateQueryPath(op Operation, entity *Entity) string {
240 if op == InsertOp {
241 return entity.Table.buildPath()
242 }
243 return entity.buildPath()
244 }
245
246
247 func generateGenericOperationHeaders(be *BatchEntity) map[string]string {
248 retval := map[string]string{}
249
250 for k, v := range defaultChangesetHeaders {
251 retval[k] = v
252 }
253
254 if be.Op == DeleteOp || be.Op == ReplaceOp || be.Op == MergeOp {
255 if be.Force || be.Entity.OdataEtag == "" {
256 retval["If-Match"] = "*"
257 } else {
258 retval["If-Match"] = be.Entity.OdataEtag
259 }
260 }
261
262 return retval
263 }
264
265
266 func (t *TableBatch) generateEntitySubset(batchEntity *BatchEntity, writer *multipart.Writer) error {
267
268 h := make(textproto.MIMEHeader)
269 h.Set(headerContentType, "application/http")
270 h.Set(headerContentTransferEncoding, "binary")
271
272 verb, err := generateVerb(batchEntity.Op)
273 if err != nil {
274 return err
275 }
276
277 genericOpHeadersMap := generateGenericOperationHeaders(batchEntity)
278 queryPath := t.generateQueryPath(batchEntity.Op, batchEntity.Entity)
279 uri := t.Table.tsc.client.getEndpoint(tableServiceName, queryPath, nil)
280
281 operationWriter, err := writer.CreatePart(h)
282 if err != nil {
283 return err
284 }
285
286 urlAndVerb := fmt.Sprintf("%s %s HTTP/1.1\r\n", verb, uri)
287 operationWriter.Write([]byte(urlAndVerb))
288 writeHeaders(genericOpHeadersMap, &operationWriter)
289 operationWriter.Write([]byte("\r\n"))
290
291
292 if batchEntity.Op != DeleteOp {
293
294 body, err := json.Marshal(batchEntity.Entity)
295 if err != nil {
296 return err
297 }
298 operationWriter.Write(body)
299 }
300
301 return nil
302 }
303
304 func writeHeaders(h map[string]string, writer *io.Writer) {
305
306 var keys []string
307 for k := range h {
308 keys = append(keys, k)
309 }
310 sort.Strings(keys)
311 for _, k := range keys {
312 (*writer).Write([]byte(fmt.Sprintf("%s: %s\r\n", k, h[k])))
313 }
314 }
315
View as plain text