1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package transport
16
17 import (
18 "context"
19 "crypto/tls"
20 "errors"
21 "net"
22 "net/http"
23 "net/url"
24 "os"
25 "strconv"
26 "strings"
27
28 "cloud.google.com/go/auth/internal"
29 "cloud.google.com/go/auth/internal/transport/cert"
30 "github.com/google/s2a-go"
31 "github.com/google/s2a-go/fallback"
32 "google.golang.org/grpc/credentials"
33 )
34
35 const (
36 mTLSModeAlways = "always"
37 mTLSModeNever = "never"
38 mTLSModeAuto = "auto"
39
40
41 googleAPIUseS2AEnv = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
42 googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
43 googleAPIUseMTLS = "GOOGLE_API_USE_MTLS_ENDPOINT"
44 googleAPIUseMTLSOld = "GOOGLE_API_USE_MTLS"
45
46 universeDomainPlaceholder = "UNIVERSE_DOMAIN"
47 )
48
49 var (
50 mdsMTLSAutoConfigSource mtlsConfigSource
51 errUniverseNotSupportedMTLS = errors.New("mTLS is not supported in any universe other than googleapis.com")
52 )
53
54
55
56
57 type Options struct {
58 Endpoint string
59 DefaultMTLSEndpoint string
60 DefaultEndpointTemplate string
61 ClientCertProvider cert.Provider
62 Client *http.Client
63 UniverseDomain string
64 EnableDirectPath bool
65 EnableDirectPathXds bool
66 }
67
68
69
70 func (o *Options) getUniverseDomain() string {
71 if o.UniverseDomain == "" {
72 return internal.DefaultUniverseDomain
73 }
74 return o.UniverseDomain
75 }
76
77
78
79 func (o *Options) isUniverseDomainGDU() bool {
80 return o.getUniverseDomain() == internal.DefaultUniverseDomain
81 }
82
83
84
85
86 func (o *Options) defaultEndpoint() string {
87 if o.DefaultEndpointTemplate == "" {
88 return ""
89 }
90 return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1)
91 }
92
93
94
95 func (o *Options) mergedEndpoint() (string, error) {
96 defaultEndpoint := o.defaultEndpoint()
97 u, err := url.Parse(fixScheme(defaultEndpoint))
98 if err != nil {
99 return "", err
100 }
101 return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil
102 }
103
104 func fixScheme(baseURL string) string {
105 if !strings.Contains(baseURL, "://") {
106 baseURL = "https://" + baseURL
107 }
108 return baseURL
109 }
110
111
112
113
114 func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCredentials, string, error) {
115 config, err := getTransportConfig(opts)
116 if err != nil {
117 return nil, "", err
118 }
119
120 defaultTransportCreds := credentials.NewTLS(&tls.Config{
121 GetClientCertificate: config.clientCertSource,
122 })
123 if config.s2aAddress == "" {
124 return defaultTransportCreds, config.endpoint, nil
125 }
126
127 var fallbackOpts *s2a.FallbackOptions
128
129 if fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(config.endpoint); err == nil {
130 fallbackOpts = &s2a.FallbackOptions{
131 FallbackClientHandshakeFunc: fallbackHandshake,
132 }
133 }
134
135 s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
136 S2AAddress: config.s2aAddress,
137 FallbackOpts: fallbackOpts,
138 })
139 if err != nil {
140
141 return defaultTransportCreds, config.endpoint, nil
142 }
143 return s2aTransportCreds, config.s2aMTLSEndpoint, nil
144 }
145
146
147
148 func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, string, string) (net.Conn, error), error) {
149 config, err := getTransportConfig(opts)
150 if err != nil {
151 return nil, nil, err
152 }
153
154 if config.s2aAddress == "" {
155 return config.clientCertSource, nil, nil
156 }
157
158 var fallbackOpts *s2a.FallbackOptions
159
160 if fallbackURL, err := url.Parse(config.endpoint); err == nil {
161 if fallbackDialer, fallbackServerAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackURL.Hostname()); err == nil {
162 fallbackOpts = &s2a.FallbackOptions{
163 FallbackDialer: &s2a.FallbackDialer{
164 Dialer: fallbackDialer,
165 ServerAddr: fallbackServerAddr,
166 },
167 }
168 }
169 }
170
171 dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
172 S2AAddress: config.s2aAddress,
173 FallbackOpts: fallbackOpts,
174 })
175 return nil, dialTLSContextFunc, nil
176 }
177
178 func getTransportConfig(opts *Options) (*transportConfig, error) {
179 clientCertSource, err := getClientCertificateSource(opts)
180 if err != nil {
181 return nil, err
182 }
183 endpoint, err := getEndpoint(opts, clientCertSource)
184 if err != nil {
185 return nil, err
186 }
187 defaultTransportConfig := transportConfig{
188 clientCertSource: clientCertSource,
189 endpoint: endpoint,
190 }
191
192 if !shouldUseS2A(clientCertSource, opts) {
193 return &defaultTransportConfig, nil
194 }
195 if !opts.isUniverseDomainGDU() {
196 return nil, errUniverseNotSupportedMTLS
197 }
198
199 s2aMTLSEndpoint := opts.DefaultMTLSEndpoint
200
201 s2aAddress := GetS2AAddress()
202 if s2aAddress == "" {
203 return &defaultTransportConfig, nil
204 }
205 return &transportConfig{
206 clientCertSource: clientCertSource,
207 endpoint: endpoint,
208 s2aAddress: s2aAddress,
209 s2aMTLSEndpoint: s2aMTLSEndpoint,
210 }, nil
211 }
212
213
214
215
216
217
218
219 func getClientCertificateSource(opts *Options) (cert.Provider, error) {
220 if !isClientCertificateEnabled() {
221 return nil, nil
222 } else if opts.ClientCertProvider != nil {
223 return opts.ClientCertProvider, nil
224 }
225 return cert.DefaultProvider()
226
227 }
228
229
230 func isClientCertificateEnabled() bool {
231 if value, ok := os.LookupEnv(googleAPIUseCertSource); ok {
232
233 b, _ := strconv.ParseBool(value)
234 return b
235 }
236 return true
237 }
238
239 type transportConfig struct {
240
241 clientCertSource cert.Provider
242
243 endpoint string
244
245 s2aAddress string
246
247 s2aMTLSEndpoint string
248 }
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263 func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) {
264 if opts.Endpoint == "" {
265 mtlsMode := getMTLSMode()
266 if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
267 if !opts.isUniverseDomainGDU() {
268 return "", errUniverseNotSupportedMTLS
269 }
270 return opts.DefaultMTLSEndpoint, nil
271 }
272 return opts.defaultEndpoint(), nil
273 }
274 if strings.Contains(opts.Endpoint, "://") {
275
276 return opts.Endpoint, nil
277 }
278 if opts.defaultEndpoint() == "" {
279
280
281
282 return opts.Endpoint, nil
283 }
284
285
286 return opts.mergedEndpoint()
287 }
288
289 func getMTLSMode() string {
290 mode := os.Getenv(googleAPIUseMTLS)
291 if mode == "" {
292 mode = os.Getenv(googleAPIUseMTLSOld)
293 }
294 if mode == "" {
295 return mTLSModeAuto
296 }
297 return strings.ToLower(mode)
298 }
299
View as plain text