...

Source file src/cloud.google.com/go/auth/internal/transport/cba.go

Documentation: cloud.google.com/go/auth/internal/transport

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transport
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"errors"
    21  	"net"
    22  	"net/http"
    23  	"net/url"
    24  	"os"
    25  	"strconv"
    26  	"strings"
    27  
    28  	"cloud.google.com/go/auth/internal"
    29  	"cloud.google.com/go/auth/internal/transport/cert"
    30  	"github.com/google/s2a-go"
    31  	"github.com/google/s2a-go/fallback"
    32  	"google.golang.org/grpc/credentials"
    33  )
    34  
    35  const (
    36  	mTLSModeAlways = "always"
    37  	mTLSModeNever  = "never"
    38  	mTLSModeAuto   = "auto"
    39  
    40  	// Experimental: if true, the code will try MTLS with S2A as the default for transport security. Default value is false.
    41  	googleAPIUseS2AEnv     = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
    42  	googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
    43  	googleAPIUseMTLS       = "GOOGLE_API_USE_MTLS_ENDPOINT"
    44  	googleAPIUseMTLSOld    = "GOOGLE_API_USE_MTLS"
    45  
    46  	universeDomainPlaceholder = "UNIVERSE_DOMAIN"
    47  )
    48  
    49  var (
    50  	mdsMTLSAutoConfigSource     mtlsConfigSource
    51  	errUniverseNotSupportedMTLS = errors.New("mTLS is not supported in any universe other than googleapis.com")
    52  )
    53  
    54  // Options is a struct that is duplicated information from the individual
    55  // transport packages in order to avoid cyclic deps. It correlates 1:1 with
    56  // fields on httptransport.Options and grpctransport.Options.
    57  type Options struct {
    58  	Endpoint                string
    59  	DefaultMTLSEndpoint     string
    60  	DefaultEndpointTemplate string
    61  	ClientCertProvider      cert.Provider
    62  	Client                  *http.Client
    63  	UniverseDomain          string
    64  	EnableDirectPath        bool
    65  	EnableDirectPathXds     bool
    66  }
    67  
    68  // getUniverseDomain returns the default service domain for a given Cloud
    69  // universe.
    70  func (o *Options) getUniverseDomain() string {
    71  	if o.UniverseDomain == "" {
    72  		return internal.DefaultUniverseDomain
    73  	}
    74  	return o.UniverseDomain
    75  }
    76  
    77  // isUniverseDomainGDU returns true if the universe domain is the default Google
    78  // universe.
    79  func (o *Options) isUniverseDomainGDU() bool {
    80  	return o.getUniverseDomain() == internal.DefaultUniverseDomain
    81  }
    82  
    83  // defaultEndpoint returns the DefaultEndpointTemplate merged with the
    84  // universe domain if the DefaultEndpointTemplate is set, otherwise returns an
    85  // empty string.
    86  func (o *Options) defaultEndpoint() string {
    87  	if o.DefaultEndpointTemplate == "" {
    88  		return ""
    89  	}
    90  	return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1)
    91  }
    92  
    93  // mergedEndpoint merges a user-provided Endpoint of format host[:port] with the
    94  // default endpoint.
    95  func (o *Options) mergedEndpoint() (string, error) {
    96  	defaultEndpoint := o.defaultEndpoint()
    97  	u, err := url.Parse(fixScheme(defaultEndpoint))
    98  	if err != nil {
    99  		return "", err
   100  	}
   101  	return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil
   102  }
   103  
   104  func fixScheme(baseURL string) string {
   105  	if !strings.Contains(baseURL, "://") {
   106  		baseURL = "https://" + baseURL
   107  	}
   108  	return baseURL
   109  }
   110  
   111  // GetGRPCTransportCredsAndEndpoint returns an instance of
   112  // [google.golang.org/grpc/credentials.TransportCredentials], and the
   113  // corresponding endpoint to use for GRPC client.
   114  func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCredentials, string, error) {
   115  	config, err := getTransportConfig(opts)
   116  	if err != nil {
   117  		return nil, "", err
   118  	}
   119  
   120  	defaultTransportCreds := credentials.NewTLS(&tls.Config{
   121  		GetClientCertificate: config.clientCertSource,
   122  	})
   123  	if config.s2aAddress == "" {
   124  		return defaultTransportCreds, config.endpoint, nil
   125  	}
   126  
   127  	var fallbackOpts *s2a.FallbackOptions
   128  	// In case of S2A failure, fall back to the endpoint that would've been used without S2A.
   129  	if fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(config.endpoint); err == nil {
   130  		fallbackOpts = &s2a.FallbackOptions{
   131  			FallbackClientHandshakeFunc: fallbackHandshake,
   132  		}
   133  	}
   134  
   135  	s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
   136  		S2AAddress:   config.s2aAddress,
   137  		FallbackOpts: fallbackOpts,
   138  	})
   139  	if err != nil {
   140  		// Use default if we cannot initialize S2A client transport credentials.
   141  		return defaultTransportCreds, config.endpoint, nil
   142  	}
   143  	return s2aTransportCreds, config.s2aMTLSEndpoint, nil
   144  }
   145  
   146  // GetHTTPTransportConfig returns a client certificate source and a function for
   147  // dialing MTLS with S2A.
   148  func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, string, string) (net.Conn, error), error) {
   149  	config, err := getTransportConfig(opts)
   150  	if err != nil {
   151  		return nil, nil, err
   152  	}
   153  
   154  	if config.s2aAddress == "" {
   155  		return config.clientCertSource, nil, nil
   156  	}
   157  
   158  	var fallbackOpts *s2a.FallbackOptions
   159  	// In case of S2A failure, fall back to the endpoint that would've been used without S2A.
   160  	if fallbackURL, err := url.Parse(config.endpoint); err == nil {
   161  		if fallbackDialer, fallbackServerAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackURL.Hostname()); err == nil {
   162  			fallbackOpts = &s2a.FallbackOptions{
   163  				FallbackDialer: &s2a.FallbackDialer{
   164  					Dialer:     fallbackDialer,
   165  					ServerAddr: fallbackServerAddr,
   166  				},
   167  			}
   168  		}
   169  	}
   170  
   171  	dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
   172  		S2AAddress:   config.s2aAddress,
   173  		FallbackOpts: fallbackOpts,
   174  	})
   175  	return nil, dialTLSContextFunc, nil
   176  }
   177  
   178  func getTransportConfig(opts *Options) (*transportConfig, error) {
   179  	clientCertSource, err := getClientCertificateSource(opts)
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  	endpoint, err := getEndpoint(opts, clientCertSource)
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  	defaultTransportConfig := transportConfig{
   188  		clientCertSource: clientCertSource,
   189  		endpoint:         endpoint,
   190  	}
   191  
   192  	if !shouldUseS2A(clientCertSource, opts) {
   193  		return &defaultTransportConfig, nil
   194  	}
   195  	if !opts.isUniverseDomainGDU() {
   196  		return nil, errUniverseNotSupportedMTLS
   197  	}
   198  
   199  	s2aMTLSEndpoint := opts.DefaultMTLSEndpoint
   200  
   201  	s2aAddress := GetS2AAddress()
   202  	if s2aAddress == "" {
   203  		return &defaultTransportConfig, nil
   204  	}
   205  	return &transportConfig{
   206  		clientCertSource: clientCertSource,
   207  		endpoint:         endpoint,
   208  		s2aAddress:       s2aAddress,
   209  		s2aMTLSEndpoint:  s2aMTLSEndpoint,
   210  	}, nil
   211  }
   212  
   213  // getClientCertificateSource returns a default client certificate source, if
   214  // not provided by the user.
   215  //
   216  // A nil default source can be returned if the source does not exist. Any exceptions
   217  // encountered while initializing the default source will be reported as client
   218  // error (ex. corrupt metadata file).
   219  func getClientCertificateSource(opts *Options) (cert.Provider, error) {
   220  	if !isClientCertificateEnabled() {
   221  		return nil, nil
   222  	} else if opts.ClientCertProvider != nil {
   223  		return opts.ClientCertProvider, nil
   224  	}
   225  	return cert.DefaultProvider()
   226  
   227  }
   228  
   229  // isClientCertificateEnabled returns true by default, unless explicitly set to false via env var.
   230  func isClientCertificateEnabled() bool {
   231  	if value, ok := os.LookupEnv(googleAPIUseCertSource); ok {
   232  		// error as false is OK
   233  		b, _ := strconv.ParseBool(value)
   234  		return b
   235  	}
   236  	return true
   237  }
   238  
   239  type transportConfig struct {
   240  	// The client certificate source.
   241  	clientCertSource cert.Provider
   242  	// The corresponding endpoint to use based on client certificate source.
   243  	endpoint string
   244  	// The S2A address if it can be used, otherwise an empty string.
   245  	s2aAddress string
   246  	// The MTLS endpoint to use with S2A.
   247  	s2aMTLSEndpoint string
   248  }
   249  
   250  // getEndpoint returns the endpoint for the service, taking into account the
   251  // user-provided endpoint override "settings.Endpoint".
   252  //
   253  // If no endpoint override is specified, we will either return the default endpoint or
   254  // the default mTLS endpoint if a client certificate is available.
   255  //
   256  // You can override the default endpoint choice (mtls vs. regular) by setting the
   257  // GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
   258  //
   259  // If the endpoint override is an address (host:port) rather than full base
   260  // URL (ex. https://...), then the user-provided address will be merged into
   261  // the default endpoint. For example, WithEndpoint("myhost:8000") and
   262  // DefaultEndpointTemplate("https://UNIVERSE_DOMAIN/bar/baz") will return "https://myhost:8080/bar/baz"
   263  func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) {
   264  	if opts.Endpoint == "" {
   265  		mtlsMode := getMTLSMode()
   266  		if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
   267  			if !opts.isUniverseDomainGDU() {
   268  				return "", errUniverseNotSupportedMTLS
   269  			}
   270  			return opts.DefaultMTLSEndpoint, nil
   271  		}
   272  		return opts.defaultEndpoint(), nil
   273  	}
   274  	if strings.Contains(opts.Endpoint, "://") {
   275  		// User passed in a full URL path, use it verbatim.
   276  		return opts.Endpoint, nil
   277  	}
   278  	if opts.defaultEndpoint() == "" {
   279  		// If DefaultEndpointTemplate is not configured,
   280  		// use the user provided endpoint verbatim. This allows a naked
   281  		// "host[:port]" URL to be used with GRPC Direct Path.
   282  		return opts.Endpoint, nil
   283  	}
   284  
   285  	// Assume user-provided endpoint is host[:port], merge it with the default endpoint.
   286  	return opts.mergedEndpoint()
   287  }
   288  
   289  func getMTLSMode() string {
   290  	mode := os.Getenv(googleAPIUseMTLS)
   291  	if mode == "" {
   292  		mode = os.Getenv(googleAPIUseMTLSOld) // Deprecated.
   293  	}
   294  	if mode == "" {
   295  		return mTLSModeAuto
   296  	}
   297  	return strings.ToLower(mode)
   298  }
   299  

View as plain text