1 package pool
2
3 import (
4 "context"
5 "errors"
6 "fmt"
7 "sort"
8 "strconv"
9 "sync/atomic"
10 "testing"
11 "time"
12
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15 )
16
17 func TestResultContextPool(t *testing.T) {
18 t.Parallel()
19
20 err1 := errors.New("err1")
21 err2 := errors.New("err2")
22
23 t.Run("panics on configuration after init", func(t *testing.T) {
24 t.Run("before wait", func(t *testing.T) {
25 t.Parallel()
26 g := NewWithResults[int]().WithContext(context.Background())
27 g.Go(func(context.Context) (int, error) { return 0, nil })
28 require.Panics(t, func() { g.WithMaxGoroutines(10) })
29 })
30
31 t.Run("after wait", func(t *testing.T) {
32 t.Parallel()
33 g := NewWithResults[int]().WithContext(context.Background())
34 g.Go(func(context.Context) (int, error) { return 0, nil })
35 _, _ = g.Wait()
36 require.Panics(t, func() { g.WithMaxGoroutines(10) })
37 })
38 })
39
40 t.Run("behaves the same as ErrorGroup", func(t *testing.T) {
41 t.Parallel()
42 bgctx := context.Background()
43 t.Run("wait returns no error if no errors", func(t *testing.T) {
44 t.Parallel()
45 g := NewWithResults[int]().WithContext(bgctx)
46 g.Go(func(context.Context) (int, error) { return 0, nil })
47 res, err := g.Wait()
48 require.Len(t, res, 1)
49 require.NoError(t, err)
50 })
51
52 t.Run("wait error if func returns error", func(t *testing.T) {
53 t.Parallel()
54 g := NewWithResults[int]().WithContext(bgctx)
55 g.Go(func(context.Context) (int, error) { return 0, err1 })
56 res, err := g.Wait()
57 require.Len(t, res, 0)
58 require.ErrorIs(t, err, err1)
59 })
60
61 t.Run("wait error is all returned errors", func(t *testing.T) {
62 t.Parallel()
63 g := NewWithResults[int]().WithErrors().WithContext(bgctx)
64 g.Go(func(context.Context) (int, error) { return 0, err1 })
65 g.Go(func(context.Context) (int, error) { return 0, nil })
66 g.Go(func(context.Context) (int, error) { return 0, err2 })
67 res, err := g.Wait()
68 require.Len(t, res, 1)
69 require.ErrorIs(t, err, err1)
70 require.ErrorIs(t, err, err2)
71 })
72 })
73
74 t.Run("context cancel propagates", func(t *testing.T) {
75 t.Parallel()
76 ctx, cancel := context.WithCancel(context.Background())
77 g := NewWithResults[int]().WithContext(ctx)
78 g.Go(func(ctx context.Context) (int, error) {
79 <-ctx.Done()
80 return 0, ctx.Err()
81 })
82 cancel()
83 res, err := g.Wait()
84 require.Len(t, res, 0)
85 require.ErrorIs(t, err, context.Canceled)
86 })
87
88 t.Run("WithCancelOnError", func(t *testing.T) {
89 t.Parallel()
90 g := NewWithResults[int]().WithContext(context.Background()).WithCancelOnError()
91 g.Go(func(ctx context.Context) (int, error) {
92 <-ctx.Done()
93 return 0, ctx.Err()
94 })
95 g.Go(func(ctx context.Context) (int, error) {
96 return 0, err1
97 })
98 res, err := g.Wait()
99 require.Len(t, res, 0)
100 require.ErrorIs(t, err, context.Canceled)
101 require.ErrorIs(t, err, err1)
102 })
103
104 t.Run("WithCancelOnError and panic", func(t *testing.T) {
105 t.Parallel()
106 p := NewWithResults[int]().
107 WithContext(context.Background()).
108 WithCancelOnError()
109 var cancelledTasks atomic.Int64
110 p.Go(func(ctx context.Context) (int, error) {
111 <-ctx.Done()
112 cancelledTasks.Add(1)
113 return 0, ctx.Err()
114 })
115 p.Go(func(ctx context.Context) (int, error) {
116 <-ctx.Done()
117 cancelledTasks.Add(1)
118 return 0, ctx.Err()
119 })
120 p.Go(func(ctx context.Context) (int, error) {
121 panic("abort!")
122 })
123 assert.Panics(t, func() { _, _ = p.Wait() })
124 assert.EqualValues(t, 2, cancelledTasks.Load())
125 })
126
127 t.Run("no WithCancelOnError", func(t *testing.T) {
128 t.Parallel()
129 g := NewWithResults[int]().WithContext(context.Background())
130 g.Go(func(ctx context.Context) (int, error) {
131 select {
132 case <-ctx.Done():
133 return 0, ctx.Err()
134 case <-time.After(10 * time.Millisecond):
135 return 0, nil
136 }
137 })
138 g.Go(func(ctx context.Context) (int, error) {
139 return 0, err1
140 })
141 res, err := g.Wait()
142 require.Len(t, res, 1)
143 require.NotErrorIs(t, err, context.Canceled)
144 require.ErrorIs(t, err, err1)
145 })
146
147 t.Run("WithCollectErrored", func(t *testing.T) {
148 t.Parallel()
149 g := NewWithResults[int]().WithContext(context.Background()).WithCollectErrored()
150 g.Go(func(context.Context) (int, error) { return 0, err1 })
151 res, err := g.Wait()
152 require.Len(t, res, 1)
153 require.ErrorIs(t, err, err1)
154 })
155
156 t.Run("WithFirstError", func(t *testing.T) {
157 t.Parallel()
158 g := NewWithResults[int]().WithContext(context.Background()).WithFirstError()
159 sync := make(chan struct{})
160 g.Go(func(ctx context.Context) (int, error) {
161 defer close(sync)
162 return 0, err1
163 })
164 g.Go(func(ctx context.Context) (int, error) {
165
166
167
168
169
170
171 <-sync
172 time.Sleep(10 * time.Millisecond)
173 return 0, err2
174 })
175 res, err := g.Wait()
176 require.Len(t, res, 0)
177 require.ErrorIs(t, err, err1)
178 require.NotErrorIs(t, err, err2)
179 })
180
181 t.Run("limit", func(t *testing.T) {
182 t.Parallel()
183 for _, maxConcurrency := range []int{1, 10, 100} {
184 t.Run(strconv.Itoa(maxConcurrency), func(t *testing.T) {
185 maxConcurrency := maxConcurrency
186
187 t.Parallel()
188 ctx := context.Background()
189 g := NewWithResults[int]().WithContext(ctx).WithMaxGoroutines(maxConcurrency)
190
191 var currentConcurrent atomic.Int64
192 taskCount := maxConcurrency * 10
193 expected := make([]int, taskCount)
194 for i := 0; i < taskCount; i++ {
195 i := i
196 expected[i] = i
197 g.Go(func(context.Context) (int, error) {
198 cur := currentConcurrent.Add(1)
199 if cur > int64(maxConcurrency) {
200 return 0, fmt.Errorf("expected no more than %d concurrent goroutines", maxConcurrency)
201 }
202 time.Sleep(time.Millisecond)
203 currentConcurrent.Add(-1)
204 return i, nil
205 })
206 }
207 res, err := g.Wait()
208 sort.Ints(res)
209 require.Equal(t, expected, res)
210 require.NoError(t, err)
211 require.Equal(t, int64(0), currentConcurrent.Load())
212 })
213 }
214 })
215 }
216
View as plain text