1
18
19 package orca_test
20
21 import (
22 "context"
23 "errors"
24 "io"
25 "testing"
26
27 "github.com/google/go-cmp/cmp"
28 "google.golang.org/grpc"
29 "google.golang.org/grpc/credentials/insecure"
30 "google.golang.org/grpc/internal/pretty"
31 "google.golang.org/grpc/internal/stubserver"
32 "google.golang.org/grpc/metadata"
33 "google.golang.org/grpc/orca"
34 "google.golang.org/grpc/orca/internal"
35 "google.golang.org/protobuf/proto"
36
37 v3orcapb "github.com/cncf/xds/go/xds/data/orca/v3"
38 testgrpc "google.golang.org/grpc/interop/grpc_testing"
39 testpb "google.golang.org/grpc/interop/grpc_testing"
40 )
41
42
43
44
45 func (s) TestE2ECallMetricsUnary(t *testing.T) {
46 tests := []struct {
47 desc string
48 injectMetrics bool
49 wantProto *v3orcapb.OrcaLoadReport
50 }{
51 {
52 desc: "with custom backend metrics",
53 injectMetrics: true,
54 wantProto: &v3orcapb.OrcaLoadReport{
55 CpuUtilization: 1.0,
56 MemUtilization: 0.9,
57 RequestCost: map[string]float64{"queryCost": 25.0},
58 Utilization: map[string]float64{"queueSize": 0.75},
59 },
60 },
61 {
62 desc: "with no custom backend metrics",
63 injectMetrics: false,
64 },
65 }
66
67 for _, test := range tests {
68 t.Run(test.desc, func(t *testing.T) {
69
70 smr := orca.NewServerMetricsRecorder()
71 callMetricsServerOption := orca.CallMetricsServerOption(smr)
72 smr.SetCPUUtilization(1.0)
73
74
75
76 injectingInterceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
77 recorder := orca.CallMetricsRecorderFromContext(ctx)
78 if recorder == nil {
79 err := errors.New("Failed to retrieve per-RPC custom metrics recorder from the RPC context")
80 t.Error(err)
81 return nil, err
82 }
83 recorder.SetMemoryUtilization(0.9)
84
85
86 recorder.SetNamedUtilization("queueSize", 1.0)
87 return handler(ctx, req)
88 }
89
90
91
92
93 srv := stubserver.StubServer{
94 EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
95 if !test.injectMetrics {
96 return &testpb.Empty{}, nil
97 }
98 recorder := orca.CallMetricsRecorderFromContext(ctx)
99 if recorder == nil {
100 err := errors.New("Failed to retrieve per-RPC custom metrics recorder from the RPC context")
101 t.Error(err)
102 return nil, err
103 }
104 recorder.SetRequestCost("queryCost", 25.0)
105 recorder.SetNamedUtilization("queueSize", 0.75)
106 return &testpb.Empty{}, nil
107 },
108 }
109
110
111 sopts := []grpc.ServerOption{callMetricsServerOption}
112 if test.injectMetrics {
113 sopts = append(sopts, grpc.ChainUnaryInterceptor(injectingInterceptor))
114 }
115 if err := srv.StartServer(sopts...); err != nil {
116 t.Fatalf("Failed to start server: %v", err)
117 }
118 defer srv.Stop()
119
120
121 cc, err := grpc.NewClient(srv.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
122 if err != nil {
123 t.Fatalf("grpc.NewClient(%s) failed: %v", srv.Address, err)
124 }
125 defer cc.Close()
126
127
128
129 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
130 defer cancel()
131 client := testgrpc.NewTestServiceClient(cc)
132 trailer := metadata.MD{}
133 if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
134 t.Fatalf("EmptyCall failed: %v", err)
135 }
136
137 gotProto, err := internal.ToLoadReport(trailer)
138 if err != nil {
139 t.Fatalf("When retrieving load report, got error: %v, want: <nil>", err)
140 }
141 if test.wantProto != nil && !cmp.Equal(gotProto, test.wantProto, cmp.Comparer(proto.Equal)) {
142 t.Fatalf("Received load report in trailer: %s, want: %s", pretty.ToJSON(gotProto), pretty.ToJSON(test.wantProto))
143 }
144 })
145 }
146 }
147
148
149
150
151 func (s) TestE2ECallMetricsStreaming(t *testing.T) {
152 tests := []struct {
153 desc string
154 injectMetrics bool
155 wantProto *v3orcapb.OrcaLoadReport
156 }{
157 {
158 desc: "with custom backend metrics",
159 injectMetrics: true,
160 wantProto: &v3orcapb.OrcaLoadReport{
161 CpuUtilization: 1.0,
162 MemUtilization: 0.5,
163 RequestCost: map[string]float64{"queryCost": 0.25},
164 Utilization: map[string]float64{"queueSize": 0.75},
165 },
166 },
167 {
168 desc: "with no custom backend metrics",
169 injectMetrics: false,
170 },
171 }
172
173 for _, test := range tests {
174 t.Run(test.desc, func(t *testing.T) {
175
176 smr := orca.NewServerMetricsRecorder()
177 callMetricsServerOption := orca.CallMetricsServerOption(smr)
178 smr.SetCPUUtilization(1.0)
179
180
181
182 injectingInterceptor := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
183 recorder := orca.CallMetricsRecorderFromContext(ss.Context())
184 if recorder == nil {
185 err := errors.New("Failed to retrieve per-RPC custom metrics recorder from the RPC context")
186 t.Error(err)
187 return err
188 }
189 recorder.SetMemoryUtilization(0.5)
190
191
192 recorder.SetNamedUtilization("queueSize", 1.0)
193 return handler(srv, ss)
194 }
195
196
197
198
199 srv := stubserver.StubServer{
200 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
201 if test.injectMetrics {
202 recorder := orca.CallMetricsRecorderFromContext(stream.Context())
203 if recorder == nil {
204 err := errors.New("Failed to retrieve per-RPC custom metrics recorder from the RPC context")
205 t.Error(err)
206 return err
207 }
208 recorder.SetRequestCost("queryCost", 0.25)
209 recorder.SetNamedUtilization("queueSize", 0.75)
210 }
211
212
213
214
215 for {
216 _, err := stream.Recv()
217 if err == io.EOF {
218 return nil
219 }
220 if err != nil {
221 return err
222 }
223 payload := &testpb.Payload{Body: make([]byte, 32)}
224 if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: payload}); err != nil {
225 return err
226 }
227 }
228 },
229 }
230
231
232 sopts := []grpc.ServerOption{callMetricsServerOption}
233 if test.injectMetrics {
234 sopts = append(sopts, grpc.ChainStreamInterceptor(injectingInterceptor))
235 }
236 if err := srv.StartServer(sopts...); err != nil {
237 t.Fatalf("Failed to start server: %v", err)
238 }
239 defer srv.Stop()
240
241
242 cc, err := grpc.NewClient(srv.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
243 if err != nil {
244 t.Fatalf("grpc.NewClient(%s) failed: %v", srv.Address, err)
245 }
246 defer cc.Close()
247
248
249 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
250 defer cancel()
251 tc := testgrpc.NewTestServiceClient(cc)
252 stream, err := tc.FullDuplexCall(ctx)
253 if err != nil {
254 t.Fatalf("FullDuplexCall failed: %v", err)
255 }
256
257
258 payload := &testpb.Payload{Body: make([]byte, 32)}
259 req := &testpb.StreamingOutputCallRequest{Payload: payload}
260 if err := stream.Send(req); err != nil {
261 t.Fatalf("stream.Send() failed: %v", err)
262 }
263
264 if _, err := stream.Recv(); err != nil {
265 t.Fatalf("stream.Recv() failed: %v", err)
266 }
267
268 if err := stream.CloseSend(); err != nil {
269 t.Fatalf("stream.CloseSend() failed: %v", err)
270 }
271
272 for {
273 if _, err := stream.Recv(); err != nil {
274 break
275 }
276 }
277
278 gotProto, err := internal.ToLoadReport(stream.Trailer())
279 if err != nil {
280 t.Fatalf("When retrieving load report, got error: %v, want: <nil>", err)
281 }
282 if test.wantProto != nil && !cmp.Equal(gotProto, test.wantProto, cmp.Comparer(proto.Equal)) {
283 t.Fatalf("Received load report in trailer: %s, want: %s", pretty.ToJSON(gotProto), pretty.ToJSON(test.wantProto))
284 }
285 })
286 }
287 }
288
View as plain text