1 package awstesting
2
3 import (
4 "bytes"
5 "crypto/rand"
6 "crypto/rsa"
7 "crypto/tls"
8 "crypto/x509"
9 "crypto/x509/pkix"
10 "encoding/pem"
11 "fmt"
12 "io/ioutil"
13 "math/big"
14 "net"
15 "net/http"
16 "net/http/httptest"
17 "os"
18 "strings"
19 "time"
20 )
21
22 var (
23
24 TLSBundleCA []byte
25
26
27 TLSBundleCert []byte
28
29
30 TLSBundleKey []byte
31
32
33 ClientTLSCert []byte
34
35
36 ClientTLSKey []byte
37 )
38
39 func init() {
40 caPEM, _, caCert, caPrivKey, err := generateRootCA()
41 if err != nil {
42 panic("failed to generate testing root CA, " + err.Error())
43 }
44 TLSBundleCA = caPEM
45
46 serverCertPEM, serverCertPrivKeyPEM, err := generateLocalCert(caCert, caPrivKey)
47 if err != nil {
48 panic("failed to generate testing server cert, " + err.Error())
49 }
50 TLSBundleCert = serverCertPEM
51 TLSBundleKey = serverCertPrivKeyPEM
52
53 clientCertPEM, clientCertPrivKeyPEM, err := generateLocalCert(caCert, caPrivKey)
54 if err != nil {
55 panic("failed to generate testing client cert, " + err.Error())
56 }
57 ClientTLSCert = clientCertPEM
58 ClientTLSKey = clientCertPrivKeyPEM
59 }
60
61 func generateRootCA() (
62 caPEM, caPrivKeyPEM []byte, caCert *x509.Certificate, caPrivKey *rsa.PrivateKey, err error,
63 ) {
64 caCert = &x509.Certificate{
65 SerialNumber: big.NewInt(42),
66 Subject: pkix.Name{
67 Country: []string{"US"},
68 Organization: []string{"AWS SDK for Go Test Certificate"},
69 CommonName: "Test Root CA",
70 },
71 NotBefore: time.Now().Add(-time.Minute),
72 NotAfter: time.Now().AddDate(1, 0, 0),
73 KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
74 ExtKeyUsage: []x509.ExtKeyUsage{
75 x509.ExtKeyUsageClientAuth,
76 x509.ExtKeyUsageServerAuth,
77 },
78 BasicConstraintsValid: true,
79 IsCA: true,
80 }
81
82
83 caPrivKey, err = rsa.GenerateKey(rand.Reader, 4096)
84 if err != nil {
85 return nil, nil, nil, nil, fmt.Errorf("failed generate CA RSA key, %w", err)
86 }
87
88
89 caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caPrivKey.PublicKey, caPrivKey)
90 if err != nil {
91 return nil, nil, nil, nil, fmt.Errorf("failed generate CA certificate, %w", err)
92 }
93
94
95 var caPEMBuf bytes.Buffer
96 pem.Encode(&caPEMBuf, &pem.Block{
97 Type: "CERTIFICATE",
98 Bytes: caBytes,
99 })
100
101 var caPrivKeyPEMBuf bytes.Buffer
102 pem.Encode(&caPrivKeyPEMBuf, &pem.Block{
103 Type: "RSA PRIVATE KEY",
104 Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
105 })
106
107 return caPEMBuf.Bytes(), caPrivKeyPEMBuf.Bytes(), caCert, caPrivKey, nil
108 }
109
110 func generateLocalCert(parentCert *x509.Certificate, parentPrivKey *rsa.PrivateKey) (
111 certPEM, certPrivKeyPEM []byte, err error,
112 ) {
113 cert := &x509.Certificate{
114 SerialNumber: big.NewInt(42),
115 Subject: pkix.Name{
116 Country: []string{"US"},
117 Organization: []string{"AWS SDK for Go Test Certificate"},
118 CommonName: "Test Root CA",
119 },
120 IPAddresses: []net.IP{
121 net.IPv4(127, 0, 0, 1),
122 net.IPv6loopback,
123 },
124 NotBefore: time.Now().Add(-time.Minute),
125 NotAfter: time.Now().AddDate(1, 0, 0),
126 ExtKeyUsage: []x509.ExtKeyUsage{
127 x509.ExtKeyUsageClientAuth,
128 x509.ExtKeyUsageServerAuth,
129 },
130 KeyUsage: x509.KeyUsageDigitalSignature,
131 }
132
133
134 certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
135 if err != nil {
136 return nil, nil, fmt.Errorf("failed to generate server RSA private key, %w", err)
137 }
138
139
140 certBytes, err := x509.CreateCertificate(rand.Reader, cert, parentCert, &certPrivKey.PublicKey, parentPrivKey)
141 if err != nil {
142 return nil, nil, fmt.Errorf("failed to generate server certificate, %w", err)
143 }
144
145
146 var certPEMBuf bytes.Buffer
147 pem.Encode(&certPEMBuf, &pem.Block{
148 Type: "CERTIFICATE",
149 Bytes: certBytes,
150 })
151
152 var certPrivKeyPEMBuf bytes.Buffer
153 pem.Encode(&certPrivKeyPEMBuf, &pem.Block{
154 Type: "RSA PRIVATE KEY",
155 Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
156 })
157
158 return certPEMBuf.Bytes(), certPrivKeyPEMBuf.Bytes(), nil
159 }
160
161
162
163 func NewTLSClientCertServer(handler http.Handler) (*httptest.Server, error) {
164 server := httptest.NewUnstartedServer(handler)
165
166 if server.TLS == nil {
167 server.TLS = &tls.Config{}
168 }
169 server.TLS.ClientAuth = tls.RequireAndVerifyClientCert
170
171 if server.TLS.ClientCAs == nil {
172 server.TLS.ClientCAs = x509.NewCertPool()
173 }
174 certPem := append(ClientTLSCert, ClientTLSKey...)
175 if ok := server.TLS.ClientCAs.AppendCertsFromPEM(certPem); !ok {
176 return nil, fmt.Errorf("failed to append client certs")
177 }
178
179 return server, nil
180 }
181
182
183
184 func CreateClientTLSCertFiles() (cert, key string, err error) {
185 cert, err = createTmpFile(ClientTLSCert)
186 if err != nil {
187 return "", "", err
188 }
189
190 key, err = createTmpFile(ClientTLSKey)
191 if err != nil {
192 return "", "", err
193 }
194
195 return cert, key, nil
196 }
197
198 func availableLocalAddr(ip string) (v string, err error) {
199 l, err := net.Listen("tcp", ip+":0")
200 if err != nil {
201 return "", err
202 }
203 defer func() {
204 closeErr := l.Close()
205 if err == nil {
206 err = closeErr
207 } else if closeErr != nil {
208 err = fmt.Errorf("ip listener close error: %v, original error: %w", closeErr, err)
209 }
210 }()
211
212 return l.Addr().String(), nil
213 }
214
215
216
217 func CreateTLSServer(cert, key string, mux *http.ServeMux) (string, error) {
218 addr, err := availableLocalAddr("127.0.0.1")
219 if err != nil {
220 return "", err
221 }
222
223 if mux == nil {
224 mux = http.NewServeMux()
225 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {})
226 }
227
228 go func() {
229 if err := http.ListenAndServeTLS(addr, cert, key, mux); err != nil {
230 panic(err)
231 }
232 }()
233
234 for i := 0; i < 60; i++ {
235 if _, err := http.Get("https://" + addr); err != nil && !strings.Contains(err.Error(), "connection refused") {
236 break
237 }
238
239 time.Sleep(1 * time.Second)
240 }
241
242 return "https://" + addr, nil
243 }
244
245
246
247
248 func CreateTLSBundleFiles() (cert, key, ca string, err error) {
249 cert, err = createTmpFile(TLSBundleCert)
250 if err != nil {
251 return "", "", "", err
252 }
253
254 key, err = createTmpFile(TLSBundleKey)
255 if err != nil {
256 return "", "", "", err
257 }
258
259 ca, err = createTmpFile(TLSBundleCA)
260 if err != nil {
261 return "", "", "", err
262 }
263
264 return cert, key, ca, nil
265 }
266
267
268 func CleanupTLSBundleFiles(files ...string) error {
269 for _, file := range files {
270 if err := os.Remove(file); err != nil {
271 return err
272 }
273 }
274
275 return nil
276 }
277
278 func createTmpFile(b []byte) (string, error) {
279 bundleFile, err := ioutil.TempFile(os.TempDir(), "aws-sdk-go-session-test")
280 if err != nil {
281 return "", err
282 }
283
284 _, err = bundleFile.Write(b)
285 if err != nil {
286 return "", err
287 }
288
289 defer bundleFile.Close()
290 return bundleFile.Name(), nil
291 }
292
View as plain text