1 package main
2
3 import (
4 "bytes"
5 "context"
6 "crypto/tls"
7 "crypto/x509"
8 "encoding/base64"
9 "encoding/binary"
10 "encoding/json"
11 "flag"
12 "fmt"
13 "io"
14 "net/http"
15 "net/url"
16 "os"
17 "strconv"
18 "strings"
19 "syscall"
20 "time"
21
22 "github.com/gorilla/websocket"
23 "google.golang.org/grpc"
24 "google.golang.org/grpc/credentials/insecure"
25 "google.golang.org/grpc/metadata"
26 "google.golang.org/grpc/status"
27 "google.golang.org/protobuf/proto"
28
29 "github.com/datawire/dlib/dlog"
30 grpc_echo_pb "github.com/emissary-ingress/emissary/v3/pkg/api/kat"
31 )
32
33
34 var debug_grpc_web bool
35
36
37
38
39 type Semaphore chan bool
40
41
42 func NewSemaphore(n int) Semaphore {
43 sem := make(Semaphore, n)
44 for i := 0; i < n; i++ {
45 sem.Release()
46 }
47 return sem
48 }
49
50
51 func (s Semaphore) Acquire() {
52 <-s
53 }
54
55
56 func (s Semaphore) Release() {
57 s <- true
58 }
59
60
61
62 func rlimit(ctx context.Context) {
63 var rLimit syscall.Rlimit
64 err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
65 if err != nil {
66 dlog.Println(ctx, "Error getting rlimit:", err)
67 } else {
68 dlog.Println(ctx, "Initial rlimit:", rLimit)
69 }
70
71 rLimit.Max = 999999
72 rLimit.Cur = 999999
73 err = syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit)
74 if err != nil {
75 dlog.Println(ctx, "Error setting rlimit:", err)
76 }
77
78 err = syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
79 if err != nil {
80 dlog.Println(ctx, "Error getting rlimit:", err)
81 } else {
82 dlog.Println(ctx, "Final rlimit", rLimit)
83 }
84 }
85
86
87
88
89
90 type Query map[string]interface{}
91
92
93 func (q Query) CACert() string {
94 val, ok := q["ca_cert"]
95 if ok {
96 return val.(string)
97 }
98 return ""
99 }
100
101
102 func (q Query) ClientCert() string {
103 val, ok := q["client_cert"]
104 if ok {
105 return val.(string)
106 }
107 return ""
108 }
109
110
111 func (q Query) ClientKey() string {
112 val, ok := q["client_key"]
113 if ok {
114 return val.(string)
115 }
116 return ""
117 }
118
119
120 func (q Query) Insecure() bool {
121 val, ok := q["insecure"]
122 return ok && val.(bool)
123 }
124
125
126 func (q Query) SNI() bool {
127 val, ok := q["sni"]
128 return ok && val.(bool)
129 }
130
131
132 func (q Query) IsWebsocket() bool {
133 return strings.HasPrefix(q.URL(), "ws:")
134 }
135
136
137 func (q Query) URL() string {
138 return q["url"].(string)
139 }
140
141
142 func (q Query) MinTLSVersion() uint16 {
143 switch q["minTLSv"].(string) {
144 case "v1.0":
145 return tls.VersionTLS10
146 case "v1.1":
147 return tls.VersionTLS11
148 case "v1.2":
149 return tls.VersionTLS12
150 case "v1.3":
151 return tls.VersionTLS13
152 default:
153 return 0
154 }
155 }
156
157
158 func (q Query) MaxTLSVersion() uint16 {
159 switch q["maxTLSv"].(string) {
160 case "v1.0":
161 return tls.VersionTLS10
162 case "v1.1":
163 return tls.VersionTLS11
164 case "v1.2":
165 return tls.VersionTLS12
166 case "v1.3":
167 return tls.VersionTLS13
168 default:
169 return 0
170 }
171 }
172
173
174 func (q Query) CipherSuites() []uint16 {
175 val, ok := q["cipherSuites"]
176 if !ok {
177 return []uint16{}
178 }
179 cs := []uint16{}
180 for _, s := range val.([]interface{}) {
181 switch s.(string) {
182
183 case "TLS_RSA_WITH_RC4_128_SHA":
184 cs = append(cs, tls.TLS_RSA_WITH_RC4_128_SHA)
185 case "TLS_RSA_WITH_3DES_EDE_CBC_SHA":
186 cs = append(cs, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA)
187 case "TLS_RSA_WITH_AES_128_CBC_SHA":
188 cs = append(cs, tls.TLS_RSA_WITH_AES_128_CBC_SHA)
189 case "TLS_RSA_WITH_AES_256_CBC_SHA":
190 cs = append(cs, tls.TLS_RSA_WITH_AES_256_CBC_SHA)
191 case "TLS_RSA_WITH_AES_128_CBC_SHA256":
192 cs = append(cs, tls.TLS_RSA_WITH_AES_128_CBC_SHA256)
193 case "TLS_RSA_WITH_AES_128_GCM_SHA256":
194 cs = append(cs, tls.TLS_RSA_WITH_AES_128_GCM_SHA256)
195 case "TLS_RSA_WITH_AES_256_GCM_SHA384":
196 cs = append(cs, tls.TLS_RSA_WITH_AES_256_GCM_SHA384)
197 case "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA":
198 cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA)
199 case "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA":
200 cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA)
201 case "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA":
202 cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA)
203 case "TLS_ECDHE_RSA_WITH_RC4_128_SHA":
204 cs = append(cs, tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA)
205 case "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA":
206 cs = append(cs, tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA)
207 case "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA":
208 cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA)
209 case "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA":
210 cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA)
211 case "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256":
212 cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256)
213 case "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256":
214 cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256)
215 case "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256":
216 cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)
217 case "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256":
218 cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)
219 case "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384":
220 cs = append(cs, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384)
221 case "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384":
222 cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384)
223 case "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305":
224 cs = append(cs, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305)
225 case "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305":
226 cs = append(cs, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305)
227
228
229
230
231
232
233
234
235 case "TLS_FALLBACK_SCSV":
236 cs = append(cs, tls.TLS_FALLBACK_SCSV)
237 default:
238 }
239 }
240 return cs
241 }
242
243
244 func (q Query) ECDHCurves() []tls.CurveID {
245 val, ok := q["ecdhCurves"]
246 if !ok {
247 return []tls.CurveID{}
248 }
249 cs := []tls.CurveID{}
250 for _, s := range val.([]interface{}) {
251 switch s.(string) {
252
253 case "CurveP256":
254 cs = append(cs, tls.CurveP256)
255 case "CurveP384":
256 cs = append(cs, tls.CurveP384)
257 case "CurveP521":
258 cs = append(cs, tls.CurveP521)
259 case "X25519":
260 cs = append(cs, tls.X25519)
261 default:
262 }
263 }
264 return cs
265 }
266
267
268 func (q Query) Method() string {
269 val, ok := q["method"]
270 if ok {
271 return val.(string)
272 }
273 return "GET"
274 }
275
276
277
278 func (q Query) Headers() (result http.Header) {
279 result = make(http.Header)
280 headers, ok := q["headers"]
281 if ok {
282 for key, val := range headers.(map[string]interface{}) {
283 result.Add(key, val.(string))
284 }
285 }
286 return result
287 }
288
289
290
291 func (q Query) Body() io.Reader {
292 body, ok := q["body"]
293 if ok {
294 buf, err := base64.StdEncoding.DecodeString(body.(string))
295 if err != nil {
296 panic(err)
297 }
298 return bytes.NewReader(buf)
299 } else {
300 return nil
301 }
302 }
303
304
305 func (q Query) GrpcType() string {
306 val, ok := q["grpc_type"]
307 if ok {
308 return val.(string)
309 }
310 return ""
311 }
312
313
314
315 func (q Query) Cookies() (result []http.Cookie) {
316 result = []http.Cookie{}
317 cookies, ok := q["cookies"]
318 if ok {
319 for _, c := range cookies.([]interface{}) {
320 cookie := http.Cookie{
321 Name: c.(map[string]interface{})["name"].(string),
322 Value: c.(map[string]interface{})["value"].(string),
323 }
324 result = append(result, cookie)
325 }
326 }
327 return result
328 }
329
330
331
332 type Result map[string]interface{}
333
334
335
336
337 func (q Query) Result() Result {
338 val, ok := q["result"]
339 if !ok {
340 val = make(Result)
341 q["result"] = val
342 }
343 return val.(Result)
344 }
345
346
347
348 func (q Query) CheckErr(ctx context.Context, err error) bool {
349 if err != nil {
350 dlog.Printf(ctx, "%v: %v", q.URL(), err)
351 q.Result()["error"] = err.Error()
352 return true
353 }
354 return false
355 }
356
357
358
359 func DecodeGrpcWebTextBody(ctx context.Context, body []byte) ([]byte, http.Header, error) {
360
361
362
363
364
365
366
367
368
369 var raw []byte
370
371 cycle := 0
372
373 for {
374 if debug_grpc_web {
375 dlog.Printf(ctx, "%v: base64 body '%v'", cycle, body)
376 }
377
378 cycle++
379
380 if len(body) <= 0 {
381 break
382 }
383
384 chunk := make([]byte, base64.StdEncoding.DecodedLen(len(body)))
385 n, err := base64.StdEncoding.Decode(chunk, body)
386
387 if err != nil && n <= 0 {
388 dlog.Printf(ctx, "Failed to process body: %v\n", err)
389 return nil, nil, err
390 }
391
392 raw = append(raw, chunk[:n]...)
393
394 consumed := base64.StdEncoding.EncodedLen(n)
395
396 body = body[consumed:]
397 }
398
399
400
401
402
403
404
405
406
407 trailers := make(http.Header)
408 var proto []byte
409
410 var frame_start, frame_len uint32
411 var frame_type byte
412 var frame []byte
413
414 frame_start = 0
415
416 if debug_grpc_web {
417 dlog.Printf(ctx, "starting frame split, len %v: %v", len(raw), raw)
418 }
419
420 for (frame_start + 5) < uint32(len(raw)) {
421 frame_type = raw[frame_start]
422 frame_len = binary.BigEndian.Uint32(raw[frame_start+1 : frame_start+5])
423
424 frame = raw[frame_start+5 : frame_start+5+frame_len]
425
426 if (frame_type & 128) > 0 {
427
428 if debug_grpc_web {
429 dlog.Printf(ctx, " trailers @%v (len %v, type %v) %v - %v", frame_start, frame_len, frame_type, len(frame), frame)
430 }
431
432 lines := strings.Split(string(frame), "\n")
433
434 for _, line := range lines {
435 split := strings.SplitN(strings.TrimSpace(line), ":", 2)
436 if len(split) == 2 {
437 key := strings.TrimSpace(split[0])
438 value := strings.TrimSpace(split[1])
439 trailers.Add(key, value)
440 }
441 }
442 } else {
443
444 if debug_grpc_web {
445 dlog.Printf(ctx, " protobuf @%v (len %v, type %v) %v - %v", frame_start, frame_len, frame_type, len(frame), frame)
446 }
447
448 proto = frame
449 }
450
451 frame_start += frame_len + 5
452 }
453
454 return proto, trailers, nil
455 }
456
457
458
459
460
461
462 func (q Query) AddResponse(ctx context.Context, resp *http.Response) {
463 result := q.Result()
464 result["status"] = resp.StatusCode
465 result["headers"] = resp.Header
466
467 headers := result["headers"].(http.Header)
468
469 if headers != nil {
470
471 cstart := q["client-start-date"]
472
473
474
475
476 if cstart != nil {
477 headers.Add("Client-Start-Date", q["client-start-date"].(string))
478
479
480 headers.Add("Client-End-Date", time.Now().Format(time.RFC3339Nano))
481 }
482 }
483
484 if resp.TLS != nil {
485 result["tls_version"] = resp.TLS.Version
486 result["tls"] = resp.TLS.PeerCertificates
487 result["cipher_suite"] = resp.TLS.CipherSuite
488 }
489 body, err := io.ReadAll(resp.Body)
490 if !q.CheckErr(ctx, err) {
491 dlog.Printf(ctx, "%v: %v", q.URL(), resp.Status)
492 result["body"] = body
493 if q.GrpcType() != "" && len(body) > 5 {
494 if q.GrpcType() == "web" {
495
496
497 decodedBody, trailers, err := DecodeGrpcWebTextBody(ctx, body)
498 if q.CheckErr(ctx, err) {
499 dlog.Printf(ctx, "Failed to decode grpc-web-text body: %v", err)
500 return
501 }
502 body = decodedBody
503
504 if debug_grpc_web {
505 dlog.Printf(ctx, "decodedBody '%v'", body)
506 }
507
508 for key, values := range trailers {
509 for _, value := range values {
510 headers.Add(key, value)
511 }
512 }
513
514 } else {
515
516
517 body = body[5:]
518 }
519
520 response := &grpc_echo_pb.EchoResponse{}
521 err := proto.Unmarshal(body, response)
522 if q.CheckErr(ctx, err) {
523 dlog.Printf(ctx, "Failed to unmarshal proto: %v", err)
524 return
525 }
526 result["text"] = response
527 return
528 }
529 var jsonBody interface{}
530 err = json.Unmarshal(body, &jsonBody)
531 if err == nil {
532 result["json"] = jsonBody
533 } else {
534 result["text"] = string(body)
535 }
536 }
537 }
538
539
540
541
542 func ExecuteWebsocketQuery(ctx context.Context, query Query) {
543 url := query.URL()
544 c, resp, err := websocket.DefaultDialer.Dial(url, query.Headers())
545 if query.CheckErr(ctx, err) {
546 return
547 }
548 defer c.Close()
549 query.AddResponse(ctx, resp)
550 messages := query["messages"].([]interface{})
551 for _, msg := range messages {
552 err = c.WriteMessage(websocket.TextMessage, []byte(msg.(string)))
553 if query.CheckErr(ctx, err) {
554 return
555 }
556 }
557
558 err = c.WriteMessage(websocket.CloseMessage,
559 websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
560 if query.CheckErr(ctx, err) {
561 return
562 }
563
564 answers := []string{}
565
566 result := query.Result()
567 defer func() {
568 result["messages"] = answers
569 }()
570
571 for {
572 _, message, err := c.ReadMessage()
573 if err != nil {
574 if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
575 query.CheckErr(ctx, err)
576 }
577 return
578 }
579 answers = append(answers, string(message))
580 }
581 }
582
583
584
585
586 func GetGRPCReqBody(ctx context.Context) (*bytes.Buffer, error) {
587
588
589
590
591 buf := &bytes.Buffer{}
592 if err := binary.Write(buf, binary.BigEndian, uint8(0)); err != nil {
593 dlog.Printf(ctx, "error when packing first byte: %v", err)
594 return nil, err
595 }
596
597 m := &grpc_echo_pb.EchoRequest{}
598 m.Data = "foo"
599
600 bs, err := proto.Marshal(m)
601 if err != nil {
602 dlog.Printf(ctx, "error when serializing the gRPC message: %v", err)
603 return nil, err
604 }
605
606 if err := binary.Write(buf, binary.BigEndian, uint32(len(bs))); err != nil {
607 dlog.Printf(ctx, "error when packing message length: %v", err)
608 return nil, err
609 }
610
611 for i := 0; i < len(bs); i++ {
612 if err := binary.Write(buf, binary.BigEndian, bs[i]); err != nil {
613 dlog.Printf(ctx, "error when packing message: %v", err)
614 return nil, err
615 }
616 }
617
618 return buf, nil
619 }
620
621
622
623 func CallRealGRPC(ctx context.Context, query Query) {
624 qURL, err := url.Parse(query.URL())
625 if query.CheckErr(ctx, err) {
626 dlog.Printf(ctx, "grpc url parse failed: %v", err)
627 return
628 }
629
630 const requiredPath = "/echo.EchoService/Echo"
631 if qURL.Path != requiredPath {
632 query.Result()["error"] = fmt.Sprintf("GRPC path %s is not %s", qURL.Path, requiredPath)
633 return
634 }
635
636 dialHost := qURL.Host
637 if !strings.Contains(dialHost, ":") {
638
639 if qURL.Scheme == "https" {
640 dialHost = dialHost + ":443"
641 } else {
642 dialHost = dialHost + ":80"
643 }
644 }
645
646 ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
647 defer cancel()
648
649
650
651
652
653
654
655
656 var dialOptions []grpc.DialOption
657 if qURL.Scheme != "https" {
658 dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
659 }
660 conn, err := grpc.DialContext(ctx, dialHost, dialOptions...)
661 if query.CheckErr(ctx, err) {
662 dlog.Printf(ctx, "grpc dial failed: %v", err)
663 return
664 }
665 defer conn.Close()
666
667 client := grpc_echo_pb.NewEchoServiceClient(conn)
668 request := &grpc_echo_pb.EchoRequest{Data: "real gRPC"}
669
670
671 md := metadata.MD{}
672 headers, ok := query["headers"]
673 if ok {
674 for key, val := range headers.(map[string]interface{}) {
675 md.Set(key, val.(string))
676 }
677 }
678 ctx = metadata.NewOutgoingContext(ctx, md)
679
680 response, err := client.Echo(ctx, request)
681 stat, ok := status.FromError(err)
682 if !ok {
683 query.CheckErr(ctx, err)
684 dlog.Printf(ctx, "grpc echo request failed: %v", err)
685 return
686 }
687
688
689
690
691
692 grpcCode := int(stat.Code())
693 if grpcCode == 14 {
694 query.CheckErr(ctx, err)
695 dlog.Printf(ctx, "grpc echo request connection failed: %v", err)
696 return
697 }
698
699
700
701
702 resHeader := make(http.Header)
703 resHeader.Add("Grpc-Status", fmt.Sprint(grpcCode))
704 resHeader.Add("Grpc-Message", stat.Message())
705
706 result := query.Result()
707 result["headers"] = resHeader
708 result["body"] = ""
709 result["status"] = 200
710 if err == nil {
711 result["text"] = response
712 }
713
714
715
716
717
718
719
720
721
722 }
723
724
725
726 func ExecuteQuery(ctx context.Context, query Query) error {
727
728 if query.IsWebsocket() {
729 ExecuteWebsocketQuery(ctx, query)
730 return nil
731 }
732
733
734 if query.GrpcType() == "real" {
735 CallRealGRPC(ctx, query)
736 return nil
737 }
738
739
740 transport := &http.Transport{
741 MaxIdleConns: 10,
742 IdleConnTimeout: 30 * time.Second,
743 TLSClientConfig: &tls.Config{},
744 }
745 if query.Insecure() {
746 transport.TLSClientConfig.InsecureSkipVerify = true
747 }
748 if caCert := query.CACert(); len(caCert) > 0 {
749 caCertPool := x509.NewCertPool()
750 caCertPool.AppendCertsFromPEM([]byte(caCert))
751 transport.TLSClientConfig.RootCAs = caCertPool
752 }
753 if query.ClientCert() != "" || query.ClientKey() != "" {
754 clientCert, err := tls.X509KeyPair([]byte(query.ClientCert()), []byte(query.ClientKey()))
755 if err != nil {
756 dlog.Error(ctx, err)
757 return err
758 }
759 transport.TLSClientConfig.Certificates = []tls.Certificate{clientCert}
760 }
761 if query.MinTLSVersion() != 0 {
762 transport.TLSClientConfig.MinVersion = query.MinTLSVersion()
763 }
764 if query.MaxTLSVersion() != 0 {
765 transport.TLSClientConfig.MaxVersion = query.MaxTLSVersion()
766 }
767 if len(query.CipherSuites()) > 0 {
768 transport.TLSClientConfig.CipherSuites = query.CipherSuites()
769 }
770 if len(query.ECDHCurves()) > 0 {
771 transport.TLSClientConfig.CurvePreferences = query.ECDHCurves()
772 }
773
774
775 var body io.Reader
776 method := query.Method()
777 if query.GrpcType() != "" {
778
779 buf, err := GetGRPCReqBody(ctx)
780 if query.CheckErr(ctx, err) {
781 dlog.Printf(ctx, "gRPC buffer error: %v", err)
782 return nil
783 }
784 if query.GrpcType() == "web" {
785 result := make([]byte, base64.StdEncoding.EncodedLen(buf.Len()))
786 base64.StdEncoding.Encode(result, buf.Bytes())
787 buf = bytes.NewBuffer(result)
788 }
789 body = buf
790 method = "POST"
791 } else {
792 body = query.Body()
793 }
794 req, err := http.NewRequest(method, query.URL(), body)
795 if query.CheckErr(ctx, err) {
796 dlog.Printf(ctx, "request error: %v", err)
797 return nil
798 }
799 req.Header = query.Headers()
800 for _, cookie := range query.Cookies() {
801 req.AddCookie(&cookie)
802 }
803
804
805 query["client-start-date"] = time.Now().Format(time.RFC3339Nano)
806
807
808 host := req.Header.Get("Host")
809 if host != "" {
810 if query.SNI() {
811 transport.TLSClientConfig.ServerName = host
812 }
813 req.Host = host
814 }
815
816
817 client := &http.Client{
818 Transport: transport,
819 Timeout: time.Duration(10 * time.Second),
820 CheckRedirect: func(req *http.Request, via []*http.Request) error {
821 return http.ErrUseLastResponse
822 },
823 }
824 resp, err := client.Do(req)
825 if query.CheckErr(ctx, err) {
826 return nil
827 }
828 query.AddResponse(ctx, resp)
829 return nil
830 }
831
832 type Args struct {
833 input string
834 output string
835 }
836
837 func parseArgs(rawArgs ...string) (Args, error) {
838 var args Args
839 flagset := flag.NewFlagSet("kat-client", flag.ContinueOnError)
840 flagset.StringVar(&args.input, "input", "", "input filename")
841 flagset.StringVar(&args.output, "output", "", "output filename")
842 err := flagset.Parse(rawArgs)
843 return args, err
844 }
845
846 func main() {
847 ctx := context.Background()
848 debug_grpc_web = false
849
850 rlimit(ctx)
851
852 args, err := parseArgs(os.Args[1:]...)
853 if err != nil {
854 panic(err)
855 }
856
857 var data []byte
858
859
860 if args.input == "" {
861 dlog.Printf(ctx, "processing queries from stdin")
862 data, err = io.ReadAll(os.Stdin)
863 } else {
864
865 data, err = os.ReadFile(args.input)
866 }
867 if err != nil {
868 panic(err)
869 }
870
871
872 var specs []Query
873 err = json.Unmarshal(data, &specs)
874 if err != nil {
875 panic(err)
876 }
877
878
879 limitStr := os.Getenv("KAT_QUERY_LIMIT")
880 limit, err := strconv.Atoi(limitStr)
881 if err != nil {
882 limit = 25
883 }
884 sem := NewSemaphore(limit)
885
886
887 count := len(specs)
888 queries := make(chan bool)
889 for i := 0; i < count; i++ {
890 go func(idx int) {
891 sem.Acquire()
892 defer func() {
893 queries <- true
894 sem.Release()
895 }()
896 if err := ExecuteQuery(ctx, specs[idx]); err != nil {
897 dlog.Errorf(ctx, "an error occurred executing query %d, kat-client will panic: %s", idx, err.Error())
898 panic(err)
899 }
900 }(i)
901 }
902
903
904 for i := 0; i < count; i++ {
905 <-queries
906 }
907
908
909 bytes, err := json.MarshalIndent(specs, "", " ")
910 if err != nil {
911 dlog.Print(ctx, err)
912 } else if args.output == "" {
913 dlog.Printf(ctx, "writing results to stdout")
914 fmt.Print(string(bytes))
915 } else {
916 dlog.Printf(ctx, "writing results to output file: %s", args.output)
917 err = os.WriteFile(args.output, bytes, 0644)
918 if err != nil {
919 dlog.Print(ctx, err)
920 }
921 }
922 }
923
View as plain text