...

Source file src/github.com/aws/aws-sdk-go-v2/credentials/ssocreds/sso_cached_token_test.go

Documentation: github.com/aws/aws-sdk-go-v2/credentials/ssocreds

     1  package ssocreds
     2  
     3  import (
     4  	"io/ioutil"
     5  	"os"
     6  	"path/filepath"
     7  	"strings"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/aws/aws-sdk-go-v2/aws"
    12  )
    13  
    14  func TestStandardSSOCacheTokenFilepath(t *testing.T) {
    15  	origHomeDur := osUserHomeDur
    16  	defer func() {
    17  		osUserHomeDur = origHomeDur
    18  	}()
    19  
    20  	cases := map[string]struct {
    21  		key            string
    22  		osUserHomeDir  func() string
    23  		expectFilename string
    24  		expectErr      string
    25  	}{
    26  		"success": {
    27  			key: "https://example.awsapps.com/start",
    28  			osUserHomeDir: func() string {
    29  				return os.TempDir()
    30  			},
    31  			expectFilename: filepath.Join(os.TempDir(), ".aws", "sso", "cache",
    32  				"e8be5486177c5b5392bd9aa76563515b29358e6e.json"),
    33  		},
    34  		"failure": {
    35  			key: "https://example.awsapps.com/start",
    36  			osUserHomeDir: func() string {
    37  				return ""
    38  			},
    39  			expectErr: "some error",
    40  		},
    41  	}
    42  
    43  	for name, c := range cases {
    44  		t.Run(name, func(t *testing.T) {
    45  			osUserHomeDur = c.osUserHomeDir
    46  
    47  			actual, err := StandardCachedTokenFilepath(c.key)
    48  			if c.expectErr != "" {
    49  				if err == nil {
    50  					t.Fatalf("expect error, got none")
    51  				}
    52  				return
    53  			}
    54  			if err != nil {
    55  				t.Fatalf("expect no error, got %v", err)
    56  			}
    57  
    58  			if e, a := c.expectFilename, actual; e != a {
    59  				t.Errorf("expect %v filename, got %v", e, a)
    60  			}
    61  		})
    62  	}
    63  }
    64  
    65  func TestLoadCachedToken(t *testing.T) {
    66  	cases := map[string]struct {
    67  		filename    string
    68  		expectToken token
    69  		expectErr   string
    70  	}{
    71  		"file not found": {
    72  			filename:  filepath.Join("testdata", "does_not_exist.json"),
    73  			expectErr: "failed to read cached SSO token file",
    74  		},
    75  		"invalid json": {
    76  			filename:  filepath.Join("testdata", "invalid_json.json"),
    77  			expectErr: "failed to parse cached SSO token file",
    78  		},
    79  		"missing accessToken": {
    80  			filename:  filepath.Join("testdata", "missing_accessToken.json"),
    81  			expectErr: "must contain accessToken and expiresAt fields",
    82  		},
    83  		"missing expiresAt": {
    84  			filename:  filepath.Join("testdata", "missing_expiresAt.json"),
    85  			expectErr: "must contain accessToken and expiresAt fields",
    86  		},
    87  		"standard token": {
    88  			filename: filepath.Join("testdata", "valid_token.json"),
    89  			expectToken: token{
    90  				tokenKnownFields: tokenKnownFields{
    91  					AccessToken:  "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
    92  					ExpiresAt:    (*rfc3339)(aws.Time(time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC))),
    93  					ClientID:     "client id",
    94  					ClientSecret: "client secret",
    95  					RefreshToken: "refresh token",
    96  				},
    97  				UnknownFields: map[string]interface{}{
    98  					"unknownField":          "some value",
    99  					"registrationExpiresAt": "2044-04-04T07:00:01Z",
   100  					"region":                "region",
   101  					"startURL":              "start URL",
   102  				},
   103  			},
   104  		},
   105  	}
   106  
   107  	for name, c := range cases {
   108  		t.Run(name, func(t *testing.T) {
   109  			actualToken, err := loadCachedToken(c.filename)
   110  			if c.expectErr != "" {
   111  				if err == nil {
   112  					t.Fatalf("expect %v error, got none", c.expectErr)
   113  				}
   114  				if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) {
   115  					t.Fatalf("expect %v error, got %v", e, a)
   116  				}
   117  				return
   118  			}
   119  			if err != nil {
   120  				t.Fatalf("expect no error, got %v", err)
   121  			}
   122  
   123  			if diff := cmpDiff(c.expectToken, actualToken); diff != "" {
   124  				t.Errorf("expect tokens match\n%s", diff)
   125  			}
   126  		})
   127  	}
   128  }
   129  
   130  func TestStoreCachedToken(t *testing.T) {
   131  	tempDir, err := ioutil.TempDir(os.TempDir(), "aws-sdk-go-v2-"+t.Name())
   132  	if err != nil {
   133  		t.Fatalf("failed to create temporary test directory, %v", err)
   134  	}
   135  	defer func() {
   136  		if err := os.RemoveAll(tempDir); err != nil {
   137  			t.Errorf("failed to cleanup temporary test directory, %v", err)
   138  		}
   139  	}()
   140  
   141  	cases := map[string]struct {
   142  		token    token
   143  		filename string
   144  		fileMode os.FileMode
   145  	}{
   146  		"standard token": {
   147  			filename: filepath.Join(tempDir, "token_file.json"),
   148  			fileMode: 0600,
   149  			token: token{
   150  				tokenKnownFields: tokenKnownFields{
   151  					AccessToken:  "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
   152  					ExpiresAt:    (*rfc3339)(aws.Time(time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC))),
   153  					ClientID:     "client id",
   154  					ClientSecret: "client secret",
   155  					RefreshToken: "refresh token",
   156  				},
   157  				UnknownFields: map[string]interface{}{
   158  					"unknownField":          "some value",
   159  					"registrationExpiresAt": "2044-04-04T07:00:01Z",
   160  					"region":                "region",
   161  					"startURL":              "start URL",
   162  				},
   163  			},
   164  		},
   165  	}
   166  
   167  	for name, c := range cases {
   168  		t.Run(name, func(t *testing.T) {
   169  			err := storeCachedToken(c.filename, c.token, c.fileMode)
   170  			if err != nil {
   171  				t.Fatalf("expect no error, got %v", err)
   172  			}
   173  
   174  			actual, err := loadCachedToken(c.filename)
   175  			if err != nil {
   176  				t.Fatalf("failed to load stored token, %v", err)
   177  			}
   178  
   179  			if diff := cmpDiff(c.token, actual); diff != "" {
   180  				t.Errorf("expect tokens match\n%s", diff)
   181  			}
   182  		})
   183  	}
   184  }
   185  

View as plain text