1 package ssocreds
2
3 import (
4 "io/ioutil"
5 "os"
6 "path/filepath"
7 "strings"
8 "testing"
9 "time"
10
11 "github.com/aws/aws-sdk-go-v2/aws"
12 )
13
14 func TestStandardSSOCacheTokenFilepath(t *testing.T) {
15 origHomeDur := osUserHomeDur
16 defer func() {
17 osUserHomeDur = origHomeDur
18 }()
19
20 cases := map[string]struct {
21 key string
22 osUserHomeDir func() string
23 expectFilename string
24 expectErr string
25 }{
26 "success": {
27 key: "https://example.awsapps.com/start",
28 osUserHomeDir: func() string {
29 return os.TempDir()
30 },
31 expectFilename: filepath.Join(os.TempDir(), ".aws", "sso", "cache",
32 "e8be5486177c5b5392bd9aa76563515b29358e6e.json"),
33 },
34 "failure": {
35 key: "https://example.awsapps.com/start",
36 osUserHomeDir: func() string {
37 return ""
38 },
39 expectErr: "some error",
40 },
41 }
42
43 for name, c := range cases {
44 t.Run(name, func(t *testing.T) {
45 osUserHomeDur = c.osUserHomeDir
46
47 actual, err := StandardCachedTokenFilepath(c.key)
48 if c.expectErr != "" {
49 if err == nil {
50 t.Fatalf("expect error, got none")
51 }
52 return
53 }
54 if err != nil {
55 t.Fatalf("expect no error, got %v", err)
56 }
57
58 if e, a := c.expectFilename, actual; e != a {
59 t.Errorf("expect %v filename, got %v", e, a)
60 }
61 })
62 }
63 }
64
65 func TestLoadCachedToken(t *testing.T) {
66 cases := map[string]struct {
67 filename string
68 expectToken token
69 expectErr string
70 }{
71 "file not found": {
72 filename: filepath.Join("testdata", "does_not_exist.json"),
73 expectErr: "failed to read cached SSO token file",
74 },
75 "invalid json": {
76 filename: filepath.Join("testdata", "invalid_json.json"),
77 expectErr: "failed to parse cached SSO token file",
78 },
79 "missing accessToken": {
80 filename: filepath.Join("testdata", "missing_accessToken.json"),
81 expectErr: "must contain accessToken and expiresAt fields",
82 },
83 "missing expiresAt": {
84 filename: filepath.Join("testdata", "missing_expiresAt.json"),
85 expectErr: "must contain accessToken and expiresAt fields",
86 },
87 "standard token": {
88 filename: filepath.Join("testdata", "valid_token.json"),
89 expectToken: token{
90 tokenKnownFields: tokenKnownFields{
91 AccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
92 ExpiresAt: (*rfc3339)(aws.Time(time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC))),
93 ClientID: "client id",
94 ClientSecret: "client secret",
95 RefreshToken: "refresh token",
96 },
97 UnknownFields: map[string]interface{}{
98 "unknownField": "some value",
99 "registrationExpiresAt": "2044-04-04T07:00:01Z",
100 "region": "region",
101 "startURL": "start URL",
102 },
103 },
104 },
105 }
106
107 for name, c := range cases {
108 t.Run(name, func(t *testing.T) {
109 actualToken, err := loadCachedToken(c.filename)
110 if c.expectErr != "" {
111 if err == nil {
112 t.Fatalf("expect %v error, got none", c.expectErr)
113 }
114 if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) {
115 t.Fatalf("expect %v error, got %v", e, a)
116 }
117 return
118 }
119 if err != nil {
120 t.Fatalf("expect no error, got %v", err)
121 }
122
123 if diff := cmpDiff(c.expectToken, actualToken); diff != "" {
124 t.Errorf("expect tokens match\n%s", diff)
125 }
126 })
127 }
128 }
129
130 func TestStoreCachedToken(t *testing.T) {
131 tempDir, err := ioutil.TempDir(os.TempDir(), "aws-sdk-go-v2-"+t.Name())
132 if err != nil {
133 t.Fatalf("failed to create temporary test directory, %v", err)
134 }
135 defer func() {
136 if err := os.RemoveAll(tempDir); err != nil {
137 t.Errorf("failed to cleanup temporary test directory, %v", err)
138 }
139 }()
140
141 cases := map[string]struct {
142 token token
143 filename string
144 fileMode os.FileMode
145 }{
146 "standard token": {
147 filename: filepath.Join(tempDir, "token_file.json"),
148 fileMode: 0600,
149 token: token{
150 tokenKnownFields: tokenKnownFields{
151 AccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl",
152 ExpiresAt: (*rfc3339)(aws.Time(time.Date(2044, 4, 4, 7, 0, 1, 0, time.UTC))),
153 ClientID: "client id",
154 ClientSecret: "client secret",
155 RefreshToken: "refresh token",
156 },
157 UnknownFields: map[string]interface{}{
158 "unknownField": "some value",
159 "registrationExpiresAt": "2044-04-04T07:00:01Z",
160 "region": "region",
161 "startURL": "start URL",
162 },
163 },
164 },
165 }
166
167 for name, c := range cases {
168 t.Run(name, func(t *testing.T) {
169 err := storeCachedToken(c.filename, c.token, c.fileMode)
170 if err != nil {
171 t.Fatalf("expect no error, got %v", err)
172 }
173
174 actual, err := loadCachedToken(c.filename)
175 if err != nil {
176 t.Fatalf("failed to load stored token, %v", err)
177 }
178
179 if diff := cmpDiff(c.token, actual); diff != "" {
180 t.Errorf("expect tokens match\n%s", diff)
181 }
182 })
183 }
184 }
185
View as plain text