1 package services
2
3 import (
4
5 "context"
6 "crypto/tls"
7 "encoding/json"
8 "fmt"
9 "net/http"
10 "strings"
11
12
13 "google.golang.org/grpc"
14 "google.golang.org/protobuf/types/known/wrapperspb"
15
16
17 core "github.com/emissary-ingress/emissary/v3/pkg/api/envoy/config/core/v3"
18 pb "github.com/emissary-ingress/emissary/v3/pkg/api/envoy/service/ratelimit/v3"
19
20
21 "github.com/datawire/dlib/dgroup"
22 "github.com/datawire/dlib/dhttp"
23 "github.com/datawire/dlib/dlog"
24 )
25
26
27 type GRPCRLSV3 struct {
28 Port int16
29 Backend string
30 SecurePort int16
31 SecureBackend string
32 Cert string
33 Key string
34 ProtocolVersion string
35 }
36
37
38 func (g *GRPCRLSV3) Start(ctx context.Context) <-chan bool {
39 dlog.Printf(ctx, "GRPCRLSV3: %s listening on %d/%d", g.Backend, g.Port, g.SecurePort)
40
41 grpcHandler := grpc.NewServer()
42 dlog.Printf(ctx, "registering v3 service")
43 pb.RegisterRateLimitServiceServer(grpcHandler, g)
44
45 cer, err := tls.LoadX509KeyPair(g.Cert, g.Key)
46 if err != nil {
47 dlog.Error(ctx, err)
48 panic(err)
49 }
50
51 sc := &dhttp.ServerConfig{
52 Handler: grpcHandler,
53 TLSConfig: &tls.Config{
54 Certificates: []tls.Certificate{cer},
55 },
56 }
57
58 grp := dgroup.NewGroup(ctx, dgroup.GroupConfig{})
59 grp.Go("cleartext", func(ctx context.Context) error {
60 return sc.ListenAndServe(ctx, fmt.Sprintf(":%v", g.Port))
61 })
62 grp.Go("tls", func(ctx context.Context) error {
63 return sc.ListenAndServeTLS(ctx, fmt.Sprintf(":%v", g.SecurePort), "", "")
64 })
65
66 dlog.Print(ctx, "starting gRPC rls service")
67
68 exited := make(chan bool)
69 go func() {
70 if err := grp.Wait(); err != nil {
71 dlog.Error(ctx, err)
72 panic(err)
73 }
74 close(exited)
75 }()
76 return exited
77 }
78
79
80 func (g *GRPCRLSV3) ShouldRateLimit(ctx context.Context, r *pb.RateLimitRequest) (*pb.RateLimitResponse, error) {
81 rs := &RLSResponseV3{}
82
83 dlog.Printf(ctx, "shouldRateLimit descriptors: %v\n", r.Descriptors)
84
85 descEntries := make(map[string]string)
86 for _, desc := range r.Descriptors {
87 for _, entry := range desc.Entries {
88 descEntries[entry.Key] = entry.Value
89 }
90 }
91
92
93
94 if allowValue := descEntries["kat-req-rls-allow"]; allowValue == "true" {
95 rs.SetOverallCode(pb.RateLimitResponse_OK)
96 } else {
97 rs.SetOverallCode(pb.RateLimitResponse_OVER_LIMIT)
98
99
100
101
102
103 for _, token := range strings.Split(descEntries["kat-req-rls-headers-append"], ";") {
104 header := strings.Split(strings.TrimSpace(token), "=")
105 if len(header) > 1 {
106 dlog.Printf(ctx, "appending header %s : %s", header[0], header[1])
107 rs.AddHeader(true, header[0], header[1])
108 }
109 }
110
111
112 rs.AddHeader(true, "content-type", "application/json")
113 rs.AddHeader(true, "kat-resp-rls-protocol-version", g.ProtocolVersion)
114
115
116 results := make(map[string]interface{})
117
118 results["descriptors"] = ""
119 results["backend"] = g.Backend
120 results["status"] = rs.GetOverallCode()
121 if rs.GetHTTPHeaderMap() != nil {
122 results["headers"] = *rs.GetHTTPHeaderMap()
123 }
124 body, err := json.MarshalIndent(results, "", " ")
125 if err != nil {
126 body = []byte(fmt.Sprintf("Error: %v", err))
127 }
128
129
130 dlog.Printf(ctx, "setting response body: %s", string(body))
131 rs.SetBody(string(body))
132 }
133
134 return rs.GetResponse(), nil
135 }
136
137
138 type RLSResponseV3 struct {
139 headers []*core.HeaderValueOption
140 body string
141 overallCode pb.RateLimitResponse_Code
142 }
143
144
145
146 func (r *RLSResponseV3) AddHeader(a bool, k, v string) {
147 val := &core.HeaderValueOption{
148 Header: &core.HeaderValue{
149 Key: k,
150 Value: v,
151 },
152 Append: &wrapperspb.BoolValue{Value: a},
153 }
154 r.headers = append(r.headers, val)
155 }
156
157
158 func (r *RLSResponseV3) GetHTTPHeaderMap() *http.Header {
159 h := &http.Header{}
160 for _, v := range r.headers {
161 h.Add(v.Header.Key, v.Header.Value)
162 }
163 return h
164 }
165
166
167 func (r *RLSResponseV3) SetBody(s string) {
168 r.body = s
169 }
170
171
172 func (r *RLSResponseV3) SetOverallCode(code pb.RateLimitResponse_Code) {
173 r.overallCode = code
174 }
175
176
177 func (r *RLSResponseV3) GetOverallCode() pb.RateLimitResponse_Code {
178 return r.overallCode
179 }
180
181
182 func (r *RLSResponseV3) GetResponse() *pb.RateLimitResponse {
183 rs := &pb.RateLimitResponse{}
184 rs.OverallCode = r.overallCode
185 rs.RawBody = []byte(r.body)
186 for _, h := range r.headers {
187 hdr := h.Header
188 if hdr != nil {
189 rs.ResponseHeadersToAdd = append(rs.ResponseHeadersToAdd,
190 &core.HeaderValue{
191 Key: hdr.Key,
192 Value: hdr.Value,
193 },
194 )
195 }
196 }
197 return rs
198 }
199
View as plain text