...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 package main
22
23 import (
24 "flag"
25 "fmt"
26 "io"
27 "log"
28 "net"
29 "net/http"
30 "os"
31 "os/signal"
32 "strings"
33
34 "cloud.google.com/go/httpreplay/internal/proxy"
35 "github.com/google/martian/v3/martianhttp"
36 )
37
38 var (
39 port = flag.Int("port", 8080, "port of the proxy")
40 controlPort = flag.Int("control-port", 8181, "port for controlling the proxy")
41 record = flag.String("record", "", "record traffic and save to filename")
42 replay = flag.String("replay", "", "read filename and replay traffic")
43 debugHeaders = flag.Bool("debug-headers", false, "log header mismatches")
44 ignoreHeaders repeatedString
45 )
46
47 func main() {
48 flag.Var(&ignoreHeaders, "ignore-header", "header key(s) to ignore when matching")
49
50 flag.Parse()
51 if *record == "" && *replay == "" {
52 log.Fatal("provide either -record or -replay")
53 }
54 if *record != "" && *replay != "" {
55 log.Fatal("provide only one of -record and -replay")
56 }
57 log.Printf("httpr: starting proxy on port %d and control on port %d", *port, *controlPort)
58
59 var pr *proxy.Proxy
60 var err error
61 if *record != "" {
62 pr, err = proxy.ForRecording(*record, *port)
63 } else {
64 pr, err = proxy.ForReplaying(*replay, *port)
65 }
66 if err != nil {
67 log.Fatal(err)
68 }
69 proxy.DebugHeaders = *debugHeaders
70 for _, key := range ignoreHeaders {
71 pr.IgnoreHeader(key)
72 }
73
74
75 mux := http.NewServeMux()
76 mux.Handle("/authority.cer", martianhttp.NewAuthorityHandler(pr.CACert))
77 mux.HandleFunc("/initial", handleInitial(pr))
78 lControl, err := net.Listen("tcp", fmt.Sprintf(":%d", *controlPort))
79 if err != nil {
80 log.Fatal(err)
81 }
82 go http.Serve(lControl, mux)
83
84 sigc := make(chan os.Signal, 1)
85 signal.Notify(sigc, os.Interrupt)
86
87 <-sigc
88
89 log.Println("httpr: shutting down")
90 if err := pr.Close(); err != nil {
91 log.Fatal(err)
92 }
93 }
94
95 func handleInitial(pr *proxy.Proxy) http.HandlerFunc {
96 return func(w http.ResponseWriter, req *http.Request) {
97 switch req.Method {
98 case "GET":
99 if pr.Initial != nil {
100 w.Write(pr.Initial)
101 }
102
103 case "POST":
104 bytes, err := io.ReadAll(req.Body)
105 req.Body.Close()
106 if err != nil {
107 w.WriteHeader(http.StatusInternalServerError)
108 fmt.Fprintf(w, "reading body: %v", err)
109 }
110 pr.Initial = bytes
111
112 default:
113 w.WriteHeader(http.StatusBadRequest)
114 fmt.Fprint(w, "use GET to retrieve initial or POST to set it")
115 }
116 }
117 }
118
119 type repeatedString []string
120
121 func (i *repeatedString) String() string {
122 v := make([]string, 0)
123 if i != nil {
124 v = *i
125 }
126 return strings.Join(v, ",")
127 }
128
129 func (i *repeatedString) Set(value string) error {
130 *i = append(*i, value)
131 return nil
132 }
133
View as plain text