...
1 package requestcompression
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "github.com/aws/smithy-go/middleware"
8 smithyhttp "github.com/aws/smithy-go/transport/http"
9 "io"
10 "net/http"
11 )
12
13 const captureUncompressedRequestID = "CaptureUncompressedRequest"
14
15
16 func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error {
17 return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{
18 buf: buf,
19 }, "RequestCompression", middleware.Before)
20 }
21
22 type captureUncompressedRequestMiddleware struct {
23 req *http.Request
24 buf *bytes.Buffer
25 bytes []byte
26 }
27
28
29 func (*captureUncompressedRequestMiddleware) ID() string {
30 return captureUncompressedRequestID
31 }
32
33
34 func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
35 ) (
36 output middleware.SerializeOutput, metadata middleware.Metadata, err error,
37 ) {
38 request, ok := input.Request.(*smithyhttp.Request)
39 if !ok {
40 return output, metadata, fmt.Errorf("error when retrieving http request")
41 }
42
43 _, err = io.Copy(m.buf, request.GetStream())
44 if err != nil {
45 return output, metadata, fmt.Errorf("error when copying http request stream: %q", err)
46 }
47 if err = request.RewindStream(); err != nil {
48 return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err)
49 }
50
51 return next.HandleSerialize(ctx, input)
52 }
53
View as plain text