...

Source file src/github.com/letsencrypt/boulder/ratelimits/limiter_test.go

Documentation: github.com/letsencrypt/boulder/ratelimits

     1  package ratelimits
     2  
     3  import (
     4  	"context"
     5  	"math/rand"
     6  	"net"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/jmhodges/clock"
    11  	"github.com/letsencrypt/boulder/metrics"
    12  	"github.com/letsencrypt/boulder/test"
    13  	"github.com/prometheus/client_golang/prometheus"
    14  )
    15  
    16  // tenZeroZeroTwo is overridden in 'testdata/working_override.yml' to have
    17  // higher burst and count values.
    18  const tenZeroZeroTwo = "10.0.0.2"
    19  
    20  // newTestLimiter constructs a new limiter with the following configuration:
    21  //   - 'NewRegistrationsPerIPAddress' burst: 20 count: 20 period: 1s
    22  //   - 'NewRegistrationsPerIPAddress:10.0.0.2' burst: 40 count: 40 period: 1s
    23  func newTestLimiter(t *testing.T, s source, clk clock.FakeClock) *Limiter {
    24  	l, err := NewLimiter(clk, s, "testdata/working_default.yml", "testdata/working_override.yml", metrics.NoopRegisterer)
    25  	test.AssertNotError(t, err, "should not error")
    26  	return l
    27  }
    28  
    29  func setup(t *testing.T) (context.Context, map[string]*Limiter, clock.FakeClock, string) {
    30  	testCtx := context.Background()
    31  	clk := clock.NewFake()
    32  
    33  	// Generate a random IP address to avoid collisions during and between test
    34  	// runs.
    35  	randIP := make(net.IP, 4)
    36  	for i := 0; i < 4; i++ {
    37  		randIP[i] = byte(rand.Intn(256))
    38  	}
    39  
    40  	// Construct a limiter for each source.
    41  	return testCtx, map[string]*Limiter{
    42  		"inmem": newInmemTestLimiter(t, clk),
    43  		"redis": newRedisTestLimiter(t, clk),
    44  	}, clk, randIP.String()
    45  }
    46  
    47  func Test_Limiter_WithBadLimitsPath(t *testing.T) {
    48  	t.Parallel()
    49  	_, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/does-not-exist.yml", "", metrics.NoopRegisterer)
    50  	test.AssertError(t, err, "should error")
    51  
    52  	_, err = NewLimiter(clock.NewFake(), newInmem(), "testdata/defaults.yml", "testdata/does-not-exist.yml", metrics.NoopRegisterer)
    53  	test.AssertError(t, err, "should error")
    54  }
    55  
    56  func Test_Limiter_getLimitNoExist(t *testing.T) {
    57  	t.Parallel()
    58  	l, err := NewLimiter(clock.NewFake(), newInmem(), "testdata/working_default.yml", "", metrics.NoopRegisterer)
    59  	test.AssertNotError(t, err, "should not error")
    60  	_, err = l.getLimit(Name(9999), "")
    61  	test.AssertError(t, err, "should error")
    62  
    63  }
    64  
    65  func Test_Limiter_CheckWithLimitNoExist(t *testing.T) {
    66  	t.Parallel()
    67  	testCtx, limiters, _, testIP := setup(t)
    68  	for name, l := range limiters {
    69  		t.Run(name, func(t *testing.T) {
    70  			_, err := l.Check(testCtx, Name(9999), testIP, 1)
    71  			test.AssertError(t, err, "should error")
    72  		})
    73  	}
    74  }
    75  
    76  func Test_Limiter_CheckWithLimitOverrides(t *testing.T) {
    77  	t.Parallel()
    78  	testCtx, limiters, clk, _ := setup(t)
    79  	for name, l := range limiters {
    80  		t.Run(name, func(t *testing.T) {
    81  			// Verify our overrideUsageGauge is being set correctly. 0.0 == 0% of
    82  			// the bucket has been consumed.
    83  			test.AssertMetricWithLabelsEquals(t, l.overrideUsageGauge, prometheus.Labels{
    84  				"limit": NewRegistrationsPerIPAddress.String(), "client_id": tenZeroZeroTwo}, 0)
    85  
    86  			// Attempt to check a spend of 41 requests (a cost > the limit burst
    87  			// capacity), this should fail with a specific error.
    88  			_, err := l.Check(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
    89  			test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
    90  
    91  			// Attempt to spend 41 requests (a cost > the limit burst capacity),
    92  			// this should fail with a specific error.
    93  			_, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 41)
    94  			test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
    95  
    96  			// Attempt to spend all 40 requests, this should succeed.
    97  			d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 40)
    98  			test.AssertNotError(t, err, "should not error")
    99  			test.Assert(t, d.Allowed, "should be allowed")
   100  
   101  			// Attempting to spend 1 more, this should fail.
   102  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
   103  			test.AssertNotError(t, err, "should not error")
   104  			test.Assert(t, !d.Allowed, "should not be allowed")
   105  			test.AssertEquals(t, d.Remaining, int64(0))
   106  			test.AssertEquals(t, d.ResetIn, time.Second)
   107  
   108  			// Verify our overrideUsageGauge is being set correctly. 1.0 == 100% of
   109  			// the bucket has been consumed.
   110  			test.AssertMetricWithLabelsEquals(t, l.overrideUsageGauge, prometheus.Labels{
   111  				"limit_name": NewRegistrationsPerIPAddress.String(), "client_id": tenZeroZeroTwo}, 1.0)
   112  
   113  			// Verify our RetryIn is correct. 1 second == 1000 milliseconds and
   114  			// 1000/40 = 25 milliseconds per request.
   115  			test.AssertEquals(t, d.RetryIn, time.Millisecond*25)
   116  
   117  			// Wait 50 milliseconds and try again.
   118  			clk.Add(d.RetryIn)
   119  
   120  			// We should be allowed to spend 1 more request.
   121  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
   122  			test.AssertNotError(t, err, "should not error")
   123  			test.Assert(t, d.Allowed, "should be allowed")
   124  			test.AssertEquals(t, d.Remaining, int64(0))
   125  			test.AssertEquals(t, d.ResetIn, time.Second)
   126  
   127  			// Wait 1 second for a full bucket reset.
   128  			clk.Add(d.ResetIn)
   129  
   130  			// Quickly spend 40 requests in a row.
   131  			for i := 0; i < 40; i++ {
   132  				d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
   133  				test.AssertNotError(t, err, "should not error")
   134  				test.Assert(t, d.Allowed, "should be allowed")
   135  				test.AssertEquals(t, d.Remaining, int64(39-i))
   136  			}
   137  
   138  			// Attempting to spend 1 more, this should fail.
   139  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo, 1)
   140  			test.AssertNotError(t, err, "should not error")
   141  			test.Assert(t, !d.Allowed, "should not be allowed")
   142  			test.AssertEquals(t, d.Remaining, int64(0))
   143  			test.AssertEquals(t, d.ResetIn, time.Second)
   144  
   145  			// Reset between tests.
   146  			err = l.Reset(testCtx, NewRegistrationsPerIPAddress, tenZeroZeroTwo)
   147  			test.AssertNotError(t, err, "should not error")
   148  		})
   149  	}
   150  }
   151  
   152  func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) {
   153  	t.Parallel()
   154  	testCtx, limiters, _, testIP := setup(t)
   155  	for name, l := range limiters {
   156  		t.Run(name, func(t *testing.T) {
   157  			// Check on an empty bucket should initialize it and return the
   158  			// theoretical next state of that bucket if the cost were spent.
   159  			d, err := l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
   160  			test.AssertNotError(t, err, "should not error")
   161  			test.Assert(t, d.Allowed, "should be allowed")
   162  			test.AssertEquals(t, d.Remaining, int64(19))
   163  			// Verify our ResetIn timing is correct. 1 second == 1000
   164  			// milliseconds and 1000/20 = 50 milliseconds per request.
   165  			test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
   166  			test.AssertEquals(t, d.RetryIn, time.Duration(0))
   167  
   168  			// However, that cost should not be spent yet, a 0 cost check should
   169  			// tell us that we actually have 20 remaining.
   170  			d, err = l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 0)
   171  			test.AssertNotError(t, err, "should not error")
   172  			test.Assert(t, d.Allowed, "should be allowed")
   173  			test.AssertEquals(t, d.Remaining, int64(20))
   174  			test.AssertEquals(t, d.ResetIn, time.Duration(0))
   175  			test.AssertEquals(t, d.RetryIn, time.Duration(0))
   176  
   177  			// Reset our bucket.
   178  			err = l.Reset(testCtx, NewRegistrationsPerIPAddress, testIP)
   179  			test.AssertNotError(t, err, "should not error")
   180  
   181  			// Similar to above, but we'll use Spend() instead of Check() to
   182  			// initialize the bucket. Spend should return the same result as
   183  			// Check.
   184  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
   185  			test.AssertNotError(t, err, "should not error")
   186  			test.Assert(t, d.Allowed, "should be allowed")
   187  			test.AssertEquals(t, d.Remaining, int64(19))
   188  			// Verify our ResetIn timing is correct. 1 second == 1000
   189  			// milliseconds and 1000/20 = 50 milliseconds per request.
   190  			test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
   191  			test.AssertEquals(t, d.RetryIn, time.Duration(0))
   192  
   193  			// However, that cost should not be spent yet, a 0 cost check should
   194  			// tell us that we actually have 19 remaining.
   195  			d, err = l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, 0)
   196  			test.AssertNotError(t, err, "should not error")
   197  			test.Assert(t, d.Allowed, "should be allowed")
   198  			test.AssertEquals(t, d.Remaining, int64(19))
   199  			// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
   200  			// 1000/20 = 50 milliseconds per request.
   201  			test.AssertEquals(t, d.ResetIn, time.Millisecond*50)
   202  			test.AssertEquals(t, d.RetryIn, time.Duration(0))
   203  		})
   204  	}
   205  }
   206  
   207  func Test_Limiter_RefundAndSpendCostErr(t *testing.T) {
   208  	t.Parallel()
   209  	testCtx, limiters, _, testIP := setup(t)
   210  	for name, l := range limiters {
   211  		t.Run(name, func(t *testing.T) {
   212  			// Spend a cost of 0, which should fail.
   213  			_, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 0)
   214  			test.AssertErrorIs(t, err, ErrInvalidCost)
   215  
   216  			// Spend a negative cost, which should fail.
   217  			_, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, -1)
   218  			test.AssertErrorIs(t, err, ErrInvalidCost)
   219  
   220  			// Refund a cost of 0, which should fail.
   221  			_, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 0)
   222  			test.AssertErrorIs(t, err, ErrInvalidCost)
   223  
   224  			// Refund a negative cost, which should fail.
   225  			_, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, -1)
   226  			test.AssertErrorIs(t, err, ErrInvalidCost)
   227  		})
   228  	}
   229  }
   230  
   231  func Test_Limiter_CheckWithBadCost(t *testing.T) {
   232  	t.Parallel()
   233  	testCtx, limiters, _, testIP := setup(t)
   234  	for name, l := range limiters {
   235  		t.Run(name, func(t *testing.T) {
   236  			_, err := l.Check(testCtx, NewRegistrationsPerIPAddress, testIP, -1)
   237  			test.AssertErrorIs(t, err, ErrInvalidCostForCheck)
   238  		})
   239  	}
   240  }
   241  
   242  func Test_Limiter_DefaultLimits(t *testing.T) {
   243  	t.Parallel()
   244  	testCtx, limiters, clk, testIP := setup(t)
   245  	for name, l := range limiters {
   246  		t.Run(name, func(t *testing.T) {
   247  			// Attempt to spend 21 requests (a cost > the limit burst capacity),
   248  			// this should fail with a specific error.
   249  			_, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 21)
   250  			test.AssertErrorIs(t, err, ErrInvalidCostOverLimit)
   251  
   252  			// Attempt to spend all 20 requests, this should succeed.
   253  			d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20)
   254  			test.AssertNotError(t, err, "should not error")
   255  			test.Assert(t, d.Allowed, "should be allowed")
   256  			test.AssertEquals(t, d.Remaining, int64(0))
   257  			test.AssertEquals(t, d.ResetIn, time.Second)
   258  
   259  			// Attempting to spend 1 more, this should fail.
   260  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
   261  			test.AssertNotError(t, err, "should not error")
   262  			test.Assert(t, !d.Allowed, "should not be allowed")
   263  			test.AssertEquals(t, d.Remaining, int64(0))
   264  			test.AssertEquals(t, d.ResetIn, time.Second)
   265  
   266  			// Verify our ResetIn is correct. 1 second == 1000 milliseconds and
   267  			// 1000/20 = 50 milliseconds per request.
   268  			test.AssertEquals(t, d.RetryIn, time.Millisecond*50)
   269  
   270  			// Wait 50 milliseconds and try again.
   271  			clk.Add(d.RetryIn)
   272  
   273  			// We should be allowed to spend 1 more request.
   274  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
   275  			test.AssertNotError(t, err, "should not error")
   276  			test.Assert(t, d.Allowed, "should be allowed")
   277  			test.AssertEquals(t, d.Remaining, int64(0))
   278  			test.AssertEquals(t, d.ResetIn, time.Second)
   279  
   280  			// Wait 1 second for a full bucket reset.
   281  			clk.Add(d.ResetIn)
   282  
   283  			// Quickly spend 20 requests in a row.
   284  			for i := 0; i < 20; i++ {
   285  				d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
   286  				test.AssertNotError(t, err, "should not error")
   287  				test.Assert(t, d.Allowed, "should be allowed")
   288  				test.AssertEquals(t, d.Remaining, int64(19-i))
   289  			}
   290  
   291  			// Attempting to spend 1 more, this should fail.
   292  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
   293  			test.AssertNotError(t, err, "should not error")
   294  			test.Assert(t, !d.Allowed, "should not be allowed")
   295  			test.AssertEquals(t, d.Remaining, int64(0))
   296  			test.AssertEquals(t, d.ResetIn, time.Second)
   297  		})
   298  	}
   299  }
   300  
   301  func Test_Limiter_RefundAndReset(t *testing.T) {
   302  	t.Parallel()
   303  	testCtx, limiters, clk, testIP := setup(t)
   304  	for name, l := range limiters {
   305  		t.Run(name, func(t *testing.T) {
   306  			// Attempt to spend all 20 requests, this should succeed.
   307  			d, err := l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20)
   308  			test.AssertNotError(t, err, "should not error")
   309  			test.Assert(t, d.Allowed, "should be allowed")
   310  			test.AssertEquals(t, d.Remaining, int64(0))
   311  			test.AssertEquals(t, d.ResetIn, time.Second)
   312  
   313  			// Refund 10 requests.
   314  			d, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 10)
   315  			test.AssertNotError(t, err, "should not error")
   316  			test.AssertEquals(t, d.Remaining, int64(10))
   317  
   318  			// Spend 10 requests, this should succeed.
   319  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 10)
   320  			test.AssertNotError(t, err, "should not error")
   321  			test.Assert(t, d.Allowed, "should be allowed")
   322  			test.AssertEquals(t, d.Remaining, int64(0))
   323  			test.AssertEquals(t, d.ResetIn, time.Second)
   324  
   325  			err = l.Reset(testCtx, NewRegistrationsPerIPAddress, testIP)
   326  			test.AssertNotError(t, err, "should not error")
   327  
   328  			// Attempt to spend 20 more requests, this should succeed.
   329  			d, err = l.Spend(testCtx, NewRegistrationsPerIPAddress, testIP, 20)
   330  			test.AssertNotError(t, err, "should not error")
   331  			test.Assert(t, d.Allowed, "should be allowed")
   332  			test.AssertEquals(t, d.Remaining, int64(0))
   333  			test.AssertEquals(t, d.ResetIn, time.Second)
   334  
   335  			// Reset to full.
   336  			clk.Add(d.ResetIn)
   337  
   338  			// Refund 1 requests above our limit, this should fail.
   339  			d, err = l.Refund(testCtx, NewRegistrationsPerIPAddress, testIP, 1)
   340  			test.AssertNotError(t, err, "should not error")
   341  			test.Assert(t, !d.Allowed, "should not be allowed")
   342  			test.AssertEquals(t, d.Remaining, int64(20))
   343  		})
   344  	}
   345  }
   346  

View as plain text