...

Source file src/go.mongodb.org/mongo-driver/internal/aws/signer/v4/v4.go

Documentation: go.mongodb.org/mongo-driver/internal/aws/signer/v4

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
     8  // - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4.go
     9  // See THIRD-PARTY-NOTICES for original license terms
    10  
    11  package v4
    12  
    13  import (
    14  	"crypto/hmac"
    15  	"crypto/sha256"
    16  	"encoding/hex"
    17  	"fmt"
    18  	"io"
    19  	"io/ioutil"
    20  	"net/http"
    21  	"net/url"
    22  	"sort"
    23  	"strings"
    24  	"time"
    25  
    26  	"go.mongodb.org/mongo-driver/internal/aws"
    27  	"go.mongodb.org/mongo-driver/internal/aws/credentials"
    28  )
    29  
    30  const (
    31  	authorizationHeader     = "Authorization"
    32  	authHeaderSignatureElem = "Signature="
    33  
    34  	authHeaderPrefix = "AWS4-HMAC-SHA256"
    35  	timeFormat       = "20060102T150405Z"
    36  	shortTimeFormat  = "20060102"
    37  	awsV4Request     = "aws4_request"
    38  
    39  	// emptyStringSHA256 is a SHA256 of an empty string
    40  	emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
    41  )
    42  
    43  var ignoredHeaders = rules{
    44  	excludeList{
    45  		mapRule{
    46  			authorizationHeader: struct{}{},
    47  			"User-Agent":        struct{}{},
    48  			"X-Amzn-Trace-Id":   struct{}{},
    49  		},
    50  	},
    51  }
    52  
    53  // Signer applies AWS v4 signing to given request. Use this to sign requests
    54  // that need to be signed with AWS V4 Signatures.
    55  type Signer struct {
    56  	// The authentication credentials the request will be signed against.
    57  	// This value must be set to sign requests.
    58  	Credentials *credentials.Credentials
    59  }
    60  
    61  // NewSigner returns a Signer pointer configured with the credentials provided.
    62  func NewSigner(credentials *credentials.Credentials) *Signer {
    63  	v4 := &Signer{
    64  		Credentials: credentials,
    65  	}
    66  
    67  	return v4
    68  }
    69  
    70  type signingCtx struct {
    71  	ServiceName      string
    72  	Region           string
    73  	Request          *http.Request
    74  	Body             io.ReadSeeker
    75  	Query            url.Values
    76  	Time             time.Time
    77  	SignedHeaderVals http.Header
    78  
    79  	credValues credentials.Value
    80  
    81  	bodyDigest       string
    82  	signedHeaders    string
    83  	canonicalHeaders string
    84  	canonicalString  string
    85  	credentialString string
    86  	stringToSign     string
    87  	signature        string
    88  }
    89  
    90  // Sign signs AWS v4 requests with the provided body, service name, region the
    91  // request is made to, and time the request is signed at. The signTime allows
    92  // you to specify that a request is signed for the future, and cannot be
    93  // used until then.
    94  //
    95  // Returns a list of HTTP headers that were included in the signature or an
    96  // error if signing the request failed. Generally for signed requests this value
    97  // is not needed as the full request context will be captured by the http.Request
    98  // value. It is included for reference though.
    99  //
   100  // Sign will set the request's Body to be the `body` parameter passed in. If
   101  // the body is not already an io.ReadCloser, it will be wrapped within one. If
   102  // a `nil` body parameter passed to Sign, the request's Body field will be
   103  // also set to nil. Its important to note that this functionality will not
   104  // change the request's ContentLength of the request.
   105  //
   106  // Sign differs from Presign in that it will sign the request using HTTP
   107  // header values. This type of signing is intended for http.Request values that
   108  // will not be shared, or are shared in a way the header values on the request
   109  // will not be lost.
   110  //
   111  // The requests body is an io.ReadSeeker so the SHA256 of the body can be
   112  // generated. To bypass the signer computing the hash you can set the
   113  // "X-Amz-Content-Sha256" header with a precomputed value. The signer will
   114  // only compute the hash if the request header value is empty.
   115  func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
   116  	return v4.signWithBody(r, body, service, region, signTime)
   117  }
   118  
   119  func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
   120  	ctx := &signingCtx{
   121  		Request:     r,
   122  		Body:        body,
   123  		Query:       r.URL.Query(),
   124  		Time:        signTime,
   125  		ServiceName: service,
   126  		Region:      region,
   127  	}
   128  
   129  	for key := range ctx.Query {
   130  		sort.Strings(ctx.Query[key])
   131  	}
   132  
   133  	if ctx.isRequestSigned() {
   134  		ctx.Time = time.Now()
   135  	}
   136  
   137  	var err error
   138  	ctx.credValues, err = v4.Credentials.GetWithContext(r.Context())
   139  	if err != nil {
   140  		return http.Header{}, err
   141  	}
   142  
   143  	ctx.sanitizeHostForHeader()
   144  	ctx.assignAmzQueryValues()
   145  	if err := ctx.build(); err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	var reader io.ReadCloser
   150  	if body != nil {
   151  		var ok bool
   152  		if reader, ok = body.(io.ReadCloser); !ok {
   153  			reader = ioutil.NopCloser(body)
   154  		}
   155  	}
   156  	r.Body = reader
   157  
   158  	return ctx.SignedHeaderVals, nil
   159  }
   160  
   161  // sanitizeHostForHeader removes default port from host and updates request.Host
   162  func (ctx *signingCtx) sanitizeHostForHeader() {
   163  	r := ctx.Request
   164  	host := getHost(r)
   165  	port := portOnly(host)
   166  	if port != "" && isDefaultPort(r.URL.Scheme, port) {
   167  		r.Host = stripPort(host)
   168  	}
   169  }
   170  
   171  func (ctx *signingCtx) assignAmzQueryValues() {
   172  	if ctx.credValues.SessionToken != "" {
   173  		ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
   174  	}
   175  }
   176  
   177  func (ctx *signingCtx) build() error {
   178  	ctx.buildTime()             // no depends
   179  	ctx.buildCredentialString() // no depends
   180  
   181  	if err := ctx.buildBodyDigest(); err != nil {
   182  		return err
   183  	}
   184  
   185  	unsignedHeaders := ctx.Request.Header
   186  
   187  	ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
   188  	ctx.buildCanonicalString() // depends on canon headers / signed headers
   189  	ctx.buildStringToSign()    // depends on canon string
   190  	ctx.buildSignature()       // depends on string to sign
   191  
   192  	parts := []string{
   193  		authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
   194  		"SignedHeaders=" + ctx.signedHeaders,
   195  		authHeaderSignatureElem + ctx.signature,
   196  	}
   197  	ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
   198  
   199  	return nil
   200  }
   201  
   202  func (ctx *signingCtx) buildTime() {
   203  	ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
   204  }
   205  
   206  func (ctx *signingCtx) buildCredentialString() {
   207  	ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
   208  }
   209  
   210  func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
   211  	headers := make([]string, 0, len(header)+1)
   212  	headers = append(headers, "host")
   213  	for k, v := range header {
   214  		if !r.IsValid(k) {
   215  			continue // ignored header
   216  		}
   217  		if ctx.SignedHeaderVals == nil {
   218  			ctx.SignedHeaderVals = make(http.Header)
   219  		}
   220  
   221  		lowerCaseKey := strings.ToLower(k)
   222  		if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
   223  			// include additional values
   224  			ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
   225  			continue
   226  		}
   227  
   228  		headers = append(headers, lowerCaseKey)
   229  		ctx.SignedHeaderVals[lowerCaseKey] = v
   230  	}
   231  	sort.Strings(headers)
   232  
   233  	ctx.signedHeaders = strings.Join(headers, ";")
   234  
   235  	headerItems := make([]string, len(headers))
   236  	for i, k := range headers {
   237  		if k == "host" {
   238  			if ctx.Request.Host != "" {
   239  				headerItems[i] = "host:" + ctx.Request.Host
   240  			} else {
   241  				headerItems[i] = "host:" + ctx.Request.URL.Host
   242  			}
   243  		} else {
   244  			headerValues := make([]string, len(ctx.SignedHeaderVals[k]))
   245  			for i, v := range ctx.SignedHeaderVals[k] {
   246  				headerValues[i] = strings.TrimSpace(v)
   247  			}
   248  			headerItems[i] = k + ":" +
   249  				strings.Join(headerValues, ",")
   250  		}
   251  	}
   252  	stripExcessSpaces(headerItems)
   253  	ctx.canonicalHeaders = strings.Join(headerItems, "\n")
   254  }
   255  
   256  func (ctx *signingCtx) buildCanonicalString() {
   257  	ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
   258  
   259  	uri := getURIPath(ctx.Request.URL)
   260  
   261  	uri = EscapePath(uri, false)
   262  
   263  	ctx.canonicalString = strings.Join([]string{
   264  		ctx.Request.Method,
   265  		uri,
   266  		ctx.Request.URL.RawQuery,
   267  		ctx.canonicalHeaders + "\n",
   268  		ctx.signedHeaders,
   269  		ctx.bodyDigest,
   270  	}, "\n")
   271  }
   272  
   273  func (ctx *signingCtx) buildStringToSign() {
   274  	ctx.stringToSign = strings.Join([]string{
   275  		authHeaderPrefix,
   276  		formatTime(ctx.Time),
   277  		ctx.credentialString,
   278  		hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
   279  	}, "\n")
   280  }
   281  
   282  func (ctx *signingCtx) buildSignature() {
   283  	creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
   284  	signature := hmacSHA256(creds, []byte(ctx.stringToSign))
   285  	ctx.signature = hex.EncodeToString(signature)
   286  }
   287  
   288  func (ctx *signingCtx) buildBodyDigest() error {
   289  	hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
   290  	if hash == "" {
   291  		if ctx.Body == nil {
   292  			hash = emptyStringSHA256
   293  		} else {
   294  			if !aws.IsReaderSeekable(ctx.Body) {
   295  				return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
   296  			}
   297  			hashBytes, err := makeSha256Reader(ctx.Body)
   298  			if err != nil {
   299  				return err
   300  			}
   301  			hash = hex.EncodeToString(hashBytes)
   302  		}
   303  	}
   304  	ctx.bodyDigest = hash
   305  
   306  	return nil
   307  }
   308  
   309  // isRequestSigned returns if the request is currently signed or presigned
   310  func (ctx *signingCtx) isRequestSigned() bool {
   311  	return ctx.Request.Header.Get("Authorization") != ""
   312  }
   313  
   314  func hmacSHA256(key []byte, data []byte) []byte {
   315  	hash := hmac.New(sha256.New, key)
   316  	hash.Write(data)
   317  	return hash.Sum(nil)
   318  }
   319  
   320  func hashSHA256(data []byte) []byte {
   321  	hash := sha256.New()
   322  	hash.Write(data)
   323  	return hash.Sum(nil)
   324  }
   325  
   326  func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
   327  	hash := sha256.New()
   328  	start, err := reader.Seek(0, io.SeekCurrent)
   329  	if err != nil {
   330  		return nil, err
   331  	}
   332  	defer func() {
   333  		// ensure error is return if unable to seek back to start of payload.
   334  		_, err = reader.Seek(start, io.SeekStart)
   335  	}()
   336  
   337  	// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
   338  	// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
   339  	size, err := aws.SeekerLen(reader)
   340  	if err != nil {
   341  		_, _ = io.Copy(hash, reader)
   342  	} else {
   343  		_, _ = io.CopyN(hash, reader, size)
   344  	}
   345  
   346  	return hash.Sum(nil), nil
   347  }
   348  
   349  const doubleSpace = "  "
   350  
   351  // stripExcessSpaces will rewrite the passed in slice's string values to not
   352  // contain multiple side-by-side spaces.
   353  func stripExcessSpaces(vals []string) {
   354  	var j, k, l, m, spaces int
   355  	for i, str := range vals {
   356  		// revive:disable:empty-block
   357  
   358  		// Trim trailing spaces
   359  		for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
   360  		}
   361  
   362  		// Trim leading spaces
   363  		for k = 0; k < j && str[k] == ' '; k++ {
   364  		}
   365  
   366  		// revive:enable:empty-block
   367  
   368  		str = str[k : j+1]
   369  
   370  		// Strip multiple spaces.
   371  		j = strings.Index(str, doubleSpace)
   372  		if j < 0 {
   373  			vals[i] = str
   374  			continue
   375  		}
   376  
   377  		buf := []byte(str)
   378  		for k, m, l = j, j, len(buf); k < l; k++ {
   379  			if buf[k] == ' ' {
   380  				if spaces == 0 {
   381  					// First space.
   382  					buf[m] = buf[k]
   383  					m++
   384  				}
   385  				spaces++
   386  			} else {
   387  				// End of multiple spaces.
   388  				spaces = 0
   389  				buf[m] = buf[k]
   390  				m++
   391  			}
   392  		}
   393  
   394  		vals[i] = string(buf[:m])
   395  	}
   396  }
   397  
   398  func buildSigningScope(region, service string, dt time.Time) string {
   399  	return strings.Join([]string{
   400  		formatShortTime(dt),
   401  		region,
   402  		service,
   403  		awsV4Request,
   404  	}, "/")
   405  }
   406  
   407  func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
   408  	keyDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
   409  	keyRegion := hmacSHA256(keyDate, []byte(region))
   410  	keyService := hmacSHA256(keyRegion, []byte(service))
   411  	signingKey := hmacSHA256(keyService, []byte(awsV4Request))
   412  	return signingKey
   413  }
   414  
   415  func formatShortTime(dt time.Time) string {
   416  	return dt.UTC().Format(shortTimeFormat)
   417  }
   418  
   419  func formatTime(dt time.Time) string {
   420  	return dt.UTC().Format(timeFormat)
   421  }
   422  

View as plain text