1 package requestcompression
2
3 import (
4 "bytes"
5 "compress/gzip"
6 "context"
7 "fmt"
8 "github.com/aws/smithy-go/middleware"
9 "github.com/aws/smithy-go/transport/http"
10 "io"
11 "reflect"
12 "strings"
13 "testing"
14 )
15
16 func TestRequestCompression(t *testing.T) {
17 cases := map[string]struct {
18 DisableRequestCompression bool
19 RequestMinCompressSizeBytes int64
20 ContentLength int64
21 Header map[string][]string
22 Stream io.Reader
23 ExpectedStream []byte
24 ExpectedHeader map[string][]string
25 }{
26 "GZip request stream": {
27 Stream: strings.NewReader("Hi, world!"),
28 ExpectedStream: []byte("Hi, world!"),
29 ExpectedHeader: map[string][]string{
30 "Content-Encoding": {"gzip"},
31 },
32 },
33 "GZip request stream with existing encoding header": {
34 Stream: strings.NewReader("Hi, world!"),
35 ExpectedStream: []byte("Hi, world!"),
36 Header: map[string][]string{
37 "Content-Encoding": {"custom"},
38 },
39 ExpectedHeader: map[string][]string{
40 "Content-Encoding": {"custom, gzip"},
41 },
42 },
43 "GZip request stream smaller than min compress request size": {
44 RequestMinCompressSizeBytes: 100,
45 Stream: strings.NewReader("Hi, world!"),
46 ExpectedStream: []byte("Hi, world!"),
47 ExpectedHeader: map[string][]string{},
48 },
49 "Disable GZip request stream": {
50 DisableRequestCompression: true,
51 Stream: strings.NewReader("Hi, world!"),
52 ExpectedStream: []byte("Hi, world!"),
53 ExpectedHeader: map[string][]string{},
54 },
55 }
56
57 for name, c := range cases {
58 t.Run(name, func(t *testing.T) {
59 var err error
60 req := http.NewStackRequest().(*http.Request)
61 req.ContentLength = c.ContentLength
62 req, _ = req.SetStream(c.Stream)
63 if c.Header != nil {
64 req.Header = c.Header
65 }
66 var updatedRequest *http.Request
67
68 m := requestCompression{
69 disableRequestCompression: c.DisableRequestCompression,
70 requestMinCompressSizeBytes: c.RequestMinCompressSizeBytes,
71 compressAlgorithms: []string{GZIP},
72 }
73 _, _, err = m.HandleSerialize(context.Background(),
74 middleware.SerializeInput{Request: req},
75 middleware.SerializeHandlerFunc(func(ctx context.Context, input middleware.SerializeInput) (
76 out middleware.SerializeOutput, metadata middleware.Metadata, err error) {
77 updatedRequest = input.Request.(*http.Request)
78 return out, metadata, nil
79 }),
80 )
81 if err != nil {
82 t.Fatalf("expect no error, got %v", err)
83 }
84
85 if stream := updatedRequest.GetStream(); stream != nil {
86 if err := testUnzipContent(stream, c.ExpectedStream, c.DisableRequestCompression, c.RequestMinCompressSizeBytes); err != nil {
87 t.Errorf("error while checking request stream: %q", err)
88 }
89 }
90
91 if e, a := c.ExpectedHeader, map[string][]string(updatedRequest.Header); !reflect.DeepEqual(e, a) {
92 t.Errorf("expect request header to be %q, got %q", e, a)
93 }
94 })
95 }
96 }
97
98 func testUnzipContent(content io.Reader, expect []byte, disableRequestCompression bool, requestMinCompressionSizeBytes int64) error {
99 if disableRequestCompression || int64(len(expect)) < requestMinCompressionSizeBytes {
100 b, err := io.ReadAll(content)
101 if err != nil {
102 return fmt.Errorf("error while reading request")
103 }
104 if e, a := expect, b; !bytes.Equal(e, a) {
105 return fmt.Errorf("expect content to be %s, got %s", e, a)
106 }
107 } else {
108 r, err := gzip.NewReader(content)
109 if err != nil {
110 return fmt.Errorf("error while reading request")
111 }
112
113 var actualBytes bytes.Buffer
114 _, err = actualBytes.ReadFrom(r)
115 if err != nil {
116 return fmt.Errorf("error while unzipping request payload")
117 }
118
119 if e, a := expect, actualBytes.Bytes(); !bytes.Equal(e, a) {
120 return fmt.Errorf("expect unzipped content to be %s, got %s", e, a)
121 }
122 }
123
124 return nil
125 }
126
View as plain text