1
18
19 package audit_test
20
21 import (
22 "context"
23 "crypto/tls"
24 "crypto/x509"
25 "encoding/json"
26 "io"
27 "net"
28 "os"
29 "testing"
30 "time"
31
32 "github.com/google/go-cmp/cmp"
33 "google.golang.org/grpc"
34 "google.golang.org/grpc/authz"
35 "google.golang.org/grpc/authz/audit"
36 "google.golang.org/grpc/codes"
37 "google.golang.org/grpc/credentials"
38 "google.golang.org/grpc/internal/grpctest"
39 "google.golang.org/grpc/internal/stubserver"
40 testgrpc "google.golang.org/grpc/interop/grpc_testing"
41 testpb "google.golang.org/grpc/interop/grpc_testing"
42 "google.golang.org/grpc/status"
43 "google.golang.org/grpc/testdata"
44
45 _ "google.golang.org/grpc/authz/audit/stdout"
46 )
47
48 type s struct {
49 grpctest.Tester
50 }
51
52 func Test(t *testing.T) {
53 grpctest.RunSubTests(t, s{})
54 }
55
56 type statAuditLogger struct {
57 authzDecisionStat map[bool]int
58 lastEvent *audit.Event
59 }
60
61 func (s *statAuditLogger) Log(event *audit.Event) {
62 s.authzDecisionStat[event.Authorized]++
63 *s.lastEvent = *event
64 }
65
66 type loggerBuilder struct {
67 authzDecisionStat map[bool]int
68 lastEvent *audit.Event
69 }
70
71 func (loggerBuilder) Name() string {
72 return "stat_logger"
73 }
74
75 func (lb *loggerBuilder) Build(audit.LoggerConfig) audit.Logger {
76 return &statAuditLogger{
77 authzDecisionStat: lb.authzDecisionStat,
78 lastEvent: lb.lastEvent,
79 }
80 }
81
82 func (*loggerBuilder) ParseLoggerConfig(config json.RawMessage) (audit.LoggerConfig, error) {
83 return nil, nil
84 }
85
86
87
88
89
90
91 func (s) TestAuditLogger(t *testing.T) {
92
93
94
95
96
97 tests := []struct {
98 name string
99 authzPolicy string
100 wantAuthzOutcomes map[bool]int
101 eventContent *audit.Event
102 wantUnaryCallCode codes.Code
103 wantStreamingCallCode codes.Code
104 }{
105 {
106 name: "No audit",
107 authzPolicy: `{
108 "name": "authz",
109 "allow_rules": [
110 {
111 "name": "allow_UnaryCall",
112 "request": {
113 "paths": [
114 "/grpc.testing.TestService/UnaryCall"
115 ]
116 }
117 }
118 ],
119 "audit_logging_options": {
120 "audit_condition": "NONE",
121 "audit_loggers": [
122 {
123 "name": "stat_logger",
124 "config": {},
125 "is_optional": false
126 }
127 ]
128 }
129 }`,
130 wantAuthzOutcomes: map[bool]int{true: 0, false: 0},
131 wantUnaryCallCode: codes.OK,
132 wantStreamingCallCode: codes.PermissionDenied,
133 },
134 {
135 name: "Allow All Deny Streaming - Audit All",
136 authzPolicy: `{
137 "name": "authz",
138 "allow_rules": [
139 {
140 "name": "allow_all",
141 "request": {
142 "paths": [
143 "*"
144 ]
145 }
146 }
147 ],
148 "deny_rules": [
149 {
150 "name": "deny_all",
151 "request": {
152 "paths": [
153 "/grpc.testing.TestService/StreamingInputCall"
154 ]
155 }
156 }
157 ],
158 "audit_logging_options": {
159 "audit_condition": "ON_DENY_AND_ALLOW",
160 "audit_loggers": [
161 {
162 "name": "stat_logger",
163 "config": {},
164 "is_optional": false
165 },
166 {
167 "name": "stdout_logger",
168 "is_optional": false
169 }
170 ]
171 }
172 }`,
173 wantAuthzOutcomes: map[bool]int{true: 2, false: 1},
174 eventContent: &audit.Event{
175 FullMethodName: "/grpc.testing.TestService/StreamingInputCall",
176 Principal: "spiffe://foo.bar.com/client/workload/1",
177 PolicyName: "authz",
178 MatchedRule: "authz_deny_all",
179 Authorized: false,
180 },
181 wantUnaryCallCode: codes.OK,
182 wantStreamingCallCode: codes.PermissionDenied,
183 },
184 {
185 name: "Allow Unary - Audit Allow",
186 authzPolicy: `{
187 "name": "authz",
188 "allow_rules": [
189 {
190 "name": "allow_UnaryCall",
191 "request": {
192 "paths": [
193 "/grpc.testing.TestService/UnaryCall"
194 ]
195 }
196 }
197 ],
198 "audit_logging_options": {
199 "audit_condition": "ON_ALLOW",
200 "audit_loggers": [
201 {
202 "name": "stat_logger",
203 "config": {},
204 "is_optional": false
205 }
206 ]
207 }
208 }`,
209 wantAuthzOutcomes: map[bool]int{true: 2, false: 0},
210 wantUnaryCallCode: codes.OK,
211 wantStreamingCallCode: codes.PermissionDenied,
212 },
213 {
214 name: "Allow Typo - Audit Deny",
215 authzPolicy: `{
216 "name": "authz",
217 "allow_rules": [
218 {
219 "name": "allow_UnaryCall",
220 "request": {
221 "paths": [
222 "/grpc.testing.TestService/UnaryCall_Z"
223 ]
224 }
225 }
226 ],
227 "audit_logging_options": {
228 "audit_condition": "ON_DENY",
229 "audit_loggers": [
230 {
231 "name": "stat_logger",
232 "config": {},
233 "is_optional": false
234 }
235 ]
236 }
237 }`,
238 wantAuthzOutcomes: map[bool]int{true: 0, false: 3},
239 wantUnaryCallCode: codes.PermissionDenied,
240 wantStreamingCallCode: codes.PermissionDenied,
241 },
242 }
243
244 serverCreds := loadServerCreds(t)
245 clientCreds := loadClientCreds(t)
246 ss := &stubserver.StubServer{
247 UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
248 return &testpb.SimpleResponse{}, nil
249 },
250 FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
251 _, err := stream.Recv()
252 if err != io.EOF {
253 return err
254 }
255 return nil
256 },
257 }
258 for _, test := range tests {
259 t.Run(test.name, func(t *testing.T) {
260
261
262 lb := &loggerBuilder{
263 authzDecisionStat: map[bool]int{true: 0, false: 0},
264 lastEvent: &audit.Event{},
265 }
266 audit.RegisterLoggerBuilder(lb)
267 i, _ := authz.NewStatic(test.authzPolicy)
268
269 s := grpc.NewServer(
270 grpc.Creds(serverCreds),
271 grpc.ChainUnaryInterceptor(i.UnaryInterceptor),
272 grpc.ChainStreamInterceptor(i.StreamInterceptor))
273 defer s.Stop()
274 testgrpc.RegisterTestServiceServer(s, ss)
275 lis, err := net.Listen("tcp", "localhost:0")
276 if err != nil {
277 t.Fatalf("Error listening: %v", err)
278 }
279 go s.Serve(lis)
280
281
282 clientConn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(clientCreds))
283 if err != nil {
284 t.Fatalf("grpc.NewClient(%v) failed: %v", lis.Addr().String(), err)
285 }
286 defer clientConn.Close()
287 client := testgrpc.NewTestServiceClient(clientConn)
288 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
289 defer cancel()
290
291 if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != test.wantUnaryCallCode {
292 t.Errorf("Unexpected UnaryCall fail: got %v want %v", err, test.wantUnaryCallCode)
293 }
294 if _, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != test.wantUnaryCallCode {
295 t.Errorf("Unexpected UnaryCall fail: got %v want %v", err, test.wantUnaryCallCode)
296 }
297 stream, err := client.StreamingInputCall(ctx)
298 if err != nil {
299 t.Fatalf("StreamingInputCall failed:%v", err)
300 }
301 req := &testpb.StreamingInputCallRequest{
302 Payload: &testpb.Payload{
303 Body: []byte("hi"),
304 },
305 }
306 if err := stream.Send(req); err != nil && err != io.EOF {
307 t.Fatalf("stream.Send failed:%v", err)
308 }
309 if _, err := stream.CloseAndRecv(); status.Code(err) != test.wantStreamingCallCode {
310 t.Errorf("Unexpected stream.CloseAndRecv fail: got %v want %v", err, test.wantStreamingCallCode)
311 }
312
313
314
315 if diff := cmp.Diff(lb.authzDecisionStat, test.wantAuthzOutcomes); diff != "" {
316 t.Errorf("Authorization decisions do not match\ndiff (-got +want):\n%s", diff)
317 }
318
319 if test.eventContent != nil {
320 if diff := cmp.Diff(lb.lastEvent, test.eventContent); diff != "" {
321 t.Errorf("Unexpected message\ndiff (-got +want):\n%s", diff)
322 }
323 }
324 })
325 }
326 }
327
328
329 func loadServerCreds(t *testing.T) credentials.TransportCredentials {
330 t.Helper()
331 cert := loadKeys(t, "x509/server1_cert.pem", "x509/server1_key.pem")
332 certPool := loadCACerts(t, "x509/client_ca_cert.pem")
333 return credentials.NewTLS(&tls.Config{
334 ClientAuth: tls.RequireAndVerifyClientCert,
335 Certificates: []tls.Certificate{cert},
336 ClientCAs: certPool,
337 })
338 }
339
340
341 func loadClientCreds(t *testing.T) credentials.TransportCredentials {
342 t.Helper()
343 cert := loadKeys(t, "x509/client_with_spiffe_cert.pem", "x509/client_with_spiffe_key.pem")
344 roots := loadCACerts(t, "x509/server_ca_cert.pem")
345 return credentials.NewTLS(&tls.Config{
346 Certificates: []tls.Certificate{cert},
347 RootCAs: roots,
348 ServerName: "x.test.example.com",
349 })
350
351 }
352
353
354
355 func loadKeys(t *testing.T, certPath, key string) tls.Certificate {
356 t.Helper()
357 cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(key))
358 if err != nil {
359 t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, key, err)
360 }
361 return cert
362 }
363
364
365
366 func loadCACerts(t *testing.T, certPath string) *x509.CertPool {
367 t.Helper()
368 ca, err := os.ReadFile(testdata.Path(certPath))
369 if err != nil {
370 t.Fatalf("os.ReadFile(%q) failed: %v", certPath, err)
371 }
372 roots := x509.NewCertPool()
373 if !roots.AppendCertsFromPEM(ca) {
374 t.Fatal("Failed to append certificates")
375 }
376 return roots
377 }
378
View as plain text