...
1 package http
2
3 import (
4 "context"
5 "fmt"
6 "io"
7 "io/ioutil"
8 "net/http"
9 "net/url"
10 "strings"
11
12 iointernal "github.com/aws/smithy-go/transport/http/internal/io"
13 )
14
15
16
17 type Request struct {
18 *http.Request
19 stream io.Reader
20 isStreamSeekable bool
21 streamStartPos int64
22 }
23
24
25
26
27 func NewStackRequest() interface{} {
28 return &Request{
29 Request: &http.Request{
30 URL: &url.URL{},
31 Header: http.Header{},
32 ContentLength: -1,
33 },
34 }
35 }
36
37
38 func (r *Request) IsHTTPS() bool {
39 if r.URL == nil {
40 return false
41 }
42 return strings.EqualFold(r.URL.Scheme, "https")
43 }
44
45
46
47 func (r *Request) Clone() *Request {
48 rc := *r
49 rc.Request = rc.Request.Clone(context.TODO())
50 return &rc
51 }
52
53
54
55
56 func (r *Request) StreamLength() (size int64, ok bool, err error) {
57 return streamLength(r.stream, r.isStreamSeekable, r.streamStartPos)
58 }
59
60 func streamLength(stream io.Reader, seekable bool, startPos int64) (size int64, ok bool, err error) {
61 if stream == nil {
62 return 0, true, nil
63 }
64
65 if l, ok := stream.(interface{ Len() int }); ok {
66 return int64(l.Len()), true, nil
67 }
68
69 if !seekable {
70 return 0, false, nil
71 }
72
73 s := stream.(io.Seeker)
74 endOffset, err := s.Seek(0, io.SeekEnd)
75 if err != nil {
76 return 0, false, err
77 }
78
79
80
81
82
83
84
85 _, err = s.Seek(startPos, io.SeekStart)
86 if err != nil {
87 return 0, false, err
88 }
89
90 return endOffset - startPos, true, nil
91 }
92
93
94
95 func (r *Request) RewindStream() error {
96
97 if r.stream == nil {
98 return nil
99 }
100
101 if !r.isStreamSeekable {
102 return fmt.Errorf("request stream is not seekable")
103 }
104 _, err := r.stream.(io.Seeker).Seek(r.streamStartPos, io.SeekStart)
105 return err
106 }
107
108
109
110 func (r *Request) GetStream() io.Reader {
111 return r.stream
112 }
113
114
115 func (r *Request) IsStreamSeekable() bool {
116 return r.isStreamSeekable
117 }
118
119
120
121
122 func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) {
123 rc = r.Clone()
124
125 if reader == http.NoBody {
126 reader = nil
127 }
128
129 var isStreamSeekable bool
130 var streamStartPos int64
131 switch v := reader.(type) {
132 case io.Seeker:
133 n, err := v.Seek(0, io.SeekCurrent)
134 if err != nil {
135 return r, err
136 }
137 isStreamSeekable = true
138 streamStartPos = n
139 default:
140
141
142
143 length, ok, err := streamLength(reader, false, 0)
144 if err != nil {
145 return nil, err
146 } else if ok && length == 0 {
147 reader = nil
148 }
149 }
150
151 rc.stream = reader
152 rc.isStreamSeekable = isStreamSeekable
153 rc.streamStartPos = streamStartPos
154
155 return rc, err
156 }
157
158
159
160
161 func (r *Request) Build(ctx context.Context) *http.Request {
162 req := r.Request.Clone(ctx)
163
164 if r.stream == nil && req.ContentLength == -1 {
165 req.ContentLength = 0
166 }
167
168 switch stream := r.stream.(type) {
169 case *io.PipeReader:
170 req.Body = ioutil.NopCloser(stream)
171 req.ContentLength = -1
172 default:
173
174
175
176
177 if req.ContentLength != 0 && r.stream != nil {
178 req.Body = iointernal.NewSafeReadCloser(ioutil.NopCloser(stream))
179 }
180 }
181
182 return req
183 }
184
185
186
187 func RequestCloner(v interface{}) interface{} {
188 return v.(*Request).Clone()
189 }
190
View as plain text