1 package pool
2
3 import (
4 "errors"
5 "fmt"
6 "strconv"
7 "sync/atomic"
8 "testing"
9 "time"
10
11 "github.com/stretchr/testify/require"
12 )
13
14 func TestResultErrorGroup(t *testing.T) {
15 t.Parallel()
16
17 err1 := errors.New("err1")
18 err2 := errors.New("err2")
19
20 t.Run("panics on configuration after init", func(t *testing.T) {
21 t.Run("before wait", func(t *testing.T) {
22 t.Parallel()
23 g := NewWithResults[int]().WithErrors()
24 g.Go(func() (int, error) { return 0, nil })
25 require.Panics(t, func() { g.WithMaxGoroutines(10) })
26 })
27
28 t.Run("after wait", func(t *testing.T) {
29 t.Parallel()
30 g := NewWithResults[int]().WithErrors()
31 g.Go(func() (int, error) { return 0, nil })
32 _, _ = g.Wait()
33 require.Panics(t, func() { g.WithMaxGoroutines(10) })
34 })
35 })
36
37 t.Run("wait returns no error if no errors", func(t *testing.T) {
38 t.Parallel()
39 g := NewWithResults[int]().WithErrors()
40 g.Go(func() (int, error) { return 1, nil })
41 res, err := g.Wait()
42 require.NoError(t, err)
43 require.Equal(t, []int{1}, res)
44 })
45
46 t.Run("wait error if func returns error", func(t *testing.T) {
47 t.Parallel()
48 g := NewWithResults[int]().WithErrors()
49 g.Go(func() (int, error) { return 0, err1 })
50 res, err := g.Wait()
51 require.Len(t, res, 0)
52 require.ErrorIs(t, err, err1)
53 })
54
55 t.Run("WithCollectErrored", func(t *testing.T) {
56 t.Parallel()
57 g := NewWithResults[int]().WithErrors().WithCollectErrored()
58 g.Go(func() (int, error) { return 0, err1 })
59 res, err := g.Wait()
60 require.Len(t, res, 1)
61 require.ErrorIs(t, err, err1)
62 })
63
64 t.Run("WithFirstError", func(t *testing.T) {
65 t.Parallel()
66 g := NewWithResults[int]().WithErrors().WithFirstError()
67 synchronizer := make(chan struct{})
68 g.Go(func() (int, error) {
69 <-synchronizer
70
71
72
73
74
75
76
77 time.Sleep(100 * time.Millisecond)
78 return 0, err1
79 })
80 g.Go(func() (int, error) {
81 defer close(synchronizer)
82 return 0, err2
83 })
84 res, err := g.Wait()
85 require.Len(t, res, 0)
86 require.ErrorIs(t, err, err2)
87 require.NotErrorIs(t, err, err1)
88 })
89
90 t.Run("wait error is all returned errors", func(t *testing.T) {
91 t.Parallel()
92 g := NewWithResults[int]().WithErrors()
93 g.Go(func() (int, error) { return 0, err1 })
94 g.Go(func() (int, error) { return 0, nil })
95 g.Go(func() (int, error) { return 0, err2 })
96 res, err := g.Wait()
97 require.Len(t, res, 1)
98 require.ErrorIs(t, err, err1)
99 require.ErrorIs(t, err, err2)
100 })
101
102 t.Run("limit", func(t *testing.T) {
103 t.Parallel()
104 for _, maxConcurrency := range []int{1, 10, 100} {
105 t.Run(strconv.Itoa(maxConcurrency), func(t *testing.T) {
106 maxConcurrency := maxConcurrency
107
108 t.Parallel()
109 g := NewWithResults[int]().WithErrors().WithMaxGoroutines(maxConcurrency)
110
111 var currentConcurrent atomic.Int64
112 taskCount := maxConcurrency * 10
113 for i := 0; i < taskCount; i++ {
114 g.Go(func() (int, error) {
115 cur := currentConcurrent.Add(1)
116 if cur > int64(maxConcurrency) {
117 return 0, fmt.Errorf("expected no more than %d concurrent goroutine", maxConcurrency)
118 }
119 time.Sleep(time.Millisecond)
120 currentConcurrent.Add(-1)
121 return 0, nil
122 })
123 }
124 res, err := g.Wait()
125 require.Len(t, res, taskCount)
126 require.NoError(t, err)
127 require.Equal(t, int64(0), currentConcurrent.Load())
128 })
129 }
130 })
131 }
132
View as plain text