1 package awstesting
2
3 import (
4 "context"
5 "io"
6 "net/http"
7 "os"
8 "runtime"
9 "strings"
10 "time"
11
12 "github.com/aws/aws-sdk-go-v2/aws"
13 )
14
15
16 type ZeroReader struct{}
17
18
19 func (r *ZeroReader) Read(b []byte) (int, error) {
20 for i := 0; i < len(b); i++ {
21 b[i] = 0
22 }
23 return len(b), nil
24 }
25
26
27
28
29 type ReadCloser struct {
30 Size int
31 Closed bool
32 set bool
33 FillData func(bool, []byte, int, int)
34 }
35
36
37
38 func (r *ReadCloser) Read(b []byte) (int, error) {
39 if r.Closed {
40 return 0, io.EOF
41 }
42
43 delta := len(b)
44 if delta > r.Size {
45 delta = r.Size
46 }
47 r.Size -= delta
48
49 for i := 0; i < delta; i++ {
50 b[i] = 'a'
51 }
52
53 if r.FillData != nil {
54 r.FillData(r.set, b, r.Size, delta)
55 }
56 r.set = true
57
58 if r.Size > 0 {
59 return delta, nil
60 }
61 return delta, io.EOF
62 }
63
64
65 func (r *ReadCloser) Close() error {
66 r.Closed = true
67 return nil
68 }
69
70
71 type FakeContext struct {
72 Error error
73 DoneCh chan struct{}
74 }
75
76
77 func (c *FakeContext) Deadline() (deadline time.Time, ok bool) {
78 return time.Time{}, false
79 }
80
81
82 func (c *FakeContext) Done() <-chan struct{} {
83 return c.DoneCh
84 }
85
86
87 func (c *FakeContext) Err() error {
88 return c.Error
89 }
90
91
92 func (c *FakeContext) Value(key interface{}) interface{} {
93 return nil
94 }
95
96
97
98 func StashEnv(envToKeep ...string) []string {
99 if runtime.GOOS == "windows" {
100 envToKeep = append(envToKeep, "ComSpec")
101 envToKeep = append(envToKeep, "SYSTEM32")
102 envToKeep = append(envToKeep, "SYSTEMROOT")
103 }
104 envToKeep = append(envToKeep, "PATH", "HOME", "USERPROFILE")
105 extraEnv := getEnvs(envToKeep)
106 originalEnv := os.Environ()
107 os.Clearenv()
108 for key, val := range extraEnv {
109 os.Setenv(key, val)
110 }
111 return originalEnv
112 }
113
114
115
116
117 func PopEnv(env []string) {
118 os.Clearenv()
119
120 for _, e := range env {
121 p := strings.SplitN(e, "=", 2)
122 k, v := p[0], ""
123 if len(p) > 1 {
124 v = p[1]
125 }
126 os.Setenv(k, v)
127 }
128 }
129
130
131
132 type MockCredentialsProvider struct {
133 RetrieveFn func(ctx context.Context) (aws.Credentials, error)
134 InvalidateFn func()
135 }
136
137
138 func (p MockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
139 return p.RetrieveFn(ctx)
140 }
141
142
143 func (p MockCredentialsProvider) Invalidate() {
144 p.InvalidateFn()
145 }
146
147 func getEnvs(envs []string) map[string]string {
148 extraEnvs := make(map[string]string)
149 for _, env := range envs {
150 if val, ok := os.LookupEnv(env); ok && len(val) > 0 {
151 extraEnvs[env] = val
152 }
153 }
154 return extraEnvs
155 }
156
157 const (
158 signaturePreambleSigV4 = "AWS4-HMAC-SHA256"
159 signaturePreambleSigV4A = "AWS4-ECDSA-P256-SHA256"
160 )
161
162
163 type SigV4Signature struct {
164 Preamble string
165 SigningName string
166 SigningRegion string
167 SignedHeaders []string
168 Signature string
169 }
170
171
172
173 func ParseSigV4Signature(header http.Header) *SigV4Signature {
174 auth := header.Get("Authorization")
175
176 preamble, after, _ := strings.Cut(auth, " ")
177 credential, after, _ := strings.Cut(after, ", ")
178 signedHeaders, signature, _ := strings.Cut(after, ", ")
179
180 credentialParts := strings.Split(credential, "/")
181
182
183
184
185 var signingName, signingRegion string
186 if preamble == signaturePreambleSigV4 {
187 signingName = credentialParts[3]
188 signingRegion = credentialParts[2]
189 } else if preamble == signaturePreambleSigV4A {
190 signingName = credentialParts[2]
191 signingRegion = header.Get("X-Amz-Region-Set")
192 }
193
194 return &SigV4Signature{
195 Preamble: preamble,
196 SigningName: signingName,
197 SigningRegion: signingRegion,
198 SignedHeaders: strings.Split(signedHeaders, ";"),
199 Signature: signature,
200 }
201 }
202
View as plain text