1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 package internal
34
35 import (
36 "context"
37 "crypto/tls"
38 "errors"
39 "net"
40 "net/url"
41 "os"
42 "strings"
43
44 "github.com/google/s2a-go"
45 "github.com/google/s2a-go/fallback"
46 "google.golang.org/api/internal/cert"
47 "google.golang.org/grpc/credentials"
48 )
49
50 const (
51 mTLSModeAlways = "always"
52 mTLSModeNever = "never"
53 mTLSModeAuto = "auto"
54
55
56 googleAPIUseS2AEnv = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
57
58 universeDomainPlaceholder = "UNIVERSE_DOMAIN"
59 )
60
61 var (
62 errUniverseNotSupportedMTLS = errors.New("mTLS is not supported in any universe other than googleapis.com")
63 )
64
65
66
67
68 func getClientCertificateSourceAndEndpoint(settings *DialSettings) (cert.Source, string, error) {
69 clientCertSource, err := getClientCertificateSource(settings)
70 if err != nil {
71 return nil, "", err
72 }
73 endpoint, err := getEndpoint(settings, clientCertSource)
74 if err != nil {
75 return nil, "", err
76 }
77
78 if settings.Endpoint == "" && !settings.IsUniverseDomainGDU() && settings.DefaultEndpointTemplate != "" {
79
80
81
82
83 endpoint = resolvedDefaultEndpoint(settings)
84 }
85 return clientCertSource, endpoint, nil
86 }
87
88 type transportConfig struct {
89 clientCertSource cert.Source
90 endpoint string
91 s2aAddress string
92 s2aMTLSEndpoint string
93 }
94
95 func getTransportConfig(settings *DialSettings) (*transportConfig, error) {
96 clientCertSource, endpoint, err := getClientCertificateSourceAndEndpoint(settings)
97 if err != nil {
98 return nil, err
99 }
100 defaultTransportConfig := transportConfig{
101 clientCertSource: clientCertSource,
102 endpoint: endpoint,
103 s2aAddress: "",
104 s2aMTLSEndpoint: "",
105 }
106
107 if !shouldUseS2A(clientCertSource, settings) {
108 return &defaultTransportConfig, nil
109 }
110 if !settings.IsUniverseDomainGDU() {
111 return nil, errUniverseNotSupportedMTLS
112 }
113
114 s2aAddress := GetS2AAddress()
115 if s2aAddress == "" {
116 return &defaultTransportConfig, nil
117 }
118 return &transportConfig{
119 clientCertSource: clientCertSource,
120 endpoint: endpoint,
121 s2aAddress: s2aAddress,
122 s2aMTLSEndpoint: settings.DefaultMTLSEndpoint,
123 }, nil
124 }
125
126
127
128
129
130
131
132
133
134
135
136 func getClientCertificateSource(settings *DialSettings) (cert.Source, error) {
137 if !isClientCertificateEnabled() {
138 return nil, nil
139 } else if settings.ClientCertSource != nil {
140 return settings.ClientCertSource, nil
141 } else {
142 return cert.DefaultSource()
143 }
144 }
145
146 func isClientCertificateEnabled() bool {
147 useClientCert := os.Getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE")
148
149 return strings.ToLower(useClientCert) == "true"
150 }
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165 func getEndpoint(settings *DialSettings, clientCertSource cert.Source) (string, error) {
166 if settings.Endpoint == "" {
167 if isMTLS(clientCertSource) {
168 if !settings.IsUniverseDomainGDU() {
169 return "", errUniverseNotSupportedMTLS
170 }
171 return settings.DefaultMTLSEndpoint, nil
172 }
173 return resolvedDefaultEndpoint(settings), nil
174 }
175 if strings.Contains(settings.Endpoint, "://") {
176
177 return settings.Endpoint, nil
178 }
179 if resolvedDefaultEndpoint(settings) == "" {
180
181
182 return settings.Endpoint, nil
183 }
184
185
186 return mergeEndpoints(resolvedDefaultEndpoint(settings), settings.Endpoint)
187 }
188
189 func isMTLS(clientCertSource cert.Source) bool {
190 mtlsMode := getMTLSMode()
191 return mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto)
192 }
193
194
195
196
197 func resolvedDefaultEndpoint(settings *DialSettings) string {
198 if settings.DefaultEndpointTemplate == "" {
199 return settings.DefaultEndpoint
200 }
201 return strings.Replace(settings.DefaultEndpointTemplate, universeDomainPlaceholder, settings.GetUniverseDomain(), 1)
202 }
203
204 func getMTLSMode() string {
205 mode := os.Getenv("GOOGLE_API_USE_MTLS_ENDPOINT")
206 if mode == "" {
207 mode = os.Getenv("GOOGLE_API_USE_MTLS")
208 }
209 if mode == "" {
210 return mTLSModeAuto
211 }
212 return strings.ToLower(mode)
213 }
214
215 func mergeEndpoints(baseURL, newHost string) (string, error) {
216 u, err := url.Parse(fixScheme(baseURL))
217 if err != nil {
218 return "", err
219 }
220 return strings.Replace(baseURL, u.Host, newHost, 1), nil
221 }
222
223 func fixScheme(baseURL string) string {
224 if !strings.Contains(baseURL, "://") {
225 return "https://" + baseURL
226 }
227 return baseURL
228 }
229
230
231
232 func GetGRPCTransportConfigAndEndpoint(settings *DialSettings) (credentials.TransportCredentials, string, error) {
233 config, err := getTransportConfig(settings)
234 if err != nil {
235 return nil, "", err
236 }
237
238 defaultTransportCreds := credentials.NewTLS(&tls.Config{
239 GetClientCertificate: config.clientCertSource,
240 })
241 if config.s2aAddress == "" {
242 return defaultTransportCreds, config.endpoint, nil
243 }
244
245 var fallbackOpts *s2a.FallbackOptions
246
247 if fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(config.endpoint); err == nil {
248 fallbackOpts = &s2a.FallbackOptions{
249 FallbackClientHandshakeFunc: fallbackHandshake,
250 }
251 }
252
253 s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
254 S2AAddress: config.s2aAddress,
255 FallbackOpts: fallbackOpts,
256 })
257 if err != nil {
258
259 return defaultTransportCreds, config.endpoint, nil
260 }
261 return s2aTransportCreds, config.s2aMTLSEndpoint, nil
262 }
263
264
265
266 func GetHTTPTransportConfigAndEndpoint(settings *DialSettings) (cert.Source, func(context.Context, string, string) (net.Conn, error), string, error) {
267 config, err := getTransportConfig(settings)
268 if err != nil {
269 return nil, nil, "", err
270 }
271
272 if config.s2aAddress == "" {
273 return config.clientCertSource, nil, config.endpoint, nil
274 }
275
276 var fallbackOpts *s2a.FallbackOptions
277
278 if fallbackURL, err := url.Parse(config.endpoint); err == nil {
279 if fallbackDialer, fallbackServerAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackURL.Hostname()); err == nil {
280 fallbackOpts = &s2a.FallbackOptions{
281 FallbackDialer: &s2a.FallbackDialer{
282 Dialer: fallbackDialer,
283 ServerAddr: fallbackServerAddr,
284 },
285 }
286 }
287 }
288
289 dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
290 S2AAddress: config.s2aAddress,
291 FallbackOpts: fallbackOpts,
292 })
293 return nil, dialTLSContextFunc, config.s2aMTLSEndpoint, nil
294 }
295
296 func shouldUseS2A(clientCertSource cert.Source, settings *DialSettings) bool {
297
298 if clientCertSource != nil {
299 return false
300 }
301
302 if !isGoogleS2AEnabled() {
303 return false
304 }
305
306 if settings.DefaultMTLSEndpoint == "" || settings.Endpoint != "" {
307 return false
308 }
309
310 if settings.HTTPClient != nil {
311 return false
312 }
313 return !settings.EnableDirectPath && !settings.EnableDirectPathXds
314 }
315
316 func isGoogleS2AEnabled() bool {
317 return strings.ToLower(os.Getenv(googleAPIUseS2AEnv)) == "true"
318 }
319
View as plain text