...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/ssocreds/sso_token_provider_test.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/ssocreds

     1  //go:build go1.16
     2  // +build go1.16
     3  
     4  package ssocreds
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/aws/aws-sdk-go-v2/aws"
    18  	"github.com/aws/aws-sdk-go-v2/internal/sdk"
    19  	"github.com/aws/aws-sdk-go-v2/service/ssooidc"
    20  	smithybearer "github.com/aws/smithy-go/auth/bearer"
    21  )
    22  
    23  func TestSSOTokenProvider(t *testing.T) {
    24  	restoreTime := sdk.TestingUseReferenceTime(time.Date(2021, 12, 21, 12, 21, 1, 0, time.UTC))
    25  	defer restoreTime()
    26  
    27  	tempDir, err := ioutil.TempDir(os.TempDir(), "aws-sdk-go-v2-"+t.Name())
    28  	if err != nil {
    29  		t.Fatalf("failed to create temporary test directory, %v", err)
    30  	}
    31  	defer func() {
    32  		if err := os.RemoveAll(tempDir); err != nil {
    33  			t.Errorf("failed to cleanup temporary test directory, %v", err)
    34  		}
    35  	}()
    36  
    37  	cases := map[string]struct {
    38  		setup         func() error
    39  		postRetrieve  func() error
    40  		client        CreateTokenAPIClient
    41  		cacheFilePath string
    42  		optFns        []func(*SSOTokenProviderOptions)
    43  
    44  		expectToken smithybearer.Token
    45  		expectErr   string
    46  	}{
    47  		"no cache file": {
    48  			cacheFilePath: filepath.Join("testdata", "file_not_exists"),
    49  			expectErr:     "failed to read cached SSO token file",
    50  		},
    51  		"invalid json cache file": {
    52  			cacheFilePath: filepath.Join("testdata", "invalid_json.json"),
    53  			expectErr:     "failed to parse cached SSO token file",
    54  		},
    55  		"missing accessToken": {
    56  			cacheFilePath: filepath.Join("testdata", "missing_accessToken.json"),
    57  			expectErr:     "must contain accessToken and expiresAt fields",
    58  		},
    59  		"missing expiresAt": {
    60  			cacheFilePath: filepath.Join("testdata", "missing_expiresAt.json"),
    61  			expectErr:     "must contain accessToken and expiresAt fields",
    62  		},
    63  		"expired no clientSecret": {
    64  			cacheFilePath: filepath.Join("testdata", "missing_clientSecret.json"),
    65  			expectErr:     "cached SSO token is expired, or not present",
    66  		},
    67  		"expired no clientId": {
    68  			cacheFilePath: filepath.Join("testdata", "missing_clientId.json"),
    69  			expectErr:     "cached SSO token is expired, or not present",
    70  		},
    71  		"expired no refreshToken": {
    72  			cacheFilePath: filepath.Join("testdata", "missing_refreshToken.json"),
    73  			expectErr:     "cached SSO token is expired, or not present",
    74  		},
    75  		"valid sso token": {
    76  			cacheFilePath: filepath.Join("testdata", "valid_token.json"),
    77  			expectToken: smithybearer.Token{
    78  				Value:     "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
    79  				CanExpire: true,
    80  				Expires:   time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC),
    81  			},
    82  		},
    83  		"refresh expired token": {
    84  			setup: func() error {
    85  				testFile, err := os.ReadFile(filepath.Join("testdata", "expired_token.json"))
    86  				if err != nil {
    87  					return err
    88  				}
    89  
    90  				return os.WriteFile(filepath.Join(tempDir, "expired_token.json"), testFile, 0600)
    91  			},
    92  			postRetrieve: func() error {
    93  				actual, err := loadCachedToken(filepath.Join(tempDir, "expired_token.json"))
    94  				if err != nil {
    95  					return err
    96  
    97  				}
    98  				expect := token{
    99  					tokenKnownFields: tokenKnownFields{
   100  						AccessToken: "updated access token",
   101  						ExpiresAt:   (*rfc3339)(aws.Time(time.Date(2021, 12, 21, 12, 31, 1, 0, time.UTC))),
   102  
   103  						RefreshToken: "updated refresh token",
   104  						ClientID:     "client id",
   105  						ClientSecret: "client secret",
   106  					},
   107  					UnknownFields: map[string]interface{}{
   108  						"unknownField": "some value",
   109  					},
   110  				}
   111  
   112  				if diff := cmpDiff(expect, actual); diff != "" {
   113  					return fmt.Errorf("expect token file match\n%s", diff)
   114  				}
   115  				return nil
   116  			},
   117  			cacheFilePath: filepath.Join(tempDir, "expired_token.json"),
   118  			client: &mockCreateTokenAPIClient{
   119  				expectInput: &ssooidc.CreateTokenInput{
   120  					ClientId:     aws.String("client id"),
   121  					ClientSecret: aws.String("client secret"),
   122  					RefreshToken: aws.String("refresh token"),
   123  					GrantType:    aws.String("refresh_token"),
   124  				},
   125  				output: &ssooidc.CreateTokenOutput{
   126  					AccessToken:  aws.String("updated access token"),
   127  					ExpiresIn:    600,
   128  					RefreshToken: aws.String("updated refresh token"),
   129  				},
   130  			},
   131  			expectToken: smithybearer.Token{
   132  				Value:     "updated access token",
   133  				CanExpire: true,
   134  				Expires:   time.Date(2021, 12, 21, 12, 31, 1, 0, time.UTC),
   135  			},
   136  		},
   137  		"fail refresh expired token": {
   138  			setup: func() error {
   139  				testFile, err := os.ReadFile(filepath.Join("testdata", "expired_token.json"))
   140  				if err != nil {
   141  					return err
   142  				}
   143  				return os.WriteFile(filepath.Join(tempDir, "expired_token.json"), testFile, 0600)
   144  			},
   145  			postRetrieve: func() error {
   146  				actual, err := loadCachedToken(filepath.Join(tempDir, "expired_token.json"))
   147  				if err != nil {
   148  					return err
   149  
   150  				}
   151  				expect := token{
   152  					tokenKnownFields: tokenKnownFields{
   153  						AccessToken: "access token",
   154  						ExpiresAt:   (*rfc3339)(aws.Time(time.Date(2021, 12, 21, 12, 21, 1, 0, time.UTC))),
   155  
   156  						RefreshToken: "refresh token",
   157  						ClientID:     "client id",
   158  						ClientSecret: "client secret",
   159  					},
   160  				}
   161  
   162  				if diff := cmpDiff(expect, actual); diff != "" {
   163  					return fmt.Errorf("expect token file match\n%s", diff)
   164  				}
   165  				return nil
   166  			},
   167  			cacheFilePath: filepath.Join(tempDir, "expired_token.json"),
   168  			client: &mockCreateTokenAPIClient{
   169  				err: fmt.Errorf("sky is falling"),
   170  			},
   171  			expectErr: "unable to refresh SSO token, sky is falling",
   172  		},
   173  	}
   174  
   175  	for name, c := range cases {
   176  		t.Run(name, func(t *testing.T) {
   177  			if c.setup != nil {
   178  				if err := c.setup(); err != nil {
   179  					t.Fatalf("failed to setup test, %v", err)
   180  				}
   181  			}
   182  			provider := NewSSOTokenProvider(c.client, c.cacheFilePath, c.optFns...)
   183  
   184  			token, err := provider.RetrieveBearerToken(context.Background())
   185  			if c.expectErr != "" {
   186  				if err == nil {
   187  					t.Fatalf("expect %v error, got none", c.expectErr)
   188  				}
   189  				if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) {
   190  					t.Fatalf("expect %v error, got %v", e, a)
   191  				}
   192  				return
   193  			}
   194  			if err != nil {
   195  				t.Fatalf("expect no error, got %v", err)
   196  			}
   197  
   198  			if diff := cmpDiff(c.expectToken, token); diff != "" {
   199  				t.Errorf("expect token match\n%s", diff)
   200  			}
   201  
   202  			if c.postRetrieve != nil {
   203  				if err := c.postRetrieve(); err != nil {
   204  					t.Fatalf("post retrieve failed, %v", err)
   205  				}
   206  			}
   207  		})
   208  	}
   209  }
   210  
   211  type mockCreateTokenAPIClient struct {
   212  	expectInput *ssooidc.CreateTokenInput
   213  	output      *ssooidc.CreateTokenOutput
   214  	err         error
   215  }
   216  
   217  func (c *mockCreateTokenAPIClient) CreateToken(
   218  	ctx context.Context, input *ssooidc.CreateTokenInput, optFns ...func(*ssooidc.Options)) (
   219  	*ssooidc.CreateTokenOutput, error,
   220  ) {
   221  	if c.expectInput != nil {
   222  		if diff := cmpDiff(c.expectInput, input); diff != "" {
   223  			return nil, fmt.Errorf("expect input match\n%s", diff)
   224  		}
   225  	}
   226  
   227  	return c.output, c.err
   228  }
   229  
   230  func cmpDiff(e, a interface{}) string {
   231  	if !reflect.DeepEqual(e, a) {
   232  		return fmt.Sprintf("%v != %v", e, a)
   233  	}
   234  	return ""
   235  }
   236  

View as plain text