1 package processcreds
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "os"
10 "os/exec"
11 "runtime"
12 "time"
13
14 "github.com/aws/aws-sdk-go-v2/aws"
15 "github.com/aws/aws-sdk-go-v2/internal/sdkio"
16 )
17
18 const (
19
20
21 ProviderName = `ProcessProvider`
22
23
24 DefaultTimeout = time.Duration(1) * time.Minute
25 )
26
27
28
29 type ProviderError struct {
30 Err error
31 }
32
33
34 func (e *ProviderError) Error() string {
35 return fmt.Sprintf("process provider error: %v", e.Err)
36 }
37
38
39 func (e *ProviderError) Unwrap() error {
40 return e.Err
41 }
42
43
44
45 type Provider struct {
46
47
48
49
50
51 commandBuilder NewCommandBuilder
52
53 options Options
54 }
55
56
57 type Options struct {
58
59 Timeout time.Duration
60 }
61
62
63
64 type NewCommandBuilder interface {
65 NewCommand(context.Context) (*exec.Cmd, error)
66 }
67
68
69
70 type NewCommandBuilderFunc func(context.Context) (*exec.Cmd, error)
71
72
73 func (fn NewCommandBuilderFunc) NewCommand(ctx context.Context) (*exec.Cmd, error) {
74 return fn(ctx)
75 }
76
77
78
79
80
81 type DefaultNewCommandBuilder struct {
82 Args []string
83 }
84
85
86
87
88 func (b DefaultNewCommandBuilder) NewCommand(ctx context.Context) (*exec.Cmd, error) {
89 var cmdArgs []string
90 if runtime.GOOS == "windows" {
91 cmdArgs = []string{"cmd.exe", "/C"}
92 } else {
93 cmdArgs = []string{"sh", "-c"}
94 }
95
96 if len(b.Args) == 0 {
97 return nil, &ProviderError{
98 Err: fmt.Errorf("failed to prepare command: command must not be empty"),
99 }
100 }
101
102 cmdArgs = append(cmdArgs, b.Args...)
103 cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
104 cmd.Env = os.Environ()
105
106 cmd.Stderr = os.Stderr
107 cmd.Stdin = os.Stdin
108
109 return cmd, nil
110 }
111
112
113
114
115
116
117 func NewProvider(command string, options ...func(*Options)) *Provider {
118 var args []string
119
120
121
122
123 if len(command) > 0 {
124 args = []string{command}
125 }
126
127 commanBuilder := DefaultNewCommandBuilder{
128 Args: args,
129 }
130 return NewProviderCommand(commanBuilder, options...)
131 }
132
133
134
135
136
137 func NewProviderCommand(builder NewCommandBuilder, options ...func(*Options)) *Provider {
138 p := &Provider{
139 commandBuilder: builder,
140 options: Options{
141 Timeout: DefaultTimeout,
142 },
143 }
144
145 for _, option := range options {
146 option(&p.options)
147 }
148
149 return p
150 }
151
152
153
154 type CredentialProcessResponse struct {
155
156
157 Version int
158
159
160 AccessKeyID string `json:"AccessKeyId"`
161
162
163 SecretAccessKey string
164
165
166 SessionToken string
167
168
169 Expiration *time.Time
170 }
171
172
173
174 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
175 out, err := p.executeCredentialProcess(ctx)
176 if err != nil {
177 return aws.Credentials{Source: ProviderName}, err
178 }
179
180
181 resp := &CredentialProcessResponse{}
182 if err = json.Unmarshal(out, resp); err != nil {
183 return aws.Credentials{Source: ProviderName}, &ProviderError{
184 Err: fmt.Errorf("parse failed of process output: %s, error: %w", out, err),
185 }
186 }
187
188 if resp.Version != 1 {
189 return aws.Credentials{Source: ProviderName}, &ProviderError{
190 Err: fmt.Errorf("wrong version in process output (not 1)"),
191 }
192 }
193
194 if len(resp.AccessKeyID) == 0 {
195 return aws.Credentials{Source: ProviderName}, &ProviderError{
196 Err: fmt.Errorf("missing AccessKeyId in process output"),
197 }
198 }
199
200 if len(resp.SecretAccessKey) == 0 {
201 return aws.Credentials{Source: ProviderName}, &ProviderError{
202 Err: fmt.Errorf("missing SecretAccessKey in process output"),
203 }
204 }
205
206 creds := aws.Credentials{
207 Source: ProviderName,
208 AccessKeyID: resp.AccessKeyID,
209 SecretAccessKey: resp.SecretAccessKey,
210 SessionToken: resp.SessionToken,
211 }
212
213
214 if resp.Expiration != nil {
215 creds.CanExpire = true
216 creds.Expires = *resp.Expiration
217 }
218
219 return creds, nil
220 }
221
222
223
224 func (p *Provider) executeCredentialProcess(ctx context.Context) ([]byte, error) {
225 if p.options.Timeout >= 0 {
226 var cancelFunc func()
227 ctx, cancelFunc = context.WithTimeout(ctx, p.options.Timeout)
228 defer cancelFunc()
229 }
230
231 cmd, err := p.commandBuilder.NewCommand(ctx)
232 if err != nil {
233 return nil, err
234 }
235
236
237 output := bytes.NewBuffer(make([]byte, 0, int(8*sdkio.KibiByte)))
238 if cmd.Stdout != nil {
239 cmd.Stdout = io.MultiWriter(cmd.Stdout, output)
240 } else {
241 cmd.Stdout = output
242 }
243
244 execCh := make(chan error, 1)
245 go executeCommand(cmd, execCh)
246
247 select {
248 case execError := <-execCh:
249 if execError == nil {
250 break
251 }
252 select {
253 case <-ctx.Done():
254 return output.Bytes(), &ProviderError{
255 Err: fmt.Errorf("credential process timed out: %w", execError),
256 }
257 default:
258 return output.Bytes(), &ProviderError{
259 Err: fmt.Errorf("error in credential_process: %w", execError),
260 }
261 }
262 }
263
264 out := output.Bytes()
265 if runtime.GOOS == "windows" {
266
267 out = bytes.ReplaceAll(out, []byte(`\"`), []byte(`"`))
268 }
269
270 return out, nil
271 }
272
273 func executeCommand(cmd *exec.Cmd, exec chan error) {
274
275 err := cmd.Start()
276 if err == nil {
277 err = cmd.Wait()
278 }
279
280 exec <- err
281 }
282
View as plain text