1
2
3
4
5
6
7
8
9
10
11 package v4
12
13 import (
14 "crypto/hmac"
15 "crypto/sha256"
16 "encoding/hex"
17 "fmt"
18 "io"
19 "io/ioutil"
20 "net/http"
21 "net/url"
22 "sort"
23 "strings"
24 "time"
25
26 "go.mongodb.org/mongo-driver/internal/aws"
27 "go.mongodb.org/mongo-driver/internal/aws/credentials"
28 )
29
30 const (
31 authorizationHeader = "Authorization"
32 authHeaderSignatureElem = "Signature="
33
34 authHeaderPrefix = "AWS4-HMAC-SHA256"
35 timeFormat = "20060102T150405Z"
36 shortTimeFormat = "20060102"
37 awsV4Request = "aws4_request"
38
39
40 emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
41 )
42
43 var ignoredHeaders = rules{
44 excludeList{
45 mapRule{
46 authorizationHeader: struct{}{},
47 "User-Agent": struct{}{},
48 "X-Amzn-Trace-Id": struct{}{},
49 },
50 },
51 }
52
53
54
55 type Signer struct {
56
57
58 Credentials *credentials.Credentials
59 }
60
61
62 func NewSigner(credentials *credentials.Credentials) *Signer {
63 v4 := &Signer{
64 Credentials: credentials,
65 }
66
67 return v4
68 }
69
70 type signingCtx struct {
71 ServiceName string
72 Region string
73 Request *http.Request
74 Body io.ReadSeeker
75 Query url.Values
76 Time time.Time
77 SignedHeaderVals http.Header
78
79 credValues credentials.Value
80
81 bodyDigest string
82 signedHeaders string
83 canonicalHeaders string
84 canonicalString string
85 credentialString string
86 stringToSign string
87 signature string
88 }
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115 func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
116 return v4.signWithBody(r, body, service, region, signTime)
117 }
118
119 func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
120 ctx := &signingCtx{
121 Request: r,
122 Body: body,
123 Query: r.URL.Query(),
124 Time: signTime,
125 ServiceName: service,
126 Region: region,
127 }
128
129 for key := range ctx.Query {
130 sort.Strings(ctx.Query[key])
131 }
132
133 if ctx.isRequestSigned() {
134 ctx.Time = time.Now()
135 }
136
137 var err error
138 ctx.credValues, err = v4.Credentials.GetWithContext(r.Context())
139 if err != nil {
140 return http.Header{}, err
141 }
142
143 ctx.sanitizeHostForHeader()
144 ctx.assignAmzQueryValues()
145 if err := ctx.build(); err != nil {
146 return nil, err
147 }
148
149 var reader io.ReadCloser
150 if body != nil {
151 var ok bool
152 if reader, ok = body.(io.ReadCloser); !ok {
153 reader = ioutil.NopCloser(body)
154 }
155 }
156 r.Body = reader
157
158 return ctx.SignedHeaderVals, nil
159 }
160
161
162 func (ctx *signingCtx) sanitizeHostForHeader() {
163 r := ctx.Request
164 host := getHost(r)
165 port := portOnly(host)
166 if port != "" && isDefaultPort(r.URL.Scheme, port) {
167 r.Host = stripPort(host)
168 }
169 }
170
171 func (ctx *signingCtx) assignAmzQueryValues() {
172 if ctx.credValues.SessionToken != "" {
173 ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
174 }
175 }
176
177 func (ctx *signingCtx) build() error {
178 ctx.buildTime()
179 ctx.buildCredentialString()
180
181 if err := ctx.buildBodyDigest(); err != nil {
182 return err
183 }
184
185 unsignedHeaders := ctx.Request.Header
186
187 ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
188 ctx.buildCanonicalString()
189 ctx.buildStringToSign()
190 ctx.buildSignature()
191
192 parts := []string{
193 authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
194 "SignedHeaders=" + ctx.signedHeaders,
195 authHeaderSignatureElem + ctx.signature,
196 }
197 ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
198
199 return nil
200 }
201
202 func (ctx *signingCtx) buildTime() {
203 ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
204 }
205
206 func (ctx *signingCtx) buildCredentialString() {
207 ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
208 }
209
210 func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
211 headers := make([]string, 0, len(header)+1)
212 headers = append(headers, "host")
213 for k, v := range header {
214 if !r.IsValid(k) {
215 continue
216 }
217 if ctx.SignedHeaderVals == nil {
218 ctx.SignedHeaderVals = make(http.Header)
219 }
220
221 lowerCaseKey := strings.ToLower(k)
222 if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
223
224 ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
225 continue
226 }
227
228 headers = append(headers, lowerCaseKey)
229 ctx.SignedHeaderVals[lowerCaseKey] = v
230 }
231 sort.Strings(headers)
232
233 ctx.signedHeaders = strings.Join(headers, ";")
234
235 headerItems := make([]string, len(headers))
236 for i, k := range headers {
237 if k == "host" {
238 if ctx.Request.Host != "" {
239 headerItems[i] = "host:" + ctx.Request.Host
240 } else {
241 headerItems[i] = "host:" + ctx.Request.URL.Host
242 }
243 } else {
244 headerValues := make([]string, len(ctx.SignedHeaderVals[k]))
245 for i, v := range ctx.SignedHeaderVals[k] {
246 headerValues[i] = strings.TrimSpace(v)
247 }
248 headerItems[i] = k + ":" +
249 strings.Join(headerValues, ",")
250 }
251 }
252 stripExcessSpaces(headerItems)
253 ctx.canonicalHeaders = strings.Join(headerItems, "\n")
254 }
255
256 func (ctx *signingCtx) buildCanonicalString() {
257 ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
258
259 uri := getURIPath(ctx.Request.URL)
260
261 uri = EscapePath(uri, false)
262
263 ctx.canonicalString = strings.Join([]string{
264 ctx.Request.Method,
265 uri,
266 ctx.Request.URL.RawQuery,
267 ctx.canonicalHeaders + "\n",
268 ctx.signedHeaders,
269 ctx.bodyDigest,
270 }, "\n")
271 }
272
273 func (ctx *signingCtx) buildStringToSign() {
274 ctx.stringToSign = strings.Join([]string{
275 authHeaderPrefix,
276 formatTime(ctx.Time),
277 ctx.credentialString,
278 hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
279 }, "\n")
280 }
281
282 func (ctx *signingCtx) buildSignature() {
283 creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
284 signature := hmacSHA256(creds, []byte(ctx.stringToSign))
285 ctx.signature = hex.EncodeToString(signature)
286 }
287
288 func (ctx *signingCtx) buildBodyDigest() error {
289 hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
290 if hash == "" {
291 if ctx.Body == nil {
292 hash = emptyStringSHA256
293 } else {
294 if !aws.IsReaderSeekable(ctx.Body) {
295 return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
296 }
297 hashBytes, err := makeSha256Reader(ctx.Body)
298 if err != nil {
299 return err
300 }
301 hash = hex.EncodeToString(hashBytes)
302 }
303 }
304 ctx.bodyDigest = hash
305
306 return nil
307 }
308
309
310 func (ctx *signingCtx) isRequestSigned() bool {
311 return ctx.Request.Header.Get("Authorization") != ""
312 }
313
314 func hmacSHA256(key []byte, data []byte) []byte {
315 hash := hmac.New(sha256.New, key)
316 hash.Write(data)
317 return hash.Sum(nil)
318 }
319
320 func hashSHA256(data []byte) []byte {
321 hash := sha256.New()
322 hash.Write(data)
323 return hash.Sum(nil)
324 }
325
326 func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
327 hash := sha256.New()
328 start, err := reader.Seek(0, io.SeekCurrent)
329 if err != nil {
330 return nil, err
331 }
332 defer func() {
333
334 _, err = reader.Seek(start, io.SeekStart)
335 }()
336
337
338
339 size, err := aws.SeekerLen(reader)
340 if err != nil {
341 _, _ = io.Copy(hash, reader)
342 } else {
343 _, _ = io.CopyN(hash, reader, size)
344 }
345
346 return hash.Sum(nil), nil
347 }
348
349 const doubleSpace = " "
350
351
352
353 func stripExcessSpaces(vals []string) {
354 var j, k, l, m, spaces int
355 for i, str := range vals {
356
357
358
359 for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
360 }
361
362
363 for k = 0; k < j && str[k] == ' '; k++ {
364 }
365
366
367
368 str = str[k : j+1]
369
370
371 j = strings.Index(str, doubleSpace)
372 if j < 0 {
373 vals[i] = str
374 continue
375 }
376
377 buf := []byte(str)
378 for k, m, l = j, j, len(buf); k < l; k++ {
379 if buf[k] == ' ' {
380 if spaces == 0 {
381
382 buf[m] = buf[k]
383 m++
384 }
385 spaces++
386 } else {
387
388 spaces = 0
389 buf[m] = buf[k]
390 m++
391 }
392 }
393
394 vals[i] = string(buf[:m])
395 }
396 }
397
398 func buildSigningScope(region, service string, dt time.Time) string {
399 return strings.Join([]string{
400 formatShortTime(dt),
401 region,
402 service,
403 awsV4Request,
404 }, "/")
405 }
406
407 func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
408 keyDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
409 keyRegion := hmacSHA256(keyDate, []byte(region))
410 keyService := hmacSHA256(keyRegion, []byte(service))
411 signingKey := hmacSHA256(keyService, []byte(awsV4Request))
412 return signingKey
413 }
414
415 func formatShortTime(dt time.Time) string {
416 return dt.UTC().Format(shortTimeFormat)
417 }
418
419 func formatTime(dt time.Time) string {
420 return dt.UTC().Format(timeFormat)
421 }
422
View as plain text