...

Source file src/cloud.google.com/go/cloudsqlconn/internal/cloudsql/lazy_test.go

Documentation: cloud.google.com/go/cloudsqlconn/internal/cloudsql

     1  // Copyright 2024 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cloudsql
    16  
    17  import (
    18  	"context"
    19  	"sync"
    20  	"testing"
    21  	"time"
    22  
    23  	"cloud.google.com/go/cloudsqlconn/instance"
    24  	"cloud.google.com/go/cloudsqlconn/internal/mock"
    25  	"golang.org/x/oauth2"
    26  )
    27  
    28  func TestLazyRefreshCacheConnectionInfo(t *testing.T) {
    29  	cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
    30  	inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
    31  	client, cleanup, err := mock.NewSQLAdminService(
    32  		context.Background(),
    33  		mock.InstanceGetSuccess(inst, 1),
    34  		mock.CreateEphemeralSuccess(inst, 1),
    35  	)
    36  	if err != nil {
    37  		t.Fatal(err)
    38  	}
    39  	defer func() {
    40  		if err := cleanup(); err != nil {
    41  			t.Fatalf("%v", err)
    42  		}
    43  	}()
    44  	c := NewLazyRefreshCache(
    45  		testInstanceConnName(), nullLogger{}, client,
    46  		RSAKey, 30*time.Second, nil, "", false,
    47  	)
    48  
    49  	ci, err := c.ConnectionInfo(context.Background())
    50  	if err != nil {
    51  		t.Fatal(err)
    52  	}
    53  	if ci.ConnectionName != cn {
    54  		t.Fatalf("want = %v, got = %v", cn, ci.ConnectionName)
    55  	}
    56  	// Request connection info again to ensure it uses the cache and doesn't
    57  	// send another API call.
    58  	_, err = c.ConnectionInfo(context.Background())
    59  	if err != nil {
    60  		t.Fatal(err)
    61  	}
    62  }
    63  
    64  func TestLazyRefreshCacheForceRefresh(t *testing.T) {
    65  	cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
    66  	inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
    67  	client, cleanup, err := mock.NewSQLAdminService(
    68  		context.Background(),
    69  		mock.InstanceGetSuccess(inst, 2),
    70  		mock.CreateEphemeralSuccess(inst, 2),
    71  	)
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  	defer func() {
    76  		if err := cleanup(); err != nil {
    77  			t.Fatalf("%v", err)
    78  		}
    79  	}()
    80  	c := NewLazyRefreshCache(
    81  		testInstanceConnName(), nullLogger{}, client,
    82  		RSAKey, 30*time.Second, nil, "", false,
    83  	)
    84  
    85  	_, err = c.ConnectionInfo(context.Background())
    86  	if err != nil {
    87  		t.Fatal(err)
    88  	}
    89  
    90  	c.ForceRefresh()
    91  
    92  	_, err = c.ConnectionInfo(context.Background())
    93  	if err != nil {
    94  		t.Fatal(err)
    95  	}
    96  }
    97  
    98  // spyTokenSource is a non-threadsafe spy for tracking token source usage
    99  type spyTokenSource struct {
   100  	mu    sync.Mutex
   101  	count int
   102  }
   103  
   104  func (s *spyTokenSource) Token() (*oauth2.Token, error) {
   105  	s.mu.Lock()
   106  	defer s.mu.Unlock()
   107  	s.count++
   108  	return &oauth2.Token{}, nil
   109  }
   110  
   111  func (s *spyTokenSource) callCount() int {
   112  	s.mu.Lock()
   113  	defer s.mu.Unlock()
   114  	return s.count
   115  }
   116  
   117  func TestLazyRefreshCacheUpdateRefresh(t *testing.T) {
   118  	cn, _ := instance.ParseConnName("my-project:my-region:my-instance")
   119  	inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
   120  	client, cleanup, err := mock.NewSQLAdminService(
   121  		context.Background(),
   122  		mock.InstanceGetSuccess(inst, 2),
   123  		mock.CreateEphemeralSuccess(inst, 2),
   124  	)
   125  	if err != nil {
   126  		t.Fatal(err)
   127  	}
   128  	defer func() {
   129  		if err := cleanup(); err != nil {
   130  			t.Fatalf("%v", err)
   131  		}
   132  	}()
   133  
   134  	spy := &spyTokenSource{}
   135  	c := NewLazyRefreshCache(
   136  		testInstanceConnName(), nullLogger{}, client,
   137  		RSAKey, 30*time.Second, spy, "", false, // disable IAM AuthN at first
   138  	)
   139  
   140  	_, err = c.ConnectionInfo(context.Background())
   141  	if err != nil {
   142  		t.Fatal(err)
   143  	}
   144  
   145  	if got := spy.callCount(); got != 0 {
   146  		t.Fatal("oauth2.TokenSource was called, but should not have been")
   147  	}
   148  
   149  	c.UpdateRefresh(ptr(true))
   150  
   151  	_, err = c.ConnectionInfo(context.Background())
   152  	if err != nil {
   153  		t.Fatal(err)
   154  	}
   155  
   156  	// Q: Why should the token source be called twice?
   157  	// A: Because the refresh code retrieves a token first (1 call) and then
   158  	//    refreshes it (1 call) for a total of 2 calls.
   159  	if got, want := spy.callCount(), 2; got != want {
   160  		t.Fatalf(
   161  			"oauth2.TokenSource call count, got = %v, want = %v",
   162  			got, want,
   163  		)
   164  	}
   165  }
   166  
   167  func ptr(val bool) *bool {
   168  	return &val
   169  }
   170  

View as plain text