...
1 package protocol
2
3 import (
4 "context"
5 "fmt"
6 "github.com/aws/smithy-go/middleware"
7 smithyhttp "github.com/aws/smithy-go/transport/http"
8 "net/http"
9 "strconv"
10 )
11
12 const captureRequestID = "CaptureProtocolTestRequest"
13
14
15 func AddCaptureRequestMiddleware(stack *middleware.Stack, req *http.Request) error {
16 return stack.Build.Add(&captureRequestMiddleware{
17 req: req,
18 }, middleware.After)
19 }
20
21 type captureRequestMiddleware struct {
22 req *http.Request
23 }
24
25 func (*captureRequestMiddleware) ID() string {
26 return captureRequestID
27 }
28
29 func (m *captureRequestMiddleware) HandleBuild(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler,
30 ) (
31 output middleware.BuildOutput, metadata middleware.Metadata, err error,
32 ) {
33 request, ok := input.Request.(*smithyhttp.Request)
34 if !ok {
35 return output, metadata, fmt.Errorf("error while retrieving http request")
36 }
37
38 *m.req = *request.Build(ctx)
39 if len(m.req.URL.RawPath) == 0 {
40 m.req.URL.RawPath = m.req.URL.Path
41 }
42 if v := m.req.ContentLength; v != 0 {
43 m.req.Header.Set("Content-Length", strconv.FormatInt(v, 10))
44 }
45
46 return next.HandleBuild(ctx, input)
47 }
48
View as plain text