...

Source file src/github.com/docker/distribution/registry/storage/driver/middleware/cloudfront/middleware.go

Documentation: github.com/docker/distribution/registry/storage/driver/middleware/cloudfront

     1  // Package middleware - cloudfront wrapper for storage libs
     2  // N.B. currently only works with S3, not arbitrary sites
     3  package middleware
     4  
     5  import (
     6  	"context"
     7  	"crypto/x509"
     8  	"encoding/pem"
     9  	"fmt"
    10  	"io/ioutil"
    11  	"net/url"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/aws/aws-sdk-go/service/cloudfront/sign"
    16  	dcontext "github.com/docker/distribution/context"
    17  	storagedriver "github.com/docker/distribution/registry/storage/driver"
    18  	storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
    19  )
    20  
    21  // cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
    22  // constructs temporary signed CloudFront URLs from the storagedriver layer URL,
    23  // then issues HTTP Temporary Redirects to this CloudFront content URL.
    24  type cloudFrontStorageMiddleware struct {
    25  	storagedriver.StorageDriver
    26  	awsIPs    *awsIPs
    27  	urlSigner *sign.URLSigner
    28  	baseURL   string
    29  	duration  time.Duration
    30  }
    31  
    32  var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
    33  
    34  // newCloudFrontLayerHandler constructs and returns a new CloudFront
    35  // LayerHandler implementation.
    36  // Required options: baseurl, privatekey, keypairid
    37  
    38  // Optional options: ipFilteredBy, awsregion
    39  // ipfilteredby: valid value "none|aws|awsregion". "none", do not filter any IP, default value. "aws", only aws IP goes
    40  //
    41  //	to S3 directly. "awsregion", only regions listed in awsregion options goes to S3 directly
    42  //
    43  // awsregion: a comma separated string of AWS regions.
    44  func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
    45  	// parse baseurl
    46  	base, ok := options["baseurl"]
    47  	if !ok {
    48  		return nil, fmt.Errorf("no baseurl provided")
    49  	}
    50  	baseURL, ok := base.(string)
    51  	if !ok {
    52  		return nil, fmt.Errorf("baseurl must be a string")
    53  	}
    54  	if !strings.Contains(baseURL, "://") {
    55  		baseURL = "https://" + baseURL
    56  	}
    57  	if !strings.HasSuffix(baseURL, "/") {
    58  		baseURL += "/"
    59  	}
    60  	if _, err := url.Parse(baseURL); err != nil {
    61  		return nil, fmt.Errorf("invalid baseurl: %v", err)
    62  	}
    63  
    64  	// parse privatekey to get pkPath
    65  	pk, ok := options["privatekey"]
    66  	if !ok {
    67  		return nil, fmt.Errorf("no privatekey provided")
    68  	}
    69  	pkPath, ok := pk.(string)
    70  	if !ok {
    71  		return nil, fmt.Errorf("privatekey must be a string")
    72  	}
    73  
    74  	// parse keypairid
    75  	kpid, ok := options["keypairid"]
    76  	if !ok {
    77  		return nil, fmt.Errorf("no keypairid provided")
    78  	}
    79  	keypairID, ok := kpid.(string)
    80  	if !ok {
    81  		return nil, fmt.Errorf("keypairid must be a string")
    82  	}
    83  
    84  	// get urlSigner from the file specified in pkPath
    85  	pkBytes, err := ioutil.ReadFile(pkPath)
    86  	if err != nil {
    87  		return nil, fmt.Errorf("failed to read privatekey file: %s", err)
    88  	}
    89  
    90  	block, _ := pem.Decode(pkBytes)
    91  	if block == nil {
    92  		return nil, fmt.Errorf("failed to decode private key as an rsa private key")
    93  	}
    94  	privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  	urlSigner := sign.NewURLSigner(keypairID, privateKey)
    99  
   100  	// parse duration
   101  	duration := 20 * time.Minute
   102  	if d, ok := options["duration"]; ok {
   103  		switch d := d.(type) {
   104  		case time.Duration:
   105  			duration = d
   106  		case string:
   107  			dur, err := time.ParseDuration(d)
   108  			if err != nil {
   109  				return nil, fmt.Errorf("invalid duration: %s", err)
   110  			}
   111  			duration = dur
   112  		}
   113  	}
   114  
   115  	// parse updatefrenquency
   116  	updateFrequency := defaultUpdateFrequency
   117  	if u, ok := options["updatefrenquency"]; ok {
   118  		switch u := u.(type) {
   119  		case time.Duration:
   120  			updateFrequency = u
   121  		case string:
   122  			updateFreq, err := time.ParseDuration(u)
   123  			if err != nil {
   124  				return nil, fmt.Errorf("invalid updatefrenquency: %s", err)
   125  			}
   126  			duration = updateFreq
   127  		}
   128  	}
   129  
   130  	// parse iprangesurl
   131  	ipRangesURL := defaultIPRangesURL
   132  	if i, ok := options["iprangesurl"]; ok {
   133  		if iprangeurl, ok := i.(string); ok {
   134  			ipRangesURL = iprangeurl
   135  		} else {
   136  			return nil, fmt.Errorf("iprangesurl must be a string")
   137  		}
   138  	}
   139  
   140  	// parse ipfilteredby
   141  	var awsIPs *awsIPs
   142  	if i, ok := options["ipfilteredby"]; ok {
   143  		if ipFilteredBy, ok := i.(string); ok {
   144  			switch strings.ToLower(strings.TrimSpace(ipFilteredBy)) {
   145  			case "", "none":
   146  				awsIPs = nil
   147  			case "aws":
   148  				awsIPs = newAWSIPs(ipRangesURL, updateFrequency, nil)
   149  			case "awsregion":
   150  				var awsRegion []string
   151  				if i, ok := options["awsregion"]; ok {
   152  					if regions, ok := i.(string); ok {
   153  						for _, awsRegions := range strings.Split(regions, ",") {
   154  							awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
   155  						}
   156  						awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion)
   157  					} else {
   158  						return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
   159  					}
   160  				} else {
   161  					return nil, fmt.Errorf("awsRegion is not defined")
   162  				}
   163  			default:
   164  				return nil, fmt.Errorf("ipfilteredby only allows a string the following value: none|aws|awsregion")
   165  			}
   166  		} else {
   167  			return nil, fmt.Errorf("ipfilteredby only allows a string with the following value: none|aws|awsregion")
   168  		}
   169  	}
   170  
   171  	return &cloudFrontStorageMiddleware{
   172  		StorageDriver: storageDriver,
   173  		urlSigner:     urlSigner,
   174  		baseURL:       baseURL,
   175  		duration:      duration,
   176  		awsIPs:        awsIPs,
   177  	}, nil
   178  }
   179  
   180  // S3BucketKeyer is any type that is capable of returning the S3 bucket key
   181  // which should be cached by AWS CloudFront.
   182  type S3BucketKeyer interface {
   183  	S3BucketKey(path string) string
   184  }
   185  
   186  // URLFor attempts to find a url which may be used to retrieve the file at the given path.
   187  // Returns an error if the file cannot be found.
   188  func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
   189  	// TODO(endophage): currently only supports S3
   190  	keyer, ok := lh.StorageDriver.(S3BucketKeyer)
   191  	if !ok {
   192  		dcontext.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver")
   193  		return lh.StorageDriver.URLFor(ctx, path, options)
   194  	}
   195  
   196  	if eligibleForS3(ctx, lh.awsIPs) {
   197  		return lh.StorageDriver.URLFor(ctx, path, options)
   198  	}
   199  
   200  	// Get signed cloudfront url.
   201  	cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration))
   202  	if err != nil {
   203  		return "", err
   204  	}
   205  	return cfURL, nil
   206  }
   207  
   208  // init registers the cloudfront layerHandler backend.
   209  func init() {
   210  	storagemiddleware.Register("cloudfront", storagemiddleware.InitFunc(newCloudFrontStorageMiddleware))
   211  }
   212  

View as plain text