...

Source file src/github.com/aws/smithy-go/transport/http/middleware_content_length_test.go

Documentation: github.com/aws/smithy-go/transport/http

     1  package http
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/aws/smithy-go/middleware"
    12  )
    13  
    14  func TestContentLengthMiddleware(t *testing.T) {
    15  	cases := map[string]struct {
    16  		Stream          io.Reader
    17  		ExpectNilStream bool
    18  		ExpectLen       int64
    19  		ExpectErr       string
    20  	}{
    21  		// Cases
    22  		"bytes.Reader": {
    23  			Stream:          bytes.NewReader(make([]byte, 10)),
    24  			ExpectLen:       10,
    25  			ExpectNilStream: false,
    26  		},
    27  		"bytes.Buffer": {
    28  			Stream:          bytes.NewBuffer(make([]byte, 10)),
    29  			ExpectLen:       10,
    30  			ExpectNilStream: false,
    31  		},
    32  		"strings.Reader": {
    33  			Stream:          strings.NewReader("hello"),
    34  			ExpectLen:       5,
    35  			ExpectNilStream: false,
    36  		},
    37  		"empty stream": {
    38  			Stream:          strings.NewReader(""),
    39  			ExpectLen:       0,
    40  			ExpectNilStream: false,
    41  		},
    42  		"empty stream bytes": {
    43  			Stream:          bytes.NewReader([]byte{}),
    44  			ExpectLen:       0,
    45  			ExpectNilStream: false,
    46  		},
    47  		"nil stream": {
    48  			ExpectLen:       0,
    49  			ExpectNilStream: true,
    50  		},
    51  		"un-seekable and no length": {
    52  			Stream:          &basicReader{buf: make([]byte, 10)},
    53  			ExpectLen:       -1,
    54  			ExpectNilStream: false,
    55  		},
    56  		"with error": {
    57  			Stream:          &errorSecondSeekableReader{err: fmt.Errorf("seek failed")},
    58  			ExpectErr:       "seek failed",
    59  			ExpectLen:       -1,
    60  			ExpectNilStream: false,
    61  		},
    62  	}
    63  
    64  	for name, c := range cases {
    65  		t.Run(name, func(t *testing.T) {
    66  			var err error
    67  			req := NewStackRequest().(*Request)
    68  			req, err = req.SetStream(c.Stream)
    69  			if err != nil {
    70  				t.Fatalf("expect to set stream, %v", err)
    71  			}
    72  
    73  			var updatedRequest *Request
    74  			var m ComputeContentLength
    75  			_, _, err = m.HandleBuild(context.Background(),
    76  				middleware.BuildInput{Request: req},
    77  				middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
    78  					out middleware.BuildOutput, metadata middleware.Metadata, err error) {
    79  					updatedRequest = input.Request.(*Request)
    80  					return out, metadata, nil
    81  				}),
    82  			)
    83  			if len(c.ExpectErr) != 0 {
    84  				if err == nil {
    85  					t.Fatalf("expect error, got none")
    86  				}
    87  				if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
    88  					t.Fatalf("expect error to contain %q, got %v", e, a)
    89  				}
    90  				return
    91  			} else if err != nil {
    92  				t.Fatalf("expect no error, got %v", err)
    93  			}
    94  
    95  			if e, a := c.ExpectLen, updatedRequest.ContentLength; e != a {
    96  				t.Errorf("expect %v content-length, got %v", e, a)
    97  			}
    98  
    99  			if e, a := c.ExpectNilStream, updatedRequest.stream == nil; e != a {
   100  				t.Errorf("expect %v nil stream, got %v", e, a)
   101  			}
   102  		})
   103  	}
   104  }
   105  
   106  func TestContentLengthMiddleware_HeaderSet(t *testing.T) {
   107  	req := NewStackRequest().(*Request)
   108  	req.Header.Set("Content-Length", "1234")
   109  
   110  	var err error
   111  	req, err = req.SetStream(strings.NewReader("hello"))
   112  	if err != nil {
   113  		t.Fatalf("expect to set stream, %v", err)
   114  	}
   115  
   116  	var m ComputeContentLength
   117  	_, _, err = m.HandleBuild(context.Background(),
   118  		middleware.BuildInput{Request: req},
   119  		nopBuildHandler,
   120  	)
   121  	if err != nil {
   122  		t.Fatalf("expect middleware to run, %v", err)
   123  	}
   124  
   125  	if e, a := "1234", req.Header.Get("Content-Length"); e != a {
   126  		t.Errorf("expect Content-Length not to change, got %v", a)
   127  	}
   128  }
   129  
   130  var nopBuildHandler = middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
   131  	out middleware.BuildOutput, metadata middleware.Metadata, err error) {
   132  	return out, metadata, nil
   133  })
   134  
   135  type basicReader struct {
   136  	buf []byte
   137  }
   138  
   139  func (r *basicReader) Read(p []byte) (int, error) {
   140  	n := copy(p, r.buf)
   141  	r.buf = r.buf[n:]
   142  	return n, nil
   143  }
   144  
   145  type errorSecondSeekableReader struct {
   146  	err   error
   147  	count int
   148  }
   149  
   150  func (r *errorSecondSeekableReader) Read(p []byte) (int, error) {
   151  	return 0, io.EOF
   152  }
   153  func (r *errorSecondSeekableReader) Seek(offset int64, whence int) (int64, error) {
   154  	r.count++
   155  	if r.count == 2 {
   156  		return 0, r.err
   157  	}
   158  	return 0, nil
   159  }
   160  
   161  func TestValidateContentLengthHeader(t *testing.T) {
   162  	cases := map[string]struct {
   163  		contentLength int64
   164  		expectError   string
   165  	}{
   166  		"success": {
   167  			contentLength: 10,
   168  		},
   169  		"length set to 0": {
   170  			contentLength: 0,
   171  		},
   172  		"content-length unset": {
   173  			contentLength: -1,
   174  			expectError:   "content length for payload is required and must be at least 0",
   175  		},
   176  	}
   177  
   178  	for name, c := range cases {
   179  		t.Run(name, func(t *testing.T) {
   180  			var err error
   181  			req := NewStackRequest().(*Request)
   182  			req.ContentLength = c.contentLength
   183  
   184  			var m validateContentLength
   185  			_, _, err = m.HandleBuild(context.Background(),
   186  				middleware.BuildInput{Request: req},
   187  				nopBuildHandler,
   188  			)
   189  
   190  			if len(c.expectError) != 0 {
   191  				if err == nil {
   192  					t.Fatalf("expect error, got none")
   193  				}
   194  				if e, a := c.expectError, err.Error(); !strings.Contains(a, e) {
   195  					t.Fatalf("expect error to contain %q, got %v", e, a)
   196  				}
   197  			} else if err != nil {
   198  				t.Fatalf("expect no error, got %v", err)
   199  			}
   200  		})
   201  	}
   202  }
   203  

View as plain text