...

Source file src/github.com/aws/smithy-go/private/requestcompression/request_compression_test.go

Documentation: github.com/aws/smithy-go/private/requestcompression

     1  package requestcompression
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"fmt"
     8  	"github.com/aws/smithy-go/middleware"
     9  	"github.com/aws/smithy-go/transport/http"
    10  	"io"
    11  	"reflect"
    12  	"strings"
    13  	"testing"
    14  )
    15  
    16  func TestRequestCompression(t *testing.T) {
    17  	cases := map[string]struct {
    18  		DisableRequestCompression   bool
    19  		RequestMinCompressSizeBytes int64
    20  		ContentLength               int64
    21  		Header                      map[string][]string
    22  		Stream                      io.Reader
    23  		ExpectedStream              []byte
    24  		ExpectedHeader              map[string][]string
    25  	}{
    26  		"GZip request stream": {
    27  			Stream:         strings.NewReader("Hi, world!"),
    28  			ExpectedStream: []byte("Hi, world!"),
    29  			ExpectedHeader: map[string][]string{
    30  				"Content-Encoding": {"gzip"},
    31  			},
    32  		},
    33  		"GZip request stream with existing encoding header": {
    34  			Stream:         strings.NewReader("Hi, world!"),
    35  			ExpectedStream: []byte("Hi, world!"),
    36  			Header: map[string][]string{
    37  				"Content-Encoding": {"custom"},
    38  			},
    39  			ExpectedHeader: map[string][]string{
    40  				"Content-Encoding": {"custom, gzip"},
    41  			},
    42  		},
    43  		"GZip request stream smaller than min compress request size": {
    44  			RequestMinCompressSizeBytes: 100,
    45  			Stream:                      strings.NewReader("Hi, world!"),
    46  			ExpectedStream:              []byte("Hi, world!"),
    47  			ExpectedHeader:              map[string][]string{},
    48  		},
    49  		"Disable GZip request stream": {
    50  			DisableRequestCompression: true,
    51  			Stream:                    strings.NewReader("Hi, world!"),
    52  			ExpectedStream:            []byte("Hi, world!"),
    53  			ExpectedHeader:            map[string][]string{},
    54  		},
    55  	}
    56  
    57  	for name, c := range cases {
    58  		t.Run(name, func(t *testing.T) {
    59  			var err error
    60  			req := http.NewStackRequest().(*http.Request)
    61  			req.ContentLength = c.ContentLength
    62  			req, _ = req.SetStream(c.Stream)
    63  			if c.Header != nil {
    64  				req.Header = c.Header
    65  			}
    66  			var updatedRequest *http.Request
    67  
    68  			m := requestCompression{
    69  				disableRequestCompression:   c.DisableRequestCompression,
    70  				requestMinCompressSizeBytes: c.RequestMinCompressSizeBytes,
    71  				compressAlgorithms:          []string{GZIP},
    72  			}
    73  			_, _, err = m.HandleSerialize(context.Background(),
    74  				middleware.SerializeInput{Request: req},
    75  				middleware.SerializeHandlerFunc(func(ctx context.Context, input middleware.SerializeInput) (
    76  					out middleware.SerializeOutput, metadata middleware.Metadata, err error) {
    77  					updatedRequest = input.Request.(*http.Request)
    78  					return out, metadata, nil
    79  				}),
    80  			)
    81  			if err != nil {
    82  				t.Fatalf("expect no error, got %v", err)
    83  			}
    84  
    85  			if stream := updatedRequest.GetStream(); stream != nil {
    86  				if err := testUnzipContent(stream, c.ExpectedStream, c.DisableRequestCompression, c.RequestMinCompressSizeBytes); err != nil {
    87  					t.Errorf("error while checking request stream: %q", err)
    88  				}
    89  			}
    90  
    91  			if e, a := c.ExpectedHeader, map[string][]string(updatedRequest.Header); !reflect.DeepEqual(e, a) {
    92  				t.Errorf("expect request header to be %q, got %q", e, a)
    93  			}
    94  		})
    95  	}
    96  }
    97  
    98  func testUnzipContent(content io.Reader, expect []byte, disableRequestCompression bool, requestMinCompressionSizeBytes int64) error {
    99  	if disableRequestCompression || int64(len(expect)) < requestMinCompressionSizeBytes {
   100  		b, err := io.ReadAll(content)
   101  		if err != nil {
   102  			return fmt.Errorf("error while reading request")
   103  		}
   104  		if e, a := expect, b; !bytes.Equal(e, a) {
   105  			return fmt.Errorf("expect content to be %s, got %s", e, a)
   106  		}
   107  	} else {
   108  		r, err := gzip.NewReader(content)
   109  		if err != nil {
   110  			return fmt.Errorf("error while reading request")
   111  		}
   112  
   113  		var actualBytes bytes.Buffer
   114  		_, err = actualBytes.ReadFrom(r)
   115  		if err != nil {
   116  			return fmt.Errorf("error while unzipping request payload")
   117  		}
   118  
   119  		if e, a := expect, actualBytes.Bytes(); !bytes.Equal(e, a) {
   120  			return fmt.Errorf("expect unzipped content to be %s, got %s", e, a)
   121  		}
   122  	}
   123  
   124  	return nil
   125  }
   126  

View as plain text