1 package server
2
3 import (
4 "context"
5 "fmt"
6 "math/rand"
7 "testing"
8 "time"
9
10 "cloud.google.com/go/pubsub"
11 "cloud.google.com/go/pubsub/pstest"
12 "github.com/stretchr/testify/require"
13 "google.golang.org/api/option"
14 "google.golang.org/grpc"
15 "google.golang.org/grpc/credentials/insecure"
16 )
17
18
19
20 func TestDiffReceiver(t *testing.T) {
21 var srv = pstest.NewServer()
22 defer srv.Close()
23
24 var grpcOpt = grpc.WithTransportCredentials(insecure.NewCredentials())
25 conn, err := grpc.NewClient(srv.Addr, grpcOpt)
26 if err != nil {
27 t.Fatal(err)
28 }
29 defer conn.Close()
30
31 var ctx, cancel = context.WithCancel(context.Background())
32 defer cancel()
33
34 client, err := pubsub.NewClient(ctx, "project", option.WithGRPCConn(conn))
35 if err != nil {
36 t.Fatal(err)
37 }
38 defer client.Close()
39
40 var rm = &ReceiverMux{
41
42 cfg: &ReceiverMuxConfig{
43 Handler: func(_ context.Context, _ *pubsub.Message) error {
44 return nil
45 },
46 PollSubscriptionExistsPeriod: time.Hour,
47 SubscriptionID: "fake",
48 },
49 client: client,
50 receivers: make(map[string]*Receiver),
51 }
52
53 var dtcs = []*DiffTestCase{
54
55 NewDiffTestCase(t, 0, 0, 0),
56 NewDiffTestCase(t, 1, 0, 0),
57 NewDiffTestCase(t, 50, 0, 0),
58
59 NewDiffTestCase(t, 0, 1, 0),
60 NewDiffTestCase(t, 1, 1, 0),
61 NewDiffTestCase(t, 20, 1, 0),
62
63 NewDiffTestCase(t, 0, 10, 0),
64 NewDiffTestCase(t, 1, 20, 0),
65 NewDiffTestCase(t, 20, 20, 0),
66
67 NewDiffTestCase(t, 1, 0, 1),
68 NewDiffTestCase(t, 2, 0, 1),
69 NewDiffTestCase(t, 20, 0, 1),
70
71 NewDiffTestCase(t, 2, 0, 2),
72 NewDiffTestCase(t, 10, 0, 10),
73 NewDiffTestCase(t, 20, 0, 10),
74
75 NewDiffTestCase(t, 1, 1, 1),
76 NewDiffTestCase(t, 10, 1, 1),
77
78 NewDiffTestCase(t, 2, 2, 2),
79 NewDiffTestCase(t, 10, 5, 5),
80 NewDiffTestCase(t, 10, 2, 7),
81 NewDiffTestCase(t, 10, 7, 2),
82 }
83
84 for i, dtc := range dtcs {
85 t.Logf("Test case %d: current=%d added=%d, dropped=%d", i, len(dtc.Before), len(dtc.Added), len(dtc.Dropped))
86 dtc.Test(t, rm)
87 }
88 }
89
90 type DiffTestCase struct {
91 Added map[string]bool
92 Dropped map[string]bool
93
94 Before map[string]bool
95 Polled []string
96
97
98
99
100
101 After map[string]bool
102 }
103
104 func (dtc *DiffTestCase) Test(t *testing.T, rm *ReceiverMux) {
105
106 rm.receivers = make(map[string]*Receiver)
107 for b := range dtc.Before {
108 rm.receivers[b] = &Receiver{
109 projectID: b,
110 }
111 }
112 require.Equal(t, len(dtc.Before), len(rm.receivers))
113
114
115 var added, dropped = rm.diffReceivers(dtc.Polled...)
116
117 require.Equal(t, len(added), len(dtc.Added))
118 for _, r := range added {
119 require.True(t, dtc.Added[r.projectID])
120 }
121
122 require.Equal(t, len(dropped), len(dtc.Dropped))
123 for _, r := range dropped {
124 require.True(t, dtc.Dropped[r.projectID])
125 }
126
127 require.Equal(t, len(rm.receivers), len(dtc.After))
128
129 require.Equal(t, len(rm.receivers), len(dtc.Before)+len(dtc.Added))
130 require.Equal(t, len(rm.receivers), len(dtc.Polled)+len(dtc.Dropped))
131 for _, r := range rm.receivers {
132 require.True(t, dtc.After[r.projectID])
133 }
134 }
135
136 func NewDiffTestCase(t *testing.T, current, add, drop int) *DiffTestCase {
137 if drop > current {
138 panic("current must be greater than the amount dropped")
139 }
140
141 var dtc = DiffTestCase{
142 Added: make(map[string]bool),
143 Dropped: make(map[string]bool),
144 Before: make(map[string]bool),
145 After: make(map[string]bool),
146 }
147
148 for range add {
149 var b = fmt.Sprintf("b%x", rand.Int63())
150 dtc.Added[b] = true
151 dtc.After[b] = true
152 dtc.Polled = append(dtc.Polled, b)
153 }
154
155 var dropped = make(map[int]bool)
156 var perm = rand.Perm(current)
157 for _, j := range perm[:drop] {
158 dropped[j] = true
159 }
160
161 for i := range current {
162 var b = fmt.Sprintf("b%x", rand.Int63())
163 dtc.Before[b] = true
164 dtc.After[b] = true
165
166 if dropped[i] {
167 dtc.Dropped[b] = true
168 } else {
169 dtc.Polled = append(dtc.Polled, b)
170 }
171 }
172
173
174 require.Equal(t, current, len(dtc.Before))
175 require.Equal(t, add, len(dtc.Added))
176 require.Equal(t, drop, len(dtc.Dropped))
177 require.Equal(t, current+add, len(dtc.After))
178 require.Equal(t, current+add-drop, len(dtc.Polled))
179 for dropped := range dtc.Dropped {
180 require.True(t, dtc.After[dropped])
181 require.True(t, dtc.Before[dropped])
182 require.False(t, dtc.Added[dropped])
183 }
184 for added := range dtc.Added {
185 require.True(t, dtc.After[added])
186 require.False(t, dtc.Before[added])
187 require.False(t, dtc.Dropped[added])
188 }
189 for _, polled := range dtc.Polled {
190 require.True(t, dtc.After[polled])
191 require.True(t, dtc.Added[polled] || dtc.Before[polled])
192 require.False(t, dtc.Dropped[polled])
193 }
194
195 return &dtc
196 }
197
View as plain text