...

Source file src/github.com/emissary-ingress/emissary/v3/cmd/kat-client/client.go

Documentation: github.com/emissary-ingress/emissary/v3/cmd/kat-client

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"encoding/json"
    11  	"flag"
    12  	"fmt"
    13  	"io"
    14  	"net/http"
    15  	"net/url"
    16  	"os"
    17  	"strconv"
    18  	"strings"
    19  	"syscall"
    20  	"time"
    21  
    22  	"github.com/gorilla/websocket"
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/grpc/credentials/insecure"
    25  	"google.golang.org/grpc/metadata"
    26  	"google.golang.org/grpc/status"
    27  	"google.golang.org/protobuf/proto"
    28  
    29  	"github.com/datawire/dlib/dlog"
    30  	grpc_echo_pb "github.com/emissary-ingress/emissary/v3/pkg/api/kat"
    31  )
    32  
    33  // Should we output GRPCWeb debugging?
    34  var debug_grpc_web bool // We set this value in main()   XXX This is a hack
    35  
    36  // Limit concurrency
    37  
    38  // Semaphore is a counting semaphore that can be used to limit concurrency.
    39  type Semaphore chan bool
    40  
    41  // NewSemaphore returns a new Semaphore with the specified capacity.
    42  func NewSemaphore(n int) Semaphore {
    43  	sem := make(Semaphore, n)
    44  	for i := 0; i < n; i++ {
    45  		sem.Release()
    46  	}
    47  	return sem
    48  }
    49  
    50  // Acquire blocks until a slot/token is available.
    51  func (s Semaphore) Acquire() {
    52  	<-s
    53  }
    54  
    55  // Release returns a slot/token to the pool.
    56  func (s Semaphore) Release() {
    57  	s <- true
    58  }
    59  
    60  // rlimit frobnicates the interplexing beacon. Or maybe it reverses the polarity
    61  // of the neutron flow. I'm not sure. FIXME.
    62  func rlimit(ctx context.Context) {
    63  	var rLimit syscall.Rlimit
    64  	err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
    65  	if err != nil {
    66  		dlog.Println(ctx, "Error getting rlimit:", err)
    67  	} else {
    68  		dlog.Println(ctx, "Initial rlimit:", rLimit)
    69  	}
    70  
    71  	rLimit.Max = 999999
    72  	rLimit.Cur = 999999
    73  	err = syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit)
    74  	if err != nil {
    75  		dlog.Println(ctx, "Error setting rlimit:", err)
    76  	}
    77  
    78  	err = syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
    79  	if err != nil {
    80  		dlog.Println(ctx, "Error getting rlimit:", err)
    81  	} else {
    82  		dlog.Println(ctx, "Final rlimit", rLimit)
    83  	}
    84  }
    85  
    86  // Query and Result manipulation
    87  
    88  // Query represents one kat query as read from the supplied input. It will be
    89  // mutated to include results from that query.
    90  type Query map[string]interface{}
    91  
    92  // CACert returns the "ca_cert" field as a string or returns the empty string.
    93  func (q Query) CACert() string {
    94  	val, ok := q["ca_cert"]
    95  	if ok {
    96  		return val.(string)
    97  	}
    98  	return ""
    99  }
   100  
   101  // ClientCert returns the "client_cert" field as a string or returns the empty string.
   102  func (q Query) ClientCert() string {
   103  	val, ok := q["client_cert"]
   104  	if ok {
   105  		return val.(string)
   106  	}
   107  	return ""
   108  }
   109  
   110  // ClientKey returns the "client_key" field as a string or returns the empty string.
   111  func (q Query) ClientKey() string {
   112  	val, ok := q["client_key"]
   113  	if ok {
   114  		return val.(string)
   115  	}
   116  	return ""
   117  }
   118  
   119  // Insecure returns whether the query has a field called "insecure" whose value is true.
   120  func (q Query) Insecure() bool {
   121  	val, ok := q["insecure"]
   122  	return ok && val.(bool)
   123  }
   124  
   125  // SNI returns whether the query has a field called "sni" whose value is true.
   126  func (q Query) SNI() bool {
   127  	val, ok := q["sni"]
   128  	return ok && val.(bool)
   129  }
   130  
   131  // IsWebsocket returns whether the query's URL starts with "ws:".
   132  func (q Query) IsWebsocket() bool {
   133  	return strings.HasPrefix(q.URL(), "ws:")
   134  }
   135  
   136  // URL returns the query's URL.
   137  func (q Query) URL() string {
   138  	return q["url"].(string)
   139  }
   140  
   141  // MinTLSVersion returns the minimun TLS protocol version.
   142  func (q Query) MinTLSVersion() uint16 {
   143  	switch q["minTLSv"].(string) {
   144  	case "v1.0":
   145  		return tls.VersionTLS10
   146  	case "v1.1":
   147  		return tls.VersionTLS11
   148  	case "v1.2":
   149  		return tls.VersionTLS12
   150  	case "v1.3":
   151  		return tls.VersionTLS13
   152  	default:
   153  		return 0
   154  	}
   155  }
   156  
   157  // MaxTLSVersion returns the maximum TLS protocol version.
   158  func (q Query) MaxTLSVersion() uint16 {
   159  	switch q["maxTLSv"].(string) {
   160  	case "v1.0":
   161  		return tls.VersionTLS10
   162  	case "v1.1":
   163  		return tls.VersionTLS11
   164  	case "v1.2":
   165  		return tls.VersionTLS12
   166  	case "v1.3":
   167  		return tls.VersionTLS13
   168  	default:
   169  		return 0
   170  	}
   171  }
   172  
   173  // CipherSuites returns the list of configured Cipher Suites
   174  func (q Query) CipherSuites() []uint16 {
   175  	val, ok := q["cipherSuites"]
   176  	if !ok {
   177  		return []uint16{}
   178  	}
   179  	cs := []uint16{}
   180  	for _, s := range val.([]interface{}) {
   181  		switch s.(string) {
   182  		// TLS 1.0 - 1.2 cipher suites.
   183  		case "TLS_RSA_WITH_RC4_128_SHA":
   184  			cs = append(cs, tls.TLS_RSA_WITH_RC4_128_SHA)
   185  		case "TLS_RSA_WITH_3DES_EDE_CBC_SHA":
   186  			cs = append(cs, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA)
   187  		case "TLS_RSA_WITH_AES_128_CBC_SHA":
   188  			cs = append(cs, tls.TLS_RSA_WITH_AES_128_CBC_SHA)
   189  		case "TLS_RSA_WITH_AES_256_CBC_SHA":
   190  			cs = append(cs, tls.TLS_RSA_WITH_AES_256_CBC_SHA)
   191  		case "TLS_RSA_WITH_AES_128_CBC_SHA256":
   192  			cs = append(cs, tls.TLS_RSA_WITH_AES_128_CBC_SHA256)
   193  		case "TLS_RSA_WITH_AES_128_GCM_SHA256":
   194  			cs = append(cs, tls.TLS_RSA_WITH_AES_128_GCM_SHA256)
   195  		case "TLS_RSA_WITH_AES_256_GCM_SHA384":
   196  			cs = append(cs, tls.TLS_RSA_WITH_AES_256_GCM_SHA384)
   197  		case "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA":
   198  			cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA)
   199  		case "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA":
   200  			cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA)
   201  		case "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA":
   202  			cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA)
   203  		case "TLS_ECDHE_RSA_WITH_RC4_128_SHA":
   204  			cs = append(cs, tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA)
   205  		case "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA":
   206  			cs = append(cs, tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA)
   207  		case "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA":
   208  			cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA)
   209  		case "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA":
   210  			cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA)
   211  		case "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256":
   212  			cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256)
   213  		case "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256":
   214  			cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256)
   215  		case "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256":
   216  			cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)
   217  		case "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256":
   218  			cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)
   219  		case "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384":
   220  			cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384)
   221  		case "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384":
   222  			cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384)
   223  		case "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305":
   224  			cs = append(cs, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305)
   225  		case "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305":
   226  			cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305)
   227  
   228  		// TLS 1.3 cipher suites are not tunable
   229  		// TLS_RSA_WITH_RC4_128_SHA
   230  		// TLS_ECDHE_RSA_WITH_RC4_128_SHA
   231  		// TLS_ECDHE_ECDSA_WITH_RC4_128_SHA
   232  
   233  		// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
   234  		// that the client is doing version fallback. See RFC 7507.
   235  		case "TLS_FALLBACK_SCSV":
   236  			cs = append(cs, tls.TLS_FALLBACK_SCSV)
   237  		default:
   238  		}
   239  	}
   240  	return cs
   241  }
   242  
   243  // ECDHCurves returns the list of configured ECDH CurveIDs
   244  func (q Query) ECDHCurves() []tls.CurveID {
   245  	val, ok := q["ecdhCurves"]
   246  	if !ok {
   247  		return []tls.CurveID{}
   248  	}
   249  	cs := []tls.CurveID{}
   250  	for _, s := range val.([]interface{}) {
   251  		switch s.(string) {
   252  		// TLS 1.0 - 1.2 cipher suites.
   253  		case "CurveP256":
   254  			cs = append(cs, tls.CurveP256)
   255  		case "CurveP384":
   256  			cs = append(cs, tls.CurveP384)
   257  		case "CurveP521":
   258  			cs = append(cs, tls.CurveP521)
   259  		case "X25519":
   260  			cs = append(cs, tls.X25519)
   261  		default:
   262  		}
   263  	}
   264  	return cs
   265  }
   266  
   267  // Method returns the query's method or "GET" if unspecified.
   268  func (q Query) Method() string {
   269  	val, ok := q["method"]
   270  	if ok {
   271  		return val.(string)
   272  	}
   273  	return "GET"
   274  }
   275  
   276  // Headers returns the an http.Header object populated with any headers passed
   277  // in as part of the query.
   278  func (q Query) Headers() (result http.Header) {
   279  	result = make(http.Header)
   280  	headers, ok := q["headers"]
   281  	if ok {
   282  		for key, val := range headers.(map[string]interface{}) {
   283  			result.Add(key, val.(string))
   284  		}
   285  	}
   286  	return result
   287  }
   288  
   289  // Body returns an io.Reader for the base64 encoded body supplied in
   290  // the query.
   291  func (q Query) Body() io.Reader {
   292  	body, ok := q["body"]
   293  	if ok {
   294  		buf, err := base64.StdEncoding.DecodeString(body.(string))
   295  		if err != nil {
   296  			panic(err)
   297  		}
   298  		return bytes.NewReader(buf)
   299  	} else {
   300  		return nil
   301  	}
   302  }
   303  
   304  // GrpcType returns the query's grpc_type field or the empty string.
   305  func (q Query) GrpcType() string {
   306  	val, ok := q["grpc_type"]
   307  	if ok {
   308  		return val.(string)
   309  	}
   310  	return ""
   311  }
   312  
   313  // Cookies returns a slice of http.Cookie objects populated with any cookies
   314  // passed in as part of the query.
   315  func (q Query) Cookies() (result []http.Cookie) {
   316  	result = []http.Cookie{}
   317  	cookies, ok := q["cookies"]
   318  	if ok {
   319  		for _, c := range cookies.([]interface{}) {
   320  			cookie := http.Cookie{
   321  				Name:  c.(map[string]interface{})["name"].(string),
   322  				Value: c.(map[string]interface{})["value"].(string),
   323  			}
   324  			result = append(result, cookie)
   325  		}
   326  	}
   327  	return result
   328  }
   329  
   330  // Result represents the result of one kat query. Upon first access to a query's
   331  // result field, the Result object will be created and added to the query.
   332  type Result map[string]interface{}
   333  
   334  // Result returns the query's result field as a Result object. If the field
   335  // doesn't exist, a new Result object is created and placed in that field. If
   336  // the field exists and contains something else, panic!
   337  func (q Query) Result() Result {
   338  	val, ok := q["result"]
   339  	if !ok {
   340  		val = make(Result)
   341  		q["result"] = val
   342  	}
   343  	return val.(Result)
   344  }
   345  
   346  // CheckErr populates the query result with error information if an error is
   347  // passed in (and logs the error).
   348  func (q Query) CheckErr(ctx context.Context, err error) bool {
   349  	if err != nil {
   350  		dlog.Printf(ctx, "%v: %v", q.URL(), err)
   351  		q.Result()["error"] = err.Error()
   352  		return true
   353  	}
   354  	return false
   355  }
   356  
   357  // DecodeGrpcWebTextBody treats the body as a series of base64-encode chunks. It
   358  // returns the decoded proto and trailers.
   359  func DecodeGrpcWebTextBody(ctx context.Context, body []byte) ([]byte, http.Header, error) {
   360  	// First, decode all the base64 stuff coming in. An annoyance here
   361  	// is that while the data coming over the wire are encoded in
   362  	// multiple chunks, we can't rely on seeing that framing when
   363  	// decoding: a chunk that's the right length to not need any base-64
   364  	// padding will just run into the next chunk.
   365  	//
   366  	// So we loop to grab all the chunks, but we just serialize it into
   367  	// a single raw byte array.
   368  
   369  	var raw []byte
   370  
   371  	cycle := 0
   372  
   373  	for {
   374  		if debug_grpc_web {
   375  			dlog.Printf(ctx, "%v: base64 body '%v'", cycle, body)
   376  		}
   377  
   378  		cycle++
   379  
   380  		if len(body) <= 0 {
   381  			break
   382  		}
   383  
   384  		chunk := make([]byte, base64.StdEncoding.DecodedLen(len(body)))
   385  		n, err := base64.StdEncoding.Decode(chunk, body)
   386  
   387  		if err != nil && n <= 0 {
   388  			dlog.Printf(ctx, "Failed to process body: %v\n", err)
   389  			return nil, nil, err
   390  		}
   391  
   392  		raw = append(raw, chunk[:n]...)
   393  
   394  		consumed := base64.StdEncoding.EncodedLen(n)
   395  
   396  		body = body[consumed:]
   397  	}
   398  
   399  	// Next up, we need to split this into protobuf data and trailers. We
   400  	// do this using grpc-web framing information for this -- each frame
   401  	// consists of one byte of type, four bytes of length, then the data
   402  	// itself.
   403  	//
   404  	// For our use case here, a type of 0 is the protobuf frame, and a type
   405  	// of 0x80 is the trailers.
   406  
   407  	trailers := make(http.Header) // the trailers will get saved here
   408  	var proto []byte              // this is what we hand off to protobuf decode
   409  
   410  	var frame_start, frame_len uint32
   411  	var frame_type byte
   412  	var frame []byte
   413  
   414  	frame_start = 0
   415  
   416  	if debug_grpc_web {
   417  		dlog.Printf(ctx, "starting frame split, len %v: %v", len(raw), raw)
   418  	}
   419  
   420  	for (frame_start + 5) < uint32(len(raw)) {
   421  		frame_type = raw[frame_start]
   422  		frame_len = binary.BigEndian.Uint32(raw[frame_start+1 : frame_start+5])
   423  
   424  		frame = raw[frame_start+5 : frame_start+5+frame_len]
   425  
   426  		if (frame_type & 128) > 0 {
   427  			// Trailers frame
   428  			if debug_grpc_web {
   429  				dlog.Printf(ctx, "  trailers @%v (len %v, type %v) %v - %v", frame_start, frame_len, frame_type, len(frame), frame)
   430  			}
   431  
   432  			lines := strings.Split(string(frame), "\n")
   433  
   434  			for _, line := range lines {
   435  				split := strings.SplitN(strings.TrimSpace(line), ":", 2)
   436  				if len(split) == 2 {
   437  					key := strings.TrimSpace(split[0])
   438  					value := strings.TrimSpace(split[1])
   439  					trailers.Add(key, value)
   440  				}
   441  			}
   442  		} else {
   443  			// Protobuf frame
   444  			if debug_grpc_web {
   445  				dlog.Printf(ctx, "  protobuf @%v (len %v, type %v) %v - %v", frame_start, frame_len, frame_type, len(frame), frame)
   446  			}
   447  
   448  			proto = frame
   449  		}
   450  
   451  		frame_start += frame_len + 5
   452  	}
   453  
   454  	return proto, trailers, nil
   455  }
   456  
   457  // AddResponse populates a query's result with data from the query's HTTP
   458  // response object.
   459  //
   460  // This is not called for websockets or real GRPC. It _is_ called for
   461  // GRPC-bridge, GRPC-web, and (of course) HTTP(s).
   462  func (q Query) AddResponse(ctx context.Context, resp *http.Response) {
   463  	result := q.Result()
   464  	result["status"] = resp.StatusCode
   465  	result["headers"] = resp.Header
   466  
   467  	headers := result["headers"].(http.Header)
   468  
   469  	if headers != nil {
   470  		// Copy in the client's start date.
   471  		cstart := q["client-start-date"]
   472  
   473  		// We'll only have a client-start-date if we're doing plain old HTTP, at
   474  		// present -- so not for WebSockets or gRPC or the like. Don't try to
   475  		// save the start and end dates if we have no start date.
   476  		if cstart != nil {
   477  			headers.Add("Client-Start-Date", q["client-start-date"].(string))
   478  
   479  			// Add the client's end date.
   480  			headers.Add("Client-End-Date", time.Now().Format(time.RFC3339Nano))
   481  		}
   482  	}
   483  
   484  	if resp.TLS != nil {
   485  		result["tls_version"] = resp.TLS.Version
   486  		result["tls"] = resp.TLS.PeerCertificates
   487  		result["cipher_suite"] = resp.TLS.CipherSuite
   488  	}
   489  	body, err := io.ReadAll(resp.Body)
   490  	if !q.CheckErr(ctx, err) {
   491  		dlog.Printf(ctx, "%v: %v", q.URL(), resp.Status)
   492  		result["body"] = body
   493  		if q.GrpcType() != "" && len(body) > 5 {
   494  			if q.GrpcType() == "web" {
   495  				// This is the GRPC-web case. Go forth and decode the base64'd
   496  				// GRPC-web body madness.
   497  				decodedBody, trailers, err := DecodeGrpcWebTextBody(ctx, body)
   498  				if q.CheckErr(ctx, err) {
   499  					dlog.Printf(ctx, "Failed to decode grpc-web-text body: %v", err)
   500  					return
   501  				}
   502  				body = decodedBody
   503  
   504  				if debug_grpc_web {
   505  					dlog.Printf(ctx, "decodedBody '%v'", body)
   506  				}
   507  
   508  				for key, values := range trailers {
   509  					for _, value := range values {
   510  						headers.Add(key, value)
   511  					}
   512  				}
   513  
   514  			} else {
   515  				// This is the GRPC-bridge case -- throw away the five-byte type/length
   516  				// framing at the start, and just leave the protobuf.
   517  				body = body[5:]
   518  			}
   519  
   520  			response := &grpc_echo_pb.EchoResponse{}
   521  			err := proto.Unmarshal(body, response)
   522  			if q.CheckErr(ctx, err) {
   523  				dlog.Printf(ctx, "Failed to unmarshal proto: %v", err)
   524  				return
   525  			}
   526  			result["text"] = response // q.r.json needs a different format
   527  			return
   528  		}
   529  		var jsonBody interface{}
   530  		err = json.Unmarshal(body, &jsonBody)
   531  		if err == nil {
   532  			result["json"] = jsonBody
   533  		} else {
   534  			result["text"] = string(body)
   535  		}
   536  	}
   537  }
   538  
   539  // Request processing
   540  
   541  // ExecuteWebsocketQuery handles Websocket queries
   542  func ExecuteWebsocketQuery(ctx context.Context, query Query) {
   543  	url := query.URL()
   544  	c, resp, err := websocket.DefaultDialer.Dial(url, query.Headers())
   545  	if query.CheckErr(ctx, err) {
   546  		return
   547  	}
   548  	defer c.Close()
   549  	query.AddResponse(ctx, resp)
   550  	messages := query["messages"].([]interface{})
   551  	for _, msg := range messages {
   552  		err = c.WriteMessage(websocket.TextMessage, []byte(msg.(string)))
   553  		if query.CheckErr(ctx, err) {
   554  			return
   555  		}
   556  	}
   557  
   558  	err = c.WriteMessage(websocket.CloseMessage,
   559  		websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
   560  	if query.CheckErr(ctx, err) {
   561  		return
   562  	}
   563  
   564  	answers := []string{}
   565  
   566  	result := query.Result()
   567  	defer func() {
   568  		result["messages"] = answers
   569  	}()
   570  
   571  	for {
   572  		_, message, err := c.ReadMessage()
   573  		if err != nil {
   574  			if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
   575  				query.CheckErr(ctx, err)
   576  			}
   577  			return
   578  		}
   579  		answers = append(answers, string(message))
   580  	}
   581  }
   582  
   583  // GetGRPCReqBody returns the body of the HTTP request using the
   584  // HTTP/1.1-gRPC bridge format as described in the Envoy docs
   585  // https://www.envoyproxy.io/docs/envoy/v1.9.0/configuration/http_filters/grpc_http1_bridge_filter
   586  func GetGRPCReqBody(ctx context.Context) (*bytes.Buffer, error) {
   587  	// Protocol:
   588  	// 	. 1 byte of zero (not compressed).
   589  	// 	. network order (big-endian) of proto message length.
   590  	// 	. serialized proto message.
   591  	buf := &bytes.Buffer{}
   592  	if err := binary.Write(buf, binary.BigEndian, uint8(0)); err != nil {
   593  		dlog.Printf(ctx, "error when packing first byte: %v", err)
   594  		return nil, err
   595  	}
   596  
   597  	m := &grpc_echo_pb.EchoRequest{}
   598  	m.Data = "foo"
   599  
   600  	bs, err := proto.Marshal(m)
   601  	if err != nil {
   602  		dlog.Printf(ctx, "error when serializing the gRPC message: %v", err)
   603  		return nil, err
   604  	}
   605  
   606  	if err := binary.Write(buf, binary.BigEndian, uint32(len(bs))); err != nil {
   607  		dlog.Printf(ctx, "error when packing message length: %v", err)
   608  		return nil, err
   609  	}
   610  
   611  	for i := 0; i < len(bs); i++ {
   612  		if err := binary.Write(buf, binary.BigEndian, bs[i]); err != nil {
   613  			dlog.Printf(ctx, "error when packing message: %v", err)
   614  			return nil, err
   615  		}
   616  	}
   617  
   618  	return buf, nil
   619  }
   620  
   621  // CallRealGRPC handles real gRPC queries, i.e. queries that use the normal gRPC
   622  // generated code and the normal HTTP/2-based transport.
   623  func CallRealGRPC(ctx context.Context, query Query) {
   624  	qURL, err := url.Parse(query.URL())
   625  	if query.CheckErr(ctx, err) {
   626  		dlog.Printf(ctx, "grpc url parse failed: %v", err)
   627  		return
   628  	}
   629  
   630  	const requiredPath = "/echo.EchoService/Echo"
   631  	if qURL.Path != requiredPath {
   632  		query.Result()["error"] = fmt.Sprintf("GRPC path %s is not %s", qURL.Path, requiredPath)
   633  		return
   634  	}
   635  
   636  	dialHost := qURL.Host
   637  	if !strings.Contains(dialHost, ":") {
   638  		// There is no port number in the URL, but grpc.Dial wants host:port.
   639  		if qURL.Scheme == "https" {
   640  			dialHost = dialHost + ":443"
   641  		} else {
   642  			dialHost = dialHost + ":80"
   643  		}
   644  	}
   645  
   646  	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
   647  	defer cancel()
   648  
   649  	// Dial runs in the background and thus always appears to succeed. If you
   650  	// pass grpc.WithBlock() to make it wait for a connection, failures just hit
   651  	// the deadline rather than returning a useful error like "no such host" or
   652  	// "connection refused" or whatever. Perhaps they are considered "transient"
   653  	// and there's some retry logic we need to turn off. Anyhow, we don't pass
   654  	// grpc.WithBlock(), instead letting the error happen at the request below.
   655  	// This makes useful error messages visible in most cases.
   656  	var dialOptions []grpc.DialOption
   657  	if qURL.Scheme != "https" {
   658  		dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
   659  	}
   660  	conn, err := grpc.DialContext(ctx, dialHost, dialOptions...)
   661  	if query.CheckErr(ctx, err) {
   662  		dlog.Printf(ctx, "grpc dial failed: %v", err)
   663  		return
   664  	}
   665  	defer conn.Close()
   666  
   667  	client := grpc_echo_pb.NewEchoServiceClient(conn)
   668  	request := &grpc_echo_pb.EchoRequest{Data: "real gRPC"}
   669  
   670  	// Prepare outgoing headers, which are passed via Context
   671  	md := metadata.MD{}
   672  	headers, ok := query["headers"]
   673  	if ok {
   674  		for key, val := range headers.(map[string]interface{}) {
   675  			md.Set(key, val.(string))
   676  		}
   677  	}
   678  	ctx = metadata.NewOutgoingContext(ctx, md)
   679  
   680  	response, err := client.Echo(ctx, request)
   681  	stat, ok := status.FromError(err)
   682  	if !ok { // err is not nil and not a grpc Status
   683  		query.CheckErr(ctx, err)
   684  		dlog.Printf(ctx, "grpc echo request failed: %v", err)
   685  		return
   686  	}
   687  
   688  	// It's hard to tell the difference between a failed connection and a
   689  	// successful connection that set an error code. We'll use the
   690  	// heuristic that DNS errors and Connection Refused both appear to
   691  	// return code 14 (Code.Unavailable).
   692  	grpcCode := int(stat.Code())
   693  	if grpcCode == 14 {
   694  		query.CheckErr(ctx, err)
   695  		dlog.Printf(ctx, "grpc echo request connection failed: %v", err)
   696  		return
   697  	}
   698  
   699  	// Now process the response and synthesize the requisite result values.
   700  	// Note: Don't set result.body to anything that cannot be decoded as base64,
   701  	// or the kat harness will fail.
   702  	resHeader := make(http.Header)
   703  	resHeader.Add("Grpc-Status", fmt.Sprint(grpcCode))
   704  	resHeader.Add("Grpc-Message", stat.Message())
   705  
   706  	result := query.Result()
   707  	result["headers"] = resHeader
   708  	result["body"] = ""
   709  	result["status"] = 200
   710  	if err == nil {
   711  		result["text"] = response // q.r.json needs a different format
   712  	}
   713  
   714  	// Stuff that's not available:
   715  	// - query.result.status (the HTTP status -- synthesized as 200)
   716  	// - query.result.headers (the HTTP response headers -- we're just putting
   717  	//   in grpc-status and grpc-message as the former is required by the
   718  	//   tests and the latter can be handy)
   719  	// - query.result.body (the raw HTTP body)
   720  	// - query.result.json or query.result.text (the parsed HTTP body -- we're
   721  	//   emitting the full EchoResponse object in the text field)
   722  }
   723  
   724  // ExecuteQuery constructs the appropriate request, executes it, and records the
   725  // response and related information in query.result.
   726  func ExecuteQuery(ctx context.Context, query Query) error {
   727  	// Websocket stuff is handled elsewhere
   728  	if query.IsWebsocket() {
   729  		ExecuteWebsocketQuery(ctx, query)
   730  		return nil
   731  	}
   732  
   733  	// Real gRPC is handled elsewhere
   734  	if query.GrpcType() == "real" {
   735  		CallRealGRPC(ctx, query)
   736  		return nil
   737  	}
   738  
   739  	// Prepare an http.Transport with customized TLS settings.
   740  	transport := &http.Transport{
   741  		MaxIdleConns:    10,
   742  		IdleConnTimeout: 30 * time.Second,
   743  		TLSClientConfig: &tls.Config{},
   744  	}
   745  	if query.Insecure() {
   746  		transport.TLSClientConfig.InsecureSkipVerify = true
   747  	}
   748  	if caCert := query.CACert(); len(caCert) > 0 {
   749  		caCertPool := x509.NewCertPool()
   750  		caCertPool.AppendCertsFromPEM([]byte(caCert))
   751  		transport.TLSClientConfig.RootCAs = caCertPool
   752  	}
   753  	if query.ClientCert() != "" || query.ClientKey() != "" {
   754  		clientCert, err := tls.X509KeyPair([]byte(query.ClientCert()), []byte(query.ClientKey()))
   755  		if err != nil {
   756  			dlog.Error(ctx, err)
   757  			return err
   758  		}
   759  		transport.TLSClientConfig.Certificates = []tls.Certificate{clientCert}
   760  	}
   761  	if query.MinTLSVersion() != 0 {
   762  		transport.TLSClientConfig.MinVersion = query.MinTLSVersion()
   763  	}
   764  	if query.MaxTLSVersion() != 0 {
   765  		transport.TLSClientConfig.MaxVersion = query.MaxTLSVersion()
   766  	}
   767  	if len(query.CipherSuites()) > 0 {
   768  		transport.TLSClientConfig.CipherSuites = query.CipherSuites()
   769  	}
   770  	if len(query.ECDHCurves()) > 0 {
   771  		transport.TLSClientConfig.CurvePreferences = query.ECDHCurves()
   772  	}
   773  
   774  	// Prepare the HTTP request
   775  	var body io.Reader
   776  	method := query.Method()
   777  	if query.GrpcType() != "" {
   778  		// Perform special handling for gRPC-bridge and gRPC-web
   779  		buf, err := GetGRPCReqBody(ctx)
   780  		if query.CheckErr(ctx, err) {
   781  			dlog.Printf(ctx, "gRPC buffer error: %v", err)
   782  			return nil
   783  		}
   784  		if query.GrpcType() == "web" {
   785  			result := make([]byte, base64.StdEncoding.EncodedLen(buf.Len()))
   786  			base64.StdEncoding.Encode(result, buf.Bytes())
   787  			buf = bytes.NewBuffer(result)
   788  		}
   789  		body = buf
   790  		method = "POST"
   791  	} else {
   792  		body = query.Body()
   793  	}
   794  	req, err := http.NewRequest(method, query.URL(), body)
   795  	if query.CheckErr(ctx, err) {
   796  		dlog.Printf(ctx, "request error: %v", err)
   797  		return nil
   798  	}
   799  	req.Header = query.Headers()
   800  	for _, cookie := range query.Cookies() {
   801  		req.AddCookie(&cookie)
   802  	}
   803  
   804  	// Save the client's start date.
   805  	query["client-start-date"] = time.Now().Format(time.RFC3339Nano)
   806  
   807  	// Handle host and SNI
   808  	host := req.Header.Get("Host")
   809  	if host != "" {
   810  		if query.SNI() {
   811  			transport.TLSClientConfig.ServerName = host
   812  		}
   813  		req.Host = host
   814  	}
   815  
   816  	// Perform the request and save the results.
   817  	client := &http.Client{
   818  		Transport: transport,
   819  		Timeout:   time.Duration(10 * time.Second),
   820  		CheckRedirect: func(req *http.Request, via []*http.Request) error {
   821  			return http.ErrUseLastResponse
   822  		},
   823  	}
   824  	resp, err := client.Do(req)
   825  	if query.CheckErr(ctx, err) {
   826  		return nil
   827  	}
   828  	query.AddResponse(ctx, resp)
   829  	return nil
   830  }
   831  
   832  type Args struct {
   833  	input  string
   834  	output string
   835  }
   836  
   837  func parseArgs(rawArgs ...string) (Args, error) {
   838  	var args Args
   839  	flagset := flag.NewFlagSet("kat-client", flag.ContinueOnError)
   840  	flagset.StringVar(&args.input, "input", "", "input filename")
   841  	flagset.StringVar(&args.output, "output", "", "output filename")
   842  	err := flagset.Parse(rawArgs)
   843  	return args, err
   844  }
   845  
   846  func main() {
   847  	ctx := context.Background() // first line in main()
   848  	debug_grpc_web = false
   849  
   850  	rlimit(ctx)
   851  
   852  	args, err := parseArgs(os.Args[1:]...)
   853  	if err != nil {
   854  		panic(err)
   855  	}
   856  
   857  	var data []byte
   858  
   859  	// Read input file
   860  	if args.input == "" {
   861  		dlog.Printf(ctx, "processing queries from stdin")
   862  		data, err = io.ReadAll(os.Stdin)
   863  	} else {
   864  
   865  		data, err = os.ReadFile(args.input)
   866  	}
   867  	if err != nil {
   868  		panic(err)
   869  	}
   870  
   871  	// Parse input file
   872  	var specs []Query
   873  	err = json.Unmarshal(data, &specs)
   874  	if err != nil {
   875  		panic(err)
   876  	}
   877  
   878  	// Prep semaphore to limit concurrency
   879  	limitStr := os.Getenv("KAT_QUERY_LIMIT")
   880  	limit, err := strconv.Atoi(limitStr)
   881  	if err != nil {
   882  		limit = 25
   883  	}
   884  	sem := NewSemaphore(limit)
   885  
   886  	// Launch queries concurrently
   887  	count := len(specs)
   888  	queries := make(chan bool)
   889  	for i := 0; i < count; i++ {
   890  		go func(idx int) {
   891  			sem.Acquire()
   892  			defer func() {
   893  				queries <- true
   894  				sem.Release()
   895  			}()
   896  			if err := ExecuteQuery(ctx, specs[idx]); err != nil {
   897  				dlog.Errorf(ctx, "an error occurred executing query %d, kat-client will panic: %s", idx, err.Error())
   898  				panic(err) // TODO: do something better
   899  			}
   900  		}(i)
   901  	}
   902  
   903  	// Wait for all the answers
   904  	for i := 0; i < count; i++ {
   905  		<-queries
   906  	}
   907  
   908  	// Generate the output file
   909  	bytes, err := json.MarshalIndent(specs, "", "  ")
   910  	if err != nil {
   911  		dlog.Print(ctx, err)
   912  	} else if args.output == "" {
   913  		dlog.Printf(ctx, "writing results to stdout")
   914  		fmt.Print(string(bytes))
   915  	} else {
   916  		dlog.Printf(ctx, "writing results to output file: %s", args.output)
   917  		err = os.WriteFile(args.output, bytes, 0644)
   918  		if err != nil {
   919  			dlog.Print(ctx, err)
   920  		}
   921  	}
   922  }
   923  

View as plain text