...

Source file src/cloud.google.com/go/auth/oauth2adapt/oauth2adapt_test.go

Documentation: cloud.google.com/go/auth/oauth2adapt

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package oauth2adapt
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"net/http"
    21  	"testing"
    22  
    23  	"cloud.google.com/go/auth"
    24  	"github.com/google/go-cmp/cmp"
    25  	"golang.org/x/oauth2"
    26  	"golang.org/x/oauth2/google"
    27  )
    28  
    29  func TestTokenProviderFromTokenSource(t *testing.T) {
    30  	tests := []struct {
    31  		name  string
    32  		token *oauth2.Token
    33  		err   error
    34  	}{
    35  		{
    36  			name:  "working token",
    37  			token: &oauth2.Token{AccessToken: "fakeToken", TokenType: "Basic"},
    38  			err:   nil,
    39  		},
    40  		{
    41  			name: "coverts err",
    42  			err: &oauth2.RetrieveError{
    43  				Body:      []byte("some bytes"),
    44  				ErrorCode: "412",
    45  				Response: &http.Response{
    46  					StatusCode: http.StatusTeapot,
    47  				},
    48  			},
    49  		},
    50  	}
    51  	for _, tt := range tests {
    52  		t.Run(tt.name, func(t *testing.T) {
    53  			tp := TokenProviderFromTokenSource(tokenSource{
    54  				token: tt.token,
    55  				err:   tt.err,
    56  			})
    57  			tok, err := tp.Token(context.Background())
    58  			if tt.err != nil {
    59  				aErr := &auth.Error{}
    60  				if !errors.As(err, &aErr) {
    61  					t.Fatalf("error not of correct type: %T", err)
    62  				}
    63  				err := tt.err.(*oauth2.RetrieveError)
    64  				if !cmp.Equal(aErr.Body, err.Body) {
    65  					t.Errorf("got %s, want %s", aErr.Body, err.Body)
    66  				}
    67  				if !cmp.Equal(aErr.Err, err) {
    68  					t.Errorf("got %s, want %s", aErr.Err, err)
    69  				}
    70  				if !cmp.Equal(aErr.Response, err.Response) {
    71  					t.Errorf("got %s, want %s", aErr.Err, err)
    72  				}
    73  				return
    74  			}
    75  			if tok.Value != tt.token.AccessToken {
    76  				t.Errorf("got %q, want %q", tok.Value, tt.token.AccessToken)
    77  			}
    78  			if tok.Type != tt.token.TokenType {
    79  				t.Errorf("got %q, want %q", tok.Type, tt.token.TokenType)
    80  			}
    81  		})
    82  	}
    83  }
    84  
    85  func TestTokenSourceFromTokenProvider(t *testing.T) {
    86  	tests := []struct {
    87  		name  string
    88  		token *auth.Token
    89  		err   error
    90  	}{
    91  		{
    92  			name: "working token",
    93  			token: &auth.Token{
    94  				Value: "fakeToken",
    95  				Type:  "Basic",
    96  			},
    97  			err: nil,
    98  		},
    99  		{
   100  			name: "coverts err",
   101  			err: &auth.Error{
   102  				Body: []byte("some bytes"),
   103  				Response: &http.Response{
   104  					StatusCode: http.StatusTeapot,
   105  				},
   106  			},
   107  		},
   108  	}
   109  	for _, tt := range tests {
   110  		t.Run(tt.name, func(t *testing.T) {
   111  			ts := TokenSourceFromTokenProvider(tokenProvider{
   112  				token: tt.token,
   113  				err:   tt.err,
   114  			})
   115  			tok, err := ts.Token()
   116  			if tt.err != nil {
   117  				// Should be able to be an auth.Error
   118  				aErr := &auth.Error{}
   119  				if !errors.As(err, &aErr) {
   120  					t.Fatalf("error not of correct type: %T", err)
   121  				}
   122  				err := tt.err.(*auth.Error)
   123  				if !cmp.Equal(aErr.Body, err.Body) {
   124  					t.Errorf("got %s, want %s", aErr.Body, err.Body)
   125  				}
   126  				if !cmp.Equal(aErr.Response, err.Response) {
   127  					t.Errorf("got %s, want %s", aErr.Err, err)
   128  				}
   129  
   130  				// Should be able to be an oauth2.RetrieveError
   131  				rErr := &oauth2.RetrieveError{}
   132  				if !errors.As(err, &rErr) {
   133  					t.Fatalf("error not of correct type: %T", err)
   134  				}
   135  				if !cmp.Equal(rErr.Body, err.Body) {
   136  					t.Errorf("got %s, want %s", aErr.Body, err.Body)
   137  				}
   138  				if !cmp.Equal(rErr.Response, err.Response) {
   139  					t.Errorf("got %s, want %s", aErr.Err, err)
   140  				}
   141  				return
   142  			}
   143  			if tok.AccessToken != tt.token.Value {
   144  				t.Errorf("got %q, want %q", tok.AccessToken, tt.token.Value)
   145  			}
   146  			if tok.TokenType != tt.token.Type {
   147  				t.Errorf("got %q, want %q", tok.TokenType, tt.token.Type)
   148  			}
   149  		})
   150  	}
   151  }
   152  
   153  func TestAuthCredentialsFromOauth2Credentials(t *testing.T) {
   154  	ctx := context.Background()
   155  	inputCreds := &google.Credentials{
   156  		ProjectID:   "test_project",
   157  		TokenSource: tokenSource{token: &oauth2.Token{AccessToken: "token"}},
   158  		JSON:        []byte("json"),
   159  		UniverseDomainProvider: func() (string, error) {
   160  			return "domain", nil
   161  		},
   162  	}
   163  	outCreds := AuthCredentialsFromOauth2Credentials(inputCreds)
   164  
   165  	gotProject, err := outCreds.ProjectID(ctx)
   166  	if err != nil {
   167  		t.Fatalf("outCreds.ProjectID() = %v", err)
   168  	}
   169  	if want := inputCreds.ProjectID; gotProject != want {
   170  		t.Fatalf("got %q, want %q", gotProject, want)
   171  	}
   172  
   173  	gotToken, err := outCreds.Token(ctx)
   174  	if err != nil {
   175  		t.Fatalf("outCreds.Token() = %v", err)
   176  	}
   177  	wantTok, err := inputCreds.TokenSource.Token()
   178  	if err != nil {
   179  		t.Fatalf("inputCreds.TokenSource.Token() = %v", err)
   180  	}
   181  	if gotToken.Value != wantTok.AccessToken {
   182  		t.Fatalf("got %q, want %q", gotToken.Value, wantTok.AccessToken)
   183  	}
   184  
   185  	gotJSON := outCreds.JSON()
   186  	if want := inputCreds.JSON; !cmp.Equal(gotJSON, want) {
   187  		t.Fatalf("got %s, want %s", gotJSON, want)
   188  	}
   189  
   190  	gotUD, err := outCreds.UniverseDomain(ctx)
   191  	if err != nil {
   192  		t.Fatalf("outCreds.UniverseDomain() = %v", err)
   193  	}
   194  	wantUD, err := inputCreds.GetUniverseDomain()
   195  	if err != nil {
   196  		t.Fatalf("inputCreds.GetUniverseDomain() = %v", err)
   197  	}
   198  	if gotUD != wantUD {
   199  		t.Fatalf("got %q, want %q", wantUD, wantUD)
   200  	}
   201  }
   202  
   203  func TestOauth2CredentialsFromAuthCredentials(t *testing.T) {
   204  	ctx := context.Background()
   205  	inputCreds := auth.NewCredentials(&auth.CredentialsOptions{
   206  		ProjectIDProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
   207  			return "project", nil
   208  		}),
   209  		TokenProvider: tokenProvider{token: &auth.Token{Value: "token"}},
   210  		JSON:          []byte("json"),
   211  		UniverseDomainProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
   212  			return "domain", nil
   213  		}),
   214  	})
   215  	outCreds := Oauth2CredentialsFromAuthCredentials(inputCreds)
   216  
   217  	wantProject, err := inputCreds.ProjectID(ctx)
   218  	if err != nil {
   219  		t.Fatalf("inputCreds.ProjectID() = %v", err)
   220  	}
   221  	if outCreds.ProjectID != wantProject {
   222  		t.Fatalf("got %q, want %q", outCreds.ProjectID, wantProject)
   223  	}
   224  
   225  	gotToken, err := inputCreds.Token(ctx)
   226  	if err != nil {
   227  		t.Fatalf("inputCreds.Token() = %v", err)
   228  	}
   229  	wantTok, err := outCreds.TokenSource.Token()
   230  	if err != nil {
   231  		t.Fatalf("outCreds.TokenSource.Token() = %v", err)
   232  	}
   233  	if gotToken.Value != wantTok.AccessToken {
   234  		t.Fatalf("got %q, want %q", gotToken.Value, wantTok.AccessToken)
   235  	}
   236  
   237  	wantJSON := inputCreds.JSON()
   238  	if !cmp.Equal(outCreds.JSON, wantJSON) {
   239  		t.Fatalf("got %s, want %s", outCreds.JSON, wantJSON)
   240  	}
   241  
   242  	wantUD, err := inputCreds.UniverseDomain(ctx)
   243  	if err != nil {
   244  		t.Fatalf("outCreds.UniverseDomain() = %v", err)
   245  	}
   246  	gotUD, err := outCreds.GetUniverseDomain()
   247  	if err != nil {
   248  		t.Fatalf("inputCreds.GetUniverseDomain() = %v", err)
   249  	}
   250  	if gotUD != wantUD {
   251  		t.Fatalf("got %q, want %q", wantUD, wantUD)
   252  	}
   253  }
   254  
   255  type tokenSource struct {
   256  	token *oauth2.Token
   257  	err   error
   258  }
   259  
   260  func (ts tokenSource) Token() (*oauth2.Token, error) {
   261  	if ts.err != nil {
   262  		return nil, ts.err
   263  	}
   264  	return &oauth2.Token{
   265  		AccessToken: ts.token.AccessToken,
   266  		TokenType:   ts.token.TokenType,
   267  	}, nil
   268  }
   269  
   270  type tokenProvider struct {
   271  	token *auth.Token
   272  	err   error
   273  }
   274  
   275  func (tp tokenProvider) Token(context.Context) (*auth.Token, error) {
   276  	if tp.err != nil {
   277  		return nil, tp.err
   278  	}
   279  	return &auth.Token{
   280  		Value: tp.token.Value,
   281  		Type:  tp.token.Type,
   282  	}, nil
   283  }
   284  

View as plain text