...

Source file src/github.com/sassoftware/relic/cmdline/remotecmd/client.go

Documentation: github.com/sassoftware/relic/cmdline/remotecmd

     1  //
     2  // Copyright (c) SAS Institute Inc.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  //
    16  
    17  package remotecmd
    18  
    19  import (
    20  	"crypto/tls"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"io/ioutil"
    25  	"net"
    26  	"net/http"
    27  	"net/url"
    28  	"os"
    29  	"strings"
    30  	"time"
    31  
    32  	"github.com/sassoftware/relic/cmdline/shared"
    33  	"github.com/sassoftware/relic/config"
    34  	"github.com/sassoftware/relic/lib/compresshttp"
    35  	"github.com/sassoftware/relic/lib/x509tools"
    36  	"golang.org/x/net/http2"
    37  )
    38  
    39  type ReaderGetter interface {
    40  	GetReader() (io.Reader, error)
    41  }
    42  
    43  // Make a single API request to a named endpoint, handling directory lookup and failover automatically.
    44  func CallRemote(endpoint, method string, query *url.Values, body ReaderGetter) (*http.Response, error) {
    45  	if err := shared.InitClientConfig(); err != nil {
    46  		return nil, err
    47  	}
    48  	if shared.CurrentConfig.Remote == nil {
    49  		return nil, errors.New("config file has no \"remote\" section")
    50  	}
    51  	encodings := compresshttp.AcceptedEncodings
    52  	bases := []string{shared.CurrentConfig.Remote.URL}
    53  	if dirurl := shared.CurrentConfig.Remote.DirectoryURL; dirurl != "" {
    54  		newBases, serverEncodings, err := getDirectory(dirurl)
    55  		if err != nil {
    56  			return nil, err
    57  		} else if len(newBases) > 0 {
    58  			bases = newBases
    59  		}
    60  		encodings = serverEncodings
    61  	}
    62  	return doRequest(bases, endpoint, method, encodings, query, body)
    63  }
    64  
    65  // Call the configured directory URL to get a list of servers to try.
    66  // callRemote() calls this automatically, use that instead.
    67  func getDirectory(dirurl string) ([]string, string, error) {
    68  	response, err := doRequest([]string{dirurl}, "directory", "GET", "", nil, nil)
    69  	if err != nil {
    70  		return nil, "", err
    71  	}
    72  	encodings := response.Header.Get("Accept-Encoding")
    73  	bodybytes, err := ioutil.ReadAll(response.Body)
    74  	if err != nil {
    75  		return nil, "", err
    76  	}
    77  	response.Body.Close()
    78  	text := strings.Trim(string(bodybytes), "\r\n")
    79  	if len(text) == 0 {
    80  		return nil, encodings, nil
    81  	}
    82  	return strings.Split(text, "\r\n"), encodings, nil
    83  }
    84  
    85  // Build a HTTP request from various bits and pieces
    86  func buildRequest(base, endpoint, method, encoding string, query *url.Values, bodyFile ReaderGetter) (*http.Request, error) {
    87  	eurl, err := url.Parse(endpoint)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	url, err := url.Parse(base)
    92  	if err != nil {
    93  		return nil, fmt.Errorf("Failed to parse remote URL: %s", err)
    94  	}
    95  	url = url.ResolveReference(eurl)
    96  	if query != nil {
    97  		url.RawQuery = query.Encode()
    98  	}
    99  	request := &http.Request{
   100  		Method: method,
   101  		URL:    url,
   102  		Header: http.Header{"User-Agent": []string{config.UserAgent}},
   103  	}
   104  	if encoding != "" {
   105  		request.Header.Set("Accept-Encoding", encoding)
   106  	}
   107  	if bodyFile != nil {
   108  		stream, err := bodyFile.GetReader()
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		request.Body = ioutil.NopCloser(stream)
   113  		if err := compresshttp.CompressRequest(request, encoding); err != nil {
   114  			return nil, err
   115  		}
   116  	}
   117  	return request, nil
   118  }
   119  
   120  // Build TLS config based on client configuration
   121  func makeTLSConfig() (*tls.Config, error) {
   122  	err := shared.InitClientConfig()
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	config := shared.CurrentConfig
   127  	if config.Remote == nil {
   128  		return nil, errors.New("Missing remote section in config file")
   129  	} else if config.Remote.URL == "" && config.Remote.DirectoryURL == "" {
   130  		return nil, errors.New("url or directoryUrl must be set in 'remote' section of configuration")
   131  	} else if config.Remote.CertFile == "" || config.Remote.KeyFile == "" {
   132  		return nil, errors.New("certfile and keyfile are required settings in 'remote' section of configuration")
   133  	}
   134  	tlscert, err := tls.LoadX509KeyPair(config.Remote.CertFile, config.Remote.KeyFile)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	tconf := &tls.Config{Certificates: []tls.Certificate{tlscert}}
   139  	x509tools.SetKeyLogFile(tconf)
   140  	if err := x509tools.LoadCertPool(config.Remote.CaCert, tconf); err != nil {
   141  		return nil, err
   142  	}
   143  	return tconf, nil
   144  }
   145  
   146  // Transact one request, trying multiple servers if necessary. Internal use only.
   147  func doRequest(bases []string, endpoint, method, encodings string, query *url.Values, bodyFile ReaderGetter) (response *http.Response, err error) {
   148  	tconf, err := makeTLSConfig()
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	dialer := &net.Dialer{
   153  		Timeout: time.Duration(shared.CurrentConfig.Remote.ConnectTimeout) * time.Second,
   154  	}
   155  	transport := &http.Transport{TLSClientConfig: tconf, DialContext: dialer.DialContext}
   156  	if err := http2.ConfigureTransport(transport); err != nil {
   157  		return nil, err
   158  	}
   159  	client := &http.Client{Transport: transport}
   160  
   161  	minAttempts := shared.CurrentConfig.Remote.Retries
   162  	if len(bases) < minAttempts {
   163  		var repeated []string
   164  		for len(repeated) < minAttempts {
   165  			repeated = append(repeated, bases...)
   166  		}
   167  		bases = repeated
   168  	}
   169  
   170  loop:
   171  	for i, base := range bases {
   172  		var request *http.Request
   173  		request, err = buildRequest(base, endpoint, method, encodings, query, bodyFile)
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  		response, err = client.Do(request)
   178  		if request.Body != nil {
   179  			request.Body.Close()
   180  		}
   181  		if err == nil {
   182  			if response.StatusCode < 300 {
   183  				if i != 0 {
   184  					fmt.Printf("successfully contacted %s\n", request.URL)
   185  				}
   186  				break loop
   187  			}
   188  			// HTTP error, probably a 503
   189  			body, _ := ioutil.ReadAll(response.Body)
   190  			response.Body.Close()
   191  			err = ResponseError{method, request.URL.String(), response.Status, response.StatusCode, string(body)}
   192  		}
   193  		if response != nil && response.StatusCode == http.StatusNotAcceptable && encodings != "" {
   194  			// try again without compression
   195  			encodings = ""
   196  			goto loop
   197  		} else if isTemporary(err) && i+1 < len(bases) {
   198  			fmt.Printf("%s\nunable to connect to %s; trying next server\n", err, request.URL)
   199  		} else {
   200  			return nil, err
   201  		}
   202  	}
   203  	if response != nil {
   204  		if err := compresshttp.DecompressResponse(response); err != nil {
   205  			return nil, err
   206  		}
   207  	}
   208  	return
   209  }
   210  
   211  func setDigestQueryParam(query url.Values) error {
   212  	if shared.ArgDigest == "" {
   213  		return nil
   214  	}
   215  	if _, err := shared.GetDigest(); err != nil {
   216  		return err
   217  	}
   218  	query.Add("digest", shared.ArgDigest)
   219  	return nil
   220  }
   221  
   222  // Check if an error is something recoverable, i.e. if we should continue to
   223  // try another server. In practice, anything other than a HTTP 4XX status will
   224  // result in a retry.
   225  func isTemporary(err error) bool {
   226  	if e, ok := err.(temporary); ok && e.Temporary() {
   227  		return true
   228  	}
   229  	// unpack error wrappers
   230  	if e, ok := err.(*url.Error); ok {
   231  		err = e.Err
   232  	}
   233  	if e, ok := err.(*net.OpError); ok {
   234  		err = e.Err
   235  	}
   236  	// treat any syscall error as something recoverable
   237  	if _, ok := err.(*os.SyscallError); ok {
   238  		return true
   239  	}
   240  	return false
   241  }
   242  
   243  type temporary interface {
   244  	Temporary() bool
   245  }
   246  
   247  type ResponseError struct {
   248  	Method     string
   249  	URL        string
   250  	Status     string
   251  	StatusCode int
   252  	BodyText   string
   253  }
   254  
   255  func (e ResponseError) Error() string {
   256  	return fmt.Sprintf("HTTP error:\n%s %s\n%s\n%s", e.Method, e.URL, e.Status, e.BodyText)
   257  }
   258  
   259  func (e ResponseError) Temporary() bool {
   260  	switch e.StatusCode {
   261  	case http.StatusGatewayTimeout,
   262  		http.StatusBadGateway,
   263  		http.StatusServiceUnavailable,
   264  		http.StatusInsufficientStorage,
   265  		http.StatusInternalServerError:
   266  		return true
   267  	default:
   268  		return false
   269  	}
   270  }
   271  

View as plain text