...
1
2
3
4 package grpc_validator
5
6 import (
7 "context"
8
9 "google.golang.org/grpc"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/status"
12 )
13
14
15
16 type validator interface {
17 Validate(all bool) error
18 }
19
20
21 type validatorLegacy interface {
22 Validate() error
23 }
24
25 func validate(req interface{}) error {
26 switch v := req.(type) {
27 case validatorLegacy:
28 if err := v.Validate(); err != nil {
29 return status.Error(codes.InvalidArgument, err.Error())
30 }
31 case validator:
32 if err := v.Validate(false); err != nil {
33 return status.Error(codes.InvalidArgument, err.Error())
34 }
35 }
36 return nil
37 }
38
39
40
41
42 func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
43 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
44 if err := validate(req); err != nil {
45 return nil, err
46 }
47 return handler(ctx, req)
48 }
49 }
50
51
52
53
54 func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
55 return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
56 if err := validate(req); err != nil {
57 return err
58 }
59 return invoker(ctx, method, req, reply, cc, opts...)
60 }
61 }
62
63
64
65
66
67
68
69 func StreamServerInterceptor() grpc.StreamServerInterceptor {
70 return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
71 wrapper := &recvWrapper{stream}
72 return handler(srv, wrapper)
73 }
74 }
75
76 type recvWrapper struct {
77 grpc.ServerStream
78 }
79
80 func (s *recvWrapper) RecvMsg(m interface{}) error {
81 if err := s.ServerStream.RecvMsg(m); err != nil {
82 return err
83 }
84
85 if err := validate(m); err != nil {
86 return err
87 }
88
89 return nil
90 }
91
View as plain text