1 package http
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "strings"
9 "testing"
10
11 "github.com/aws/smithy-go/middleware"
12 )
13
14 func TestContentLengthMiddleware(t *testing.T) {
15 cases := map[string]struct {
16 Stream io.Reader
17 ExpectNilStream bool
18 ExpectLen int64
19 ExpectErr string
20 }{
21
22 "bytes.Reader": {
23 Stream: bytes.NewReader(make([]byte, 10)),
24 ExpectLen: 10,
25 ExpectNilStream: false,
26 },
27 "bytes.Buffer": {
28 Stream: bytes.NewBuffer(make([]byte, 10)),
29 ExpectLen: 10,
30 ExpectNilStream: false,
31 },
32 "strings.Reader": {
33 Stream: strings.NewReader("hello"),
34 ExpectLen: 5,
35 ExpectNilStream: false,
36 },
37 "empty stream": {
38 Stream: strings.NewReader(""),
39 ExpectLen: 0,
40 ExpectNilStream: false,
41 },
42 "empty stream bytes": {
43 Stream: bytes.NewReader([]byte{}),
44 ExpectLen: 0,
45 ExpectNilStream: false,
46 },
47 "nil stream": {
48 ExpectLen: 0,
49 ExpectNilStream: true,
50 },
51 "un-seekable and no length": {
52 Stream: &basicReader{buf: make([]byte, 10)},
53 ExpectLen: -1,
54 ExpectNilStream: false,
55 },
56 "with error": {
57 Stream: &errorSecondSeekableReader{err: fmt.Errorf("seek failed")},
58 ExpectErr: "seek failed",
59 ExpectLen: -1,
60 ExpectNilStream: false,
61 },
62 }
63
64 for name, c := range cases {
65 t.Run(name, func(t *testing.T) {
66 var err error
67 req := NewStackRequest().(*Request)
68 req, err = req.SetStream(c.Stream)
69 if err != nil {
70 t.Fatalf("expect to set stream, %v", err)
71 }
72
73 var updatedRequest *Request
74 var m ComputeContentLength
75 _, _, err = m.HandleBuild(context.Background(),
76 middleware.BuildInput{Request: req},
77 middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
78 out middleware.BuildOutput, metadata middleware.Metadata, err error) {
79 updatedRequest = input.Request.(*Request)
80 return out, metadata, nil
81 }),
82 )
83 if len(c.ExpectErr) != 0 {
84 if err == nil {
85 t.Fatalf("expect error, got none")
86 }
87 if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
88 t.Fatalf("expect error to contain %q, got %v", e, a)
89 }
90 return
91 } else if err != nil {
92 t.Fatalf("expect no error, got %v", err)
93 }
94
95 if e, a := c.ExpectLen, updatedRequest.ContentLength; e != a {
96 t.Errorf("expect %v content-length, got %v", e, a)
97 }
98
99 if e, a := c.ExpectNilStream, updatedRequest.stream == nil; e != a {
100 t.Errorf("expect %v nil stream, got %v", e, a)
101 }
102 })
103 }
104 }
105
106 func TestContentLengthMiddleware_HeaderSet(t *testing.T) {
107 req := NewStackRequest().(*Request)
108 req.Header.Set("Content-Length", "1234")
109
110 var err error
111 req, err = req.SetStream(strings.NewReader("hello"))
112 if err != nil {
113 t.Fatalf("expect to set stream, %v", err)
114 }
115
116 var m ComputeContentLength
117 _, _, err = m.HandleBuild(context.Background(),
118 middleware.BuildInput{Request: req},
119 nopBuildHandler,
120 )
121 if err != nil {
122 t.Fatalf("expect middleware to run, %v", err)
123 }
124
125 if e, a := "1234", req.Header.Get("Content-Length"); e != a {
126 t.Errorf("expect Content-Length not to change, got %v", a)
127 }
128 }
129
130 var nopBuildHandler = middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
131 out middleware.BuildOutput, metadata middleware.Metadata, err error) {
132 return out, metadata, nil
133 })
134
135 type basicReader struct {
136 buf []byte
137 }
138
139 func (r *basicReader) Read(p []byte) (int, error) {
140 n := copy(p, r.buf)
141 r.buf = r.buf[n:]
142 return n, nil
143 }
144
145 type errorSecondSeekableReader struct {
146 err error
147 count int
148 }
149
150 func (r *errorSecondSeekableReader) Read(p []byte) (int, error) {
151 return 0, io.EOF
152 }
153 func (r *errorSecondSeekableReader) Seek(offset int64, whence int) (int64, error) {
154 r.count++
155 if r.count == 2 {
156 return 0, r.err
157 }
158 return 0, nil
159 }
160
161 func TestValidateContentLengthHeader(t *testing.T) {
162 cases := map[string]struct {
163 contentLength int64
164 expectError string
165 }{
166 "success": {
167 contentLength: 10,
168 },
169 "length set to 0": {
170 contentLength: 0,
171 },
172 "content-length unset": {
173 contentLength: -1,
174 expectError: "content length for payload is required and must be at least 0",
175 },
176 }
177
178 for name, c := range cases {
179 t.Run(name, func(t *testing.T) {
180 var err error
181 req := NewStackRequest().(*Request)
182 req.ContentLength = c.contentLength
183
184 var m validateContentLength
185 _, _, err = m.HandleBuild(context.Background(),
186 middleware.BuildInput{Request: req},
187 nopBuildHandler,
188 )
189
190 if len(c.expectError) != 0 {
191 if err == nil {
192 t.Fatalf("expect error, got none")
193 }
194 if e, a := c.expectError, err.Error(); !strings.Contains(a, e) {
195 t.Fatalf("expect error to contain %q, got %v", e, a)
196 }
197 } else if err != nil {
198 t.Fatalf("expect no error, got %v", err)
199 }
200 })
201 }
202 }
203
View as plain text