1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package remotecmd
18
19 import (
20 "crypto/tls"
21 "errors"
22 "fmt"
23 "io"
24 "io/ioutil"
25 "net"
26 "net/http"
27 "net/url"
28 "os"
29 "strings"
30 "time"
31
32 "github.com/sassoftware/relic/cmdline/shared"
33 "github.com/sassoftware/relic/config"
34 "github.com/sassoftware/relic/lib/compresshttp"
35 "github.com/sassoftware/relic/lib/x509tools"
36 "golang.org/x/net/http2"
37 )
38
39 type ReaderGetter interface {
40 GetReader() (io.Reader, error)
41 }
42
43
44 func CallRemote(endpoint, method string, query *url.Values, body ReaderGetter) (*http.Response, error) {
45 if err := shared.InitClientConfig(); err != nil {
46 return nil, err
47 }
48 if shared.CurrentConfig.Remote == nil {
49 return nil, errors.New("config file has no \"remote\" section")
50 }
51 encodings := compresshttp.AcceptedEncodings
52 bases := []string{shared.CurrentConfig.Remote.URL}
53 if dirurl := shared.CurrentConfig.Remote.DirectoryURL; dirurl != "" {
54 newBases, serverEncodings, err := getDirectory(dirurl)
55 if err != nil {
56 return nil, err
57 } else if len(newBases) > 0 {
58 bases = newBases
59 }
60 encodings = serverEncodings
61 }
62 return doRequest(bases, endpoint, method, encodings, query, body)
63 }
64
65
66
67 func getDirectory(dirurl string) ([]string, string, error) {
68 response, err := doRequest([]string{dirurl}, "directory", "GET", "", nil, nil)
69 if err != nil {
70 return nil, "", err
71 }
72 encodings := response.Header.Get("Accept-Encoding")
73 bodybytes, err := ioutil.ReadAll(response.Body)
74 if err != nil {
75 return nil, "", err
76 }
77 response.Body.Close()
78 text := strings.Trim(string(bodybytes), "\r\n")
79 if len(text) == 0 {
80 return nil, encodings, nil
81 }
82 return strings.Split(text, "\r\n"), encodings, nil
83 }
84
85
86 func buildRequest(base, endpoint, method, encoding string, query *url.Values, bodyFile ReaderGetter) (*http.Request, error) {
87 eurl, err := url.Parse(endpoint)
88 if err != nil {
89 return nil, err
90 }
91 url, err := url.Parse(base)
92 if err != nil {
93 return nil, fmt.Errorf("Failed to parse remote URL: %s", err)
94 }
95 url = url.ResolveReference(eurl)
96 if query != nil {
97 url.RawQuery = query.Encode()
98 }
99 request := &http.Request{
100 Method: method,
101 URL: url,
102 Header: http.Header{"User-Agent": []string{config.UserAgent}},
103 }
104 if encoding != "" {
105 request.Header.Set("Accept-Encoding", encoding)
106 }
107 if bodyFile != nil {
108 stream, err := bodyFile.GetReader()
109 if err != nil {
110 return nil, err
111 }
112 request.Body = ioutil.NopCloser(stream)
113 if err := compresshttp.CompressRequest(request, encoding); err != nil {
114 return nil, err
115 }
116 }
117 return request, nil
118 }
119
120
121 func makeTLSConfig() (*tls.Config, error) {
122 err := shared.InitClientConfig()
123 if err != nil {
124 return nil, err
125 }
126 config := shared.CurrentConfig
127 if config.Remote == nil {
128 return nil, errors.New("Missing remote section in config file")
129 } else if config.Remote.URL == "" && config.Remote.DirectoryURL == "" {
130 return nil, errors.New("url or directoryUrl must be set in 'remote' section of configuration")
131 } else if config.Remote.CertFile == "" || config.Remote.KeyFile == "" {
132 return nil, errors.New("certfile and keyfile are required settings in 'remote' section of configuration")
133 }
134 tlscert, err := tls.LoadX509KeyPair(config.Remote.CertFile, config.Remote.KeyFile)
135 if err != nil {
136 return nil, err
137 }
138 tconf := &tls.Config{Certificates: []tls.Certificate{tlscert}}
139 x509tools.SetKeyLogFile(tconf)
140 if err := x509tools.LoadCertPool(config.Remote.CaCert, tconf); err != nil {
141 return nil, err
142 }
143 return tconf, nil
144 }
145
146
147 func doRequest(bases []string, endpoint, method, encodings string, query *url.Values, bodyFile ReaderGetter) (response *http.Response, err error) {
148 tconf, err := makeTLSConfig()
149 if err != nil {
150 return nil, err
151 }
152 dialer := &net.Dialer{
153 Timeout: time.Duration(shared.CurrentConfig.Remote.ConnectTimeout) * time.Second,
154 }
155 transport := &http.Transport{TLSClientConfig: tconf, DialContext: dialer.DialContext}
156 if err := http2.ConfigureTransport(transport); err != nil {
157 return nil, err
158 }
159 client := &http.Client{Transport: transport}
160
161 minAttempts := shared.CurrentConfig.Remote.Retries
162 if len(bases) < minAttempts {
163 var repeated []string
164 for len(repeated) < minAttempts {
165 repeated = append(repeated, bases...)
166 }
167 bases = repeated
168 }
169
170 loop:
171 for i, base := range bases {
172 var request *http.Request
173 request, err = buildRequest(base, endpoint, method, encodings, query, bodyFile)
174 if err != nil {
175 return nil, err
176 }
177 response, err = client.Do(request)
178 if request.Body != nil {
179 request.Body.Close()
180 }
181 if err == nil {
182 if response.StatusCode < 300 {
183 if i != 0 {
184 fmt.Printf("successfully contacted %s\n", request.URL)
185 }
186 break loop
187 }
188
189 body, _ := ioutil.ReadAll(response.Body)
190 response.Body.Close()
191 err = ResponseError{method, request.URL.String(), response.Status, response.StatusCode, string(body)}
192 }
193 if response != nil && response.StatusCode == http.StatusNotAcceptable && encodings != "" {
194
195 encodings = ""
196 goto loop
197 } else if isTemporary(err) && i+1 < len(bases) {
198 fmt.Printf("%s\nunable to connect to %s; trying next server\n", err, request.URL)
199 } else {
200 return nil, err
201 }
202 }
203 if response != nil {
204 if err := compresshttp.DecompressResponse(response); err != nil {
205 return nil, err
206 }
207 }
208 return
209 }
210
211 func setDigestQueryParam(query url.Values) error {
212 if shared.ArgDigest == "" {
213 return nil
214 }
215 if _, err := shared.GetDigest(); err != nil {
216 return err
217 }
218 query.Add("digest", shared.ArgDigest)
219 return nil
220 }
221
222
223
224
225 func isTemporary(err error) bool {
226 if e, ok := err.(temporary); ok && e.Temporary() {
227 return true
228 }
229
230 if e, ok := err.(*url.Error); ok {
231 err = e.Err
232 }
233 if e, ok := err.(*net.OpError); ok {
234 err = e.Err
235 }
236
237 if _, ok := err.(*os.SyscallError); ok {
238 return true
239 }
240 return false
241 }
242
243 type temporary interface {
244 Temporary() bool
245 }
246
247 type ResponseError struct {
248 Method string
249 URL string
250 Status string
251 StatusCode int
252 BodyText string
253 }
254
255 func (e ResponseError) Error() string {
256 return fmt.Sprintf("HTTP error:\n%s %s\n%s\n%s", e.Method, e.URL, e.Status, e.BodyText)
257 }
258
259 func (e ResponseError) Temporary() bool {
260 switch e.StatusCode {
261 case http.StatusGatewayTimeout,
262 http.StatusBadGateway,
263 http.StatusServiceUnavailable,
264 http.StatusInsufficientStorage,
265 http.StatusInternalServerError:
266 return true
267 default:
268 return false
269 }
270 }
271
View as plain text