1
16
17 package ioutils
18
19 import (
20 "bytes"
21 "fmt"
22 "math/rand"
23 "testing"
24
25 "github.com/stretchr/testify/assert"
26 )
27
28 func TestLimitWriter(t *testing.T) {
29 r := rand.New(rand.NewSource(1234))
30
31 tests := []struct {
32 inputSize, limit, writeSize int64
33 }{
34
35 {100, 101, 100},
36 {100, 100, 100},
37 {100, 99, 100},
38 {1, 1, 1},
39 {100, 10, 100},
40 {100, 0, 100},
41 {100, -1, 100},
42
43 {100, 101, 10},
44 {100, 100, 10},
45 {100, 99, 10},
46 {100, 10, 10},
47 {100, 0, 10},
48 {100, -1, 10},
49 }
50
51 for _, test := range tests {
52 t.Run(fmt.Sprintf("inputSize=%d limit=%d writes=%d", test.inputSize, test.limit, test.writeSize), func(t *testing.T) {
53 input := make([]byte, test.inputSize)
54 r.Read(input)
55 output := &bytes.Buffer{}
56 w := LimitWriter(output, test.limit)
57
58 var (
59 err error
60 written int64
61 n int
62 )
63 for written < test.inputSize && err == nil {
64 n, err = w.Write(input[written : written+test.writeSize])
65 written += int64(n)
66 }
67
68 expectWritten := bounded(0, test.inputSize, test.limit)
69 assert.EqualValues(t, expectWritten, written)
70 if expectWritten <= 0 {
71 assert.Empty(t, output)
72 } else {
73 assert.Equal(t, input[:expectWritten], output.Bytes())
74 }
75
76 if test.limit < test.inputSize {
77 assert.Error(t, err)
78 } else {
79 assert.NoError(t, err)
80 }
81 })
82 }
83 }
84
85 func bounded(min, val, max int64) int64 {
86 if max < val {
87 val = max
88 }
89 if val < min {
90 val = min
91 }
92 return val
93 }
94
View as plain text