...

Source file src/github.com/go-kit/kit/sd/lb/round_robin_test.go

Documentation: github.com/go-kit/kit/sd/lb

     1  package lb
     2  
     3  import (
     4  	"context"
     5  	"reflect"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/go-kit/kit/endpoint"
    12  	"github.com/go-kit/kit/sd"
    13  )
    14  
    15  func TestRoundRobin(t *testing.T) {
    16  	var (
    17  		counts    = []int{0, 0, 0}
    18  		endpoints = []endpoint.Endpoint{
    19  			func(context.Context, interface{}) (interface{}, error) { counts[0]++; return struct{}{}, nil },
    20  			func(context.Context, interface{}) (interface{}, error) { counts[1]++; return struct{}{}, nil },
    21  			func(context.Context, interface{}) (interface{}, error) { counts[2]++; return struct{}{}, nil },
    22  		}
    23  	)
    24  
    25  	endpointer := sd.FixedEndpointer(endpoints)
    26  	balancer := NewRoundRobin(endpointer)
    27  
    28  	for i, want := range [][]int{
    29  		{1, 0, 0},
    30  		{1, 1, 0},
    31  		{1, 1, 1},
    32  		{2, 1, 1},
    33  		{2, 2, 1},
    34  		{2, 2, 2},
    35  		{3, 2, 2},
    36  	} {
    37  		endpoint, err := balancer.Endpoint()
    38  		if err != nil {
    39  			t.Fatal(err)
    40  		}
    41  		endpoint(context.Background(), struct{}{})
    42  		if have := counts; !reflect.DeepEqual(want, have) {
    43  			t.Fatalf("%d: want %v, have %v", i, want, have)
    44  		}
    45  	}
    46  }
    47  
    48  func TestRoundRobinNoEndpoints(t *testing.T) {
    49  	endpointer := sd.FixedEndpointer{}
    50  	balancer := NewRoundRobin(endpointer)
    51  	_, err := balancer.Endpoint()
    52  	if want, have := ErrNoEndpoints, err; want != have {
    53  		t.Errorf("want %v, have %v", want, have)
    54  	}
    55  }
    56  
    57  func TestRoundRobinNoRace(t *testing.T) {
    58  	balancer := NewRoundRobin(sd.FixedEndpointer([]endpoint.Endpoint{
    59  		endpoint.Nop,
    60  		endpoint.Nop,
    61  		endpoint.Nop,
    62  		endpoint.Nop,
    63  		endpoint.Nop,
    64  	}))
    65  
    66  	var (
    67  		n     = 100
    68  		done  = make(chan struct{})
    69  		wg    sync.WaitGroup
    70  		count uint64
    71  	)
    72  
    73  	wg.Add(n)
    74  
    75  	for i := 0; i < n; i++ {
    76  		go func() {
    77  			defer wg.Done()
    78  			for {
    79  				select {
    80  				case <-done:
    81  					return
    82  				default:
    83  					_, _ = balancer.Endpoint()
    84  					atomic.AddUint64(&count, 1)
    85  				}
    86  			}
    87  		}()
    88  	}
    89  
    90  	time.Sleep(time.Second)
    91  	close(done)
    92  	wg.Wait()
    93  
    94  	t.Logf("made %d calls", atomic.LoadUint64(&count))
    95  }
    96  

View as plain text