1
18
19 package test
20
21 import (
22 "context"
23 "fmt"
24 "testing"
25 "time"
26
27 "github.com/google/go-cmp/cmp"
28 "github.com/google/go-cmp/cmp/cmpopts"
29 "google.golang.org/grpc/codes"
30 iresolver "google.golang.org/grpc/internal/resolver"
31 "google.golang.org/grpc/internal/serviceconfig"
32 "google.golang.org/grpc/internal/stubserver"
33 "google.golang.org/grpc/internal/testutils"
34 testpb "google.golang.org/grpc/interop/grpc_testing"
35 "google.golang.org/grpc/metadata"
36 "google.golang.org/grpc/resolver"
37 "google.golang.org/grpc/resolver/manual"
38 "google.golang.org/grpc/status"
39 )
40
41 type funcConfigSelector struct {
42 f func(iresolver.RPCInfo) (*iresolver.RPCConfig, error)
43 }
44
45 func (f funcConfigSelector) SelectConfig(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) {
46 return f.f(i)
47 }
48
49 func (s) TestConfigSelector(t *testing.T) {
50 gotContextChan := testutils.NewChannelWithSize(1)
51
52 ss := &stubserver.StubServer{
53 EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
54 gotContextChan.SendContext(ctx, ctx)
55 return &testpb.Empty{}, nil
56 },
57 }
58 ss.R = manual.NewBuilderWithScheme("confSel")
59
60 if err := ss.Start(nil); err != nil {
61 t.Fatalf("Error starting endpoint server: %v", err)
62 }
63 defer ss.Stop()
64
65 const normalTimeout = 10 * time.Second
66 ctxDeadline := time.Now().Add(normalTimeout)
67 ctx, cancel := context.WithTimeout(context.Background(), normalTimeout)
68 defer cancel()
69
70 const longTimeout = 30 * time.Second
71 longCtxDeadline := time.Now().Add(longTimeout)
72 longdeadlineCtx, cancel := context.WithTimeout(context.Background(), longTimeout)
73 defer cancel()
74 shorterTimeout := 3 * time.Second
75
76 testMD := metadata.MD{"footest": []string{"bazbar"}}
77 mdOut := metadata.MD{"handler": []string{"value"}}
78
79 var onCommittedCalled bool
80
81 testCases := []struct {
82 name string
83 md metadata.MD
84 config *iresolver.RPCConfig
85 csErr error
86
87 wantMD metadata.MD
88 wantDeadline time.Time
89 wantTimeout time.Duration
90 wantErr error
91 }{{
92 name: "basic",
93 md: testMD,
94 config: &iresolver.RPCConfig{},
95 wantMD: testMD,
96 wantDeadline: ctxDeadline,
97 }, {
98 name: "alter MD",
99 md: testMD,
100 config: &iresolver.RPCConfig{
101 Context: metadata.NewOutgoingContext(ctx, mdOut),
102 },
103 wantMD: mdOut,
104 wantDeadline: ctxDeadline,
105 }, {
106 name: "erroring SelectConfig",
107 csErr: status.Errorf(codes.Unavailable, "cannot send RPC"),
108 wantErr: status.Errorf(codes.Unavailable, "cannot send RPC"),
109 }, {
110 name: "alter timeout; remove MD",
111 md: testMD,
112 config: &iresolver.RPCConfig{
113 Context: longdeadlineCtx,
114 },
115 wantMD: nil,
116 wantDeadline: longCtxDeadline,
117 }, {
118 name: "nil config",
119 md: metadata.MD{},
120 config: nil,
121 wantMD: nil,
122 wantDeadline: ctxDeadline,
123 }, {
124 name: "alter timeout via method config; remove MD",
125 md: testMD,
126 config: &iresolver.RPCConfig{
127 MethodConfig: serviceconfig.MethodConfig{
128 Timeout: &shorterTimeout,
129 },
130 },
131 wantMD: nil,
132 wantTimeout: shorterTimeout,
133 }, {
134 name: "onCommitted callback",
135 md: testMD,
136 config: &iresolver.RPCConfig{
137 OnCommitted: func() {
138 onCommittedCalled = true
139 },
140 },
141 wantMD: testMD,
142 wantDeadline: ctxDeadline,
143 }}
144
145 for _, tc := range testCases {
146 t.Run(tc.name, func(t *testing.T) {
147 var gotInfo *iresolver.RPCInfo
148 state := iresolver.SetConfigSelector(resolver.State{
149 Addresses: []resolver.Address{{Addr: ss.Address}},
150 ServiceConfig: parseServiceConfig(t, ss.R, "{}"),
151 }, funcConfigSelector{
152 f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) {
153 gotInfo = &i
154 cfg := tc.config
155 if cfg != nil && cfg.Context == nil {
156 cfg.Context = i.Context
157 }
158 return cfg, tc.csErr
159 },
160 })
161 ss.R.UpdateState(state)
162
163 onCommittedCalled = false
164 ctx := metadata.NewOutgoingContext(ctx, tc.md)
165 startTime := time.Now()
166 if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); fmt.Sprint(err) != fmt.Sprint(tc.wantErr) {
167 t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.wantErr)
168 } else if err != nil {
169 return
170 }
171
172 if gotInfo == nil {
173 t.Fatalf("no config selector data")
174 }
175
176 if want := "/grpc.testing.TestService/EmptyCall"; gotInfo.Method != want {
177 t.Errorf("gotInfo.Method = %q; want %q", gotInfo.Method, want)
178 }
179
180 gotContextI, ok := gotContextChan.ReceiveOrFail()
181 if !ok {
182 t.Fatalf("no context received")
183 }
184 gotContext := gotContextI.(context.Context)
185
186 gotMD, _ := metadata.FromOutgoingContext(gotInfo.Context)
187 if diff := cmp.Diff(tc.md, gotMD); diff != "" {
188 t.Errorf("gotInfo.Context contains MD %v; want %v\nDiffs: %v", gotMD, tc.md, diff)
189 }
190
191 gotMD, _ = metadata.FromIncomingContext(gotContext)
192
193 for k := range gotMD {
194 if _, ok := tc.wantMD[k]; !ok {
195 delete(gotMD, k)
196 }
197 }
198 if diff := cmp.Diff(tc.wantMD, gotMD, cmpopts.EquateEmpty()); diff != "" {
199 t.Errorf("received md = %v; want %v\nDiffs: %v", gotMD, tc.wantMD, diff)
200 }
201
202 wantDeadline := tc.wantDeadline
203 if wantDeadline == (time.Time{}) {
204 wantDeadline = startTime.Add(tc.wantTimeout)
205 }
206 deadlineGot, _ := gotContext.Deadline()
207 if diff := deadlineGot.Sub(wantDeadline); diff > time.Second || diff < -time.Second {
208 t.Errorf("received deadline = %v; want ~%v", deadlineGot, wantDeadline)
209 }
210
211 if tc.config != nil && tc.config.OnCommitted != nil && !onCommittedCalled {
212 t.Errorf("OnCommitted callback not called")
213 }
214 })
215 }
216 }
217
View as plain text