1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package testutil
16
17 import (
18 "bytes"
19 "context"
20 "errors"
21 "fmt"
22 "log"
23 "os"
24 "strings"
25
26 "google.golang.org/api/option"
27 "google.golang.org/grpc"
28 "google.golang.org/grpc/metadata"
29 )
30
31
32 type HeaderChecker struct {
33
34 Key string
35
36
37
38 ValuesValidator func(values ...string) error
39 }
40
41
42
43
44
45
46
47
48 type HeadersEnforcer struct {
49
50
51
52
53
54
55
56
57 Checkers []*HeaderChecker
58
59
60
61
62 OnFailure func(fmt_ string, args ...interface{})
63 }
64
65
66
67
68
69
70
71
72
73
74 func (h *HeadersEnforcer) StreamInterceptors() []grpc.StreamClientInterceptor {
75 return []grpc.StreamClientInterceptor{h.interceptStream}
76 }
77
78
79
80
81
82
83
84
85
86
87 func (h *HeadersEnforcer) UnaryInterceptors() []grpc.UnaryClientInterceptor {
88 return []grpc.UnaryClientInterceptor{h.interceptUnary}
89 }
90
91
92
93 func (h *HeadersEnforcer) DialOptions() []grpc.DialOption {
94 return []grpc.DialOption{
95 grpc.WithChainStreamInterceptor(h.interceptStream),
96 grpc.WithChainUnaryInterceptor(h.interceptUnary),
97 }
98 }
99
100
101
102 func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) {
103 dopts := h.DialOptions()
104 for _, dopt := range dopts {
105 copts = append(copts, option.WithGRPCDialOption(dopt))
106 }
107 return
108 }
109
110 func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
111 h.checkMetadata(ctx, method)
112 return invoker(ctx, method, req, res, cc, opts...)
113 }
114
115 func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
116 h.checkMetadata(ctx, method)
117 return streamer(ctx, desc, cc, method, opts...)
118 }
119
120
121
122 var XGoogClientHeaderChecker = &HeaderChecker{
123 Key: "x-goog-api-client",
124 ValuesValidator: func(values ...string) error {
125 if len(values) == 0 {
126 return errors.New("expecting values")
127 }
128 for _, value := range values {
129 switch {
130 case strings.Contains(value, "gl-go/"):
131
132 return nil
133
134 default:
135 }
136 }
137 return errors.New("unmatched values")
138 },
139 }
140
141
142
143
144 func DefaultHeadersEnforcer() *HeadersEnforcer {
145 return &HeadersEnforcer{
146 Checkers: []*HeaderChecker{XGoogClientHeaderChecker},
147 }
148 }
149
150 func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) {
151 onFailure := h.OnFailure
152 if onFailure == nil {
153 lgr := log.New(os.Stderr, "", 0)
154 onFailure = func(fmt_ string, args ...interface{}) {
155 lgr.Fatalf(fmt_, args...)
156 }
157 }
158
159 md, ok := metadata.FromOutgoingContext(ctx)
160 if !ok {
161 onFailure("Missing metadata for method %q", method)
162 return
163 }
164 checkers := h.Checkers
165 if len(checkers) == 0 {
166
167 checkers = append(checkers, XGoogClientHeaderChecker)
168 }
169
170 errBuf := new(bytes.Buffer)
171 for _, checker := range checkers {
172 hdrKey := checker.Key
173 outHdrValues, ok := md[hdrKey]
174 if !ok {
175 fmt.Fprintf(errBuf, "missing header %q\n", hdrKey)
176 continue
177 }
178 if err := checker.ValuesValidator(outHdrValues...); err != nil {
179 fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err)
180 }
181 }
182
183 if errBuf.Len() != 0 {
184 onFailure("For method %q, errors:\n%s", method, errBuf)
185 return
186 }
187 }
188
View as plain text