...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/processcreds/provider_test.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/processcreds

     1  package processcreds
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"os"
    12  	"path/filepath"
    13  	"runtime"
    14  	"strings"
    15  	"testing"
    16  	"time"
    17  )
    18  
    19  func TestProviderBadCommand(t *testing.T) {
    20  	provider := NewProvider("/bad/process")
    21  	_, err := provider.Retrieve(context.Background())
    22  	var pe *ProviderError
    23  	if ok := errors.As(err, &pe); !ok {
    24  		t.Fatalf("expect error to be of type %T", pe)
    25  	}
    26  	if e, a := "error in credential_process", pe.Error(); !strings.Contains(a, e) {
    27  		t.Errorf("expected %v, got %v", e, a)
    28  	}
    29  }
    30  
    31  func TestProviderMoreEmptyCommands(t *testing.T) {
    32  	provider := NewProvider("")
    33  	_, err := provider.Retrieve(context.Background())
    34  	var pe *ProviderError
    35  	if ok := errors.As(err, &pe); !ok {
    36  		t.Fatalf("expect error to be of type %T", pe)
    37  	}
    38  	if e, a := "failed to prepare command", pe.Error(); !strings.Contains(a, e) {
    39  		t.Errorf("expected %v, got %v", e, a)
    40  	}
    41  }
    42  
    43  func TestProviderExpectErrors(t *testing.T) {
    44  	provider := NewProvider(
    45  		fmt.Sprintf(
    46  			"%s %s",
    47  			getOSCat(),
    48  			filepath.Join("testdata", "malformed.json"),
    49  		))
    50  	_, err := provider.Retrieve(context.Background())
    51  	var pe *ProviderError
    52  	if ok := errors.As(err, &pe); !ok {
    53  		t.Fatalf("expect error to be of type %T", pe)
    54  	}
    55  	if e, a := "parse failed of process output", pe.Error(); !strings.Contains(a, e) {
    56  		t.Errorf("expected %v, got %v", e, a)
    57  	}
    58  
    59  	provider = NewProvider(
    60  		fmt.Sprintf("%s %s",
    61  			getOSCat(),
    62  			filepath.Join("testdata", "wrongversion.json"),
    63  		))
    64  	_, err = provider.Retrieve(context.Background())
    65  	if ok := errors.As(err, &pe); !ok {
    66  		t.Fatalf("expect error to be of type %T", pe)
    67  	}
    68  	if e, a := "wrong version in process output", pe.Error(); !strings.Contains(a, e) {
    69  		t.Errorf("expected %v, got %v", e, a)
    70  	}
    71  
    72  	provider = NewProvider(
    73  		fmt.Sprintf(
    74  			"%s %s",
    75  			getOSCat(),
    76  			filepath.Join("testdata", "missingkey.json"),
    77  		))
    78  	_, err = provider.Retrieve(context.Background())
    79  	if ok := errors.As(err, &pe); !ok {
    80  		t.Fatalf("expect error to be of type %T", pe)
    81  	}
    82  	if e, a := "missing AccessKeyId", pe.Error(); !strings.Contains(a, e) {
    83  		t.Errorf("expected %v, got %v", e, a)
    84  	}
    85  
    86  	provider = NewProvider(
    87  		fmt.Sprintf(
    88  			"%s %s",
    89  			getOSCat(),
    90  			filepath.Join("testdata", "missingsecret.json"),
    91  		))
    92  	_, err = provider.Retrieve(context.Background())
    93  	if ok := errors.As(err, &pe); !ok {
    94  		t.Fatalf("expect error to be of type %T", pe)
    95  	}
    96  	if e, a := "missing SecretAccessKey", pe.Error(); !strings.Contains(a, e) {
    97  		t.Errorf("expected %v, got %v", e, a)
    98  	}
    99  }
   100  
   101  func TestProviderTimeout(t *testing.T) {
   102  	command := "/bin/sleep 2"
   103  	if runtime.GOOS == "windows" {
   104  		// "timeout" command does not work due to pipe redirection
   105  		command = "ping -n 2 127.0.0.1>nul"
   106  	}
   107  
   108  	provider := NewProvider(command, func(options *Options) {
   109  		options.Timeout = time.Duration(1) * time.Second
   110  	})
   111  	_, err := provider.Retrieve(context.Background())
   112  	var pe *ProviderError
   113  	if ok := errors.As(err, &pe); !ok {
   114  		t.Fatalf("expect error to be of type %T", pe)
   115  	}
   116  	if e, a := "credential process timed out", pe.Error(); !strings.Contains(a, e) {
   117  		t.Errorf("expected %v, got %v", e, a)
   118  	}
   119  }
   120  
   121  func TestProviderWithLongSessionToken(t *testing.T) {
   122  	provider := NewProvider(
   123  		fmt.Sprintf(
   124  			"%s %s",
   125  			getOSCat(),
   126  			filepath.Join("testdata", "longsessiontoken.json"),
   127  		))
   128  	v, err := provider.Retrieve(context.Background())
   129  	if err != nil {
   130  		t.Errorf("expected %v, got %v", "no error", err)
   131  	}
   132  
   133  	// Text string same length as session token returned by AWS for AssumeRoleWithWebIdentity
   134  	e := "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
   135  	if a := v.SessionToken; e != a {
   136  		t.Errorf("expected %v, got %v", e, a)
   137  	}
   138  }
   139  
   140  type credentialTest struct {
   141  	Version         int
   142  	AccessKeyID     string `json:"AccessKeyId"`
   143  	SecretAccessKey string
   144  	Expiration      string
   145  }
   146  
   147  func TestProviderStatic(t *testing.T) {
   148  	// static
   149  	provider := NewProvider(
   150  		fmt.Sprintf(
   151  			"%s %s",
   152  			getOSCat(),
   153  			filepath.Join("testdata", "static.json"),
   154  		))
   155  	v, err := provider.Retrieve(context.Background())
   156  	if err != nil {
   157  		t.Errorf("expected %v, got %v", "no error", err)
   158  	}
   159  	if v.CanExpire != false {
   160  		t.Errorf("expected %v, got %v", "static credentials/not expired", "can expire")
   161  	}
   162  
   163  }
   164  
   165  func TestProviderNotExpired(t *testing.T) {
   166  	// non-static, not expired
   167  	exp := &credentialTest{}
   168  	exp.Version = 1
   169  	exp.AccessKeyID = "accesskey"
   170  	exp.SecretAccessKey = "secretkey"
   171  	exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
   172  	b, err := json.Marshal(exp)
   173  	if err != nil {
   174  		t.Errorf("expected %v, got %v", "no error", err)
   175  	}
   176  
   177  	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expiring")
   178  	if err != nil {
   179  		t.Errorf("expected %v, got %v", "no error", err)
   180  	}
   181  	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
   182  		t.Errorf("expected %v, got %v", "no error", err)
   183  	}
   184  	defer func() {
   185  		if err = tmpFile.Close(); err != nil {
   186  			t.Errorf("expected %v, got %v", "no error", err)
   187  		}
   188  		if err = os.Remove(tmpFile.Name()); err != nil {
   189  			t.Errorf("expected %v, got %v", "no error", err)
   190  		}
   191  	}()
   192  	provider := NewProvider(
   193  		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
   194  	v, err := provider.Retrieve(context.Background())
   195  	if err != nil {
   196  		t.Errorf("expected %v, got %v", "no error", err)
   197  	}
   198  	if v.Expired() {
   199  		t.Errorf("expected %v, got %v", "not expired", "expired")
   200  	}
   201  }
   202  
   203  func TestProviderExpired(t *testing.T) {
   204  	// non-static, expired
   205  	exp := &credentialTest{}
   206  	exp.Version = 1
   207  	exp.AccessKeyID = "accesskey"
   208  	exp.SecretAccessKey = "secretkey"
   209  	exp.Expiration = time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339)
   210  	b, err := json.Marshal(exp)
   211  	if err != nil {
   212  		t.Errorf("expected %v, got %v", "no error", err)
   213  	}
   214  
   215  	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_expired")
   216  	if err != nil {
   217  		t.Errorf("expected %v, got %v", "no error", err)
   218  	}
   219  	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
   220  		t.Errorf("expected %v, got %v", "no error", err)
   221  	}
   222  	defer func() {
   223  		if err = tmpFile.Close(); err != nil {
   224  			t.Errorf("expected %v, got %v", "no error", err)
   225  		}
   226  		if err = os.Remove(tmpFile.Name()); err != nil {
   227  			t.Errorf("expected %v, got %v", "no error", err)
   228  		}
   229  	}()
   230  	provider := NewProvider(
   231  		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
   232  	v, err := provider.Retrieve(context.Background())
   233  	if err != nil {
   234  		t.Errorf("expected %v, got %v", "no error", err)
   235  	}
   236  	if !v.Expired() {
   237  		t.Errorf("expected %v, got %v", "expired", "not expired")
   238  	}
   239  }
   240  
   241  func TestProviderForceExpire(t *testing.T) {
   242  	// non-static, not expired
   243  
   244  	// setup test credentials file
   245  	exp := &credentialTest{}
   246  	exp.Version = 1
   247  	exp.AccessKeyID = "accesskey"
   248  	exp.SecretAccessKey = "secretkey"
   249  	exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
   250  	b, err := json.Marshal(exp)
   251  	if err != nil {
   252  		t.Errorf("expected %v, got %v", "no error", err)
   253  	}
   254  	tmpFile, err := ioutil.TempFile(os.TempDir(), "tmp_force_expire")
   255  	if err != nil {
   256  		t.Errorf("expected %v, got %v", "no error", err)
   257  	}
   258  	if _, err = io.Copy(tmpFile, bytes.NewReader(b)); err != nil {
   259  		t.Errorf("expected %v, got %v", "no error", err)
   260  	}
   261  	defer func() {
   262  		if err = tmpFile.Close(); err != nil {
   263  			t.Errorf("expected %v, got %v", "no error", err)
   264  		}
   265  		if err = os.Remove(tmpFile.Name()); err != nil {
   266  			t.Errorf("expected %v, got %v", "no error", err)
   267  		}
   268  	}()
   269  
   270  	// get credentials from file
   271  	provider := NewProvider(
   272  		fmt.Sprintf("%s %s", getOSCat(), tmpFile.Name()))
   273  	v, err := provider.Retrieve(context.Background())
   274  	if err != nil {
   275  		t.Errorf("expected %v, got %v", "no error", err)
   276  	}
   277  	if v.Expired() {
   278  		t.Errorf("expected %v, got %v", "not expired", "expired")
   279  	}
   280  
   281  	// Re-retrieve credentials
   282  	v, err = provider.Retrieve(context.Background())
   283  	if err != nil {
   284  		t.Errorf("expected %v, got %v", "no error", err)
   285  	}
   286  	if v.Expired() {
   287  		t.Errorf("expected %v, got %v", "not expired", "expired")
   288  	}
   289  }
   290  
   291  func TestProviderAltConstruct(t *testing.T) {
   292  	cmdBuilder := DefaultNewCommandBuilder{Args: []string{
   293  		fmt.Sprintf("%s %s", getOSCat(),
   294  			filepath.Join("testdata", "static.json"),
   295  		),
   296  	}}
   297  
   298  	provider := NewProviderCommand(cmdBuilder, func(options *Options) {
   299  		options.Timeout = time.Duration(1) * time.Second
   300  	})
   301  	v, err := provider.Retrieve(context.Background())
   302  	if err != nil {
   303  		t.Errorf("expected %v, got %v", "no error", err)
   304  	}
   305  	if v.CanExpire != false {
   306  		t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
   307  	}
   308  }
   309  
   310  func BenchmarkProcessProvider(b *testing.B) {
   311  	provider := NewProvider(
   312  		fmt.Sprintf(
   313  			"%s %s",
   314  			getOSCat(),
   315  			filepath.Join("testdata", "static.json"),
   316  		))
   317  	_, err := provider.Retrieve(context.Background())
   318  	if err != nil {
   319  		b.Fatal(err)
   320  	}
   321  
   322  	b.ResetTimer()
   323  	for i := 0; i < b.N; i++ {
   324  		b.StartTimer()
   325  		_, err := provider.Retrieve(context.Background())
   326  		if err != nil {
   327  			b.Fatal(err)
   328  		}
   329  		b.StopTimer()
   330  	}
   331  }
   332  
   333  func getOSCat() string {
   334  	if runtime.GOOS == "windows" {
   335  		return "type"
   336  	}
   337  	return "cat"
   338  }
   339  

View as plain text