1 package services
2
3 import (
4 "context"
5 "crypto/tls"
6 "encoding/json"
7 "fmt"
8 "net"
9 "net/http"
10 "strings"
11
12 "google.golang.org/grpc"
13 "google.golang.org/protobuf/types/known/wrapperspb"
14
15 core "github.com/datawire/ambassador/v2/pkg/api/envoy/config/core/v3"
16 pb "github.com/datawire/ambassador/v2/pkg/api/envoy/service/ratelimit/v3"
17 "github.com/datawire/dlib/dlog"
18 )
19
20
21 type GRPCRLSV3 struct {
22 Port int16
23 Backend string
24 SecurePort int16
25 SecureBackend string
26 Cert string
27 Key string
28 ProtocolVersion string
29 }
30
31
32 func (g *GRPCRLSV3) Start(ctx context.Context) <-chan bool {
33 dlog.Printf(ctx, "GRPCRLSV3: %s listening on %d/%d", g.Backend, g.Port, g.SecurePort)
34
35 exited := make(chan bool)
36 proto := "tcp"
37
38 go func() {
39 port := fmt.Sprintf(":%v", g.Port)
40
41 ln, err := net.Listen(proto, port)
42 if err != nil {
43 dlog.Error(ctx, err)
44 panic(err)
45 }
46
47 s := grpc.NewServer()
48 dlog.Printf(ctx, "registering v3 service")
49 pb.RegisterRateLimitServiceServer(s, g)
50 if err := s.Serve(ln); err != nil {
51 panic(err)
52 }
53
54 defer ln.Close()
55 close(exited)
56 }()
57
58 go func() {
59 cer, err := tls.LoadX509KeyPair(g.Cert, g.Key)
60 if err != nil {
61 dlog.Error(ctx, err)
62 panic(err)
63 }
64
65 config := &tls.Config{Certificates: []tls.Certificate{cer}}
66 port := fmt.Sprintf(":%v", g.SecurePort)
67 ln, err := tls.Listen(proto, port, config)
68 if err != nil {
69 dlog.Error(ctx, err)
70 panic(err)
71
72 }
73
74 s := grpc.NewServer()
75 dlog.Printf(ctx, "registering v3 service")
76 pb.RegisterRateLimitServiceServer(s, g)
77 if err := s.Serve(ln); err != nil {
78 panic(err)
79 }
80
81 defer ln.Close()
82 close(exited)
83 }()
84
85 dlog.Print(ctx, "starting gRPC rls service")
86 return exited
87 }
88
89
90 func (g *GRPCRLSV3) ShouldRateLimit(ctx context.Context, r *pb.RateLimitRequest) (*pb.RateLimitResponse, error) {
91 rs := &RLSResponseV3{}
92
93 dlog.Printf(ctx, "shouldRateLimit descriptors: %v\n", r.Descriptors)
94
95 descEntries := make(map[string]string)
96 for _, desc := range r.Descriptors {
97 for _, entry := range desc.Entries {
98 descEntries[entry.Key] = entry.Value
99 }
100 }
101
102
103
104 if allowValue := descEntries["x-ambassador-test-allow"]; allowValue == "true" {
105 rs.SetOverallCode(pb.RateLimitResponse_OK)
106 } else {
107 rs.SetOverallCode(pb.RateLimitResponse_OVER_LIMIT)
108
109
110
111
112
113 for _, token := range strings.Split(descEntries["x-ambassador-test-headers-append"], ";") {
114 header := strings.Split(strings.TrimSpace(token), "=")
115 if len(header) > 1 {
116 dlog.Printf(ctx, "appending header %s : %s", header[0], header[1])
117 rs.AddHeader(true, header[0], header[1])
118 }
119 }
120
121
122 rs.AddHeader(true, "content-type", "application/json")
123 rs.AddHeader(true, "x-grpc-service-protocol-version", g.ProtocolVersion)
124
125
126 results := make(map[string]interface{})
127
128 results["descriptors"] = ""
129 results["backend"] = g.Backend
130 results["status"] = rs.GetOverallCode()
131 if rs.GetHTTPHeaderMap() != nil {
132 results["headers"] = *rs.GetHTTPHeaderMap()
133 }
134 body, err := json.MarshalIndent(results, "", " ")
135 if err != nil {
136 body = []byte(fmt.Sprintf("Error: %v", err))
137 }
138
139
140 dlog.Printf(ctx, "setting response body: %s", string(body))
141 rs.SetBody(string(body))
142 }
143
144 return rs.GetResponse(), nil
145 }
146
147
148 type RLSResponseV3 struct {
149 headers []*core.HeaderValueOption
150 body string
151 overallCode pb.RateLimitResponse_Code
152 }
153
154
155
156 func (r *RLSResponseV3) AddHeader(a bool, k, v string) {
157 val := &core.HeaderValueOption{
158 Header: &core.HeaderValue{
159 Key: k,
160 Value: v,
161 },
162 Append: &wrapperspb.BoolValue{Value: a},
163 }
164 r.headers = append(r.headers, val)
165 }
166
167
168 func (r *RLSResponseV3) GetHTTPHeaderMap() *http.Header {
169 h := &http.Header{}
170 for _, v := range r.headers {
171 h.Add(v.Header.Key, v.Header.Value)
172 }
173 return h
174 }
175
176
177 func (r *RLSResponseV3) SetBody(s string) {
178 r.body = s
179 }
180
181
182 func (r *RLSResponseV3) SetOverallCode(code pb.RateLimitResponse_Code) {
183 r.overallCode = code
184 }
185
186
187 func (r *RLSResponseV3) GetOverallCode() pb.RateLimitResponse_Code {
188 return r.overallCode
189 }
190
191
192 func (r *RLSResponseV3) GetResponse() *pb.RateLimitResponse {
193 rs := &pb.RateLimitResponse{}
194 rs.OverallCode = r.overallCode
195 rs.RawBody = []byte(r.body)
196 for _, h := range r.headers {
197 hdr := h.Header
198 if hdr != nil {
199 rs.ResponseHeadersToAdd = append(rs.ResponseHeadersToAdd,
200 &core.HeaderValue{
201 Key: hdr.Key,
202 Value: hdr.Value,
203 },
204 )
205 }
206 }
207 return rs
208 }
209
View as plain text