1 package registry
2
3 import (
4 "bufio"
5 "context"
6 "crypto"
7 "crypto/ecdsa"
8 "crypto/elliptic"
9 "crypto/rand"
10 "crypto/rsa"
11 "crypto/tls"
12 "crypto/x509"
13 "crypto/x509/pkix"
14 "encoding/pem"
15 "fmt"
16 "io/ioutil"
17 "math/big"
18 "net"
19 "net/http"
20 "os"
21 "path"
22 "reflect"
23 "strings"
24 "testing"
25 "time"
26
27 "github.com/docker/distribution/configuration"
28 _ "github.com/docker/distribution/registry/storage/driver/inmemory"
29 )
30
31
32
33
34
35 func TestNextProtos(t *testing.T) {
36 config := &configuration.Configuration{}
37 protos := nextProtos(config)
38 if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
39 t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
40 }
41 config.HTTP.HTTP2.Disabled = false
42 protos = nextProtos(config)
43 if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
44 t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
45 }
46 config.HTTP.HTTP2.Disabled = true
47 protos = nextProtos(config)
48 if !reflect.DeepEqual(protos, []string{"http/1.1"}) {
49 t.Fatalf("expected protos to equal [http/1.1], got %s", protos)
50 }
51 }
52
53 type registryTLSConfig struct {
54 cipherSuites []string
55 certificatePath string
56 privateKeyPath string
57 certificate *tls.Certificate
58 }
59
60 func setupRegistry(tlsCfg *registryTLSConfig, addr string) (*Registry, error) {
61 config := &configuration.Configuration{}
62
63
64 config.HTTP.Addr = addr
65 config.HTTP.DrainTimeout = time.Duration(10) * time.Second
66 if tlsCfg != nil {
67 config.HTTP.TLS.CipherSuites = tlsCfg.cipherSuites
68 config.HTTP.TLS.Certificate = tlsCfg.certificatePath
69 config.HTTP.TLS.Key = tlsCfg.privateKeyPath
70 }
71 config.Storage = map[string]configuration.Parameters{"inmemory": map[string]interface{}{}}
72 return NewRegistry(context.Background(), config)
73 }
74
75 func TestGracefulShutdown(t *testing.T) {
76 registry, err := setupRegistry(nil, ":5000")
77 if err != nil {
78 t.Fatal(err)
79 }
80
81
82 var errchan chan error
83 go func() {
84 errchan <- registry.ListenAndServe()
85 }()
86 select {
87 case err = <-errchan:
88 t.Fatalf("Error listening: %v", err)
89 default:
90 }
91
92
93 time.Sleep(3 * time.Second)
94
95
96 conn, err := net.Dial("tcp", "localhost:5000")
97 if err != nil {
98 t.Fatal(err)
99 }
100 fmt.Fprintf(conn, "GET /v2/ ")
101
102
103 quit <- os.Interrupt
104 time.Sleep(100 * time.Millisecond)
105
106
107 _, err = net.Dial("tcp", "localhost:5000")
108 if err == nil {
109 t.Fatal("Managed to connect after stopping.")
110 }
111
112
113 fmt.Fprintf(conn, "HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
114 resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
115 if err != nil {
116 t.Fatal(err)
117 }
118 if resp.Status != "200 OK" {
119 t.Error("response status is not 200 OK: ", resp.Status)
120 }
121 if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
122 t.Error("Body is not {}; ", string(body))
123 }
124 }
125
126 func TestGetCipherSuite(t *testing.T) {
127 resp, err := getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA"})
128 if err != nil || len(resp) != 1 || resp[0] != tls.TLS_RSA_WITH_AES_128_CBC_SHA {
129 t.Errorf("expected cipher suite %q, got %q",
130 "TLS_RSA_WITH_AES_128_CBC_SHA",
131 strings.Join(getCipherSuiteNames(resp), ","),
132 )
133 }
134
135 resp, err = getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_AES_128_GCM_SHA256"})
136 if err != nil || len(resp) != 2 ||
137 resp[0] != tls.TLS_RSA_WITH_AES_128_CBC_SHA || resp[1] != tls.TLS_AES_128_GCM_SHA256 {
138 t.Errorf("expected cipher suites %q, got %q",
139 "TLS_RSA_WITH_AES_128_CBC_SHA,TLS_AES_128_GCM_SHA256",
140 strings.Join(getCipherSuiteNames(resp), ","),
141 )
142 }
143
144 _, err = getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA", "bad_input"})
145 if err == nil {
146 t.Error("did not return expected error about unknown cipher suite")
147 }
148 }
149
150 func buildRegistryTLSConfig(name, keyType string, cipherSuites []string) (*registryTLSConfig, error) {
151 var priv interface{}
152 var pub crypto.PublicKey
153 var err error
154 switch keyType {
155 case "rsa":
156 priv, err = rsa.GenerateKey(rand.Reader, 2048)
157 if err != nil {
158 return nil, fmt.Errorf("failed to create rsa private key: %v", err)
159 }
160 rsaKey := priv.(*rsa.PrivateKey)
161 pub = rsaKey.Public()
162 case "ecdsa":
163 priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
164 if err != nil {
165 return nil, fmt.Errorf("failed to create ecdsa private key: %v", err)
166 }
167 ecdsaKey := priv.(*ecdsa.PrivateKey)
168 pub = ecdsaKey.Public()
169 default:
170 return nil, fmt.Errorf("unsupported key type: %v", keyType)
171 }
172
173 notBefore := time.Now()
174 notAfter := notBefore.Add(time.Minute)
175 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
176 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
177 if err != nil {
178 return nil, fmt.Errorf("failed to create serial number: %v", err)
179 }
180 cert := x509.Certificate{
181 SerialNumber: serialNumber,
182 Subject: pkix.Name{
183 Organization: []string{"registry_test"},
184 },
185 NotBefore: notBefore,
186 NotAfter: notAfter,
187 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
188 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
189 BasicConstraintsValid: true,
190 IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
191 DNSNames: []string{"localhost"},
192 IsCA: true,
193 }
194 derBytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, pub, priv)
195 if err != nil {
196 return nil, fmt.Errorf("failed to create certificate: %v", err)
197 }
198 if _, err := os.Stat(os.TempDir()); os.IsNotExist(err) {
199 os.Mkdir(os.TempDir(), 1777)
200 }
201
202 certPath := path.Join(os.TempDir(), name+".pem")
203 certOut, err := os.Create(certPath)
204 if err != nil {
205 return nil, fmt.Errorf("failed to create pem: %v", err)
206 }
207 if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
208 return nil, fmt.Errorf("failed to write data to %s: %v", certPath, err)
209 }
210 if err := certOut.Close(); err != nil {
211 return nil, fmt.Errorf("error closing %s: %v", certPath, err)
212 }
213
214 keyPath := path.Join(os.TempDir(), name+".key")
215 keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
216 if err != nil {
217 return nil, fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
218 }
219 privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
220 if err != nil {
221 return nil, fmt.Errorf("unable to marshal private key: %v", err)
222 }
223 if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
224 return nil, fmt.Errorf("failed to write data to key.pem: %v", err)
225 }
226 if err := keyOut.Close(); err != nil {
227 return nil, fmt.Errorf("error closing %s: %v", keyPath, err)
228 }
229
230 tlsCert := tls.Certificate{
231 Certificate: [][]byte{derBytes},
232 PrivateKey: priv,
233 }
234
235 tlsTestCfg := registryTLSConfig{
236 cipherSuites: cipherSuites,
237 certificatePath: certPath,
238 privateKeyPath: keyPath,
239 certificate: &tlsCert,
240 }
241
242 return &tlsTestCfg, nil
243 }
244
245 func TestRegistrySupportedCipherSuite(t *testing.T) {
246 name := "registry_test_server_supported_cipher"
247 cipherSuites := []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}
248 serverTLS, err := buildRegistryTLSConfig(name, "rsa", cipherSuites)
249 if err != nil {
250 t.Fatal(err)
251 }
252
253 registry, err := setupRegistry(serverTLS, ":5001")
254 if err != nil {
255 t.Fatal(err)
256 }
257
258
259 var errchan chan error
260 go func() {
261 errchan <- registry.ListenAndServe()
262 }()
263 select {
264 case err = <-errchan:
265 t.Fatalf("Error listening: %v", err)
266 default:
267 }
268
269
270 time.Sleep(3 * time.Second)
271
272
273 clientCipherSuites, err := getCipherSuites(cipherSuites)
274 if err != nil {
275 t.Fatal(err)
276 }
277 clientTLS := tls.Config{
278 InsecureSkipVerify: true,
279 CipherSuites: clientCipherSuites,
280 }
281 dialer := net.Dialer{
282 Timeout: time.Second * 5,
283 }
284 conn, err := tls.DialWithDialer(&dialer, "tcp", "127.0.0.1:5001", &clientTLS)
285 if err != nil {
286 t.Fatal(err)
287 }
288 fmt.Fprintf(conn, "GET /v2/ HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
289
290 resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
291 if err != nil {
292 t.Fatal(err)
293 }
294 if resp.Status != "200 OK" {
295 t.Error("response status is not 200 OK: ", resp.Status)
296 }
297 if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
298 t.Error("Body is not {}; ", string(body))
299 }
300
301
302 quit <- os.Interrupt
303 time.Sleep(100 * time.Millisecond)
304 }
305
306 func TestRegistryUnsupportedCipherSuite(t *testing.T) {
307 name := "registry_test_server_unsupported_cipher"
308 serverTLS, err := buildRegistryTLSConfig(name, "rsa", []string{"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA358"})
309 if err != nil {
310 t.Fatal(err)
311 }
312
313 registry, err := setupRegistry(serverTLS, ":5002")
314 if err != nil {
315 t.Fatal(err)
316 }
317
318
319 var errchan chan error
320 go func() {
321 errchan <- registry.ListenAndServe()
322 }()
323 select {
324 case err = <-errchan:
325 t.Fatalf("Error listening: %v", err)
326 default:
327 }
328
329
330 time.Sleep(3 * time.Second)
331
332
333 clientTLS := tls.Config{
334 InsecureSkipVerify: true,
335 CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
336 }
337 dialer := net.Dialer{
338 Timeout: time.Second * 5,
339 }
340 _, err = tls.DialWithDialer(&dialer, "tcp", "127.0.0.1:5002", &clientTLS)
341 if err == nil {
342 t.Error("expected TLS connection to timeout")
343 }
344
345
346 quit <- os.Interrupt
347 time.Sleep(100 * time.Millisecond)
348 }
349
View as plain text