...

Source file src/google.golang.org/api/idtoken/cache.go

Documentation: google.golang.org/api/idtoken

     1  // Copyright 2020 Google LLC.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package idtoken
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"net/http"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  )
    17  
    18  type cachingClient struct {
    19  	client *http.Client
    20  
    21  	// clock optionally specifies a func to return the current time.
    22  	// If nil, time.Now is used.
    23  	clock func() time.Time
    24  
    25  	mu    sync.Mutex
    26  	certs map[string]*cachedResponse
    27  }
    28  
    29  func newCachingClient(client *http.Client) *cachingClient {
    30  	return &cachingClient{
    31  		client: client,
    32  		certs:  make(map[string]*cachedResponse, 2),
    33  	}
    34  }
    35  
    36  type cachedResponse struct {
    37  	resp *certResponse
    38  	exp  time.Time
    39  }
    40  
    41  func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) {
    42  	if response, ok := c.get(url); ok {
    43  		return response, nil
    44  	}
    45  	req, err := http.NewRequest(http.MethodGet, url, nil)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	req = req.WithContext(ctx)
    50  	resp, err := c.client.Do(req)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	defer resp.Body.Close()
    55  	if resp.StatusCode != http.StatusOK {
    56  		return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
    57  	}
    58  
    59  	certResp := &certResponse{}
    60  	if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
    61  		return nil, err
    62  
    63  	}
    64  	c.set(url, certResp, resp.Header)
    65  	return certResp, nil
    66  }
    67  
    68  func (c *cachingClient) now() time.Time {
    69  	if c.clock != nil {
    70  		return c.clock()
    71  	}
    72  	return time.Now()
    73  }
    74  
    75  func (c *cachingClient) get(url string) (*certResponse, bool) {
    76  	c.mu.Lock()
    77  	defer c.mu.Unlock()
    78  	cachedResp, ok := c.certs[url]
    79  	if !ok {
    80  		return nil, false
    81  	}
    82  	if c.now().After(cachedResp.exp) {
    83  		return nil, false
    84  	}
    85  	return cachedResp.resp, true
    86  }
    87  
    88  func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) {
    89  	exp := c.calculateExpireTime(headers)
    90  	c.mu.Lock()
    91  	c.certs[url] = &cachedResponse{resp: resp, exp: exp}
    92  	c.mu.Unlock()
    93  }
    94  
    95  // calculateExpireTime will determine the expire time for the cache based on
    96  // HTTP headers. If there is any difficulty reading the headers the fallback is
    97  // to set the cache to expire now.
    98  func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time {
    99  	var maxAge int
   100  	cc := strings.Split(headers.Get("cache-control"), ",")
   101  	for _, v := range cc {
   102  		if strings.Contains(v, "max-age") {
   103  			ss := strings.Split(v, "=")
   104  			if len(ss) < 2 {
   105  				return c.now()
   106  			}
   107  			ma, err := strconv.Atoi(ss[1])
   108  			if err != nil {
   109  				return c.now()
   110  			}
   111  			maxAge = ma
   112  		}
   113  	}
   114  	a := headers.Get("age")
   115  	if a == "" {
   116  		return c.now().Add(time.Duration(maxAge) * time.Second)
   117  	}
   118  	age, err := strconv.Atoi(a)
   119  	if err != nil {
   120  		return c.now()
   121  	}
   122  	return c.now().Add(time.Duration(maxAge-age) * time.Second)
   123  }
   124  

View as plain text