1
2
3
4
5
6
7
8
9
10
11
12 package s3
13
14 import (
15 "bytes"
16 "context"
17 "crypto/tls"
18 "fmt"
19 "io"
20 "io/ioutil"
21 "math"
22 "net/http"
23 "reflect"
24 "sort"
25 "strconv"
26 "strings"
27 "time"
28
29 "github.com/aws/aws-sdk-go/aws"
30 "github.com/aws/aws-sdk-go/aws/awserr"
31 "github.com/aws/aws-sdk-go/aws/credentials"
32 "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
33 "github.com/aws/aws-sdk-go/aws/ec2metadata"
34 "github.com/aws/aws-sdk-go/aws/endpoints"
35 "github.com/aws/aws-sdk-go/aws/request"
36 "github.com/aws/aws-sdk-go/aws/session"
37 "github.com/aws/aws-sdk-go/service/s3"
38
39 dcontext "github.com/docker/distribution/context"
40 "github.com/docker/distribution/registry/client/transport"
41 storagedriver "github.com/docker/distribution/registry/storage/driver"
42 "github.com/docker/distribution/registry/storage/driver/base"
43 "github.com/docker/distribution/registry/storage/driver/factory"
44 )
45
46 const driverName = "s3aws"
47
48
49
50 const minChunkSize = 5 << 20
51
52
53 const maxChunkSize = 5 << 30
54
55 const defaultChunkSize = 2 * minChunkSize
56
57 const (
58
59
60
61 defaultMultipartCopyChunkSize = 32 << 20
62
63
64
65 defaultMultipartCopyMaxConcurrency = 100
66
67
68
69
70 defaultMultipartCopyThresholdSize = 32 << 20
71 )
72
73
74 const listMax = 1000
75
76
77 const noStorageClass = "NONE"
78
79
80 var validRegions = map[string]struct{}{}
81
82
83 var validObjectACLs = map[string]struct{}{}
84
85
86 type DriverParameters struct {
87 AccessKey string
88 SecretKey string
89 Bucket string
90 Region string
91 RegionEndpoint string
92 Encrypt bool
93 KeyID string
94 Secure bool
95 SkipVerify bool
96 V4Auth bool
97 ChunkSize int64
98 MultipartCopyChunkSize int64
99 MultipartCopyMaxConcurrency int64
100 MultipartCopyThresholdSize int64
101 RootDirectory string
102 StorageClass string
103 UserAgent string
104 ObjectACL string
105 SessionToken string
106 }
107
108 func init() {
109 partitions := endpoints.DefaultPartitions()
110 for _, p := range partitions {
111 for region := range p.Regions() {
112 validRegions[region] = struct{}{}
113 }
114 }
115
116 for _, objectACL := range []string{
117 s3.ObjectCannedACLPrivate,
118 s3.ObjectCannedACLPublicRead,
119 s3.ObjectCannedACLPublicReadWrite,
120 s3.ObjectCannedACLAuthenticatedRead,
121 s3.ObjectCannedACLAwsExecRead,
122 s3.ObjectCannedACLBucketOwnerRead,
123 s3.ObjectCannedACLBucketOwnerFullControl,
124 } {
125 validObjectACLs[objectACL] = struct{}{}
126 }
127
128
129 factory.Register("s3", &s3DriverFactory{})
130 factory.Register(driverName, &s3DriverFactory{})
131 }
132
133
134 type s3DriverFactory struct{}
135
136 func (factory *s3DriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
137 return FromParameters(parameters)
138 }
139
140 var _ storagedriver.StorageDriver = &driver{}
141
142 type driver struct {
143 S3 *s3.S3
144 Bucket string
145 ChunkSize int64
146 Encrypt bool
147 KeyID string
148 MultipartCopyChunkSize int64
149 MultipartCopyMaxConcurrency int64
150 MultipartCopyThresholdSize int64
151 RootDirectory string
152 StorageClass string
153 ObjectACL string
154 }
155
156 type baseEmbed struct {
157 base.Base
158 }
159
160
161
162 type Driver struct {
163 baseEmbed
164 }
165
166
167
168
169
170
171
172
173 func FromParameters(parameters map[string]interface{}) (*Driver, error) {
174
175
176
177 accessKey := parameters["accesskey"]
178 if accessKey == nil {
179 accessKey = ""
180 }
181 secretKey := parameters["secretkey"]
182 if secretKey == nil {
183 secretKey = ""
184 }
185
186 regionEndpoint := parameters["regionendpoint"]
187 if regionEndpoint == nil {
188 regionEndpoint = ""
189 }
190
191 regionName := parameters["region"]
192 if regionName == nil || fmt.Sprint(regionName) == "" {
193 return nil, fmt.Errorf("no region parameter provided")
194 }
195 region := fmt.Sprint(regionName)
196
197 if regionEndpoint == "" {
198 if _, ok := validRegions[region]; !ok {
199 return nil, fmt.Errorf("invalid region provided: %v", region)
200 }
201 }
202
203 bucket := parameters["bucket"]
204 if bucket == nil || fmt.Sprint(bucket) == "" {
205 return nil, fmt.Errorf("no bucket parameter provided")
206 }
207
208 encryptBool := false
209 encrypt := parameters["encrypt"]
210 switch encrypt := encrypt.(type) {
211 case string:
212 b, err := strconv.ParseBool(encrypt)
213 if err != nil {
214 return nil, fmt.Errorf("the encrypt parameter should be a boolean")
215 }
216 encryptBool = b
217 case bool:
218 encryptBool = encrypt
219 case nil:
220
221 default:
222 return nil, fmt.Errorf("the encrypt parameter should be a boolean")
223 }
224
225 secureBool := true
226 secure := parameters["secure"]
227 switch secure := secure.(type) {
228 case string:
229 b, err := strconv.ParseBool(secure)
230 if err != nil {
231 return nil, fmt.Errorf("the secure parameter should be a boolean")
232 }
233 secureBool = b
234 case bool:
235 secureBool = secure
236 case nil:
237
238 default:
239 return nil, fmt.Errorf("the secure parameter should be a boolean")
240 }
241
242 skipVerifyBool := false
243 skipVerify := parameters["skipverify"]
244 switch skipVerify := skipVerify.(type) {
245 case string:
246 b, err := strconv.ParseBool(skipVerify)
247 if err != nil {
248 return nil, fmt.Errorf("the skipVerify parameter should be a boolean")
249 }
250 skipVerifyBool = b
251 case bool:
252 skipVerifyBool = skipVerify
253 case nil:
254
255 default:
256 return nil, fmt.Errorf("the skipVerify parameter should be a boolean")
257 }
258
259 v4Bool := true
260 v4auth := parameters["v4auth"]
261 switch v4auth := v4auth.(type) {
262 case string:
263 b, err := strconv.ParseBool(v4auth)
264 if err != nil {
265 return nil, fmt.Errorf("the v4auth parameter should be a boolean")
266 }
267 v4Bool = b
268 case bool:
269 v4Bool = v4auth
270 case nil:
271
272 default:
273 return nil, fmt.Errorf("the v4auth parameter should be a boolean")
274 }
275
276 keyID := parameters["keyid"]
277 if keyID == nil {
278 keyID = ""
279 }
280
281 chunkSize, err := getParameterAsInt64(parameters, "chunksize", defaultChunkSize, minChunkSize, maxChunkSize)
282 if err != nil {
283 return nil, err
284 }
285
286 multipartCopyChunkSize, err := getParameterAsInt64(parameters, "multipartcopychunksize", defaultMultipartCopyChunkSize, minChunkSize, maxChunkSize)
287 if err != nil {
288 return nil, err
289 }
290
291 multipartCopyMaxConcurrency, err := getParameterAsInt64(parameters, "multipartcopymaxconcurrency", defaultMultipartCopyMaxConcurrency, 1, math.MaxInt64)
292 if err != nil {
293 return nil, err
294 }
295
296 multipartCopyThresholdSize, err := getParameterAsInt64(parameters, "multipartcopythresholdsize", defaultMultipartCopyThresholdSize, 0, maxChunkSize)
297 if err != nil {
298 return nil, err
299 }
300
301 rootDirectory := parameters["rootdirectory"]
302 if rootDirectory == nil {
303 rootDirectory = ""
304 }
305
306 storageClass := s3.StorageClassStandard
307 storageClassParam := parameters["storageclass"]
308 if storageClassParam != nil {
309 storageClassString, ok := storageClassParam.(string)
310 if !ok {
311 return nil, fmt.Errorf("the storageclass parameter must be one of %v, %v invalid",
312 []string{s3.StorageClassStandard, s3.StorageClassReducedRedundancy}, storageClassParam)
313 }
314
315 storageClassString = strings.ToUpper(storageClassString)
316 if storageClassString != noStorageClass &&
317 storageClassString != s3.StorageClassStandard &&
318 storageClassString != s3.StorageClassReducedRedundancy {
319 return nil, fmt.Errorf("the storageclass parameter must be one of %v, %v invalid",
320 []string{noStorageClass, s3.StorageClassStandard, s3.StorageClassReducedRedundancy}, storageClassParam)
321 }
322 storageClass = storageClassString
323 }
324
325 userAgent := parameters["useragent"]
326 if userAgent == nil {
327 userAgent = ""
328 }
329
330 objectACL := s3.ObjectCannedACLPrivate
331 objectACLParam := parameters["objectacl"]
332 if objectACLParam != nil {
333 objectACLString, ok := objectACLParam.(string)
334 if !ok {
335 return nil, fmt.Errorf("invalid value for objectacl parameter: %v", objectACLParam)
336 }
337
338 if _, ok = validObjectACLs[objectACLString]; !ok {
339 return nil, fmt.Errorf("invalid value for objectacl parameter: %v", objectACLParam)
340 }
341 objectACL = objectACLString
342 }
343
344 sessionToken := ""
345
346 params := DriverParameters{
347 fmt.Sprint(accessKey),
348 fmt.Sprint(secretKey),
349 fmt.Sprint(bucket),
350 region,
351 fmt.Sprint(regionEndpoint),
352 encryptBool,
353 fmt.Sprint(keyID),
354 secureBool,
355 skipVerifyBool,
356 v4Bool,
357 chunkSize,
358 multipartCopyChunkSize,
359 multipartCopyMaxConcurrency,
360 multipartCopyThresholdSize,
361 fmt.Sprint(rootDirectory),
362 storageClass,
363 fmt.Sprint(userAgent),
364 objectACL,
365 fmt.Sprint(sessionToken),
366 }
367
368 return New(params)
369 }
370
371
372
373 func getParameterAsInt64(parameters map[string]interface{}, name string, defaultt int64, min int64, max int64) (int64, error) {
374 rv := defaultt
375 param := parameters[name]
376 switch v := param.(type) {
377 case string:
378 vv, err := strconv.ParseInt(v, 0, 64)
379 if err != nil {
380 return 0, fmt.Errorf("%s parameter must be an integer, %v invalid", name, param)
381 }
382 rv = vv
383 case int64:
384 rv = v
385 case int, uint, int32, uint32, uint64:
386 rv = reflect.ValueOf(v).Convert(reflect.TypeOf(rv)).Int()
387 case nil:
388
389 default:
390 return 0, fmt.Errorf("invalid value for %s: %#v", name, param)
391 }
392
393 if rv < min || rv > max {
394 return 0, fmt.Errorf("the %s %#v parameter should be a number between %d and %d (inclusive)", name, rv, min, max)
395 }
396
397 return rv, nil
398 }
399
400
401
402 func New(params DriverParameters) (*Driver, error) {
403 if !params.V4Auth &&
404 (params.RegionEndpoint == "" ||
405 strings.Contains(params.RegionEndpoint, "s3.amazonaws.com")) {
406 return nil, fmt.Errorf("on Amazon S3 this storage driver can only be used with v4 authentication")
407 }
408
409 awsConfig := aws.NewConfig()
410 sess, err := session.NewSession()
411 if err != nil {
412 return nil, fmt.Errorf("failed to create new session: %v", err)
413 }
414 creds := credentials.NewChainCredentials([]credentials.Provider{
415 &credentials.StaticProvider{
416 Value: credentials.Value{
417 AccessKeyID: params.AccessKey,
418 SecretAccessKey: params.SecretKey,
419 SessionToken: params.SessionToken,
420 },
421 },
422 &credentials.EnvProvider{},
423 &credentials.SharedCredentialsProvider{},
424 &ec2rolecreds.EC2RoleProvider{Client: ec2metadata.New(sess)},
425 })
426
427 if params.RegionEndpoint != "" {
428 awsConfig.WithS3ForcePathStyle(true)
429 awsConfig.WithEndpoint(params.RegionEndpoint)
430 }
431
432 awsConfig.WithCredentials(creds)
433 awsConfig.WithRegion(params.Region)
434 awsConfig.WithDisableSSL(!params.Secure)
435
436 if params.UserAgent != "" || params.SkipVerify {
437 httpTransport := http.DefaultTransport
438 if params.SkipVerify {
439 httpTransport = &http.Transport{
440 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
441 }
442 }
443 if params.UserAgent != "" {
444 awsConfig.WithHTTPClient(&http.Client{
445 Transport: transport.NewTransport(httpTransport, transport.NewHeaderRequestModifier(http.Header{http.CanonicalHeaderKey("User-Agent"): []string{params.UserAgent}})),
446 })
447 } else {
448 awsConfig.WithHTTPClient(&http.Client{
449 Transport: transport.NewTransport(httpTransport),
450 })
451 }
452 }
453
454 sess, err = session.NewSession(awsConfig)
455 if err != nil {
456 return nil, fmt.Errorf("failed to create new session with aws config: %v", err)
457 }
458 s3obj := s3.New(sess)
459
460
461 if !params.V4Auth {
462 setv2Handlers(s3obj)
463 }
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480 d := &driver{
481 S3: s3obj,
482 Bucket: params.Bucket,
483 ChunkSize: params.ChunkSize,
484 Encrypt: params.Encrypt,
485 KeyID: params.KeyID,
486 MultipartCopyChunkSize: params.MultipartCopyChunkSize,
487 MultipartCopyMaxConcurrency: params.MultipartCopyMaxConcurrency,
488 MultipartCopyThresholdSize: params.MultipartCopyThresholdSize,
489 RootDirectory: params.RootDirectory,
490 StorageClass: params.StorageClass,
491 ObjectACL: params.ObjectACL,
492 }
493
494 return &Driver{
495 baseEmbed: baseEmbed{
496 Base: base.Base{
497 StorageDriver: d,
498 },
499 },
500 }, nil
501 }
502
503
504
505 func (d *driver) Name() string {
506 return driverName
507 }
508
509
510 func (d *driver) GetContent(ctx context.Context, path string) ([]byte, error) {
511 reader, err := d.Reader(ctx, path, 0)
512 if err != nil {
513 return nil, err
514 }
515 return ioutil.ReadAll(reader)
516 }
517
518
519 func (d *driver) PutContent(ctx context.Context, path string, contents []byte) error {
520 _, err := d.S3.PutObject(&s3.PutObjectInput{
521 Bucket: aws.String(d.Bucket),
522 Key: aws.String(d.s3Path(path)),
523 ContentType: d.getContentType(),
524 ACL: d.getACL(),
525 ServerSideEncryption: d.getEncryptionMode(),
526 SSEKMSKeyId: d.getSSEKMSKeyID(),
527 StorageClass: d.getStorageClass(),
528 Body: bytes.NewReader(contents),
529 })
530 return parseError(path, err)
531 }
532
533
534
535 func (d *driver) Reader(ctx context.Context, path string, offset int64) (io.ReadCloser, error) {
536 resp, err := d.S3.GetObject(&s3.GetObjectInput{
537 Bucket: aws.String(d.Bucket),
538 Key: aws.String(d.s3Path(path)),
539 Range: aws.String("bytes=" + strconv.FormatInt(offset, 10) + "-"),
540 })
541
542 if err != nil {
543 if s3Err, ok := err.(awserr.Error); ok && s3Err.Code() == "InvalidRange" {
544 return ioutil.NopCloser(bytes.NewReader(nil)), nil
545 }
546
547 return nil, parseError(path, err)
548 }
549 return resp.Body, nil
550 }
551
552
553
554 func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (storagedriver.FileWriter, error) {
555 key := d.s3Path(path)
556 if !appendParam {
557
558 resp, err := d.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
559 Bucket: aws.String(d.Bucket),
560 Key: aws.String(key),
561 ContentType: d.getContentType(),
562 ACL: d.getACL(),
563 ServerSideEncryption: d.getEncryptionMode(),
564 SSEKMSKeyId: d.getSSEKMSKeyID(),
565 StorageClass: d.getStorageClass(),
566 })
567 if err != nil {
568 return nil, err
569 }
570 return d.newWriter(key, *resp.UploadId, nil), nil
571 }
572 resp, err := d.S3.ListMultipartUploads(&s3.ListMultipartUploadsInput{
573 Bucket: aws.String(d.Bucket),
574 Prefix: aws.String(key),
575 })
576 if err != nil {
577 return nil, parseError(path, err)
578 }
579 var allParts []*s3.Part
580 for _, multi := range resp.Uploads {
581 if key != *multi.Key {
582 continue
583 }
584 resp, err := d.S3.ListParts(&s3.ListPartsInput{
585 Bucket: aws.String(d.Bucket),
586 Key: aws.String(key),
587 UploadId: multi.UploadId,
588 })
589 if err != nil {
590 return nil, parseError(path, err)
591 }
592 allParts = append(allParts, resp.Parts...)
593 for *resp.IsTruncated {
594 resp, err = d.S3.ListParts(&s3.ListPartsInput{
595 Bucket: aws.String(d.Bucket),
596 Key: aws.String(key),
597 UploadId: multi.UploadId,
598 PartNumberMarker: resp.NextPartNumberMarker,
599 })
600 if err != nil {
601 return nil, parseError(path, err)
602 }
603 allParts = append(allParts, resp.Parts...)
604 }
605 return d.newWriter(key, *multi.UploadId, allParts), nil
606 }
607 return nil, storagedriver.PathNotFoundError{Path: path}
608 }
609
610
611
612 func (d *driver) Stat(ctx context.Context, path string) (storagedriver.FileInfo, error) {
613 resp, err := d.S3.ListObjects(&s3.ListObjectsInput{
614 Bucket: aws.String(d.Bucket),
615 Prefix: aws.String(d.s3Path(path)),
616 MaxKeys: aws.Int64(1),
617 })
618 if err != nil {
619 return nil, err
620 }
621
622 fi := storagedriver.FileInfoFields{
623 Path: path,
624 }
625
626 if len(resp.Contents) == 1 {
627 if *resp.Contents[0].Key != d.s3Path(path) {
628 fi.IsDir = true
629 } else {
630 fi.IsDir = false
631 fi.Size = *resp.Contents[0].Size
632 fi.ModTime = *resp.Contents[0].LastModified
633 }
634 } else if len(resp.CommonPrefixes) == 1 {
635 fi.IsDir = true
636 } else {
637 return nil, storagedriver.PathNotFoundError{Path: path}
638 }
639
640 return storagedriver.FileInfoInternal{FileInfoFields: fi}, nil
641 }
642
643
644 func (d *driver) List(ctx context.Context, opath string) ([]string, error) {
645 path := opath
646 if path != "/" && path[len(path)-1] != '/' {
647 path = path + "/"
648 }
649
650
651
652
653 prefix := ""
654 if d.s3Path("") == "" {
655 prefix = "/"
656 }
657
658 resp, err := d.S3.ListObjects(&s3.ListObjectsInput{
659 Bucket: aws.String(d.Bucket),
660 Prefix: aws.String(d.s3Path(path)),
661 Delimiter: aws.String("/"),
662 MaxKeys: aws.Int64(listMax),
663 })
664 if err != nil {
665 return nil, parseError(opath, err)
666 }
667
668 files := []string{}
669 directories := []string{}
670
671 for {
672 for _, key := range resp.Contents {
673 files = append(files, strings.Replace(*key.Key, d.s3Path(""), prefix, 1))
674 }
675
676 for _, commonPrefix := range resp.CommonPrefixes {
677 commonPrefix := *commonPrefix.Prefix
678 directories = append(directories, strings.Replace(commonPrefix[0:len(commonPrefix)-1], d.s3Path(""), prefix, 1))
679 }
680
681 if *resp.IsTruncated {
682 resp, err = d.S3.ListObjects(&s3.ListObjectsInput{
683 Bucket: aws.String(d.Bucket),
684 Prefix: aws.String(d.s3Path(path)),
685 Delimiter: aws.String("/"),
686 MaxKeys: aws.Int64(listMax),
687 Marker: resp.NextMarker,
688 })
689 if err != nil {
690 return nil, err
691 }
692 } else {
693 break
694 }
695 }
696
697 if opath != "/" {
698 if len(files) == 0 && len(directories) == 0 {
699
700
701 return nil, storagedriver.PathNotFoundError{Path: opath}
702 }
703 }
704
705 return append(files, directories...), nil
706 }
707
708
709
710 func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
711
712 if err := d.copy(ctx, sourcePath, destPath); err != nil {
713 return err
714 }
715 return d.Delete(ctx, sourcePath)
716 }
717
718
719 func (d *driver) copy(ctx context.Context, sourcePath string, destPath string) error {
720
721
722
723
724
725
726 fileInfo, err := d.Stat(ctx, sourcePath)
727 if err != nil {
728 return parseError(sourcePath, err)
729 }
730
731 if fileInfo.Size() <= d.MultipartCopyThresholdSize {
732 _, err := d.S3.CopyObject(&s3.CopyObjectInput{
733 Bucket: aws.String(d.Bucket),
734 Key: aws.String(d.s3Path(destPath)),
735 ContentType: d.getContentType(),
736 ACL: d.getACL(),
737 ServerSideEncryption: d.getEncryptionMode(),
738 SSEKMSKeyId: d.getSSEKMSKeyID(),
739 StorageClass: d.getStorageClass(),
740 CopySource: aws.String(d.Bucket + "/" + d.s3Path(sourcePath)),
741 })
742 if err != nil {
743 return parseError(sourcePath, err)
744 }
745 return nil
746 }
747
748 createResp, err := d.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
749 Bucket: aws.String(d.Bucket),
750 Key: aws.String(d.s3Path(destPath)),
751 ContentType: d.getContentType(),
752 ACL: d.getACL(),
753 SSEKMSKeyId: d.getSSEKMSKeyID(),
754 ServerSideEncryption: d.getEncryptionMode(),
755 StorageClass: d.getStorageClass(),
756 })
757 if err != nil {
758 return err
759 }
760
761 numParts := (fileInfo.Size() + d.MultipartCopyChunkSize - 1) / d.MultipartCopyChunkSize
762 completedParts := make([]*s3.CompletedPart, numParts)
763 errChan := make(chan error, numParts)
764 limiter := make(chan struct{}, d.MultipartCopyMaxConcurrency)
765
766 for i := range completedParts {
767 i := int64(i)
768 go func() {
769 limiter <- struct{}{}
770 firstByte := i * d.MultipartCopyChunkSize
771 lastByte := firstByte + d.MultipartCopyChunkSize - 1
772 if lastByte >= fileInfo.Size() {
773 lastByte = fileInfo.Size() - 1
774 }
775 uploadResp, err := d.S3.UploadPartCopy(&s3.UploadPartCopyInput{
776 Bucket: aws.String(d.Bucket),
777 CopySource: aws.String(d.Bucket + "/" + d.s3Path(sourcePath)),
778 Key: aws.String(d.s3Path(destPath)),
779 PartNumber: aws.Int64(i + 1),
780 UploadId: createResp.UploadId,
781 CopySourceRange: aws.String(fmt.Sprintf("bytes=%d-%d", firstByte, lastByte)),
782 })
783 if err == nil {
784 completedParts[i] = &s3.CompletedPart{
785 ETag: uploadResp.CopyPartResult.ETag,
786 PartNumber: aws.Int64(i + 1),
787 }
788 }
789 errChan <- err
790 <-limiter
791 }()
792 }
793
794 for range completedParts {
795 err := <-errChan
796 if err != nil {
797 return err
798 }
799 }
800
801 _, err = d.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
802 Bucket: aws.String(d.Bucket),
803 Key: aws.String(d.s3Path(destPath)),
804 UploadId: createResp.UploadId,
805 MultipartUpload: &s3.CompletedMultipartUpload{Parts: completedParts},
806 })
807 return err
808 }
809
810 func min(a, b int) int {
811 if a < b {
812 return a
813 }
814 return b
815 }
816
817
818
819 func (d *driver) Delete(ctx context.Context, path string) error {
820 s3Objects := make([]*s3.ObjectIdentifier, 0, listMax)
821 s3Path := d.s3Path(path)
822 listObjectsInput := &s3.ListObjectsInput{
823 Bucket: aws.String(d.Bucket),
824 Prefix: aws.String(s3Path),
825 }
826 ListLoop:
827 for {
828
829 resp, err := d.S3.ListObjects(listObjectsInput)
830
831
832
833
834 if err != nil || len(resp.Contents) == 0 {
835 return storagedriver.PathNotFoundError{Path: path}
836 }
837
838 for _, key := range resp.Contents {
839
840 if len(*key.Key) > len(s3Path) && (*key.Key)[len(s3Path)] != '/' {
841 break ListLoop
842 }
843 s3Objects = append(s3Objects, &s3.ObjectIdentifier{
844 Key: key.Key,
845 })
846 }
847
848
849 listObjectsInput.Marker = resp.Contents[len(resp.Contents)-1].Key
850
851
852
853 if resp.IsTruncated == nil || !*resp.IsTruncated {
854 break
855 }
856 }
857
858
859 total := len(s3Objects)
860 for i := 0; i < total; i += 1000 {
861 _, err := d.S3.DeleteObjects(&s3.DeleteObjectsInput{
862 Bucket: aws.String(d.Bucket),
863 Delete: &s3.Delete{
864 Objects: s3Objects[i:min(i+1000, total)],
865 Quiet: aws.Bool(false),
866 },
867 })
868 if err != nil {
869 return err
870 }
871 }
872 return nil
873 }
874
875
876
877 func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
878 methodString := "GET"
879 method, ok := options["method"]
880 if ok {
881 methodString, ok = method.(string)
882 if !ok || (methodString != "GET" && methodString != "HEAD") {
883 return "", storagedriver.ErrUnsupportedMethod{}
884 }
885 }
886
887 expiresIn := 20 * time.Minute
888 expires, ok := options["expiry"]
889 if ok {
890 et, ok := expires.(time.Time)
891 if ok {
892 expiresIn = time.Until(et)
893 }
894 }
895
896 var req *request.Request
897
898 switch methodString {
899 case "GET":
900 req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{
901 Bucket: aws.String(d.Bucket),
902 Key: aws.String(d.s3Path(path)),
903 })
904 case "HEAD":
905 req, _ = d.S3.HeadObjectRequest(&s3.HeadObjectInput{
906 Bucket: aws.String(d.Bucket),
907 Key: aws.String(d.s3Path(path)),
908 })
909 default:
910 panic("unreachable")
911 }
912
913 return req.Presign(expiresIn)
914 }
915
916
917
918 func (d *driver) Walk(ctx context.Context, from string, f storagedriver.WalkFn) error {
919 path := from
920 if !strings.HasSuffix(path, "/") {
921 path = path + "/"
922 }
923
924 prefix := ""
925 if d.s3Path("") == "" {
926 prefix = "/"
927 }
928
929 var objectCount int64
930 if err := d.doWalk(ctx, &objectCount, d.s3Path(path), prefix, f); err != nil {
931 return err
932 }
933
934
935 if objectCount == 0 {
936 return storagedriver.PathNotFoundError{Path: from}
937 }
938
939 return nil
940 }
941
942 type walkInfoContainer struct {
943 storagedriver.FileInfoFields
944 prefix *string
945 }
946
947
948 func (wi walkInfoContainer) Path() string {
949 return wi.FileInfoFields.Path
950 }
951
952
953
954
955 func (wi walkInfoContainer) Size() int64 {
956 return wi.FileInfoFields.Size
957 }
958
959
960
961 func (wi walkInfoContainer) ModTime() time.Time {
962 return wi.FileInfoFields.ModTime
963 }
964
965
966 func (wi walkInfoContainer) IsDir() bool {
967 return wi.FileInfoFields.IsDir
968 }
969
970 func (d *driver) doWalk(parentCtx context.Context, objectCount *int64, path, prefix string, f storagedriver.WalkFn) error {
971 var retError error
972
973 listObjectsInput := &s3.ListObjectsV2Input{
974 Bucket: aws.String(d.Bucket),
975 Prefix: aws.String(path),
976 Delimiter: aws.String("/"),
977 MaxKeys: aws.Int64(listMax),
978 }
979
980 ctx, done := dcontext.WithTrace(parentCtx)
981 defer done("s3aws.ListObjectsV2Pages(%s)", path)
982 listObjectErr := d.S3.ListObjectsV2PagesWithContext(ctx, listObjectsInput, func(objects *s3.ListObjectsV2Output, lastPage bool) bool {
983
984 var count int64
985
986
987
988 if objects.KeyCount != nil {
989 count = *objects.KeyCount
990 *objectCount += *objects.KeyCount
991 } else {
992 count = int64(len(objects.Contents) + len(objects.CommonPrefixes))
993 *objectCount += count
994 }
995
996 walkInfos := make([]walkInfoContainer, 0, count)
997
998 for _, dir := range objects.CommonPrefixes {
999 commonPrefix := *dir.Prefix
1000 walkInfos = append(walkInfos, walkInfoContainer{
1001 prefix: dir.Prefix,
1002 FileInfoFields: storagedriver.FileInfoFields{
1003 IsDir: true,
1004 Path: strings.Replace(commonPrefix[:len(commonPrefix)-1], d.s3Path(""), prefix, 1),
1005 },
1006 })
1007 }
1008
1009 for _, file := range objects.Contents {
1010 walkInfos = append(walkInfos, walkInfoContainer{
1011 FileInfoFields: storagedriver.FileInfoFields{
1012 IsDir: false,
1013 Size: *file.Size,
1014 ModTime: *file.LastModified,
1015 Path: strings.Replace(*file.Key, d.s3Path(""), prefix, 1),
1016 },
1017 })
1018 }
1019
1020 sort.SliceStable(walkInfos, func(i, j int) bool { return walkInfos[i].FileInfoFields.Path < walkInfos[j].FileInfoFields.Path })
1021
1022 for _, walkInfo := range walkInfos {
1023 err := f(walkInfo)
1024
1025 if err == storagedriver.ErrSkipDir {
1026 if walkInfo.IsDir() {
1027 continue
1028 } else {
1029 break
1030 }
1031 } else if err != nil {
1032 retError = err
1033 return false
1034 }
1035
1036 if walkInfo.IsDir() {
1037 if err := d.doWalk(ctx, objectCount, *walkInfo.prefix, prefix, f); err != nil {
1038 retError = err
1039 return false
1040 }
1041 }
1042 }
1043 return true
1044 })
1045
1046 if retError != nil {
1047 return retError
1048 }
1049
1050 if listObjectErr != nil {
1051 return listObjectErr
1052 }
1053
1054 return nil
1055 }
1056
1057 func (d *driver) s3Path(path string) string {
1058 return strings.TrimLeft(strings.TrimRight(d.RootDirectory, "/")+path, "/")
1059 }
1060
1061
1062 func (d *Driver) S3BucketKey(path string) string {
1063 return d.StorageDriver.(*driver).s3Path(path)
1064 }
1065
1066 func parseError(path string, err error) error {
1067 if s3Err, ok := err.(awserr.Error); ok && s3Err.Code() == "NoSuchKey" {
1068 return storagedriver.PathNotFoundError{Path: path}
1069 }
1070
1071 return err
1072 }
1073
1074 func (d *driver) getEncryptionMode() *string {
1075 if !d.Encrypt {
1076 return nil
1077 }
1078 if d.KeyID == "" {
1079 return aws.String("AES256")
1080 }
1081 return aws.String("aws:kms")
1082 }
1083
1084 func (d *driver) getSSEKMSKeyID() *string {
1085 if d.KeyID != "" {
1086 return aws.String(d.KeyID)
1087 }
1088 return nil
1089 }
1090
1091 func (d *driver) getContentType() *string {
1092 return aws.String("application/octet-stream")
1093 }
1094
1095 func (d *driver) getACL() *string {
1096 return aws.String(d.ObjectACL)
1097 }
1098
1099 func (d *driver) getStorageClass() *string {
1100 if d.StorageClass == noStorageClass {
1101 return nil
1102 }
1103 return aws.String(d.StorageClass)
1104 }
1105
1106
1107
1108
1109
1110 type writer struct {
1111 driver *driver
1112 key string
1113 uploadID string
1114 parts []*s3.Part
1115 size int64
1116 readyPart []byte
1117 pendingPart []byte
1118 closed bool
1119 committed bool
1120 cancelled bool
1121 }
1122
1123 func (d *driver) newWriter(key, uploadID string, parts []*s3.Part) storagedriver.FileWriter {
1124 var size int64
1125 for _, part := range parts {
1126 size += *part.Size
1127 }
1128 return &writer{
1129 driver: d,
1130 key: key,
1131 uploadID: uploadID,
1132 parts: parts,
1133 size: size,
1134 }
1135 }
1136
1137 type completedParts []*s3.CompletedPart
1138
1139 func (a completedParts) Len() int { return len(a) }
1140 func (a completedParts) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
1141 func (a completedParts) Less(i, j int) bool { return *a[i].PartNumber < *a[j].PartNumber }
1142
1143 func (w *writer) Write(p []byte) (int, error) {
1144 if w.closed {
1145 return 0, fmt.Errorf("already closed")
1146 } else if w.committed {
1147 return 0, fmt.Errorf("already committed")
1148 } else if w.cancelled {
1149 return 0, fmt.Errorf("already cancelled")
1150 }
1151
1152
1153
1154 if len(w.parts) > 0 && int(*w.parts[len(w.parts)-1].Size) < minChunkSize {
1155 var completedUploadedParts completedParts
1156 for _, part := range w.parts {
1157 completedUploadedParts = append(completedUploadedParts, &s3.CompletedPart{
1158 ETag: part.ETag,
1159 PartNumber: part.PartNumber,
1160 })
1161 }
1162
1163 sort.Sort(completedUploadedParts)
1164
1165 _, err := w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
1166 Bucket: aws.String(w.driver.Bucket),
1167 Key: aws.String(w.key),
1168 UploadId: aws.String(w.uploadID),
1169 MultipartUpload: &s3.CompletedMultipartUpload{
1170 Parts: completedUploadedParts,
1171 },
1172 })
1173 if err != nil {
1174 w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
1175 Bucket: aws.String(w.driver.Bucket),
1176 Key: aws.String(w.key),
1177 UploadId: aws.String(w.uploadID),
1178 })
1179 return 0, err
1180 }
1181
1182 resp, err := w.driver.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
1183 Bucket: aws.String(w.driver.Bucket),
1184 Key: aws.String(w.key),
1185 ContentType: w.driver.getContentType(),
1186 ACL: w.driver.getACL(),
1187 ServerSideEncryption: w.driver.getEncryptionMode(),
1188 StorageClass: w.driver.getStorageClass(),
1189 })
1190 if err != nil {
1191 return 0, err
1192 }
1193 w.uploadID = *resp.UploadId
1194
1195
1196
1197 if w.size < minChunkSize {
1198 resp, err := w.driver.S3.GetObject(&s3.GetObjectInput{
1199 Bucket: aws.String(w.driver.Bucket),
1200 Key: aws.String(w.key),
1201 })
1202 if err != nil {
1203 return 0, err
1204 }
1205 defer resp.Body.Close()
1206 w.parts = nil
1207 w.readyPart, err = ioutil.ReadAll(resp.Body)
1208 if err != nil {
1209 return 0, err
1210 }
1211 } else {
1212
1213 copyPartResp, err := w.driver.S3.UploadPartCopy(&s3.UploadPartCopyInput{
1214 Bucket: aws.String(w.driver.Bucket),
1215 CopySource: aws.String(w.driver.Bucket + "/" + w.key),
1216 Key: aws.String(w.key),
1217 PartNumber: aws.Int64(1),
1218 UploadId: resp.UploadId,
1219 })
1220 if err != nil {
1221 return 0, err
1222 }
1223 w.parts = []*s3.Part{
1224 {
1225 ETag: copyPartResp.CopyPartResult.ETag,
1226 PartNumber: aws.Int64(1),
1227 Size: aws.Int64(w.size),
1228 },
1229 }
1230 }
1231 }
1232
1233 var n int
1234
1235 for len(p) > 0 {
1236
1237 if neededBytes := int(w.driver.ChunkSize) - len(w.readyPart); neededBytes > 0 {
1238 if len(p) >= neededBytes {
1239 w.readyPart = append(w.readyPart, p[:neededBytes]...)
1240 n += neededBytes
1241 p = p[neededBytes:]
1242 } else {
1243 w.readyPart = append(w.readyPart, p...)
1244 n += len(p)
1245 p = nil
1246 }
1247 }
1248
1249 if neededBytes := int(w.driver.ChunkSize) - len(w.pendingPart); neededBytes > 0 {
1250 if len(p) >= neededBytes {
1251 w.pendingPart = append(w.pendingPart, p[:neededBytes]...)
1252 n += neededBytes
1253 p = p[neededBytes:]
1254 err := w.flushPart()
1255 if err != nil {
1256 w.size += int64(n)
1257 return n, err
1258 }
1259 } else {
1260 w.pendingPart = append(w.pendingPart, p...)
1261 n += len(p)
1262 p = nil
1263 }
1264 }
1265 }
1266 w.size += int64(n)
1267 return n, nil
1268 }
1269
1270 func (w *writer) Size() int64 {
1271 return w.size
1272 }
1273
1274 func (w *writer) Close() error {
1275 if w.closed {
1276 return fmt.Errorf("already closed")
1277 }
1278 w.closed = true
1279 return w.flushPart()
1280 }
1281
1282 func (w *writer) Cancel() error {
1283 if w.closed {
1284 return fmt.Errorf("already closed")
1285 } else if w.committed {
1286 return fmt.Errorf("already committed")
1287 }
1288 w.cancelled = true
1289 _, err := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
1290 Bucket: aws.String(w.driver.Bucket),
1291 Key: aws.String(w.key),
1292 UploadId: aws.String(w.uploadID),
1293 })
1294 return err
1295 }
1296
1297 func (w *writer) Commit() error {
1298 if w.closed {
1299 return fmt.Errorf("already closed")
1300 } else if w.committed {
1301 return fmt.Errorf("already committed")
1302 } else if w.cancelled {
1303 return fmt.Errorf("already cancelled")
1304 }
1305 err := w.flushPart()
1306 if err != nil {
1307 return err
1308 }
1309 w.committed = true
1310
1311 var completedUploadedParts completedParts
1312 for _, part := range w.parts {
1313 completedUploadedParts = append(completedUploadedParts, &s3.CompletedPart{
1314 ETag: part.ETag,
1315 PartNumber: part.PartNumber,
1316 })
1317 }
1318
1319 sort.Sort(completedUploadedParts)
1320
1321 _, err = w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
1322 Bucket: aws.String(w.driver.Bucket),
1323 Key: aws.String(w.key),
1324 UploadId: aws.String(w.uploadID),
1325 MultipartUpload: &s3.CompletedMultipartUpload{
1326 Parts: completedUploadedParts,
1327 },
1328 })
1329 if err != nil {
1330 w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
1331 Bucket: aws.String(w.driver.Bucket),
1332 Key: aws.String(w.key),
1333 UploadId: aws.String(w.uploadID),
1334 })
1335 return err
1336 }
1337 return nil
1338 }
1339
1340
1341
1342 func (w *writer) flushPart() error {
1343 if len(w.readyPart) == 0 && len(w.pendingPart) == 0 {
1344
1345 return nil
1346 }
1347 if len(w.pendingPart) < int(w.driver.ChunkSize) {
1348
1349
1350 w.readyPart = append(w.readyPart, w.pendingPart...)
1351 w.pendingPart = nil
1352 }
1353
1354 partNumber := aws.Int64(int64(len(w.parts) + 1))
1355 resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{
1356 Bucket: aws.String(w.driver.Bucket),
1357 Key: aws.String(w.key),
1358 PartNumber: partNumber,
1359 UploadId: aws.String(w.uploadID),
1360 Body: bytes.NewReader(w.readyPart),
1361 })
1362 if err != nil {
1363 return err
1364 }
1365 w.parts = append(w.parts, &s3.Part{
1366 ETag: resp.ETag,
1367 PartNumber: partNumber,
1368 Size: aws.Int64(int64(len(w.readyPart))),
1369 })
1370 w.readyPart = w.pendingPart
1371 w.pendingPart = nil
1372 return nil
1373 }
1374
View as plain text