...

Source file src/cloud.google.com/go/auth/auth_test.go

Documentation: cloud.google.com/go/auth

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package auth
    16  
    17  import (
    18  	"context"
    19  	"encoding/base64"
    20  	"encoding/json"
    21  	"fmt"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"cloud.google.com/go/auth/internal/jwt"
    29  	"github.com/google/go-cmp/cmp"
    30  )
    31  
    32  var fakePrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
    33  MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE
    34  DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY
    35  fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK
    36  1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr
    37  k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9
    38  /E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt
    39  3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn
    40  2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3
    41  nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK
    42  6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf
    43  5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e
    44  DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1
    45  M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g
    46  z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y
    47  1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK
    48  J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U
    49  f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx
    50  QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA
    51  cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr
    52  Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw
    53  5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg
    54  KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84
    55  OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd
    56  mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ
    57  5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg==
    58  -----END RSA PRIVATE KEY-----`)
    59  
    60  func TestError_Temporary(t *testing.T) {
    61  	tests := []struct {
    62  		name string
    63  		code int
    64  		want bool
    65  	}{
    66  		{
    67  			name: "temporary with 500",
    68  			code: http.StatusInternalServerError,
    69  			want: true,
    70  		},
    71  		{
    72  			name: "temporary with 503",
    73  			code: http.StatusServiceUnavailable,
    74  			want: true,
    75  		},
    76  		{
    77  			name: "temporary with 408",
    78  			code: http.StatusRequestTimeout,
    79  			want: true,
    80  		},
    81  		{
    82  			name: "temporary with 429",
    83  			code: http.StatusTooManyRequests,
    84  			want: true,
    85  		},
    86  		{
    87  			name: "temporary with 418",
    88  			code: http.StatusTeapot,
    89  			want: false,
    90  		},
    91  	}
    92  	for _, tt := range tests {
    93  		t.Run(tt.name, func(t *testing.T) {
    94  			ae := &Error{
    95  				Response: &http.Response{
    96  					StatusCode: tt.code,
    97  				},
    98  			}
    99  			if got := ae.Temporary(); got != tt.want {
   100  				t.Errorf("Temporary() = %v; want %v", got, tt.want)
   101  			}
   102  		})
   103  	}
   104  }
   105  
   106  func TestToken_isValidWithEarlyExpiry(t *testing.T) {
   107  	now := time.Now()
   108  	timeNow = func() time.Time { return now }
   109  	defer func() { timeNow = time.Now }()
   110  
   111  	cases := []struct {
   112  		name   string
   113  		tok    *Token
   114  		expiry time.Duration
   115  		want   bool
   116  	}{
   117  		{name: "4 minutes", tok: &Token{Expiry: now.Add(4 * 60 * time.Second)}, expiry: defaultExpiryDelta, want: true},
   118  		{name: "3 minutes and 45 seconds", tok: &Token{Expiry: now.Add(defaultExpiryDelta)}, expiry: defaultExpiryDelta, want: true},
   119  		{name: "3 minutes and 45 seconds-1ns", tok: &Token{Expiry: now.Add(defaultExpiryDelta - 1*time.Nanosecond)}, expiry: defaultExpiryDelta, want: false},
   120  		{name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, expiry: defaultExpiryDelta, want: false},
   121  		{name: "12 seconds, custom expiryDelta", tok: &Token{Expiry: now.Add(12 * time.Second)}, expiry: time.Second * 5, want: true},
   122  		{name: "5 seconds, custom expiryDelta", tok: &Token{Expiry: now.Add(time.Second * 5)}, expiry: time.Second * 5, want: true},
   123  		{name: "5 seconds-1ns, custom expiryDelta", tok: &Token{Expiry: now.Add(time.Second*5 - 1*time.Nanosecond)}, expiry: time.Second * 5, want: false},
   124  		{name: "-1 hour, custom expiryDelta", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, expiry: time.Second * 5, want: false},
   125  	}
   126  	for _, tc := range cases {
   127  		tc.tok.Value = "tok"
   128  		if got, want := tc.tok.isValidWithEarlyExpiry(tc.expiry), tc.want; got != want {
   129  			t.Errorf("expired (%q) = %v; want %v", tc.name, got, want)
   130  		}
   131  	}
   132  }
   133  
   134  func TestError_Error(t *testing.T) {
   135  
   136  	tests := []struct {
   137  		name string
   138  
   139  		Response    *http.Response
   140  		Body        []byte
   141  		Err         error
   142  		code        string
   143  		description string
   144  		uri         string
   145  
   146  		want string
   147  	}{
   148  		{
   149  			name: "basic",
   150  			Response: &http.Response{
   151  				StatusCode: http.StatusTeapot,
   152  			},
   153  			Body: []byte("I'm a teapot"),
   154  			want: "auth: cannot fetch token: 418\nResponse: I'm a teapot",
   155  		},
   156  		{
   157  			name:        "from query",
   158  			code:        fmt.Sprint(http.StatusTeapot),
   159  			description: "I'm a teapot",
   160  			uri:         "somewhere",
   161  			want:        "auth: \"418\" \"I'm a teapot\" \"somewhere\"",
   162  		},
   163  	}
   164  	for _, tt := range tests {
   165  		t.Run(tt.name, func(t *testing.T) {
   166  			r := &Error{
   167  				Response:    tt.Response,
   168  				Body:        tt.Body,
   169  				Err:         tt.Err,
   170  				code:        tt.code,
   171  				description: tt.description,
   172  				uri:         tt.uri,
   173  			}
   174  			if got := r.Error(); got != tt.want {
   175  				t.Errorf("Error.Error() = %v, want %v", got, tt.want)
   176  			}
   177  		})
   178  	}
   179  }
   180  
   181  func TestNew2LOTokenProvider_JSONResponse(t *testing.T) {
   182  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   183  		w.Header().Set("Content-Type", "application/json")
   184  		w.Write([]byte(`{
   185  			"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
   186  			"scope": "user",
   187  			"token_type": "bearer",
   188  			"expires_in": 3600
   189  		}`))
   190  	}))
   191  	defer ts.Close()
   192  
   193  	opts := &Options2LO{
   194  		Email:      "aaa@example.com",
   195  		PrivateKey: fakePrivateKey,
   196  		TokenURL:   ts.URL,
   197  	}
   198  	tp, err := New2LOTokenProvider(opts)
   199  	if err != nil {
   200  		t.Fatal(err)
   201  	}
   202  	tok, err := tp.Token(context.Background())
   203  	if err != nil {
   204  		t.Fatal(err)
   205  	}
   206  	if !tok.IsValid() {
   207  		t.Errorf("got invalid token: %v", tok)
   208  	}
   209  	if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
   210  		t.Errorf("access token = %q; want %q", got, want)
   211  	}
   212  	if got, want := tok.Type, "bearer"; got != want {
   213  		t.Errorf("token type = %q; want %q", got, want)
   214  	}
   215  	if got := tok.Expiry.IsZero(); got {
   216  		t.Errorf("token expiry = %v, want none", got)
   217  	}
   218  	scope := tok.Metadata["scope"].(string)
   219  	if got, want := scope, "user"; got != want {
   220  		t.Errorf("scope = %q; want %q", got, want)
   221  	}
   222  }
   223  
   224  func TestNew2LOTokenProvider_BadResponse(t *testing.T) {
   225  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   226  		w.Header().Set("Content-Type", "application/json")
   227  		w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
   228  	}))
   229  	defer ts.Close()
   230  
   231  	opts := &Options2LO{
   232  		Email:      "aaa@example.com",
   233  		PrivateKey: fakePrivateKey,
   234  		TokenURL:   ts.URL,
   235  	}
   236  	tp, err := New2LOTokenProvider(opts)
   237  	if err != nil {
   238  		t.Fatal(err)
   239  	}
   240  	tok, err := tp.Token(context.Background())
   241  	if err != nil {
   242  		t.Fatal(err)
   243  	}
   244  	if tok == nil {
   245  		t.Fatalf("got nil token; want token")
   246  	}
   247  	if tok.IsValid() {
   248  		t.Errorf("got invalid token: %v", tok)
   249  	}
   250  	if got, want := tok.Value, ""; got != want {
   251  		t.Errorf("access token = %q; want %q", got, want)
   252  	}
   253  	if got, want := tok.Type, "bearer"; got != want {
   254  		t.Errorf("token type = %q; want %q", got, want)
   255  	}
   256  	scope := tok.Metadata["scope"].(string)
   257  	if got, want := scope, "user"; got != want {
   258  		t.Errorf("token scope = %q; want %q", got, want)
   259  	}
   260  }
   261  
   262  func TestNew2LOTokenProvider_BadResponseType(t *testing.T) {
   263  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   264  		w.Header().Set("Content-Type", "application/json")
   265  		w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
   266  	}))
   267  	defer ts.Close()
   268  	opts := &Options2LO{
   269  		Email:      "aaa@example.com",
   270  		PrivateKey: fakePrivateKey,
   271  		TokenURL:   ts.URL,
   272  	}
   273  	tp, err := New2LOTokenProvider(opts)
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  	tok, err := tp.Token(context.Background())
   278  	if err == nil {
   279  		t.Error("got a token; expected error")
   280  		if got, want := tok.Value, ""; got != want {
   281  			t.Errorf("access token = %q; want %q", got, want)
   282  		}
   283  	}
   284  }
   285  
   286  func TestNew2LOTokenProvider_Assertion(t *testing.T) {
   287  	var assertion string
   288  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   289  		r.ParseForm()
   290  		assertion = r.Form.Get("assertion")
   291  
   292  		w.Header().Set("Content-Type", "application/json")
   293  		w.Write([]byte(`{
   294  			"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
   295  			"scope": "user",
   296  			"token_type": "bearer",
   297  			"expires_in": 3600
   298  		}`))
   299  	}))
   300  	defer ts.Close()
   301  
   302  	opts := &Options2LO{
   303  		Email:        "aaa@example.com",
   304  		PrivateKey:   fakePrivateKey,
   305  		PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   306  		TokenURL:     ts.URL,
   307  	}
   308  
   309  	tp, err := New2LOTokenProvider(opts)
   310  	if err != nil {
   311  		t.Fatal(err)
   312  	}
   313  	_, err = tp.Token(context.Background())
   314  	if err != nil {
   315  		t.Fatalf("Failed to fetch token: %v", err)
   316  	}
   317  
   318  	parts := strings.Split(assertion, ".")
   319  	if len(parts) != 3 {
   320  		t.Fatalf("assertion = %q; want 3 parts", assertion)
   321  	}
   322  	gotjson, err := base64.RawURLEncoding.DecodeString(parts[0])
   323  	if err != nil {
   324  		t.Fatalf("invalid token header; err = %v", err)
   325  	}
   326  
   327  	got := jwt.Header{}
   328  	if err := json.Unmarshal(gotjson, &got); err != nil {
   329  		t.Errorf("failed to unmarshal json token header = %q; err = %v", gotjson, err)
   330  	}
   331  
   332  	want := jwt.Header{
   333  		Algorithm: "RS256",
   334  		Type:      "JWT",
   335  		KeyID:     "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   336  	}
   337  	if got != want {
   338  		t.Errorf("access token header = %q; want %q", got, want)
   339  	}
   340  }
   341  
   342  func TestNew2LOTokenProvider_AssertionPayload(t *testing.T) {
   343  	var assertion string
   344  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   345  		r.ParseForm()
   346  		assertion = r.Form.Get("assertion")
   347  
   348  		w.Header().Set("Content-Type", "application/json")
   349  		w.Write([]byte(`{
   350  			"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
   351  			"scope": "user",
   352  			"token_type": "bearer",
   353  			"expires_in": 3600
   354  		}`))
   355  	}))
   356  	defer ts.Close()
   357  
   358  	for _, opts := range []*Options2LO{
   359  		{
   360  			Email:        "aaa1@example.com",
   361  			PrivateKey:   fakePrivateKey,
   362  			PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   363  			TokenURL:     ts.URL,
   364  		},
   365  		{
   366  			Email:        "aaa2@example.com",
   367  			PrivateKey:   fakePrivateKey,
   368  			PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   369  			TokenURL:     ts.URL,
   370  			Audience:     "https://example.com",
   371  		},
   372  		{
   373  			Email:        "aaa2@example.com",
   374  			PrivateKey:   fakePrivateKey,
   375  			PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
   376  			TokenURL:     ts.URL,
   377  			PrivateClaims: map[string]interface{}{
   378  				"private0": "claim0",
   379  				"private1": "claim1",
   380  			},
   381  		},
   382  	} {
   383  		t.Run(opts.Email, func(t *testing.T) {
   384  			tp, err := New2LOTokenProvider(opts)
   385  			if err != nil {
   386  				t.Fatal(err)
   387  			}
   388  			_, err = tp.Token(context.Background())
   389  			if err != nil {
   390  				t.Fatalf("Failed to fetch token: %v", err)
   391  			}
   392  
   393  			parts := strings.Split(assertion, ".")
   394  			if len(parts) != 3 {
   395  				t.Fatalf("assertion = %q; want 3 parts", assertion)
   396  			}
   397  			gotjson, err := base64.RawURLEncoding.DecodeString(parts[1])
   398  			if err != nil {
   399  				t.Fatalf("invalid token payload; err = %v", err)
   400  			}
   401  
   402  			claimSet := jwt.Claims{}
   403  			if err := json.Unmarshal(gotjson, &claimSet); err != nil {
   404  				t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err)
   405  			}
   406  
   407  			if got, want := claimSet.Iss, opts.Email; got != want {
   408  				t.Errorf("payload email = %q; want %q", got, want)
   409  			}
   410  			if got, want := claimSet.Scope, strings.Join(opts.Scopes, " "); got != want {
   411  				t.Errorf("payload scope = %q; want %q", got, want)
   412  			}
   413  			aud := opts.TokenURL
   414  			if opts.Audience != "" {
   415  				aud = opts.Audience
   416  			}
   417  			if got, want := claimSet.Aud, aud; got != want {
   418  				t.Errorf("payload audience = %q; want %q", got, want)
   419  			}
   420  			if got, want := claimSet.Sub, opts.Subject; got != want {
   421  				t.Errorf("payload subject = %q; want %q", got, want)
   422  			}
   423  			if len(opts.PrivateClaims) > 0 {
   424  				var got interface{}
   425  				if err := json.Unmarshal(gotjson, &got); err != nil {
   426  					t.Errorf("failed to parse payload; err = %q", err)
   427  				}
   428  				m := got.(map[string]interface{})
   429  				for v, k := range opts.PrivateClaims {
   430  					if !cmp.Equal(m[v], k) {
   431  						t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k)
   432  					}
   433  				}
   434  			}
   435  		})
   436  	}
   437  }
   438  
   439  func TestNew2LOTokenProvider_TokenError(t *testing.T) {
   440  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   441  		w.Header().Set("Content-type", "application/json")
   442  		w.WriteHeader(http.StatusBadRequest)
   443  		w.Write([]byte(`{"error": "invalid_grant"}`))
   444  	}))
   445  	defer ts.Close()
   446  
   447  	opts := &Options2LO{
   448  		Email:      "aaa@example.com",
   449  		PrivateKey: fakePrivateKey,
   450  		TokenURL:   ts.URL,
   451  	}
   452  
   453  	tp, err := New2LOTokenProvider(opts)
   454  	if err != nil {
   455  		t.Fatal(err)
   456  	}
   457  	_, err = tp.Token(context.Background())
   458  	if err == nil {
   459  		t.Fatalf("got no error, expected one")
   460  	}
   461  	_, ok := err.(*Error)
   462  	if !ok {
   463  		t.Fatalf("got %T error, expected *Error", err)
   464  	}
   465  	expected := fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", "400", `{"error": "invalid_grant"}`)
   466  	if errStr := err.Error(); errStr != expected {
   467  		t.Fatalf("got %#v, expected %#v", errStr, expected)
   468  	}
   469  }
   470  
   471  func TestNew2LOTokenProvider_Validate(t *testing.T) {
   472  	tests := []struct {
   473  		name string
   474  		opts *Options2LO
   475  	}{
   476  		{
   477  			name: "missing options",
   478  		},
   479  		{
   480  			name: "missing email",
   481  			opts: &Options2LO{
   482  				PrivateKey: []byte("key"),
   483  				TokenURL:   "url",
   484  			},
   485  		},
   486  		{
   487  			name: "missing key",
   488  			opts: &Options2LO{
   489  				Email:    "email",
   490  				TokenURL: "url",
   491  			},
   492  		},
   493  		{
   494  			name: "missing URL",
   495  			opts: &Options2LO{
   496  				Email:      "email",
   497  				PrivateKey: []byte("key"),
   498  			},
   499  		},
   500  	}
   501  	for _, tt := range tests {
   502  		t.Run(tt.name, func(t *testing.T) {
   503  			if _, err := New2LOTokenProvider(tt.opts); err == nil {
   504  				t.Error("got nil, want an error")
   505  			}
   506  		})
   507  	}
   508  }
   509  

View as plain text