1 package pool
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "strconv"
8 "sync/atomic"
9 "testing"
10 "time"
11
12 "github.com/stretchr/testify/assert"
13 "github.com/stretchr/testify/require"
14 )
15
16 func ExampleContextPool_WithCancelOnError() {
17 p := New().
18 WithMaxGoroutines(4).
19 WithContext(context.Background()).
20 WithCancelOnError()
21 for i := 0; i < 3; i++ {
22 i := i
23 p.Go(func(ctx context.Context) error {
24 if i == 2 {
25 return errors.New("I will cancel all other tasks!")
26 }
27 <-ctx.Done()
28 return nil
29 })
30 }
31 err := p.Wait()
32 fmt.Println(err)
33
34
35 }
36
37 func TestContextPool(t *testing.T) {
38 t.Parallel()
39
40 err1 := errors.New("err1")
41 err2 := errors.New("err2")
42 bgctx := context.Background()
43
44 t.Run("panics on configuration after init", func(t *testing.T) {
45 t.Run("before wait", func(t *testing.T) {
46 t.Parallel()
47 g := New().WithContext(context.Background())
48 g.Go(func(context.Context) error { return nil })
49 require.Panics(t, func() { g.WithMaxGoroutines(10) })
50 })
51
52 t.Run("after wait", func(t *testing.T) {
53 t.Parallel()
54 g := New().WithContext(context.Background())
55 g.Go(func(context.Context) error { return nil })
56 _ = g.Wait()
57 require.Panics(t, func() { g.WithMaxGoroutines(10) })
58 })
59 })
60
61 t.Run("behaves the same as ErrorGroup", func(t *testing.T) {
62 t.Parallel()
63
64 t.Run("wait returns no error if no errors", func(t *testing.T) {
65 t.Parallel()
66 p := New().WithContext(bgctx)
67 p.Go(func(context.Context) error { return nil })
68 require.NoError(t, p.Wait())
69 })
70
71 t.Run("wait errors if func returns error", func(t *testing.T) {
72 t.Parallel()
73 p := New().WithContext(bgctx)
74 p.Go(func(context.Context) error { return err1 })
75 require.ErrorIs(t, p.Wait(), err1)
76 })
77
78 t.Run("wait error is all returned errors", func(t *testing.T) {
79 t.Parallel()
80 p := New().WithErrors().WithContext(bgctx)
81 p.Go(func(context.Context) error { return err1 })
82 p.Go(func(context.Context) error { return nil })
83 p.Go(func(context.Context) error { return err2 })
84 err := p.Wait()
85 require.ErrorIs(t, err, err1)
86 require.ErrorIs(t, err, err2)
87 })
88 })
89
90 t.Run("context error propagates", func(t *testing.T) {
91 t.Parallel()
92
93 t.Run("canceled", func(t *testing.T) {
94 t.Parallel()
95 ctx, cancel := context.WithCancel(bgctx)
96 p := New().WithContext(ctx)
97 p.Go(func(ctx context.Context) error {
98 <-ctx.Done()
99 return ctx.Err()
100 })
101 cancel()
102 require.ErrorIs(t, p.Wait(), context.Canceled)
103 })
104
105 t.Run("timed out", func(t *testing.T) {
106 t.Parallel()
107 ctx, cancel := context.WithTimeout(bgctx, time.Millisecond)
108 defer cancel()
109 p := New().WithContext(ctx)
110 p.Go(func(ctx context.Context) error {
111 <-ctx.Done()
112 return ctx.Err()
113 })
114 require.ErrorIs(t, p.Wait(), context.DeadlineExceeded)
115 })
116 })
117
118 t.Run("WithCancelOnError", func(t *testing.T) {
119 t.Parallel()
120 p := New().WithContext(bgctx).WithCancelOnError()
121 p.Go(func(ctx context.Context) error {
122 <-ctx.Done()
123 return ctx.Err()
124 })
125 p.Go(func(ctx context.Context) error {
126 return err1
127 })
128 err := p.Wait()
129 require.ErrorIs(t, err, context.Canceled)
130 require.ErrorIs(t, err, err1)
131 })
132
133 t.Run("no WithCancelOnError", func(t *testing.T) {
134 t.Parallel()
135 p := New().WithContext(bgctx)
136 p.Go(func(ctx context.Context) error {
137 select {
138 case <-ctx.Done():
139 return ctx.Err()
140 case <-time.After(10 * time.Millisecond):
141 return nil
142 }
143 })
144 p.Go(func(ctx context.Context) error {
145 return err1
146 })
147 err := p.Wait()
148 require.ErrorIs(t, err, err1)
149 require.NotErrorIs(t, err, context.Canceled)
150 })
151
152 t.Run("WithFirstError", func(t *testing.T) {
153 t.Parallel()
154 p := New().WithContext(bgctx).WithFirstError()
155 sync := make(chan struct{})
156 p.Go(func(ctx context.Context) error {
157 defer close(sync)
158 return err1
159 })
160 p.Go(func(ctx context.Context) error {
161
162
163
164
165
166
167 <-sync
168 time.Sleep(10 * time.Millisecond)
169 return err2
170 })
171 err := p.Wait()
172 require.ErrorIs(t, err, err1)
173 require.NotErrorIs(t, err, err2)
174 })
175
176 t.Run("WithFirstError and WithCancelOnError", func(t *testing.T) {
177 t.Parallel()
178 p := New().WithContext(bgctx).WithFirstError().WithCancelOnError()
179 p.Go(func(ctx context.Context) error {
180 return err1
181 })
182 p.Go(func(ctx context.Context) error {
183 <-ctx.Done()
184 return ctx.Err()
185 })
186 err := p.Wait()
187 require.ErrorIs(t, err, err1)
188 require.NotErrorIs(t, err, context.Canceled)
189 })
190
191 t.Run("WithCancelOnError and panic", func(t *testing.T) {
192 t.Parallel()
193 p := New().WithContext(bgctx).WithCancelOnError()
194 var cancelledTasks atomic.Int64
195 p.Go(func(ctx context.Context) error {
196 <-ctx.Done()
197 cancelledTasks.Add(1)
198 return ctx.Err()
199 })
200 p.Go(func(ctx context.Context) error {
201 <-ctx.Done()
202 cancelledTasks.Add(1)
203 return ctx.Err()
204 })
205 p.Go(func(ctx context.Context) error {
206 panic("abort!")
207 })
208 assert.Panics(t, func() { _ = p.Wait() })
209 assert.EqualValues(t, 2, cancelledTasks.Load())
210 })
211
212 t.Run("limit", func(t *testing.T) {
213 t.Parallel()
214 for _, maxConcurrent := range []int{1, 10, 100} {
215 t.Run(strconv.Itoa(maxConcurrent), func(t *testing.T) {
216 maxConcurrent := maxConcurrent
217
218 t.Parallel()
219 p := New().WithContext(bgctx).WithMaxGoroutines(maxConcurrent)
220
221 var currentConcurrent atomic.Int64
222 for i := 0; i < 100; i++ {
223 p.Go(func(context.Context) error {
224 cur := currentConcurrent.Add(1)
225 if cur > int64(maxConcurrent) {
226 return fmt.Errorf("expected no more than %d concurrent goroutine", maxConcurrent)
227 }
228 time.Sleep(time.Millisecond)
229 currentConcurrent.Add(-1)
230 return nil
231 })
232 }
233 require.NoError(t, p.Wait())
234 require.Equal(t, int64(0), currentConcurrent.Load())
235 })
236 }
237 })
238 }
239
View as plain text