...

Source file src/go.mongodb.org/mongo-driver/internal/aws/signer/v4/v4_test.go

Documentation: go.mongodb.org/mongo-driver/internal/aws/signer/v4

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
     8  // - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4_test.go
     9  // See THIRD-PARTY-NOTICES for original license terms
    10  
    11  package v4
    12  
    13  import (
    14  	"bytes"
    15  	"io"
    16  	"io/ioutil"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"reflect"
    20  	"strconv"
    21  	"strings"
    22  	"testing"
    23  	"time"
    24  
    25  	"go.mongodb.org/mongo-driver/internal/aws"
    26  	"go.mongodb.org/mongo-driver/internal/aws/credentials"
    27  	"go.mongodb.org/mongo-driver/internal/credproviders"
    28  )
    29  
    30  func epochTime() time.Time { return time.Unix(0, 0) }
    31  
    32  func TestStripExcessHeaders(t *testing.T) {
    33  	vals := []string{
    34  		"",
    35  		"123",
    36  		"1 2 3",
    37  		"1 2 3 ",
    38  		"  1 2 3",
    39  		"1  2 3",
    40  		"1  23",
    41  		"1  2  3",
    42  		"1  2  ",
    43  		" 1  2  ",
    44  		"12   3",
    45  		"12   3   1",
    46  		"12           3     1",
    47  		"12     3       1abc123",
    48  	}
    49  
    50  	expected := []string{
    51  		"",
    52  		"123",
    53  		"1 2 3",
    54  		"1 2 3",
    55  		"1 2 3",
    56  		"1 2 3",
    57  		"1 23",
    58  		"1 2 3",
    59  		"1 2",
    60  		"1 2",
    61  		"12 3",
    62  		"12 3 1",
    63  		"12 3 1",
    64  		"12 3 1abc123",
    65  	}
    66  
    67  	stripExcessSpaces(vals)
    68  	for i := 0; i < len(vals); i++ {
    69  		if e, a := expected[i], vals[i]; e != a {
    70  			t.Errorf("%d, expect %v, got %v", i, e, a)
    71  		}
    72  	}
    73  }
    74  
    75  func buildRequest(body string) (*http.Request, io.ReadSeeker) {
    76  	reader := strings.NewReader(body)
    77  	return buildRequestWithBodyReader("dynamodb", "us-east-1", reader)
    78  }
    79  
    80  func buildRequestReaderSeeker(serviceName, region, body string) (*http.Request, io.ReadSeeker) {
    81  	reader := &readerSeekerWrapper{strings.NewReader(body)}
    82  	return buildRequestWithBodyReader(serviceName, region, reader)
    83  }
    84  
    85  func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, io.ReadSeeker) {
    86  	var bodyLen int
    87  
    88  	type lenner interface {
    89  		Len() int
    90  	}
    91  	if lr, ok := body.(lenner); ok {
    92  		bodyLen = lr.Len()
    93  	}
    94  
    95  	endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
    96  	req, _ := http.NewRequest("POST", endpoint, body)
    97  	req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
    98  	req.Header.Set("X-Amz-Target", "prefix.Operation")
    99  	req.Header.Set("Content-Type", "application/x-amz-json-1.0")
   100  
   101  	if bodyLen > 0 {
   102  		req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
   103  	}
   104  
   105  	req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
   106  	req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
   107  	req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
   108  
   109  	var seeker io.ReadSeeker
   110  	if sr, ok := body.(io.ReadSeeker); ok {
   111  		seeker = sr
   112  	} else {
   113  		seeker = aws.ReadSeekCloser(body)
   114  	}
   115  
   116  	return req, seeker
   117  }
   118  
   119  func buildSigner() Signer {
   120  	return Signer{
   121  		Credentials: newTestStaticCredentials(),
   122  	}
   123  }
   124  
   125  func newTestStaticCredentials() *credentials.Credentials {
   126  	return credentials.NewCredentials(&credproviders.StaticProvider{Value: credentials.Value{
   127  		AccessKeyID:     "AKID",
   128  		SecretAccessKey: "SECRET",
   129  		SessionToken:    "SESSION",
   130  	}})
   131  }
   132  
   133  func TestSignRequest(t *testing.T) {
   134  	req, body := buildRequest("{}")
   135  	signer := buildSigner()
   136  	_, err := signer.Sign(req, body, "dynamodb", "us-east-1", epochTime())
   137  	if err != nil {
   138  		t.Errorf("Expected no err, got %v", err)
   139  	}
   140  
   141  	expectedDate := "19700101T000000Z"
   142  	expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
   143  
   144  	q := req.Header
   145  	if e, a := expectedSig, q.Get("Authorization"); e != a {
   146  		t.Errorf("expect\n%v\nactual\n%v\n", e, a)
   147  	}
   148  	if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
   149  		t.Errorf("expect\n%v\nactual\n%v\n", e, a)
   150  	}
   151  }
   152  
   153  func TestSignUnseekableBody(t *testing.T) {
   154  	req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
   155  	signer := buildSigner()
   156  	_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
   157  	if err == nil {
   158  		t.Fatalf("expect error signing request")
   159  	}
   160  
   161  	if e, a := "unseekable request body", err.Error(); !strings.Contains(a, e) {
   162  		t.Errorf("expect %q to be in %q", e, a)
   163  	}
   164  }
   165  
   166  func TestSignPreComputedHashUnseekableBody(t *testing.T) {
   167  	req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
   168  
   169  	signer := buildSigner()
   170  
   171  	req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256")
   172  	_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
   173  	if err != nil {
   174  		t.Fatalf("expect no error, got %v", err)
   175  	}
   176  
   177  	hash := req.Header.Get("X-Amz-Content-Sha256")
   178  	if e, a := "some-content-sha256", hash; e != a {
   179  		t.Errorf("expect %v, got %v", e, a)
   180  	}
   181  }
   182  
   183  func TestSignPrecomputedBodyChecksum(t *testing.T) {
   184  	req, body := buildRequest("hello")
   185  	req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
   186  	signer := buildSigner()
   187  	_, err := signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
   188  	if err != nil {
   189  		t.Errorf("Expected no err, got %v", err)
   190  	}
   191  	hash := req.Header.Get("X-Amz-Content-Sha256")
   192  	if e, a := "PRECOMPUTED", hash; e != a {
   193  		t.Errorf("expect %v, got %v", e, a)
   194  	}
   195  }
   196  
   197  func TestSignWithRequestBody(t *testing.T) {
   198  	creds := newTestStaticCredentials()
   199  	signer := NewSigner(creds)
   200  
   201  	expectBody := []byte("abc123")
   202  
   203  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   204  		b, err := ioutil.ReadAll(r.Body)
   205  		r.Body.Close()
   206  		if err != nil {
   207  			t.Errorf("expect no error, got %v", err)
   208  		}
   209  		if e, a := expectBody, b; !reflect.DeepEqual(e, a) {
   210  			t.Errorf("expect %v, got %v", e, a)
   211  		}
   212  		w.WriteHeader(http.StatusOK)
   213  	}))
   214  	defer server.Close()
   215  
   216  	req, err := http.NewRequest("POST", server.URL, nil)
   217  	if err != nil {
   218  		t.Errorf("expect not no error, got %v", err)
   219  	}
   220  
   221  	_, err = signer.Sign(req, bytes.NewReader(expectBody), "service", "region", time.Now())
   222  	if err != nil {
   223  		t.Errorf("expect not no error, got %v", err)
   224  	}
   225  
   226  	resp, err := http.DefaultClient.Do(req)
   227  	if err != nil {
   228  		t.Errorf("expect not no error, got %v", err)
   229  	}
   230  	if e, a := http.StatusOK, resp.StatusCode; e != a {
   231  		t.Errorf("expect %v, got %v", e, a)
   232  	}
   233  }
   234  
   235  func TestSignWithRequestBody_Overwrite(t *testing.T) {
   236  	creds := newTestStaticCredentials()
   237  	signer := NewSigner(creds)
   238  
   239  	var expectBody []byte
   240  
   241  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   242  		b, err := ioutil.ReadAll(r.Body)
   243  		r.Body.Close()
   244  		if err != nil {
   245  			t.Errorf("expect not no error, got %v", err)
   246  		}
   247  		if e, a := len(expectBody), len(b); e != a {
   248  			t.Errorf("expect %v, got %v", e, a)
   249  		}
   250  		w.WriteHeader(http.StatusOK)
   251  	}))
   252  	defer server.Close()
   253  
   254  	req, err := http.NewRequest("GET", server.URL, strings.NewReader("invalid body"))
   255  	if err != nil {
   256  		t.Errorf("expect not no error, got %v", err)
   257  	}
   258  
   259  	_, err = signer.Sign(req, nil, "service", "region", time.Now())
   260  	req.ContentLength = 0
   261  
   262  	if err != nil {
   263  		t.Errorf("expect not no error, got %v", err)
   264  	}
   265  
   266  	resp, err := http.DefaultClient.Do(req)
   267  	if err != nil {
   268  		t.Errorf("expect not no error, got %v", err)
   269  	}
   270  	if e, a := http.StatusOK, resp.StatusCode; e != a {
   271  		t.Errorf("expect %v, got %v", e, a)
   272  	}
   273  }
   274  
   275  func TestBuildCanonicalRequest(t *testing.T) {
   276  	req, body := buildRequest("{}")
   277  	req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
   278  	ctx := &signingCtx{
   279  		ServiceName: "dynamodb",
   280  		Region:      "us-east-1",
   281  		Request:     req,
   282  		Body:        body,
   283  		Query:       req.URL.Query(),
   284  		Time:        time.Now(),
   285  	}
   286  
   287  	ctx.buildCanonicalString()
   288  	expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=z&Foo=o&Foo=m&Foo=a"
   289  	if e, a := expected, ctx.Request.URL.String(); e != a {
   290  		t.Errorf("expect %v, got %v", e, a)
   291  	}
   292  }
   293  
   294  func TestSignWithBody_ReplaceRequestBody(t *testing.T) {
   295  	creds := newTestStaticCredentials()
   296  	req, seekerBody := buildRequest("{}")
   297  	req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
   298  
   299  	s := NewSigner(creds)
   300  	origBody := req.Body
   301  
   302  	_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
   303  	if err != nil {
   304  		t.Fatalf("expect no error, got %v", err)
   305  	}
   306  
   307  	if req.Body == origBody {
   308  		t.Errorf("expect request body to not be origBody")
   309  	}
   310  
   311  	if req.Body == nil {
   312  		t.Errorf("expect request body to be changed but was nil")
   313  	}
   314  }
   315  
   316  func TestRequestHost(t *testing.T) {
   317  	req, body := buildRequest("{}")
   318  	req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
   319  	req.Host = "myhost"
   320  	ctx := &signingCtx{
   321  		ServiceName: "dynamodb",
   322  		Region:      "us-east-1",
   323  		Request:     req,
   324  		Body:        body,
   325  		Query:       req.URL.Query(),
   326  		Time:        time.Now(),
   327  	}
   328  
   329  	ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
   330  	if !strings.Contains(ctx.canonicalHeaders, "host:"+req.Host) {
   331  		t.Errorf("canonical host header invalid")
   332  	}
   333  }
   334  
   335  func TestSign_buildCanonicalHeaders(t *testing.T) {
   336  	serviceName := "mockAPI"
   337  	region := "mock-region"
   338  	endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
   339  
   340  	req, err := http.NewRequest("POST", endpoint, nil)
   341  	if err != nil {
   342  		t.Fatalf("failed to create request, %v", err)
   343  	}
   344  
   345  	req.Header.Set("FooInnerSpace", "   inner      space    ")
   346  	req.Header.Set("FooLeadingSpace", "    leading-space")
   347  	req.Header.Add("FooMultipleSpace", "no-space")
   348  	req.Header.Add("FooMultipleSpace", "\ttab-space")
   349  	req.Header.Add("FooMultipleSpace", "trailing-space    ")
   350  	req.Header.Set("FooNoSpace", "no-space")
   351  	req.Header.Set("FooTabSpace", "\ttab-space\t")
   352  	req.Header.Set("FooTrailingSpace", "trailing-space    ")
   353  	req.Header.Set("FooWrappedSpace", "   wrapped-space    ")
   354  
   355  	ctx := &signingCtx{
   356  		ServiceName: serviceName,
   357  		Region:      region,
   358  		Request:     req,
   359  		Body:        nil,
   360  		Query:       req.URL.Query(),
   361  		Time:        time.Now(),
   362  	}
   363  
   364  	ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
   365  
   366  	expectCanonicalHeaders := strings.Join([]string{
   367  		`fooinnerspace:inner space`,
   368  		`fooleadingspace:leading-space`,
   369  		`foomultiplespace:no-space,tab-space,trailing-space`,
   370  		`foonospace:no-space`,
   371  		`footabspace:tab-space`,
   372  		`footrailingspace:trailing-space`,
   373  		`foowrappedspace:wrapped-space`,
   374  		`host:mockAPI.mock-region.amazonaws.com`,
   375  	}, "\n")
   376  	if e, a := expectCanonicalHeaders, ctx.canonicalHeaders; e != a {
   377  		t.Errorf("expect:\n%s\n\nactual:\n%s", e, a)
   378  	}
   379  }
   380  
   381  func BenchmarkSignRequest(b *testing.B) {
   382  	signer := buildSigner()
   383  	req, body := buildRequestReaderSeeker("dynamodb", "us-east-1", "{}")
   384  	for i := 0; i < b.N; i++ {
   385  		_, err := signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
   386  		if err != nil {
   387  			b.Errorf("Expected no err, got %v", err)
   388  		}
   389  	}
   390  }
   391  
   392  var stripExcessSpaceCases = []string{
   393  	`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
   394  	`123   321   123   321`,
   395  	`   123   321   123   321   `,
   396  	`   123    321    123          321   `,
   397  	"123",
   398  	"1 2 3",
   399  	"  1 2 3",
   400  	"1  2 3",
   401  	"1  23",
   402  	"1  2  3",
   403  	"1  2  ",
   404  	" 1  2  ",
   405  	"12   3",
   406  	"12   3   1",
   407  	"12           3     1",
   408  	"12     3       1abc123",
   409  }
   410  
   411  func BenchmarkStripExcessSpaces(b *testing.B) {
   412  	for i := 0; i < b.N; i++ {
   413  		// Make sure to start with a copy of the cases
   414  		cases := append([]string{}, stripExcessSpaceCases...)
   415  		stripExcessSpaces(cases)
   416  	}
   417  }
   418  
   419  // readerSeekerWrapper mimics the interface provided by request.offsetReader
   420  type readerSeekerWrapper struct {
   421  	r *strings.Reader
   422  }
   423  
   424  func (r *readerSeekerWrapper) Read(p []byte) (n int, err error) {
   425  	return r.r.Read(p)
   426  }
   427  
   428  func (r *readerSeekerWrapper) Seek(offset int64, whence int) (int64, error) {
   429  	return r.r.Seek(offset, whence)
   430  }
   431  
   432  func (r *readerSeekerWrapper) Len() int {
   433  	return r.r.Len()
   434  }
   435  

View as plain text