...

Source file src/github.com/go-kit/kit/sd/lb/retry_test.go

Documentation: github.com/go-kit/kit/sd/lb

     1  package lb_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/go-kit/kit/endpoint"
    10  	"github.com/go-kit/kit/sd"
    11  	"github.com/go-kit/kit/sd/lb"
    12  )
    13  
    14  func TestRetryMaxTotalFail(t *testing.T) {
    15  	var (
    16  		endpoints = sd.FixedEndpointer{} // no endpoints
    17  		rr        = lb.NewRoundRobin(endpoints)
    18  		retry     = lb.Retry(999, time.Second, rr) // lots of retries
    19  		ctx       = context.Background()
    20  	)
    21  	if _, err := retry(ctx, struct{}{}); err == nil {
    22  		t.Errorf("expected error, got none") // should fail
    23  	}
    24  }
    25  
    26  func TestRetryMaxPartialFail(t *testing.T) {
    27  	var (
    28  		endpoints = []endpoint.Endpoint{
    29  			func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
    30  			func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
    31  			func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
    32  		}
    33  		endpointer = sd.FixedEndpointer{
    34  			0: endpoints[0],
    35  			1: endpoints[1],
    36  			2: endpoints[2],
    37  		}
    38  		retries = len(endpoints) - 1 // not quite enough retries
    39  		rr      = lb.NewRoundRobin(endpointer)
    40  		ctx     = context.Background()
    41  	)
    42  	if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err == nil {
    43  		t.Errorf("expected error two, got none")
    44  	}
    45  }
    46  
    47  func TestRetryMaxSuccess(t *testing.T) {
    48  	var (
    49  		endpoints = []endpoint.Endpoint{
    50  			func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error one") },
    51  			func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("error two") },
    52  			func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
    53  		}
    54  		endpointer = sd.FixedEndpointer{
    55  			0: endpoints[0],
    56  			1: endpoints[1],
    57  			2: endpoints[2],
    58  		}
    59  		retries = len(endpoints) // exactly enough retries
    60  		rr      = lb.NewRoundRobin(endpointer)
    61  		ctx     = context.Background()
    62  	)
    63  	if _, err := lb.Retry(retries, time.Second, rr)(ctx, struct{}{}); err != nil {
    64  		t.Error(err)
    65  	}
    66  }
    67  
    68  func TestRetryTimeout(t *testing.T) {
    69  	var (
    70  		step    = make(chan struct{})
    71  		e       = func(context.Context, interface{}) (interface{}, error) { <-step; return struct{}{}, nil }
    72  		timeout = time.Millisecond
    73  		retry   = lb.Retry(999, timeout, lb.NewRoundRobin(sd.FixedEndpointer{0: e}))
    74  		errs    = make(chan error, 1)
    75  		invoke  = func() { _, err := retry(context.Background(), struct{}{}); errs <- err }
    76  	)
    77  
    78  	go func() { step <- struct{}{} }() // queue up a flush of the endpoint
    79  	invoke()                           // invoke the endpoint and trigger the flush
    80  	if err := <-errs; err != nil {     // that should succeed
    81  		t.Error(err)
    82  	}
    83  
    84  	go func() { time.Sleep(10 * timeout); step <- struct{}{} }() // a delayed flush
    85  	invoke()                                                     // invoke the endpoint
    86  	if err := <-errs; err != context.DeadlineExceeded {          // that should not succeed
    87  		t.Errorf("wanted %v, got none", context.DeadlineExceeded)
    88  	}
    89  }
    90  
    91  func TestAbortEarlyCustomMessage(t *testing.T) {
    92  	var (
    93  		myErr     = errors.New("aborting early")
    94  		cb        = func(int, error) (bool, error) { return false, myErr }
    95  		endpoints = sd.FixedEndpointer{} // no endpoints
    96  		rr        = lb.NewRoundRobin(endpoints)
    97  		retry     = lb.RetryWithCallback(time.Second, rr, cb) // lots of retries
    98  		ctx       = context.Background()
    99  	)
   100  	_, err := retry(ctx, struct{}{})
   101  	if want, have := myErr, err.(lb.RetryError).Final; want != have {
   102  		t.Errorf("want %v, have %v", want, have)
   103  	}
   104  }
   105  
   106  func TestErrorPassedUnchangedToCallback(t *testing.T) {
   107  	var (
   108  		myErr = errors.New("my custom error")
   109  		cb    = func(_ int, err error) (bool, error) {
   110  			if want, have := myErr, err; want != have {
   111  				t.Errorf("want %v, have %v", want, have)
   112  			}
   113  			return false, nil
   114  		}
   115  		endpoint = func(ctx context.Context, request interface{}) (interface{}, error) {
   116  			return nil, myErr
   117  		}
   118  		endpoints = sd.FixedEndpointer{endpoint} // no endpoints
   119  		rr        = lb.NewRoundRobin(endpoints)
   120  		retry     = lb.RetryWithCallback(time.Second, rr, cb) // lots of retries
   121  		ctx       = context.Background()
   122  	)
   123  	_, err := retry(ctx, struct{}{})
   124  	if want, have := myErr, err.(lb.RetryError).Final; want != have {
   125  		t.Errorf("want %v, have %v", want, have)
   126  	}
   127  }
   128  
   129  func TestHandleNilCallback(t *testing.T) {
   130  	var (
   131  		endpointer = sd.FixedEndpointer{
   132  			func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil /* OK */ },
   133  		}
   134  		rr  = lb.NewRoundRobin(endpointer)
   135  		ctx = context.Background()
   136  	)
   137  	retry := lb.RetryWithCallback(time.Second, rr, nil)
   138  	if _, err := retry(ctx, struct{}{}); err != nil {
   139  		t.Error(err)
   140  	}
   141  }
   142  

View as plain text