...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package httptransport
16
17 import (
18 "context"
19 "crypto/tls"
20 "net"
21 "net/http"
22 "time"
23
24 "cloud.google.com/go/auth"
25 "cloud.google.com/go/auth/credentials"
26 "cloud.google.com/go/auth/internal"
27 "cloud.google.com/go/auth/internal/transport"
28 "cloud.google.com/go/auth/internal/transport/cert"
29 "go.opencensus.io/plugin/ochttp"
30 "golang.org/x/net/http2"
31 )
32
33 const (
34 quotaProjectHeaderKey = "X-Goog-User-Project"
35 )
36
37 func newTransport(base http.RoundTripper, opts *Options) (http.RoundTripper, error) {
38 var headers = opts.Headers
39 ht := &headerTransport{
40 base: base,
41 headers: headers,
42 }
43 var trans http.RoundTripper = ht
44 trans = addOCTransport(trans, opts)
45 switch {
46 case opts.DisableAuthentication:
47
48 case opts.APIKey != "":
49 qp := internal.GetQuotaProject(nil, opts.Headers.Get(quotaProjectHeaderKey))
50 if qp != "" {
51 if headers == nil {
52 headers = make(map[string][]string, 1)
53 }
54 headers.Set(quotaProjectHeaderKey, qp)
55 }
56 trans = &apiKeyTransport{
57 Transport: trans,
58 Key: opts.APIKey,
59 }
60 default:
61 var creds *auth.Credentials
62 if opts.Credentials != nil {
63 creds = opts.Credentials
64 } else {
65 var err error
66 creds, err = credentials.DetectDefault(opts.resolveDetectOptions())
67 if err != nil {
68 return nil, err
69 }
70 }
71 qp, err := creds.QuotaProjectID(context.Background())
72 if err != nil {
73 return nil, err
74 }
75 if qp != "" {
76 if headers == nil {
77 headers = make(map[string][]string, 1)
78 }
79 headers.Set(quotaProjectHeaderKey, qp)
80 }
81 creds.TokenProvider = auth.NewCachedTokenProvider(creds.TokenProvider, nil)
82 trans = &authTransport{
83 base: trans,
84 creds: creds,
85 clientUniverseDomain: opts.UniverseDomain,
86 }
87 }
88 return trans, nil
89 }
90
91
92
93
94
95
96 func defaultBaseTransport(clientCertSource cert.Provider, dialTLSContext func(context.Context, string, string) (net.Conn, error)) http.RoundTripper {
97 trans := http.DefaultTransport.(*http.Transport).Clone()
98 trans.MaxIdleConnsPerHost = 100
99
100 if clientCertSource != nil {
101 trans.TLSClientConfig = &tls.Config{
102 GetClientCertificate: clientCertSource,
103 }
104 }
105 if dialTLSContext != nil {
106
107 trans.DialTLSContext = dialTLSContext
108 }
109
110
111
112
113
114 http2Trans, err := http2.ConfigureTransports(trans)
115 if err == nil {
116 http2Trans.ReadIdleTimeout = time.Second * 31
117 }
118
119 return trans
120 }
121
122 type apiKeyTransport struct {
123
124 Key string
125
126
127 Transport http.RoundTripper
128 }
129
130 func (t *apiKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
131 newReq := *req
132 args := newReq.URL.Query()
133 args.Set("key", t.Key)
134 newReq.URL.RawQuery = args.Encode()
135 return t.Transport.RoundTrip(&newReq)
136 }
137
138 type headerTransport struct {
139 headers http.Header
140 base http.RoundTripper
141 }
142
143 func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
144 rt := t.base
145 newReq := *req
146 newReq.Header = make(http.Header)
147 for k, vv := range req.Header {
148 newReq.Header[k] = vv
149 }
150
151 for k, v := range t.headers {
152 newReq.Header[k] = v
153 }
154
155 return rt.RoundTrip(&newReq)
156 }
157
158 func addOCTransport(trans http.RoundTripper, opts *Options) http.RoundTripper {
159 if opts.DisableTelemetry {
160 return trans
161 }
162 return &ochttp.Transport{
163 Base: trans,
164 Propagation: &httpFormat{},
165 }
166 }
167
168 type authTransport struct {
169 creds *auth.Credentials
170 base http.RoundTripper
171 clientUniverseDomain string
172 }
173
174
175
176 func (t *authTransport) getClientUniverseDomain() string {
177 if t.clientUniverseDomain == "" {
178 return internal.DefaultUniverseDomain
179 }
180 return t.clientUniverseDomain
181 }
182
183
184
185
186
187 func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
188 reqBodyClosed := false
189 if req.Body != nil {
190 defer func() {
191 if !reqBodyClosed {
192 req.Body.Close()
193 }
194 }()
195 }
196 credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
197 if err != nil {
198 return nil, err
199 }
200 if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
201 return nil, err
202 }
203 token, err := t.creds.Token(req.Context())
204 if err != nil {
205 return nil, err
206 }
207 req2 := req.Clone(req.Context())
208 SetAuthHeader(token, req2)
209 reqBodyClosed = true
210 return t.base.RoundTrip(req2)
211 }
212
View as plain text