...

Source file src/github.com/cli/go-gh/v2/pkg/api/http_client.go

Documentation: github.com/cli/go-gh/v2/pkg/api

     1  package api
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"regexp"
    10  	"runtime/debug"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/cli/go-gh/v2/pkg/asciisanitizer"
    15  	"github.com/cli/go-gh/v2/pkg/config"
    16  	"github.com/cli/go-gh/v2/pkg/term"
    17  	"github.com/henvic/httpretty"
    18  	"github.com/thlib/go-timezone-local/tzlocal"
    19  	"golang.org/x/text/transform"
    20  )
    21  
    22  const (
    23  	accept          = "Accept"
    24  	authorization   = "Authorization"
    25  	contentType     = "Content-Type"
    26  	github          = "github.com"
    27  	jsonContentType = "application/json; charset=utf-8"
    28  	localhost       = "github.localhost"
    29  	modulePath      = "github.com/cli/go-gh"
    30  	timeZone        = "Time-Zone"
    31  	userAgent       = "User-Agent"
    32  )
    33  
    34  var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
    35  
    36  func DefaultHTTPClient() (*http.Client, error) {
    37  	return NewHTTPClient(ClientOptions{})
    38  }
    39  
    40  // HTTPClient builds a client that can be passed to another library.
    41  // As part of the configuration a hostname, auth token, default set of headers,
    42  // and unix domain socket are resolved from the gh environment configuration.
    43  // These behaviors can be overridden using the opts argument. In this instance
    44  // providing opts.Host will not change the destination of your request as it is
    45  // the responsibility of the consumer to configure this. However, if opts.Host
    46  // does not match the request host, the auth token will not be added to the headers.
    47  // This is to protect against the case where tokens could be sent to an arbitrary
    48  // host.
    49  func NewHTTPClient(opts ClientOptions) (*http.Client, error) {
    50  	if optionsNeedResolution(opts) {
    51  		var err error
    52  		opts, err = resolveOptions(opts)
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  	}
    57  
    58  	transport := http.DefaultTransport
    59  
    60  	if opts.UnixDomainSocket != "" {
    61  		transport = newUnixDomainSocketRoundTripper(opts.UnixDomainSocket)
    62  	}
    63  
    64  	if opts.Transport != nil {
    65  		transport = opts.Transport
    66  	}
    67  
    68  	transport = newSanitizerRoundTripper(transport)
    69  
    70  	if opts.CacheDir == "" {
    71  		opts.CacheDir = config.CacheDir()
    72  	}
    73  	if opts.EnableCache && opts.CacheTTL == 0 {
    74  		opts.CacheTTL = time.Hour * 24
    75  	}
    76  	c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL}
    77  	transport = c.RoundTripper(transport)
    78  
    79  	if opts.Log == nil && !opts.LogIgnoreEnv {
    80  		ghDebug := os.Getenv("GH_DEBUG")
    81  		switch ghDebug {
    82  		case "", "0", "false", "no":
    83  			// no logging
    84  		default:
    85  			opts.Log = os.Stderr
    86  			opts.LogColorize = !term.IsColorDisabled() && term.IsTerminal(os.Stderr)
    87  			opts.LogVerboseHTTP = strings.Contains(ghDebug, "api")
    88  		}
    89  	}
    90  
    91  	if opts.Log != nil {
    92  		logger := &httpretty.Logger{
    93  			Time:            true,
    94  			TLS:             false,
    95  			Colors:          opts.LogColorize,
    96  			RequestHeader:   opts.LogVerboseHTTP,
    97  			RequestBody:     opts.LogVerboseHTTP,
    98  			ResponseHeader:  opts.LogVerboseHTTP,
    99  			ResponseBody:    opts.LogVerboseHTTP,
   100  			Formatters:      []httpretty.Formatter{&jsonFormatter{colorize: opts.LogColorize}},
   101  			MaxResponseBody: 100000,
   102  		}
   103  		logger.SetOutput(opts.Log)
   104  		logger.SetBodyFilter(func(h http.Header) (skip bool, err error) {
   105  			return !inspectableMIMEType(h.Get(contentType)), nil
   106  		})
   107  		transport = logger.RoundTripper(transport)
   108  	}
   109  
   110  	if opts.Headers == nil {
   111  		opts.Headers = map[string]string{}
   112  	}
   113  	if !opts.SkipDefaultHeaders {
   114  		resolveHeaders(opts.Headers)
   115  	}
   116  	transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport)
   117  
   118  	return &http.Client{Transport: transport, Timeout: opts.Timeout}, nil
   119  }
   120  
   121  func inspectableMIMEType(t string) bool {
   122  	return strings.HasPrefix(t, "text/") ||
   123  		strings.HasPrefix(t, "application/x-www-form-urlencoded") ||
   124  		jsonTypeRE.MatchString(t)
   125  }
   126  
   127  func isSameDomain(requestHost, domain string) bool {
   128  	requestHost = strings.ToLower(requestHost)
   129  	domain = strings.ToLower(domain)
   130  	return (requestHost == domain) || strings.HasSuffix(requestHost, "."+domain)
   131  }
   132  
   133  func isGarage(host string) bool {
   134  	return strings.EqualFold(host, "garage.github.com")
   135  }
   136  
   137  type headerRoundTripper struct {
   138  	headers map[string]string
   139  	host    string
   140  	rt      http.RoundTripper
   141  }
   142  
   143  func resolveHeaders(headers map[string]string) {
   144  	if _, ok := headers[contentType]; !ok {
   145  		headers[contentType] = jsonContentType
   146  	}
   147  	if _, ok := headers[userAgent]; !ok {
   148  		headers[userAgent] = "go-gh"
   149  		info, ok := debug.ReadBuildInfo()
   150  		if ok {
   151  			for _, dep := range info.Deps {
   152  				if dep.Path == modulePath {
   153  					headers[userAgent] += fmt.Sprintf(" %s", dep.Version)
   154  					break
   155  				}
   156  			}
   157  		}
   158  	}
   159  	if _, ok := headers[timeZone]; !ok {
   160  		tz := currentTimeZone()
   161  		if tz != "" {
   162  			headers[timeZone] = tz
   163  		}
   164  	}
   165  	if _, ok := headers[accept]; !ok {
   166  		// Preview for PullRequest.mergeStateStatus.
   167  		a := "application/vnd.github.merge-info-preview+json"
   168  		// Preview for visibility when RESTing repos into an org.
   169  		a += ", application/vnd.github.nebula-preview"
   170  		headers[accept] = a
   171  	}
   172  }
   173  
   174  func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper {
   175  	if _, ok := headers[authorization]; !ok && authToken != "" {
   176  		headers[authorization] = fmt.Sprintf("token %s", authToken)
   177  	}
   178  	if len(headers) == 0 {
   179  		return rt
   180  	}
   181  	return headerRoundTripper{host: host, headers: headers, rt: rt}
   182  }
   183  
   184  func (hrt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   185  	for k, v := range hrt.headers {
   186  		// If the authorization header has been set and the request
   187  		// host is not in the same domain that was specified in the ClientOptions
   188  		// then do not add the authorization header to the request.
   189  		if k == authorization && !isSameDomain(req.URL.Hostname(), hrt.host) {
   190  			continue
   191  		}
   192  
   193  		// If the header is already set in the request, don't overwrite it.
   194  		if req.Header.Get(k) == "" {
   195  			req.Header.Set(k, v)
   196  		}
   197  	}
   198  
   199  	return hrt.rt.RoundTrip(req)
   200  }
   201  
   202  func newUnixDomainSocketRoundTripper(socketPath string) http.RoundTripper {
   203  	dial := func(network, addr string) (net.Conn, error) {
   204  		return net.Dial("unix", socketPath)
   205  	}
   206  
   207  	return &http.Transport{
   208  		Dial:              dial,
   209  		DialTLS:           dial,
   210  		DisableKeepAlives: true,
   211  	}
   212  }
   213  
   214  type sanitizerRoundTripper struct {
   215  	rt http.RoundTripper
   216  }
   217  
   218  func newSanitizerRoundTripper(rt http.RoundTripper) http.RoundTripper {
   219  	return sanitizerRoundTripper{rt: rt}
   220  }
   221  
   222  func (srt sanitizerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   223  	resp, err := srt.rt.RoundTrip(req)
   224  	if err != nil || !jsonTypeRE.MatchString(resp.Header.Get(contentType)) {
   225  		return resp, err
   226  	}
   227  	sanitizedReadCloser := struct {
   228  		io.Reader
   229  		io.Closer
   230  	}{
   231  		Reader: transform.NewReader(resp.Body, &asciisanitizer.Sanitizer{JSON: true}),
   232  		Closer: resp.Body,
   233  	}
   234  	resp.Body = sanitizedReadCloser
   235  	return resp, err
   236  }
   237  
   238  func currentTimeZone() string {
   239  	tz, err := tzlocal.RuntimeTZ()
   240  	if err != nil {
   241  		return ""
   242  	}
   243  	return tz
   244  }
   245  

View as plain text