...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds/credscaching_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds

     1  // Copyright (C) MongoDB, Inc. 2023-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package creds
     8  
     9  import (
    10  	"context"
    11  	"fmt"
    12  	"io"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"sync/atomic"
    17  	"testing"
    18  	"time"
    19  
    20  	"go.mongodb.org/mongo-driver/internal/assert"
    21  	"go.mongodb.org/mongo-driver/internal/aws/credentials"
    22  	"go.mongodb.org/mongo-driver/internal/credproviders"
    23  )
    24  
    25  type pipeTransport struct {
    26  	url    string
    27  	param  string
    28  	client *http.Client
    29  }
    30  
    31  // RoundTrip reassembles the original request URI into the query parameter and forwards the request.
    32  func (t pipeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    33  	uri, err := url.Parse(t.url)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	values := uri.Query()
    38  	values.Add(t.param, req.URL.String())
    39  	uri.RawQuery = values.Encode()
    40  	req.URL = uri
    41  	return t.client.Do(req)
    42  }
    43  
    44  func TestAWSCredentialProviderCaching(t *testing.T) {
    45  	const (
    46  		urienv         = "TEST_CONTAINER_CREDENTIALS_RELATIVE_URI"
    47  		keyenv         = "TEST_ACCESS_KEY"
    48  		awsRelativeURI = "http://169.254.170.2/"
    49  		testEndpoint   = "foo"
    50  		param          = "source"
    51  	)
    52  
    53  	t.Setenv(urienv, testEndpoint)
    54  
    55  	testCases := []struct {
    56  		expiration time.Duration
    57  		reqCount   uint32
    58  	}{
    59  		{
    60  			expiration: 20 * time.Minute,
    61  			reqCount:   1,
    62  		},
    63  		{
    64  			expiration: 5 * time.Minute,
    65  			reqCount:   2,
    66  		},
    67  		{
    68  			expiration: -1 * time.Minute,
    69  			reqCount:   2,
    70  		},
    71  	}
    72  	for _, tc := range testCases {
    73  		t.Run(fmt.Sprintf("expires in %s", tc.expiration.String()), func(t *testing.T) {
    74  			var cnt uint32
    75  			// the test server counts the requests and replies mock responses.
    76  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    77  				if r.URL.Query().Get(param) != awsRelativeURI+testEndpoint {
    78  					w.WriteHeader(http.StatusNotFound)
    79  					return
    80  				}
    81  				atomic.AddUint32(&cnt, 1)
    82  				t := time.Now().Add(tc.expiration).Format(time.RFC3339)
    83  				_, err := io.WriteString(w, fmt.Sprintf(`{
    84  					"AccessKeyId": "id",
    85  					"SecretAccessKey": "key",
    86  					"Token": "token",
    87  					"Expiration": "%s"
    88  				}`, t))
    89  				if err != nil {
    90  					w.WriteHeader(http.StatusInternalServerError)
    91  				}
    92  			}))
    93  			defer ts.Close()
    94  
    95  			client := &http.Client{
    96  				Transport: pipeTransport{
    97  					url:    ts.URL,
    98  					param:  param,
    99  					client: ts.Client(),
   100  				},
   101  			}
   102  
   103  			env := credproviders.NewEnvProvider()
   104  			env.AwsAccessKeyIDEnv = credproviders.EnvVar(keyenv)
   105  			ecs := credproviders.NewECSProvider(client, expiryWindow)
   106  			ecs.AwsContainerCredentialsRelativeURIEnv = credproviders.EnvVar(urienv)
   107  
   108  			p := AWSCredentialProvider{credentials.NewChainCredentials([]credentials.Provider{env, ecs})}
   109  			var err error
   110  			_, err = p.GetCredentialsDoc(context.Background())
   111  			assert.NoError(t, err, "error in GetCredentialsDoc")
   112  			_, err = p.GetCredentialsDoc(context.Background())
   113  			assert.NoError(t, err, "error in GetCredentialsDoc")
   114  			assert.Equal(t, tc.reqCount, atomic.LoadUint32(&cnt), "expected and actual credentials retrieval count don't match")
   115  		})
   116  	}
   117  }
   118  

View as plain text