...

Source file src/github.com/go-openapi/runtime/client/runtime.go

Documentation: github.com/go-openapi/runtime/client

     1  // Copyright 2015 go-swagger maintainers
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package client
    16  
    17  import (
    18  	"context"
    19  	"crypto"
    20  	"crypto/ecdsa"
    21  	"crypto/rsa"
    22  	"crypto/tls"
    23  	"crypto/x509"
    24  	"encoding/pem"
    25  	"errors"
    26  	"fmt"
    27  	"mime"
    28  	"net/http"
    29  	"net/http/httputil"
    30  	"os"
    31  	"strings"
    32  	"sync"
    33  	"time"
    34  
    35  	"github.com/go-openapi/strfmt"
    36  	"github.com/opentracing/opentracing-go"
    37  
    38  	"github.com/go-openapi/runtime"
    39  	"github.com/go-openapi/runtime/logger"
    40  	"github.com/go-openapi/runtime/middleware"
    41  	"github.com/go-openapi/runtime/yamlpc"
    42  )
    43  
    44  const (
    45  	schemeHTTP  = "http"
    46  	schemeHTTPS = "https"
    47  )
    48  
    49  // TLSClientOptions to configure client authentication with mutual TLS
    50  type TLSClientOptions struct {
    51  	// Certificate is the path to a PEM-encoded certificate to be used for
    52  	// client authentication. If set then Key must also be set.
    53  	Certificate string
    54  
    55  	// LoadedCertificate is the certificate to be used for client authentication.
    56  	// This field is ignored if Certificate is set. If this field is set, LoadedKey
    57  	// is also required.
    58  	LoadedCertificate *x509.Certificate
    59  
    60  	// Key is the path to an unencrypted PEM-encoded private key for client
    61  	// authentication. This field is required if Certificate is set.
    62  	Key string
    63  
    64  	// LoadedKey is the key for client authentication. This field is required if
    65  	// LoadedCertificate is set.
    66  	LoadedKey crypto.PrivateKey
    67  
    68  	// CA is a path to a PEM-encoded certificate that specifies the root certificate
    69  	// to use when validating the TLS certificate presented by the server. If this field
    70  	// (and LoadedCA) is not set, the system certificate pool is used. This field is ignored if LoadedCA
    71  	// is set.
    72  	CA string
    73  
    74  	// LoadedCA specifies the root certificate to use when validating the server's TLS certificate.
    75  	// If this field (and CA) is not set, the system certificate pool is used.
    76  	LoadedCA *x509.Certificate
    77  
    78  	// LoadedCAPool specifies a pool of RootCAs to use when validating the server's TLS certificate.
    79  	// If set, it will be combined with the other loaded certificates (see LoadedCA and CA).
    80  	// If neither LoadedCA or CA is set, the provided pool with override the system
    81  	// certificate pool.
    82  	// The caller must not use the supplied pool after calling TLSClientAuth.
    83  	LoadedCAPool *x509.CertPool
    84  
    85  	// ServerName specifies the hostname to use when verifying the server certificate.
    86  	// If this field is set then InsecureSkipVerify will be ignored and treated as
    87  	// false.
    88  	ServerName string
    89  
    90  	// InsecureSkipVerify controls whether the certificate chain and hostname presented
    91  	// by the server are validated. If true, any certificate is accepted.
    92  	InsecureSkipVerify bool
    93  
    94  	// VerifyPeerCertificate, if not nil, is called after normal
    95  	// certificate verification. It receives the raw ASN.1 certificates
    96  	// provided by the peer and also any verified chains that normal processing found.
    97  	// If it returns a non-nil error, the handshake is aborted and that error results.
    98  	//
    99  	// If normal verification fails then the handshake will abort before
   100  	// considering this callback. If normal verification is disabled by
   101  	// setting InsecureSkipVerify then this callback will be considered but
   102  	// the verifiedChains argument will always be nil.
   103  	VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
   104  
   105  	// SessionTicketsDisabled may be set to true to disable session ticket and
   106  	// PSK (resumption) support. Note that on clients, session ticket support is
   107  	// also disabled if ClientSessionCache is nil.
   108  	SessionTicketsDisabled bool
   109  
   110  	// ClientSessionCache is a cache of ClientSessionState entries for TLS
   111  	// session resumption. It is only used by clients.
   112  	ClientSessionCache tls.ClientSessionCache
   113  
   114  	// Prevents callers using unkeyed fields.
   115  	_ struct{}
   116  }
   117  
   118  // TLSClientAuth creates a tls.Config for mutual auth
   119  func TLSClientAuth(opts TLSClientOptions) (*tls.Config, error) {
   120  	// create client tls config
   121  	cfg := &tls.Config{
   122  		MinVersion: tls.VersionTLS12,
   123  	}
   124  
   125  	// load client cert if specified
   126  	if opts.Certificate != "" {
   127  		cert, err := tls.LoadX509KeyPair(opts.Certificate, opts.Key)
   128  		if err != nil {
   129  			return nil, fmt.Errorf("tls client cert: %v", err)
   130  		}
   131  		cfg.Certificates = []tls.Certificate{cert}
   132  	} else if opts.LoadedCertificate != nil {
   133  		block := pem.Block{Type: "CERTIFICATE", Bytes: opts.LoadedCertificate.Raw}
   134  		certPem := pem.EncodeToMemory(&block)
   135  
   136  		var keyBytes []byte
   137  		switch k := opts.LoadedKey.(type) {
   138  		case *rsa.PrivateKey:
   139  			keyBytes = x509.MarshalPKCS1PrivateKey(k)
   140  		case *ecdsa.PrivateKey:
   141  			var err error
   142  			keyBytes, err = x509.MarshalECPrivateKey(k)
   143  			if err != nil {
   144  				return nil, fmt.Errorf("tls client priv key: %v", err)
   145  			}
   146  		default:
   147  			return nil, errors.New("tls client priv key: unsupported key type")
   148  		}
   149  
   150  		block = pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}
   151  		keyPem := pem.EncodeToMemory(&block)
   152  
   153  		cert, err := tls.X509KeyPair(certPem, keyPem)
   154  		if err != nil {
   155  			return nil, fmt.Errorf("tls client cert: %v", err)
   156  		}
   157  		cfg.Certificates = []tls.Certificate{cert}
   158  	}
   159  
   160  	cfg.InsecureSkipVerify = opts.InsecureSkipVerify
   161  
   162  	cfg.VerifyPeerCertificate = opts.VerifyPeerCertificate
   163  	cfg.SessionTicketsDisabled = opts.SessionTicketsDisabled
   164  	cfg.ClientSessionCache = opts.ClientSessionCache
   165  
   166  	// When no CA certificate is provided, default to the system cert pool
   167  	// that way when a request is made to a server known by the system trust store,
   168  	// the name is still verified
   169  	switch {
   170  	case opts.LoadedCA != nil:
   171  		caCertPool := basePool(opts.LoadedCAPool)
   172  		caCertPool.AddCert(opts.LoadedCA)
   173  		cfg.RootCAs = caCertPool
   174  	case opts.CA != "":
   175  		// load ca cert
   176  		caCert, err := os.ReadFile(opts.CA)
   177  		if err != nil {
   178  			return nil, fmt.Errorf("tls client ca: %v", err)
   179  		}
   180  		caCertPool := basePool(opts.LoadedCAPool)
   181  		caCertPool.AppendCertsFromPEM(caCert)
   182  		cfg.RootCAs = caCertPool
   183  	case opts.LoadedCAPool != nil:
   184  		cfg.RootCAs = opts.LoadedCAPool
   185  	}
   186  
   187  	// apply servername overrride
   188  	if opts.ServerName != "" {
   189  		cfg.InsecureSkipVerify = false
   190  		cfg.ServerName = opts.ServerName
   191  	}
   192  
   193  	return cfg, nil
   194  }
   195  
   196  func basePool(pool *x509.CertPool) *x509.CertPool {
   197  	if pool == nil {
   198  		return x509.NewCertPool()
   199  	}
   200  	return pool
   201  }
   202  
   203  // TLSTransport creates a http client transport suitable for mutual tls auth
   204  func TLSTransport(opts TLSClientOptions) (http.RoundTripper, error) {
   205  	cfg, err := TLSClientAuth(opts)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	return &http.Transport{TLSClientConfig: cfg}, nil
   211  }
   212  
   213  // TLSClient creates a http.Client for mutual auth
   214  func TLSClient(opts TLSClientOptions) (*http.Client, error) {
   215  	transport, err := TLSTransport(opts)
   216  	if err != nil {
   217  		return nil, err
   218  	}
   219  	return &http.Client{Transport: transport}, nil
   220  }
   221  
   222  // DefaultTimeout the default request timeout
   223  var DefaultTimeout = 30 * time.Second
   224  
   225  // Runtime represents an API client that uses the transport
   226  // to make http requests based on a swagger specification.
   227  type Runtime struct {
   228  	DefaultMediaType      string
   229  	DefaultAuthentication runtime.ClientAuthInfoWriter
   230  	Consumers             map[string]runtime.Consumer
   231  	Producers             map[string]runtime.Producer
   232  
   233  	Transport http.RoundTripper
   234  	Jar       http.CookieJar
   235  	// Spec      *spec.Document
   236  	Host     string
   237  	BasePath string
   238  	Formats  strfmt.Registry
   239  	Context  context.Context //nolint:containedctx  // we precisely want this type to contain the request context
   240  
   241  	Debug  bool
   242  	logger logger.Logger
   243  
   244  	clientOnce *sync.Once
   245  	client     *http.Client
   246  	schemes    []string
   247  	response   ClientResponseFunc
   248  }
   249  
   250  // New creates a new default runtime for a swagger api runtime.Client
   251  func New(host, basePath string, schemes []string) *Runtime {
   252  	var rt Runtime
   253  	rt.DefaultMediaType = runtime.JSONMime
   254  
   255  	// TODO: actually infer this stuff from the spec
   256  	rt.Consumers = map[string]runtime.Consumer{
   257  		runtime.YAMLMime:    yamlpc.YAMLConsumer(),
   258  		runtime.JSONMime:    runtime.JSONConsumer(),
   259  		runtime.XMLMime:     runtime.XMLConsumer(),
   260  		runtime.TextMime:    runtime.TextConsumer(),
   261  		runtime.HTMLMime:    runtime.TextConsumer(),
   262  		runtime.CSVMime:     runtime.CSVConsumer(),
   263  		runtime.DefaultMime: runtime.ByteStreamConsumer(),
   264  	}
   265  	rt.Producers = map[string]runtime.Producer{
   266  		runtime.YAMLMime:    yamlpc.YAMLProducer(),
   267  		runtime.JSONMime:    runtime.JSONProducer(),
   268  		runtime.XMLMime:     runtime.XMLProducer(),
   269  		runtime.TextMime:    runtime.TextProducer(),
   270  		runtime.HTMLMime:    runtime.TextProducer(),
   271  		runtime.CSVMime:     runtime.CSVProducer(),
   272  		runtime.DefaultMime: runtime.ByteStreamProducer(),
   273  	}
   274  	rt.Transport = http.DefaultTransport
   275  	rt.Jar = nil
   276  	rt.Host = host
   277  	rt.BasePath = basePath
   278  	rt.Context = context.Background()
   279  	rt.clientOnce = new(sync.Once)
   280  	if !strings.HasPrefix(rt.BasePath, "/") {
   281  		rt.BasePath = "/" + rt.BasePath
   282  	}
   283  
   284  	rt.Debug = logger.DebugEnabled()
   285  	rt.logger = logger.StandardLogger{}
   286  	rt.response = newResponse
   287  
   288  	if len(schemes) > 0 {
   289  		rt.schemes = schemes
   290  	}
   291  	return &rt
   292  }
   293  
   294  // NewWithClient allows you to create a new transport with a configured http.Client
   295  func NewWithClient(host, basePath string, schemes []string, client *http.Client) *Runtime {
   296  	rt := New(host, basePath, schemes)
   297  	if client != nil {
   298  		rt.clientOnce.Do(func() {
   299  			rt.client = client
   300  		})
   301  	}
   302  	return rt
   303  }
   304  
   305  // WithOpenTracing adds opentracing support to the provided runtime.
   306  // A new client span is created for each request.
   307  // If the context of the client operation does not contain an active span, no span is created.
   308  // The provided opts are applied to each spans - for example to add global tags.
   309  func (r *Runtime) WithOpenTracing(opts ...opentracing.StartSpanOption) runtime.ClientTransport {
   310  	return newOpenTracingTransport(r, r.Host, opts)
   311  }
   312  
   313  // WithOpenTelemetry adds opentelemetry support to the provided runtime.
   314  // A new client span is created for each request.
   315  // If the context of the client operation does not contain an active span, no span is created.
   316  // The provided opts are applied to each spans - for example to add global tags.
   317  func (r *Runtime) WithOpenTelemetry(opts ...OpenTelemetryOpt) runtime.ClientTransport {
   318  	return newOpenTelemetryTransport(r, r.Host, opts)
   319  }
   320  
   321  func (r *Runtime) pickScheme(schemes []string) string {
   322  	if v := r.selectScheme(r.schemes); v != "" {
   323  		return v
   324  	}
   325  	if v := r.selectScheme(schemes); v != "" {
   326  		return v
   327  	}
   328  	return schemeHTTP
   329  }
   330  
   331  func (r *Runtime) selectScheme(schemes []string) string {
   332  	schLen := len(schemes)
   333  	if schLen == 0 {
   334  		return ""
   335  	}
   336  
   337  	scheme := schemes[0]
   338  	// prefer https, but skip when not possible
   339  	if scheme != schemeHTTPS && schLen > 1 {
   340  		for _, sch := range schemes {
   341  			if sch == schemeHTTPS {
   342  				scheme = sch
   343  				break
   344  			}
   345  		}
   346  	}
   347  	return scheme
   348  }
   349  
   350  func transportOrDefault(left, right http.RoundTripper) http.RoundTripper {
   351  	if left == nil {
   352  		return right
   353  	}
   354  	return left
   355  }
   356  
   357  // EnableConnectionReuse drains the remaining body from a response
   358  // so that go will reuse the TCP connections.
   359  //
   360  // This is not enabled by default because there are servers where
   361  // the response never gets closed and that would make the code hang forever.
   362  // So instead it's provided as a http client middleware that can be used to override
   363  // any request.
   364  func (r *Runtime) EnableConnectionReuse() {
   365  	if r.client == nil {
   366  		r.Transport = KeepAliveTransport(
   367  			transportOrDefault(r.Transport, http.DefaultTransport),
   368  		)
   369  		return
   370  	}
   371  
   372  	r.client.Transport = KeepAliveTransport(
   373  		transportOrDefault(r.client.Transport,
   374  			transportOrDefault(r.Transport, http.DefaultTransport),
   375  		),
   376  	)
   377  }
   378  
   379  // takes a client operation and creates equivalent http.Request
   380  func (r *Runtime) createHttpRequest(operation *runtime.ClientOperation) (*request, *http.Request, error) { //nolint:revive,stylecheck
   381  	params, _, auth := operation.Params, operation.Reader, operation.AuthInfo
   382  
   383  	request := newRequest(operation.Method, operation.PathPattern, params)
   384  
   385  	var accept []string
   386  	accept = append(accept, operation.ProducesMediaTypes...)
   387  	if err := request.SetHeaderParam(runtime.HeaderAccept, accept...); err != nil {
   388  		return nil, nil, err
   389  	}
   390  
   391  	if auth == nil && r.DefaultAuthentication != nil {
   392  		auth = runtime.ClientAuthInfoWriterFunc(func(req runtime.ClientRequest, reg strfmt.Registry) error {
   393  			if req.GetHeaderParams().Get(runtime.HeaderAuthorization) != "" {
   394  				return nil
   395  			}
   396  			return r.DefaultAuthentication.AuthenticateRequest(req, reg)
   397  		})
   398  	}
   399  	// if auth != nil {
   400  	//	if err := auth.AuthenticateRequest(request, r.Formats); err != nil {
   401  	//		return nil, err
   402  	//	}
   403  	//}
   404  
   405  	// TODO: pick appropriate media type
   406  	cmt := r.DefaultMediaType
   407  	for _, mediaType := range operation.ConsumesMediaTypes {
   408  		// Pick first non-empty media type
   409  		if mediaType != "" {
   410  			cmt = mediaType
   411  			break
   412  		}
   413  	}
   414  
   415  	if _, ok := r.Producers[cmt]; !ok && cmt != runtime.MultipartFormMime && cmt != runtime.URLencodedFormMime {
   416  		return nil, nil, fmt.Errorf("none of producers: %v registered. try %s", r.Producers, cmt)
   417  	}
   418  
   419  	req, err := request.buildHTTP(cmt, r.BasePath, r.Producers, r.Formats, auth)
   420  	if err != nil {
   421  		return nil, nil, err
   422  	}
   423  	req.URL.Scheme = r.pickScheme(operation.Schemes)
   424  	req.URL.Host = r.Host
   425  	req.Host = r.Host
   426  	return request, req, nil
   427  }
   428  
   429  func (r *Runtime) CreateHttpRequest(operation *runtime.ClientOperation) (req *http.Request, err error) { //nolint:revive,stylecheck
   430  	_, req, err = r.createHttpRequest(operation)
   431  	return
   432  }
   433  
   434  // Submit a request and when there is a body on success it will turn that into the result
   435  // all other things are turned into an api error for swagger which retains the status code
   436  func (r *Runtime) Submit(operation *runtime.ClientOperation) (interface{}, error) {
   437  	_, readResponse, _ := operation.Params, operation.Reader, operation.AuthInfo
   438  
   439  	request, req, err := r.createHttpRequest(operation)
   440  	if err != nil {
   441  		return nil, err
   442  	}
   443  
   444  	r.clientOnce.Do(func() {
   445  		r.client = &http.Client{
   446  			Transport: r.Transport,
   447  			Jar:       r.Jar,
   448  		}
   449  	})
   450  
   451  	if r.Debug {
   452  		b, err2 := httputil.DumpRequestOut(req, true)
   453  		if err2 != nil {
   454  			return nil, err2
   455  		}
   456  		r.logger.Debugf("%s\n", string(b))
   457  	}
   458  
   459  	var parentCtx context.Context
   460  	switch {
   461  	case operation.Context != nil:
   462  		parentCtx = operation.Context
   463  	case r.Context != nil:
   464  		parentCtx = r.Context
   465  	default:
   466  		parentCtx = context.Background()
   467  	}
   468  
   469  	var (
   470  		ctx    context.Context
   471  		cancel context.CancelFunc
   472  	)
   473  	if request.timeout == 0 {
   474  		// There may be a deadline in the context passed to the operation.
   475  		// Otherwise, there is no timeout set.
   476  		ctx, cancel = context.WithCancel(parentCtx)
   477  	} else {
   478  		// Sets the timeout passed from request params (by default runtime.DefaultTimeout).
   479  		// If there is already a deadline in the parent context, the shortest will
   480  		// apply.
   481  		ctx, cancel = context.WithTimeout(parentCtx, request.timeout)
   482  	}
   483  	defer cancel()
   484  
   485  	var client *http.Client
   486  	if operation.Client != nil {
   487  		client = operation.Client
   488  	} else {
   489  		client = r.client
   490  	}
   491  	req = req.WithContext(ctx)
   492  	res, err := client.Do(req) // make requests, by default follows 10 redirects before failing
   493  	if err != nil {
   494  		return nil, err
   495  	}
   496  	defer res.Body.Close()
   497  
   498  	ct := res.Header.Get(runtime.HeaderContentType)
   499  	if ct == "" { // this should really never occur
   500  		ct = r.DefaultMediaType
   501  	}
   502  
   503  	if r.Debug {
   504  		printBody := true
   505  		if ct == runtime.DefaultMime {
   506  			printBody = false // Spare the terminal from a binary blob.
   507  		}
   508  		b, err2 := httputil.DumpResponse(res, printBody)
   509  		if err2 != nil {
   510  			return nil, err2
   511  		}
   512  		r.logger.Debugf("%s\n", string(b))
   513  	}
   514  
   515  	mt, _, err := mime.ParseMediaType(ct)
   516  	if err != nil {
   517  		return nil, fmt.Errorf("parse content type: %s", err)
   518  	}
   519  
   520  	cons, ok := r.Consumers[mt]
   521  	if !ok {
   522  		if cons, ok = r.Consumers["*/*"]; !ok {
   523  			// scream about not knowing what to do
   524  			return nil, fmt.Errorf("no consumer: %q", ct)
   525  		}
   526  	}
   527  	return readResponse.ReadResponse(r.response(res), cons)
   528  }
   529  
   530  // SetDebug changes the debug flag.
   531  // It ensures that client and middlewares have the set debug level.
   532  func (r *Runtime) SetDebug(debug bool) {
   533  	r.Debug = debug
   534  	middleware.Debug = debug
   535  }
   536  
   537  // SetLogger changes the logger stream.
   538  // It ensures that client and middlewares use the same logger.
   539  func (r *Runtime) SetLogger(logger logger.Logger) {
   540  	r.logger = logger
   541  	middleware.Logger = logger
   542  }
   543  
   544  type ClientResponseFunc = func(*http.Response) runtime.ClientResponse //nolint:revive
   545  
   546  // SetResponseReader changes the response reader implementation.
   547  func (r *Runtime) SetResponseReader(f ClientResponseFunc) {
   548  	if f == nil {
   549  		return
   550  	}
   551  	r.response = f
   552  }
   553  

View as plain text