1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package oauth2adapt
16
17 import (
18 "context"
19 "errors"
20 "net/http"
21 "testing"
22
23 "cloud.google.com/go/auth"
24 "github.com/google/go-cmp/cmp"
25 "golang.org/x/oauth2"
26 "golang.org/x/oauth2/google"
27 )
28
29 func TestTokenProviderFromTokenSource(t *testing.T) {
30 tests := []struct {
31 name string
32 token *oauth2.Token
33 err error
34 }{
35 {
36 name: "working token",
37 token: &oauth2.Token{AccessToken: "fakeToken", TokenType: "Basic"},
38 err: nil,
39 },
40 {
41 name: "coverts err",
42 err: &oauth2.RetrieveError{
43 Body: []byte("some bytes"),
44 ErrorCode: "412",
45 Response: &http.Response{
46 StatusCode: http.StatusTeapot,
47 },
48 },
49 },
50 }
51 for _, tt := range tests {
52 t.Run(tt.name, func(t *testing.T) {
53 tp := TokenProviderFromTokenSource(tokenSource{
54 token: tt.token,
55 err: tt.err,
56 })
57 tok, err := tp.Token(context.Background())
58 if tt.err != nil {
59 aErr := &auth.Error{}
60 if !errors.As(err, &aErr) {
61 t.Fatalf("error not of correct type: %T", err)
62 }
63 err := tt.err.(*oauth2.RetrieveError)
64 if !cmp.Equal(aErr.Body, err.Body) {
65 t.Errorf("got %s, want %s", aErr.Body, err.Body)
66 }
67 if !cmp.Equal(aErr.Err, err) {
68 t.Errorf("got %s, want %s", aErr.Err, err)
69 }
70 if !cmp.Equal(aErr.Response, err.Response) {
71 t.Errorf("got %s, want %s", aErr.Err, err)
72 }
73 return
74 }
75 if tok.Value != tt.token.AccessToken {
76 t.Errorf("got %q, want %q", tok.Value, tt.token.AccessToken)
77 }
78 if tok.Type != tt.token.TokenType {
79 t.Errorf("got %q, want %q", tok.Type, tt.token.TokenType)
80 }
81 })
82 }
83 }
84
85 func TestTokenSourceFromTokenProvider(t *testing.T) {
86 tests := []struct {
87 name string
88 token *auth.Token
89 err error
90 }{
91 {
92 name: "working token",
93 token: &auth.Token{
94 Value: "fakeToken",
95 Type: "Basic",
96 },
97 err: nil,
98 },
99 {
100 name: "coverts err",
101 err: &auth.Error{
102 Body: []byte("some bytes"),
103 Response: &http.Response{
104 StatusCode: http.StatusTeapot,
105 },
106 },
107 },
108 }
109 for _, tt := range tests {
110 t.Run(tt.name, func(t *testing.T) {
111 ts := TokenSourceFromTokenProvider(tokenProvider{
112 token: tt.token,
113 err: tt.err,
114 })
115 tok, err := ts.Token()
116 if tt.err != nil {
117
118 aErr := &auth.Error{}
119 if !errors.As(err, &aErr) {
120 t.Fatalf("error not of correct type: %T", err)
121 }
122 err := tt.err.(*auth.Error)
123 if !cmp.Equal(aErr.Body, err.Body) {
124 t.Errorf("got %s, want %s", aErr.Body, err.Body)
125 }
126 if !cmp.Equal(aErr.Response, err.Response) {
127 t.Errorf("got %s, want %s", aErr.Err, err)
128 }
129
130
131 rErr := &oauth2.RetrieveError{}
132 if !errors.As(err, &rErr) {
133 t.Fatalf("error not of correct type: %T", err)
134 }
135 if !cmp.Equal(rErr.Body, err.Body) {
136 t.Errorf("got %s, want %s", aErr.Body, err.Body)
137 }
138 if !cmp.Equal(rErr.Response, err.Response) {
139 t.Errorf("got %s, want %s", aErr.Err, err)
140 }
141 return
142 }
143 if tok.AccessToken != tt.token.Value {
144 t.Errorf("got %q, want %q", tok.AccessToken, tt.token.Value)
145 }
146 if tok.TokenType != tt.token.Type {
147 t.Errorf("got %q, want %q", tok.TokenType, tt.token.Type)
148 }
149 })
150 }
151 }
152
153 func TestAuthCredentialsFromOauth2Credentials(t *testing.T) {
154 ctx := context.Background()
155 inputCreds := &google.Credentials{
156 ProjectID: "test_project",
157 TokenSource: tokenSource{token: &oauth2.Token{AccessToken: "token"}},
158 JSON: []byte("json"),
159 UniverseDomainProvider: func() (string, error) {
160 return "domain", nil
161 },
162 }
163 outCreds := AuthCredentialsFromOauth2Credentials(inputCreds)
164
165 gotProject, err := outCreds.ProjectID(ctx)
166 if err != nil {
167 t.Fatalf("outCreds.ProjectID() = %v", err)
168 }
169 if want := inputCreds.ProjectID; gotProject != want {
170 t.Fatalf("got %q, want %q", gotProject, want)
171 }
172
173 gotToken, err := outCreds.Token(ctx)
174 if err != nil {
175 t.Fatalf("outCreds.Token() = %v", err)
176 }
177 wantTok, err := inputCreds.TokenSource.Token()
178 if err != nil {
179 t.Fatalf("inputCreds.TokenSource.Token() = %v", err)
180 }
181 if gotToken.Value != wantTok.AccessToken {
182 t.Fatalf("got %q, want %q", gotToken.Value, wantTok.AccessToken)
183 }
184
185 gotJSON := outCreds.JSON()
186 if want := inputCreds.JSON; !cmp.Equal(gotJSON, want) {
187 t.Fatalf("got %s, want %s", gotJSON, want)
188 }
189
190 gotUD, err := outCreds.UniverseDomain(ctx)
191 if err != nil {
192 t.Fatalf("outCreds.UniverseDomain() = %v", err)
193 }
194 wantUD, err := inputCreds.GetUniverseDomain()
195 if err != nil {
196 t.Fatalf("inputCreds.GetUniverseDomain() = %v", err)
197 }
198 if gotUD != wantUD {
199 t.Fatalf("got %q, want %q", wantUD, wantUD)
200 }
201 }
202
203 func TestOauth2CredentialsFromAuthCredentials(t *testing.T) {
204 ctx := context.Background()
205 inputCreds := auth.NewCredentials(&auth.CredentialsOptions{
206 ProjectIDProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
207 return "project", nil
208 }),
209 TokenProvider: tokenProvider{token: &auth.Token{Value: "token"}},
210 JSON: []byte("json"),
211 UniverseDomainProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
212 return "domain", nil
213 }),
214 })
215 outCreds := Oauth2CredentialsFromAuthCredentials(inputCreds)
216
217 wantProject, err := inputCreds.ProjectID(ctx)
218 if err != nil {
219 t.Fatalf("inputCreds.ProjectID() = %v", err)
220 }
221 if outCreds.ProjectID != wantProject {
222 t.Fatalf("got %q, want %q", outCreds.ProjectID, wantProject)
223 }
224
225 gotToken, err := inputCreds.Token(ctx)
226 if err != nil {
227 t.Fatalf("inputCreds.Token() = %v", err)
228 }
229 wantTok, err := outCreds.TokenSource.Token()
230 if err != nil {
231 t.Fatalf("outCreds.TokenSource.Token() = %v", err)
232 }
233 if gotToken.Value != wantTok.AccessToken {
234 t.Fatalf("got %q, want %q", gotToken.Value, wantTok.AccessToken)
235 }
236
237 wantJSON := inputCreds.JSON()
238 if !cmp.Equal(outCreds.JSON, wantJSON) {
239 t.Fatalf("got %s, want %s", outCreds.JSON, wantJSON)
240 }
241
242 wantUD, err := inputCreds.UniverseDomain(ctx)
243 if err != nil {
244 t.Fatalf("outCreds.UniverseDomain() = %v", err)
245 }
246 gotUD, err := outCreds.GetUniverseDomain()
247 if err != nil {
248 t.Fatalf("inputCreds.GetUniverseDomain() = %v", err)
249 }
250 if gotUD != wantUD {
251 t.Fatalf("got %q, want %q", wantUD, wantUD)
252 }
253 }
254
255 type tokenSource struct {
256 token *oauth2.Token
257 err error
258 }
259
260 func (ts tokenSource) Token() (*oauth2.Token, error) {
261 if ts.err != nil {
262 return nil, ts.err
263 }
264 return &oauth2.Token{
265 AccessToken: ts.token.AccessToken,
266 TokenType: ts.token.TokenType,
267 }, nil
268 }
269
270 type tokenProvider struct {
271 token *auth.Token
272 err error
273 }
274
275 func (tp tokenProvider) Token(context.Context) (*auth.Token, error) {
276 if tp.err != nil {
277 return nil, tp.err
278 }
279 return &auth.Token{
280 Value: tp.token.Value,
281 Type: tp.token.Type,
282 }, nil
283 }
284
View as plain text