...

Source file src/github.com/sourcegraph/conc/pool/context_pool_test.go

Documentation: github.com/sourcegraph/conc/pool

     1  package pool
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strconv"
     8  	"sync/atomic"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  func ExampleContextPool_WithCancelOnError() {
    17  	p := New().
    18  		WithMaxGoroutines(4).
    19  		WithContext(context.Background()).
    20  		WithCancelOnError()
    21  	for i := 0; i < 3; i++ {
    22  		i := i
    23  		p.Go(func(ctx context.Context) error {
    24  			if i == 2 {
    25  				return errors.New("I will cancel all other tasks!")
    26  			}
    27  			<-ctx.Done()
    28  			return nil
    29  		})
    30  	}
    31  	err := p.Wait()
    32  	fmt.Println(err)
    33  	// Output:
    34  	// I will cancel all other tasks!
    35  }
    36  
    37  func TestContextPool(t *testing.T) {
    38  	t.Parallel()
    39  
    40  	err1 := errors.New("err1")
    41  	err2 := errors.New("err2")
    42  	bgctx := context.Background()
    43  
    44  	t.Run("panics on configuration after init", func(t *testing.T) {
    45  		t.Run("before wait", func(t *testing.T) {
    46  			t.Parallel()
    47  			g := New().WithContext(context.Background())
    48  			g.Go(func(context.Context) error { return nil })
    49  			require.Panics(t, func() { g.WithMaxGoroutines(10) })
    50  		})
    51  
    52  		t.Run("after wait", func(t *testing.T) {
    53  			t.Parallel()
    54  			g := New().WithContext(context.Background())
    55  			g.Go(func(context.Context) error { return nil })
    56  			_ = g.Wait()
    57  			require.Panics(t, func() { g.WithMaxGoroutines(10) })
    58  		})
    59  	})
    60  
    61  	t.Run("behaves the same as ErrorGroup", func(t *testing.T) {
    62  		t.Parallel()
    63  
    64  		t.Run("wait returns no error if no errors", func(t *testing.T) {
    65  			t.Parallel()
    66  			p := New().WithContext(bgctx)
    67  			p.Go(func(context.Context) error { return nil })
    68  			require.NoError(t, p.Wait())
    69  		})
    70  
    71  		t.Run("wait errors if func returns error", func(t *testing.T) {
    72  			t.Parallel()
    73  			p := New().WithContext(bgctx)
    74  			p.Go(func(context.Context) error { return err1 })
    75  			require.ErrorIs(t, p.Wait(), err1)
    76  		})
    77  
    78  		t.Run("wait error is all returned errors", func(t *testing.T) {
    79  			t.Parallel()
    80  			p := New().WithErrors().WithContext(bgctx)
    81  			p.Go(func(context.Context) error { return err1 })
    82  			p.Go(func(context.Context) error { return nil })
    83  			p.Go(func(context.Context) error { return err2 })
    84  			err := p.Wait()
    85  			require.ErrorIs(t, err, err1)
    86  			require.ErrorIs(t, err, err2)
    87  		})
    88  	})
    89  
    90  	t.Run("context error propagates", func(t *testing.T) {
    91  		t.Parallel()
    92  
    93  		t.Run("canceled", func(t *testing.T) {
    94  			t.Parallel()
    95  			ctx, cancel := context.WithCancel(bgctx)
    96  			p := New().WithContext(ctx)
    97  			p.Go(func(ctx context.Context) error {
    98  				<-ctx.Done()
    99  				return ctx.Err()
   100  			})
   101  			cancel()
   102  			require.ErrorIs(t, p.Wait(), context.Canceled)
   103  		})
   104  
   105  		t.Run("timed out", func(t *testing.T) {
   106  			t.Parallel()
   107  			ctx, cancel := context.WithTimeout(bgctx, time.Millisecond)
   108  			defer cancel()
   109  			p := New().WithContext(ctx)
   110  			p.Go(func(ctx context.Context) error {
   111  				<-ctx.Done()
   112  				return ctx.Err()
   113  			})
   114  			require.ErrorIs(t, p.Wait(), context.DeadlineExceeded)
   115  		})
   116  	})
   117  
   118  	t.Run("WithCancelOnError", func(t *testing.T) {
   119  		t.Parallel()
   120  		p := New().WithContext(bgctx).WithCancelOnError()
   121  		p.Go(func(ctx context.Context) error {
   122  			<-ctx.Done()
   123  			return ctx.Err()
   124  		})
   125  		p.Go(func(ctx context.Context) error {
   126  			return err1
   127  		})
   128  		err := p.Wait()
   129  		require.ErrorIs(t, err, context.Canceled)
   130  		require.ErrorIs(t, err, err1)
   131  	})
   132  
   133  	t.Run("no WithCancelOnError", func(t *testing.T) {
   134  		t.Parallel()
   135  		p := New().WithContext(bgctx)
   136  		p.Go(func(ctx context.Context) error {
   137  			select {
   138  			case <-ctx.Done():
   139  				return ctx.Err()
   140  			case <-time.After(10 * time.Millisecond):
   141  				return nil
   142  			}
   143  		})
   144  		p.Go(func(ctx context.Context) error {
   145  			return err1
   146  		})
   147  		err := p.Wait()
   148  		require.ErrorIs(t, err, err1)
   149  		require.NotErrorIs(t, err, context.Canceled)
   150  	})
   151  
   152  	t.Run("WithFirstError", func(t *testing.T) {
   153  		t.Parallel()
   154  		p := New().WithContext(bgctx).WithFirstError()
   155  		sync := make(chan struct{})
   156  		p.Go(func(ctx context.Context) error {
   157  			defer close(sync)
   158  			return err1
   159  		})
   160  		p.Go(func(ctx context.Context) error {
   161  			// This test has a race condition. After the first goroutine
   162  			// completes, this goroutine is woken up because sync is closed.
   163  			// However, this goroutine might be woken up before the error from
   164  			// the first goroutine is registered. To prevent that, we sleep for
   165  			// another 10 milliseconds, giving the other goroutine time to return
   166  			// and register its error before this goroutine returns its error.
   167  			<-sync
   168  			time.Sleep(10 * time.Millisecond)
   169  			return err2
   170  		})
   171  		err := p.Wait()
   172  		require.ErrorIs(t, err, err1)
   173  		require.NotErrorIs(t, err, err2)
   174  	})
   175  
   176  	t.Run("WithFirstError and WithCancelOnError", func(t *testing.T) {
   177  		t.Parallel()
   178  		p := New().WithContext(bgctx).WithFirstError().WithCancelOnError()
   179  		p.Go(func(ctx context.Context) error {
   180  			return err1
   181  		})
   182  		p.Go(func(ctx context.Context) error {
   183  			<-ctx.Done()
   184  			return ctx.Err()
   185  		})
   186  		err := p.Wait()
   187  		require.ErrorIs(t, err, err1)
   188  		require.NotErrorIs(t, err, context.Canceled)
   189  	})
   190  
   191  	t.Run("WithCancelOnError and panic", func(t *testing.T) {
   192  		t.Parallel()
   193  		p := New().WithContext(bgctx).WithCancelOnError()
   194  		var cancelledTasks atomic.Int64
   195  		p.Go(func(ctx context.Context) error {
   196  			<-ctx.Done()
   197  			cancelledTasks.Add(1)
   198  			return ctx.Err()
   199  		})
   200  		p.Go(func(ctx context.Context) error {
   201  			<-ctx.Done()
   202  			cancelledTasks.Add(1)
   203  			return ctx.Err()
   204  		})
   205  		p.Go(func(ctx context.Context) error {
   206  			panic("abort!")
   207  		})
   208  		assert.Panics(t, func() { _ = p.Wait() })
   209  		assert.EqualValues(t, 2, cancelledTasks.Load())
   210  	})
   211  
   212  	t.Run("limit", func(t *testing.T) {
   213  		t.Parallel()
   214  		for _, maxConcurrent := range []int{1, 10, 100} {
   215  			t.Run(strconv.Itoa(maxConcurrent), func(t *testing.T) {
   216  				maxConcurrent := maxConcurrent // copy
   217  
   218  				t.Parallel()
   219  				p := New().WithContext(bgctx).WithMaxGoroutines(maxConcurrent)
   220  
   221  				var currentConcurrent atomic.Int64
   222  				for i := 0; i < 100; i++ {
   223  					p.Go(func(context.Context) error {
   224  						cur := currentConcurrent.Add(1)
   225  						if cur > int64(maxConcurrent) {
   226  							return fmt.Errorf("expected no more than %d concurrent goroutine", maxConcurrent)
   227  						}
   228  						time.Sleep(time.Millisecond)
   229  						currentConcurrent.Add(-1)
   230  						return nil
   231  					})
   232  				}
   233  				require.NoError(t, p.Wait())
   234  				require.Equal(t, int64(0), currentConcurrent.Load())
   235  			})
   236  		}
   237  	})
   238  }
   239  

View as plain text