1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package externalaccountuser
16
17 import (
18 "context"
19 "encoding/json"
20 "io"
21 "net/http"
22 "net/http/httptest"
23 "testing"
24
25 "cloud.google.com/go/auth/credentials/internal/stsexchange"
26 "cloud.google.com/go/auth/internal"
27 )
28
29 type testTokenServer struct {
30 URL string
31 Authorization string
32 ContentType string
33 Body string
34 ResponsePayload *stsexchange.TokenResponse
35 Response string
36 server *httptest.Server
37 }
38
39 func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInResponse(t *testing.T) {
40 s := &testTokenServer{
41 URL: "/",
42 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
43 ContentType: "application/x-www-form-urlencoded",
44 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
45 ResponsePayload: &stsexchange.TokenResponse{
46 ExpiresIn: 3600,
47 AccessToken: "AAAAAAA",
48 RefreshToken: "CCCCCCC",
49 },
50 }
51
52 s.startTestServer(t)
53 defer s.server.Close()
54
55 opts := &Options{
56 RefreshToken: "BBBBBBBBB",
57 TokenURL: s.server.URL,
58 ClientID: "CLIENT_ID",
59 ClientSecret: "CLIENT_SECRET",
60 Client: internal.CloneDefaultClient(),
61 }
62 tp, err := NewTokenProvider(opts)
63 if err != nil {
64 t.Fatalf("NewTokenProvider() = %v", err)
65 }
66
67 token, err := tp.Token(context.Background())
68 if err != nil {
69 t.Fatalf("Token() = %v", err)
70 }
71 if got, want := token.Value, "AAAAAAA"; got != want {
72 t.Fatalf("got %v, want %v", got, want)
73 }
74 if got, want := opts.RefreshToken, "CCCCCCC"; got != want {
75 t.Fatalf("got %v, want %v", got, want)
76 }
77 }
78
79 func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
80 s := &testTokenServer{
81 URL: "/",
82 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
83 ContentType: "application/x-www-form-urlencoded",
84 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
85 ResponsePayload: &stsexchange.TokenResponse{
86 ExpiresIn: 3600,
87 AccessToken: "AAAAAAA",
88 },
89 }
90
91 s.startTestServer(t)
92 defer s.server.Close()
93
94 opts := &Options{
95 RefreshToken: "BBBBBBBBB",
96 TokenURL: s.server.URL,
97 ClientID: "CLIENT_ID",
98 ClientSecret: "CLIENT_SECRET",
99 Client: internal.CloneDefaultClient(),
100 }
101 ts, err := NewTokenProvider(opts)
102 if err != nil {
103 t.Fatalf("NewTokenProvider() = %v", err)
104 }
105
106 token, err := ts.Token(context.Background())
107 if err != nil {
108 t.Fatalf("Token() = %v", err)
109 }
110 if got, want := token.Value, "AAAAAAA"; got != want {
111 t.Fatalf("got %v, want %v", got, want)
112 }
113 }
114
115 func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
116 s := &testTokenServer{
117 URL: "/",
118 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
119 ContentType: "application/x-www-form-urlencoded",
120 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
121 ResponsePayload: &stsexchange.TokenResponse{
122 ExpiresIn: 3600,
123 AccessToken: "AAAAAAA",
124 },
125 }
126
127 s.startTestServer(t)
128 defer s.server.Close()
129 testCases := []struct {
130 name string
131 opts *Options
132 }{
133 {
134 name: "empty config",
135 opts: &Options{},
136 },
137 {
138 name: "missing refresh token",
139 opts: &Options{
140 TokenURL: s.server.URL,
141 ClientID: "CLIENT_ID",
142 ClientSecret: "CLIENT_SECRET",
143 },
144 },
145 {
146 name: "missing token url",
147 opts: &Options{
148 RefreshToken: "BBBBBBBBB",
149 ClientID: "CLIENT_ID",
150 ClientSecret: "CLIENT_SECRET",
151 },
152 },
153 {
154 name: "missing client id",
155 opts: &Options{
156 RefreshToken: "BBBBBBBBB",
157 TokenURL: s.server.URL,
158 ClientSecret: "CLIENT_SECRET",
159 },
160 },
161 {
162 name: "missing client secrect",
163 opts: &Options{
164 RefreshToken: "BBBBBBBBB",
165 TokenURL: s.server.URL,
166 ClientID: "CLIENT_ID",
167 },
168 },
169 }
170 for _, tt := range testCases {
171 t.Run(tt.name, func(t *testing.T) {
172 if _, err := NewTokenProvider(tt.opts); err == nil {
173 t.Fatalf("got nil, want an error")
174 }
175 })
176 }
177 }
178
179 func (s *testTokenServer) startTestServer(t *testing.T) {
180 t.Helper()
181 s.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
182 if got, want := r.URL.String(), s.URL; got != want {
183 t.Errorf("got %v, want %v", got, want)
184 }
185 headerAuth := r.Header.Get("Authorization")
186 if got, want := headerAuth, s.Authorization; got != want {
187 t.Errorf("got %v, want %v", got, want)
188 }
189 headerContentType := r.Header.Get("Content-Type")
190 if got, want := headerContentType, s.ContentType; got != want {
191 t.Errorf("got %v. want %v", got, want)
192 }
193 body, err := io.ReadAll(r.Body)
194 if err != nil {
195 t.Error(err)
196 }
197 if got, want := string(body), s.Body; got != want {
198 t.Errorf("got %q, want %q", got, want)
199 }
200 w.Header().Set("Content-Type", "application/json")
201 if s.ResponsePayload != nil {
202 content, err := json.Marshal(s.ResponsePayload)
203 if err != nil {
204 t.Error(err)
205 }
206 w.Write(content)
207 } else {
208 w.Write([]byte(s.Response))
209 }
210 }))
211 }
212
View as plain text