...

Source file src/github.com/aws/smithy-go/private/requestcompression/middleware_capture_request_compression.go

Documentation: github.com/aws/smithy-go/private/requestcompression

     1  package requestcompression
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"github.com/aws/smithy-go/middleware"
     8  	smithyhttp "github.com/aws/smithy-go/transport/http"
     9  	"io"
    10  	"net/http"
    11  )
    12  
    13  const captureUncompressedRequestID = "CaptureUncompressedRequest"
    14  
    15  // AddCaptureUncompressedRequestMiddleware captures http request before compress encoding for check
    16  func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error {
    17  	return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{
    18  		buf: buf,
    19  	}, "RequestCompression", middleware.Before)
    20  }
    21  
    22  type captureUncompressedRequestMiddleware struct {
    23  	req   *http.Request
    24  	buf   *bytes.Buffer
    25  	bytes []byte
    26  }
    27  
    28  // ID returns id of the captureUncompressedRequestMiddleware
    29  func (*captureUncompressedRequestMiddleware) ID() string {
    30  	return captureUncompressedRequestID
    31  }
    32  
    33  // HandleSerialize captures request payload before it is compressed by request compression middleware
    34  func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
    35  ) (
    36  	output middleware.SerializeOutput, metadata middleware.Metadata, err error,
    37  ) {
    38  	request, ok := input.Request.(*smithyhttp.Request)
    39  	if !ok {
    40  		return output, metadata, fmt.Errorf("error when retrieving http request")
    41  	}
    42  
    43  	_, err = io.Copy(m.buf, request.GetStream())
    44  	if err != nil {
    45  		return output, metadata, fmt.Errorf("error when copying http request stream: %q", err)
    46  	}
    47  	if err = request.RewindStream(); err != nil {
    48  		return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err)
    49  	}
    50  
    51  	return next.HandleSerialize(ctx, input)
    52  }
    53  

View as plain text