...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/stscreds/assume_role_provider_test.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/stscreds

     1  package stscreds_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/aws/aws-sdk-go-v2/aws"
    10  	"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
    11  	"github.com/aws/aws-sdk-go-v2/service/sts"
    12  	"github.com/aws/aws-sdk-go-v2/service/sts/types"
    13  )
    14  
    15  type mockAssumeRole struct {
    16  	TestInput func(*sts.AssumeRoleInput)
    17  }
    18  
    19  func (s *mockAssumeRole) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) {
    20  	if s.TestInput != nil {
    21  		s.TestInput(params)
    22  	}
    23  	expiry := time.Now().Add(60 * time.Minute)
    24  
    25  	return &sts.AssumeRoleOutput{
    26  		Credentials: &types.Credentials{
    27  			// Just reflect the role arn to the provider.
    28  			AccessKeyId:     params.RoleArn,
    29  			SecretAccessKey: aws.String("assumedSecretAccessKey"),
    30  			SessionToken:    aws.String("assumedSessionToken"),
    31  			Expiration:      &expiry,
    32  		},
    33  	}, nil
    34  }
    35  
    36  const roleARN = "00000000000000000000000000000000000"
    37  const tokenCode = "00000000000000000000"
    38  
    39  func TestAssumeRoleProvider(t *testing.T) {
    40  	stub := &mockAssumeRole{}
    41  	p := stscreds.NewAssumeRoleProvider(stub, roleARN)
    42  
    43  	creds, err := p.Retrieve(context.Background())
    44  	if err != nil {
    45  		t.Fatalf("Expect no error, %v", err)
    46  	}
    47  
    48  	if e, a := roleARN, creds.AccessKeyID; e != a {
    49  		t.Errorf("Expect access key ID to be reflected role ARN")
    50  	}
    51  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
    52  		t.Errorf("Expect secret access key to match")
    53  	}
    54  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
    55  		t.Errorf("Expect session token to match")
    56  	}
    57  }
    58  
    59  func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
    60  	stub := &mockAssumeRole{
    61  		TestInput: func(in *sts.AssumeRoleInput) {
    62  			if e, a := "0123456789", *in.SerialNumber; e != a {
    63  				t.Errorf("expect %v, got %v", e, a)
    64  			}
    65  			if e, a := tokenCode, *in.TokenCode; e != a {
    66  				t.Errorf("expect %v, got %v", e, a)
    67  			}
    68  		},
    69  	}
    70  	p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
    71  		options.SerialNumber = aws.String("0123456789")
    72  		options.TokenProvider = func() (string, error) {
    73  			return tokenCode, nil
    74  		}
    75  	})
    76  
    77  	creds, err := p.Retrieve(context.Background())
    78  	if err != nil {
    79  		t.Fatalf("Expect no error, %v", err)
    80  	}
    81  
    82  	if e, a := roleARN, creds.AccessKeyID; e != a {
    83  		t.Errorf("Expect access key ID to be reflected role ARN")
    84  	}
    85  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
    86  		t.Errorf("Expect secret access key to match")
    87  	}
    88  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
    89  		t.Errorf("Expect session token to match")
    90  	}
    91  }
    92  
    93  func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
    94  	stub := &mockAssumeRole{
    95  		TestInput: func(in *sts.AssumeRoleInput) {
    96  			t.Fatalf("API request should not of been called")
    97  		},
    98  	}
    99  	p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
   100  		options.SerialNumber = aws.String("0123456789")
   101  		options.TokenProvider = func() (string, error) {
   102  			return "", fmt.Errorf("error occurred")
   103  		}
   104  	})
   105  
   106  	creds, err := p.Retrieve(context.Background())
   107  	if err == nil {
   108  		t.Fatalf("expect error, got none")
   109  	}
   110  
   111  	if v := creds.AccessKeyID; len(v) != 0 {
   112  		t.Errorf("expect zero, got %v", v)
   113  	}
   114  	if v := creds.SecretAccessKey; len(v) != 0 {
   115  		t.Errorf("expect zero, got %v", v)
   116  	}
   117  	if v := creds.SessionToken; len(v) != 0 {
   118  		t.Errorf("expect zero, got %v", v)
   119  	}
   120  }
   121  
   122  func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
   123  	stub := &mockAssumeRole{
   124  		TestInput: func(in *sts.AssumeRoleInput) {
   125  			t.Fatalf("API request should not of been called")
   126  		},
   127  	}
   128  	p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
   129  		options.SerialNumber = aws.String("0123456789")
   130  	})
   131  
   132  	creds, err := p.Retrieve(context.Background())
   133  	if err == nil {
   134  		t.Fatalf("expect error, got none")
   135  	}
   136  
   137  	if v := creds.AccessKeyID; len(v) != 0 {
   138  		t.Errorf("expect zero, got %v", v)
   139  	}
   140  	if v := creds.SecretAccessKey; len(v) != 0 {
   141  		t.Errorf("expect zero, got %v", v)
   142  	}
   143  	if v := creds.SessionToken; len(v) != 0 {
   144  		t.Errorf("expect zero, got %v", v)
   145  	}
   146  }
   147  
   148  func TestAssumeRoleProvider_WithSourceIdentity(t *testing.T) {
   149  	const sourceIdentity = "Source-Identity"
   150  
   151  	stub := &mockAssumeRole{
   152  		TestInput: func(in *sts.AssumeRoleInput) {
   153  			if e, a := sourceIdentity, *in.SourceIdentity; e != a {
   154  				t.Fatalf("expect %v, got %v", e, a)
   155  			}
   156  		},
   157  	}
   158  	p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
   159  		options.SourceIdentity = aws.String(sourceIdentity)
   160  	})
   161  
   162  	creds, err := p.Retrieve(context.Background())
   163  	if err != nil {
   164  		t.Fatalf("Expect no error, %v", err)
   165  	}
   166  
   167  	if e, a := roleARN, creds.AccessKeyID; e != a {
   168  		t.Errorf("Expect access key ID to be reflected role ARN")
   169  	}
   170  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
   171  		t.Errorf("Expect secret access key to match")
   172  	}
   173  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
   174  		t.Errorf("Expect session token to match")
   175  	}
   176  }
   177  
   178  func TestAssumeRoleProvider_WithTags(t *testing.T) {
   179  	stub := &mockAssumeRole{
   180  		TestInput: func(in *sts.AssumeRoleInput) {
   181  			if e, a := 1, len(in.Tags); e != a {
   182  				t.Fatalf("expect %v, got %v", e, a)
   183  			}
   184  			tag := in.Tags[0]
   185  			if e, a := "KEY", *tag.Key; e != a {
   186  				t.Errorf("expect %v, got %v", e, a)
   187  			}
   188  			if e, a := "value", *tag.Value; e != a {
   189  				t.Errorf("expect %v, got %v", e, a)
   190  			}
   191  		},
   192  	}
   193  	p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
   194  		options.Tags = []types.Tag{
   195  			{
   196  				Key:   aws.String("KEY"),
   197  				Value: aws.String("value"),
   198  			},
   199  		}
   200  	})
   201  
   202  	creds, err := p.Retrieve(context.Background())
   203  	if err != nil {
   204  		t.Fatalf("Expect no error, %v", err)
   205  	}
   206  
   207  	if e, a := roleARN, creds.AccessKeyID; e != a {
   208  		t.Errorf("Expect access key ID to be reflected role ARN")
   209  	}
   210  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
   211  		t.Errorf("Expect secret access key to match")
   212  	}
   213  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
   214  		t.Errorf("Expect session token to match")
   215  	}
   216  }
   217  
   218  func TestAssumeRoleProvider_WithTransitiveTagKeys(t *testing.T) {
   219  	stub := &mockAssumeRole{
   220  		TestInput: func(in *sts.AssumeRoleInput) {
   221  			if e, a := 1, len(in.TransitiveTagKeys); e != a {
   222  				t.Fatalf("expect %v, got %v", e, a)
   223  			}
   224  			if e, a := "KEY", in.TransitiveTagKeys[0]; e != a {
   225  				t.Errorf("expect %v, got %v", e, a)
   226  			}
   227  		},
   228  	}
   229  	p := stscreds.NewAssumeRoleProvider(stub, roleARN, func(options *stscreds.AssumeRoleOptions) {
   230  		options.Tags = []types.Tag{
   231  			{
   232  				Key:   aws.String("KEY"),
   233  				Value: aws.String("value"),
   234  			},
   235  		}
   236  		options.TransitiveTagKeys = []string{"KEY"}
   237  	})
   238  
   239  	creds, err := p.Retrieve(context.Background())
   240  	if err != nil {
   241  		t.Fatalf("Expect no error, %v", err)
   242  	}
   243  
   244  	if e, a := roleARN, creds.AccessKeyID; e != a {
   245  		t.Errorf("Expect access key ID to be reflected role ARN")
   246  	}
   247  	if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
   248  		t.Errorf("Expect secret access key to match")
   249  	}
   250  	if e, a := "assumedSessionToken", creds.SessionToken; e != a {
   251  		t.Errorf("Expect session token to match")
   252  	}
   253  }
   254  
   255  func BenchmarkAssumeRoleProvider(b *testing.B) {
   256  	stub := &mockAssumeRole{}
   257  	p := stscreds.NewAssumeRoleProvider(stub, roleARN)
   258  
   259  	b.ResetTimer()
   260  	for i := 0; i < b.N; i++ {
   261  		if _, err := p.Retrieve(context.Background()); err != nil {
   262  			b.Fatal(err)
   263  		}
   264  	}
   265  }
   266  

View as plain text