...

Source file src/github.com/sourcegraph/conc/pool/result_context_pool_test.go

Documentation: github.com/sourcegraph/conc/pool

     1  package pool
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sort"
     8  	"strconv"
     9  	"sync/atomic"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func TestResultContextPool(t *testing.T) {
    18  	t.Parallel()
    19  
    20  	err1 := errors.New("err1")
    21  	err2 := errors.New("err2")
    22  
    23  	t.Run("panics on configuration after init", func(t *testing.T) {
    24  		t.Run("before wait", func(t *testing.T) {
    25  			t.Parallel()
    26  			g := NewWithResults[int]().WithContext(context.Background())
    27  			g.Go(func(context.Context) (int, error) { return 0, nil })
    28  			require.Panics(t, func() { g.WithMaxGoroutines(10) })
    29  		})
    30  
    31  		t.Run("after wait", func(t *testing.T) {
    32  			t.Parallel()
    33  			g := NewWithResults[int]().WithContext(context.Background())
    34  			g.Go(func(context.Context) (int, error) { return 0, nil })
    35  			_, _ = g.Wait()
    36  			require.Panics(t, func() { g.WithMaxGoroutines(10) })
    37  		})
    38  	})
    39  
    40  	t.Run("behaves the same as ErrorGroup", func(t *testing.T) {
    41  		t.Parallel()
    42  		bgctx := context.Background()
    43  		t.Run("wait returns no error if no errors", func(t *testing.T) {
    44  			t.Parallel()
    45  			g := NewWithResults[int]().WithContext(bgctx)
    46  			g.Go(func(context.Context) (int, error) { return 0, nil })
    47  			res, err := g.Wait()
    48  			require.Len(t, res, 1)
    49  			require.NoError(t, err)
    50  		})
    51  
    52  		t.Run("wait error if func returns error", func(t *testing.T) {
    53  			t.Parallel()
    54  			g := NewWithResults[int]().WithContext(bgctx)
    55  			g.Go(func(context.Context) (int, error) { return 0, err1 })
    56  			res, err := g.Wait()
    57  			require.Len(t, res, 0)
    58  			require.ErrorIs(t, err, err1)
    59  		})
    60  
    61  		t.Run("wait error is all returned errors", func(t *testing.T) {
    62  			t.Parallel()
    63  			g := NewWithResults[int]().WithErrors().WithContext(bgctx)
    64  			g.Go(func(context.Context) (int, error) { return 0, err1 })
    65  			g.Go(func(context.Context) (int, error) { return 0, nil })
    66  			g.Go(func(context.Context) (int, error) { return 0, err2 })
    67  			res, err := g.Wait()
    68  			require.Len(t, res, 1)
    69  			require.ErrorIs(t, err, err1)
    70  			require.ErrorIs(t, err, err2)
    71  		})
    72  	})
    73  
    74  	t.Run("context cancel propagates", func(t *testing.T) {
    75  		t.Parallel()
    76  		ctx, cancel := context.WithCancel(context.Background())
    77  		g := NewWithResults[int]().WithContext(ctx)
    78  		g.Go(func(ctx context.Context) (int, error) {
    79  			<-ctx.Done()
    80  			return 0, ctx.Err()
    81  		})
    82  		cancel()
    83  		res, err := g.Wait()
    84  		require.Len(t, res, 0)
    85  		require.ErrorIs(t, err, context.Canceled)
    86  	})
    87  
    88  	t.Run("WithCancelOnError", func(t *testing.T) {
    89  		t.Parallel()
    90  		g := NewWithResults[int]().WithContext(context.Background()).WithCancelOnError()
    91  		g.Go(func(ctx context.Context) (int, error) {
    92  			<-ctx.Done()
    93  			return 0, ctx.Err()
    94  		})
    95  		g.Go(func(ctx context.Context) (int, error) {
    96  			return 0, err1
    97  		})
    98  		res, err := g.Wait()
    99  		require.Len(t, res, 0)
   100  		require.ErrorIs(t, err, context.Canceled)
   101  		require.ErrorIs(t, err, err1)
   102  	})
   103  
   104  	t.Run("WithCancelOnError and panic", func(t *testing.T) {
   105  		t.Parallel()
   106  		p := NewWithResults[int]().
   107  			WithContext(context.Background()).
   108  			WithCancelOnError()
   109  		var cancelledTasks atomic.Int64
   110  		p.Go(func(ctx context.Context) (int, error) {
   111  			<-ctx.Done()
   112  			cancelledTasks.Add(1)
   113  			return 0, ctx.Err()
   114  		})
   115  		p.Go(func(ctx context.Context) (int, error) {
   116  			<-ctx.Done()
   117  			cancelledTasks.Add(1)
   118  			return 0, ctx.Err()
   119  		})
   120  		p.Go(func(ctx context.Context) (int, error) {
   121  			panic("abort!")
   122  		})
   123  		assert.Panics(t, func() { _, _ = p.Wait() })
   124  		assert.EqualValues(t, 2, cancelledTasks.Load())
   125  	})
   126  
   127  	t.Run("no WithCancelOnError", func(t *testing.T) {
   128  		t.Parallel()
   129  		g := NewWithResults[int]().WithContext(context.Background())
   130  		g.Go(func(ctx context.Context) (int, error) {
   131  			select {
   132  			case <-ctx.Done():
   133  				return 0, ctx.Err()
   134  			case <-time.After(10 * time.Millisecond):
   135  				return 0, nil
   136  			}
   137  		})
   138  		g.Go(func(ctx context.Context) (int, error) {
   139  			return 0, err1
   140  		})
   141  		res, err := g.Wait()
   142  		require.Len(t, res, 1)
   143  		require.NotErrorIs(t, err, context.Canceled)
   144  		require.ErrorIs(t, err, err1)
   145  	})
   146  
   147  	t.Run("WithCollectErrored", func(t *testing.T) {
   148  		t.Parallel()
   149  		g := NewWithResults[int]().WithContext(context.Background()).WithCollectErrored()
   150  		g.Go(func(context.Context) (int, error) { return 0, err1 })
   151  		res, err := g.Wait()
   152  		require.Len(t, res, 1) // errored value is collected
   153  		require.ErrorIs(t, err, err1)
   154  	})
   155  
   156  	t.Run("WithFirstError", func(t *testing.T) {
   157  		t.Parallel()
   158  		g := NewWithResults[int]().WithContext(context.Background()).WithFirstError()
   159  		sync := make(chan struct{})
   160  		g.Go(func(ctx context.Context) (int, error) {
   161  			defer close(sync)
   162  			return 0, err1
   163  		})
   164  		g.Go(func(ctx context.Context) (int, error) {
   165  			// This test has a race condition. After the first goroutine
   166  			// completes, this goroutine is woken up because sync is closed.
   167  			// However, this goroutine might be woken up before the error from
   168  			// the first goroutine is registered. To prevent that, we sleep for
   169  			// another 10 milliseconds, giving the other goroutine time to return
   170  			// and register its error before this goroutine returns its error.
   171  			<-sync
   172  			time.Sleep(10 * time.Millisecond)
   173  			return 0, err2
   174  		})
   175  		res, err := g.Wait()
   176  		require.Len(t, res, 0)
   177  		require.ErrorIs(t, err, err1)
   178  		require.NotErrorIs(t, err, err2)
   179  	})
   180  
   181  	t.Run("limit", func(t *testing.T) {
   182  		t.Parallel()
   183  		for _, maxConcurrency := range []int{1, 10, 100} {
   184  			t.Run(strconv.Itoa(maxConcurrency), func(t *testing.T) {
   185  				maxConcurrency := maxConcurrency // copy
   186  
   187  				t.Parallel()
   188  				ctx := context.Background()
   189  				g := NewWithResults[int]().WithContext(ctx).WithMaxGoroutines(maxConcurrency)
   190  
   191  				var currentConcurrent atomic.Int64
   192  				taskCount := maxConcurrency * 10
   193  				expected := make([]int, taskCount)
   194  				for i := 0; i < taskCount; i++ {
   195  					i := i
   196  					expected[i] = i
   197  					g.Go(func(context.Context) (int, error) {
   198  						cur := currentConcurrent.Add(1)
   199  						if cur > int64(maxConcurrency) {
   200  							return 0, fmt.Errorf("expected no more than %d concurrent goroutines", maxConcurrency)
   201  						}
   202  						time.Sleep(time.Millisecond)
   203  						currentConcurrent.Add(-1)
   204  						return i, nil
   205  					})
   206  				}
   207  				res, err := g.Wait()
   208  				sort.Ints(res)
   209  				require.Equal(t, expected, res)
   210  				require.NoError(t, err)
   211  				require.Equal(t, int64(0), currentConcurrent.Load())
   212  			})
   213  		}
   214  	})
   215  }
   216  

View as plain text