1
2
3
4
5 package aws
6
7 import (
8 "context"
9 "crypto/hmac"
10 "crypto/sha256"
11 "encoding/hex"
12 "encoding/json"
13 "errors"
14 "fmt"
15 "net"
16 "net/url"
17 "os"
18 "runtime"
19 "strings"
20 "time"
21
22 "github.com/twmb/franz-go/pkg/sasl"
23 )
24
25
26
27
28
29 type Auth struct {
30
31 AccessKey string
32
33
34 SecretKey string
35
36
37
38
39
40
41
42
43 SessionToken string
44
45
46
47
48
49
50
51
52
53 UserAgent string
54
55 _ struct{}
56 }
57
58 var hostname, _ = os.Hostname()
59
60 func init() {
61 if hostname == "" {
62 hostname = "unknown"
63 }
64 }
65
66
67
68
69
70
71 func (a Auth) AsManagedStreamingIAMMechanism() sasl.Mechanism {
72 return ManagedStreamingIAM(func(context.Context) (Auth, error) {
73 return a, nil
74 })
75 }
76
77 type mskiam func(context.Context) (Auth, error)
78
79
80
81
82 func ManagedStreamingIAM(authFn func(context.Context) (Auth, error)) sasl.Mechanism {
83 return mskiam(authFn)
84 }
85
86 func (mskiam) Name() string { return "AWS_MSK_IAM" }
87
88 func (fn mskiam) Authenticate(ctx context.Context, host string) (sasl.Session, []byte, error) {
89 auth, err := fn(ctx)
90 if err != nil {
91 return nil, nil, err
92 }
93
94 challenge, err := challenge(auth, host)
95 if err != nil {
96 return nil, nil, err
97 }
98
99 return new(session), challenge, nil
100 }
101
102 type session struct{}
103
104 func (session) Challenge(resp []byte) (bool, []byte, error) {
105 if len(resp) == 0 {
106 return false, nil, errors.New("empty challenge response: failed")
107 }
108 return true, nil, nil
109 }
110
111 const service = "kafka-cluster"
112
113 func challenge(auth Auth, host string) ([]byte, error) {
114 host, _, err := net.SplitHostPort(host)
115 if err != nil {
116 return nil, err
117 }
118 region, err := identifyRegion(host)
119 if err != nil {
120 return nil, err
121 }
122
123 var (
124 timestamp = time.Now().UTC().Format("20060102T150405Z")
125 date = timestamp[:8]
126 scope = scope(date, region)
127 v = make(url.Values)
128 )
129
130 v.Set("Action", service+":Connect")
131 v.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256")
132 v.Set("X-Amz-Credential", auth.AccessKey+"/"+scope)
133 v.Set("X-Amz-Date", timestamp)
134 v.Set("X-Amz-Expires", "300")
135 v.Set("X-Amz-SignedHeaders", "host")
136 if auth.SessionToken != "" {
137 v.Set("X-Amz-Security-Token", auth.SessionToken)
138 }
139
140 qps := strings.ReplaceAll(v.Encode(), "+", "%20")
141
142 canonicalRequest := task1(host, qps)
143 sts := task2(timestamp, scope, canonicalRequest)
144 signature := task3(auth.SecretKey, region, date, sts)
145
146 v.Set("X-Amz-Signature", signature)
147
148
149
150
151 keyvals := make(map[string]string)
152 for key, values := range v {
153 keyvals[strings.ToLower(key)] = values[0]
154 }
155 keyvals["host"] = host
156 keyvals["version"] = "2020_10_22"
157 ua := auth.UserAgent
158 if ua == "" {
159 ua = strings.Join([]string{"franz-go", runtime.Version(), hostname}, "/")
160 }
161 keyvals["user-agent"] = ua
162
163 marshaled, err := json.Marshal(keyvals)
164 if err != nil {
165 return nil, err
166 }
167 return marshaled, nil
168 }
169
170
171
172 func scope(date, region string) string {
173 return strings.Join([]string{date, region, service, "aws4_request"}, "/")
174 }
175
176
177 func task1(host, qps string) []byte {
178
179
180
181
182
183
184
185 canon := make([]byte, 0, 200)
186 canon = append(canon, "GET\n"...)
187 canon = append(canon, "/\n"...)
188 canon = append(canon, qps...)
189 canon = append(canon, '\n')
190
191
192
193
194
195
196
197 canon = append(canon, "host:"...)
198 canon = append(canon, host...)
199 canon = append(canon, "\n\nhost\n"...)
200
201
202
203
204 const emptyBody = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
205 return append(canon, emptyBody...)
206 }
207
208
209 func task2(timestamp, scope string, canonicalRequest []byte) []byte {
210 toSign := make([]byte, 0, 512)
211 toSign = append(toSign, "AWS4-HMAC-SHA256\n"...)
212 toSign = append(toSign, timestamp...)
213 toSign = append(toSign, '\n')
214 toSign = append(toSign, scope...)
215 toSign = append(toSign, '\n')
216 canonHash := sha256.Sum256(canonicalRequest)
217 hexBuf := make([]byte, 64)
218 hex.Encode(hexBuf, canonHash[:])
219 toSign = append(toSign, hexBuf...)
220 return toSign
221 }
222
223 var aws4requestBytes = []byte("aws4_request")
224
225
226 func task3(secretKey, region, date string, sts []byte) string {
227 key := make([]byte, 0, 100)
228 key = append(key, "AWS4"...)
229 key = append(key, secretKey...)
230
231 h := hmac.New(sha256.New, key)
232 h.Write([]byte(date))
233
234 key = h.Sum(key[:0])
235 h = hmac.New(sha256.New, key)
236 h.Write([]byte(region))
237
238 key = h.Sum(key[:0])
239 h = hmac.New(sha256.New, key)
240 h.Write([]byte(service))
241
242 key = h.Sum(key[:0])
243 h = hmac.New(sha256.New, key)
244 h.Write(aws4requestBytes)
245
246 key = h.Sum(key[:0])
247 h = hmac.New(sha256.New, key)
248 h.Write(sts)
249
250 return hex.EncodeToString(h.Sum(key[:0]))
251 }
252
253
254 var suffixes = []string{
255 ".amazonaws.com",
256 ".amazonaws.com.cn",
257 ".c2s.ic.gov",
258 ".sc2s.sgov.gov",
259 }
260
261
262
263 func identifyRegion(host string) (string, error) {
264 for _, suffix := range suffixes {
265 if strings.HasSuffix(host, suffix) {
266 serviceRegion := strings.TrimSuffix(host, suffix)
267 regionDot := strings.LastIndexByte(serviceRegion, '.')
268 if regionDot == -1 {
269 break
270 }
271 return serviceRegion[regionDot+1:], nil
272 }
273 }
274 return "", fmt.Errorf("cannot determine the region in %+q", host)
275 }
276
View as plain text