...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/ssocreds/sso_credentials_provider_test.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/ssocreds

     1  package ssocreds
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"path/filepath"
     7  	"strings"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/aws/aws-sdk-go-v2/aws"
    12  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
    13  	"github.com/aws/aws-sdk-go-v2/service/sso"
    14  	"github.com/aws/aws-sdk-go-v2/service/sso/types"
    15  )
    16  
    17  type mockClient struct {
    18  	t *testing.T
    19  
    20  	Output *sso.GetRoleCredentialsOutput
    21  	Err    error
    22  
    23  	ExpectedAccountID   string
    24  	ExpectedAccessToken string
    25  	ExpectedRoleName    string
    26  
    27  	Response func(mockClient) (*sso.GetRoleCredentialsOutput, error)
    28  }
    29  
    30  func (m mockClient) GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(options *sso.Options)) (out *sso.GetRoleCredentialsOutput, err error) {
    31  	m.t.Helper()
    32  
    33  	if len(m.ExpectedAccountID) > 0 {
    34  		if diff := cmpDiff(m.ExpectedAccountID, aws.ToString(params.AccountId)); len(diff) > 0 {
    35  			m.t.Error(diff)
    36  		}
    37  	}
    38  
    39  	if len(m.ExpectedAccessToken) > 0 {
    40  		if diff := cmpDiff(m.ExpectedAccessToken, aws.ToString(params.AccessToken)); len(diff) > 0 {
    41  			m.t.Error(diff)
    42  		}
    43  	}
    44  
    45  	if len(m.ExpectedRoleName) > 0 {
    46  		if diff := cmpDiff(m.ExpectedRoleName, aws.ToString(params.RoleName)); len(diff) > 0 {
    47  			m.t.Error(diff)
    48  		}
    49  	}
    50  
    51  	if m.Response == nil {
    52  		return out, err
    53  	}
    54  	return m.Response(m)
    55  }
    56  
    57  func TestProvider(t *testing.T) {
    58  	origHomeDir := osUserHomeDur
    59  	defer func() {
    60  		osUserHomeDur = origHomeDir
    61  	}()
    62  
    63  	osUserHomeDur = func() string {
    64  		return "testdata"
    65  	}
    66  
    67  	restoreTime := sdk.TestingUseReferenceTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC))
    68  	defer restoreTime()
    69  
    70  	cases := map[string]struct {
    71  		Client    mockClient
    72  		AccountID string
    73  		Region    string
    74  		RoleName  string
    75  		StartURL  string
    76  		Options   []func(*Options)
    77  
    78  		ExpectedErr         string
    79  		ExpectedCredentials aws.Credentials
    80  	}{
    81  		"missing required parameter values": {
    82  			StartURL:    "https://invalid-required",
    83  			ExpectedErr: "cached SSO token must contain accessToken and expiresAt fields",
    84  		},
    85  		"valid required parameter values": {
    86  			Client: mockClient{
    87  				ExpectedAccountID:   "012345678901",
    88  				ExpectedRoleName:    "TestRole",
    89  				ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
    90  				Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
    91  					return &sso.GetRoleCredentialsOutput{
    92  						RoleCredentials: &types.RoleCredentials{
    93  							AccessKeyId:     aws.String("AccessKey"),
    94  							SecretAccessKey: aws.String("SecretKey"),
    95  							SessionToken:    aws.String("SessionToken"),
    96  							Expiration:      1611177743123,
    97  						},
    98  					}, nil
    99  				},
   100  			},
   101  			AccountID: "012345678901",
   102  			Region:    "us-west-2",
   103  			RoleName:  "TestRole",
   104  			StartURL:  "https://valid-required-only",
   105  			ExpectedCredentials: aws.Credentials{
   106  				AccessKeyID:     "AccessKey",
   107  				SecretAccessKey: "SecretKey",
   108  				SessionToken:    "SessionToken",
   109  				CanExpire:       true,
   110  				Expires:         time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
   111  				Source:          ProviderName,
   112  			},
   113  		},
   114  		"custom cached token file": {
   115  			Client: mockClient{
   116  				ExpectedAccountID:   "012345678901",
   117  				ExpectedRoleName:    "TestRole",
   118  				ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
   119  				Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
   120  					return &sso.GetRoleCredentialsOutput{
   121  						RoleCredentials: &types.RoleCredentials{
   122  							AccessKeyId:     aws.String("AccessKey"),
   123  							SecretAccessKey: aws.String("SecretKey"),
   124  							SessionToken:    aws.String("SessionToken"),
   125  							Expiration:      1611177743123,
   126  						},
   127  					}, nil
   128  				},
   129  			},
   130  			Options: []func(*Options){
   131  				func(o *Options) {
   132  					o.CachedTokenFilepath = filepath.Join("testdata", "valid_token.json")
   133  				},
   134  			},
   135  			AccountID: "012345678901",
   136  			Region:    "us-west-2",
   137  			RoleName:  "TestRole",
   138  			StartURL:  "ignored value",
   139  			ExpectedCredentials: aws.Credentials{
   140  				AccessKeyID:     "AccessKey",
   141  				SecretAccessKey: "SecretKey",
   142  				SessionToken:    "SessionToken",
   143  				CanExpire:       true,
   144  				Expires:         time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
   145  				Source:          ProviderName,
   146  			},
   147  		},
   148  		"expired access token": {
   149  			StartURL:    "https://expired",
   150  			ExpectedErr: "SSO session has expired or is invalid",
   151  		},
   152  		"api error": {
   153  			Client: mockClient{
   154  				ExpectedAccountID:   "012345678901",
   155  				ExpectedRoleName:    "TestRole",
   156  				ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
   157  				Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
   158  					return nil, fmt.Errorf("api error")
   159  				},
   160  			},
   161  			AccountID:   "012345678901",
   162  			Region:      "us-west-2",
   163  			RoleName:    "TestRole",
   164  			StartURL:    "https://valid-required-only",
   165  			ExpectedErr: "api error",
   166  		},
   167  	}
   168  
   169  	for name, tt := range cases {
   170  		t.Run(name, func(t *testing.T) {
   171  			tt.Client.t = t
   172  
   173  			provider := New(tt.Client, tt.AccountID, tt.RoleName, tt.StartURL, tt.Options...)
   174  
   175  			credentials, err := provider.Retrieve(context.Background())
   176  			if tt.ExpectedErr != "" {
   177  				if err == nil {
   178  					t.Fatalf("expect %v error, got none", tt.ExpectedErr)
   179  				}
   180  				if e, a := tt.ExpectedErr, err.Error(); !strings.Contains(a, e) {
   181  					t.Fatalf("expect %v error, got %v", e, a)
   182  				}
   183  				return
   184  			}
   185  			if err != nil {
   186  				t.Fatalf("expect no error, got %v", err)
   187  			}
   188  
   189  			if diff := cmpDiff(tt.ExpectedCredentials, credentials); len(diff) > 0 {
   190  				t.Errorf(diff)
   191  			}
   192  		})
   193  	}
   194  }
   195  

View as plain text