...

Source file src/cloud.google.com/go/auth/httptransport/transport.go

Documentation: cloud.google.com/go/auth/httptransport

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package httptransport
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"net"
    21  	"net/http"
    22  	"time"
    23  
    24  	"cloud.google.com/go/auth"
    25  	"cloud.google.com/go/auth/credentials"
    26  	"cloud.google.com/go/auth/internal"
    27  	"cloud.google.com/go/auth/internal/transport"
    28  	"cloud.google.com/go/auth/internal/transport/cert"
    29  	"go.opencensus.io/plugin/ochttp"
    30  	"golang.org/x/net/http2"
    31  )
    32  
    33  const (
    34  	quotaProjectHeaderKey = "X-Goog-User-Project"
    35  )
    36  
    37  func newTransport(base http.RoundTripper, opts *Options) (http.RoundTripper, error) {
    38  	var headers = opts.Headers
    39  	ht := &headerTransport{
    40  		base:    base,
    41  		headers: headers,
    42  	}
    43  	var trans http.RoundTripper = ht
    44  	trans = addOCTransport(trans, opts)
    45  	switch {
    46  	case opts.DisableAuthentication:
    47  		// Do nothing.
    48  	case opts.APIKey != "":
    49  		qp := internal.GetQuotaProject(nil, opts.Headers.Get(quotaProjectHeaderKey))
    50  		if qp != "" {
    51  			if headers == nil {
    52  				headers = make(map[string][]string, 1)
    53  			}
    54  			headers.Set(quotaProjectHeaderKey, qp)
    55  		}
    56  		trans = &apiKeyTransport{
    57  			Transport: trans,
    58  			Key:       opts.APIKey,
    59  		}
    60  	default:
    61  		var creds *auth.Credentials
    62  		if opts.Credentials != nil {
    63  			creds = opts.Credentials
    64  		} else {
    65  			var err error
    66  			creds, err = credentials.DetectDefault(opts.resolveDetectOptions())
    67  			if err != nil {
    68  				return nil, err
    69  			}
    70  		}
    71  		qp, err := creds.QuotaProjectID(context.Background())
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  		if qp != "" {
    76  			if headers == nil {
    77  				headers = make(map[string][]string, 1)
    78  			}
    79  			headers.Set(quotaProjectHeaderKey, qp)
    80  		}
    81  		creds.TokenProvider = auth.NewCachedTokenProvider(creds.TokenProvider, nil)
    82  		trans = &authTransport{
    83  			base:                 trans,
    84  			creds:                creds,
    85  			clientUniverseDomain: opts.UniverseDomain,
    86  		}
    87  	}
    88  	return trans, nil
    89  }
    90  
    91  // defaultBaseTransport returns the base HTTP transport.
    92  // On App Engine, this is urlfetch.Transport.
    93  // Otherwise, use a default transport, taking most defaults from
    94  // http.DefaultTransport.
    95  // If TLSCertificate is available, set TLSClientConfig as well.
    96  func defaultBaseTransport(clientCertSource cert.Provider, dialTLSContext func(context.Context, string, string) (net.Conn, error)) http.RoundTripper {
    97  	trans := http.DefaultTransport.(*http.Transport).Clone()
    98  	trans.MaxIdleConnsPerHost = 100
    99  
   100  	if clientCertSource != nil {
   101  		trans.TLSClientConfig = &tls.Config{
   102  			GetClientCertificate: clientCertSource,
   103  		}
   104  	}
   105  	if dialTLSContext != nil {
   106  		// If DialTLSContext is set, TLSClientConfig wil be ignored
   107  		trans.DialTLSContext = dialTLSContext
   108  	}
   109  
   110  	// Configures the ReadIdleTimeout HTTP/2 option for the
   111  	// transport. This allows broken idle connections to be pruned more quickly,
   112  	// preventing the client from attempting to re-use connections that will no
   113  	// longer work.
   114  	http2Trans, err := http2.ConfigureTransports(trans)
   115  	if err == nil {
   116  		http2Trans.ReadIdleTimeout = time.Second * 31
   117  	}
   118  
   119  	return trans
   120  }
   121  
   122  type apiKeyTransport struct {
   123  	// Key is the API Key to set on requests.
   124  	Key string
   125  	// Transport is the underlying HTTP transport.
   126  	// If nil, http.DefaultTransport is used.
   127  	Transport http.RoundTripper
   128  }
   129  
   130  func (t *apiKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   131  	newReq := *req
   132  	args := newReq.URL.Query()
   133  	args.Set("key", t.Key)
   134  	newReq.URL.RawQuery = args.Encode()
   135  	return t.Transport.RoundTrip(&newReq)
   136  }
   137  
   138  type headerTransport struct {
   139  	headers http.Header
   140  	base    http.RoundTripper
   141  }
   142  
   143  func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   144  	rt := t.base
   145  	newReq := *req
   146  	newReq.Header = make(http.Header)
   147  	for k, vv := range req.Header {
   148  		newReq.Header[k] = vv
   149  	}
   150  
   151  	for k, v := range t.headers {
   152  		newReq.Header[k] = v
   153  	}
   154  
   155  	return rt.RoundTrip(&newReq)
   156  }
   157  
   158  func addOCTransport(trans http.RoundTripper, opts *Options) http.RoundTripper {
   159  	if opts.DisableTelemetry {
   160  		return trans
   161  	}
   162  	return &ochttp.Transport{
   163  		Base:        trans,
   164  		Propagation: &httpFormat{},
   165  	}
   166  }
   167  
   168  type authTransport struct {
   169  	creds                *auth.Credentials
   170  	base                 http.RoundTripper
   171  	clientUniverseDomain string
   172  }
   173  
   174  // getClientUniverseDomain returns the universe domain configured for the client.
   175  // The default value is "googleapis.com".
   176  func (t *authTransport) getClientUniverseDomain() string {
   177  	if t.clientUniverseDomain == "" {
   178  		return internal.DefaultUniverseDomain
   179  	}
   180  	return t.clientUniverseDomain
   181  }
   182  
   183  // RoundTrip authorizes and authenticates the request with an
   184  // access token from Transport's Source. Per the RoundTripper contract we must
   185  // not modify the initial request, so we clone it, and we must close the body
   186  // on any errors that happens during our token logic.
   187  func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   188  	reqBodyClosed := false
   189  	if req.Body != nil {
   190  		defer func() {
   191  			if !reqBodyClosed {
   192  				req.Body.Close()
   193  			}
   194  		}()
   195  	}
   196  	credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  	if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
   201  		return nil, err
   202  	}
   203  	token, err := t.creds.Token(req.Context())
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  	req2 := req.Clone(req.Context())
   208  	SetAuthHeader(token, req2)
   209  	reqBodyClosed = true
   210  	return t.base.RoundTrip(req2)
   211  }
   212  

View as plain text