...

Source file src/github.com/docker/distribution/registry/registry_test.go

Documentation: github.com/docker/distribution/registry

     1  package registry
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"crypto"
     7  	"crypto/ecdsa"
     8  	"crypto/elliptic"
     9  	"crypto/rand"
    10  	"crypto/rsa"
    11  	"crypto/tls"
    12  	"crypto/x509"
    13  	"crypto/x509/pkix"
    14  	"encoding/pem"
    15  	"fmt"
    16  	"io/ioutil"
    17  	"math/big"
    18  	"net"
    19  	"net/http"
    20  	"os"
    21  	"path"
    22  	"reflect"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/docker/distribution/configuration"
    28  	_ "github.com/docker/distribution/registry/storage/driver/inmemory"
    29  )
    30  
    31  // Tests to ensure nextProtos returns the correct protocols when:
    32  // * config.HTTP.HTTP2.Disabled is not explicitly set => [h2 http/1.1]
    33  // * config.HTTP.HTTP2.Disabled is explicitly set to false [h2 http/1.1]
    34  // * config.HTTP.HTTP2.Disabled is explicitly set to true [http/1.1]
    35  func TestNextProtos(t *testing.T) {
    36  	config := &configuration.Configuration{}
    37  	protos := nextProtos(config)
    38  	if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
    39  		t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
    40  	}
    41  	config.HTTP.HTTP2.Disabled = false
    42  	protos = nextProtos(config)
    43  	if !reflect.DeepEqual(protos, []string{"h2", "http/1.1"}) {
    44  		t.Fatalf("expected protos to equal [h2 http/1.1], got %s", protos)
    45  	}
    46  	config.HTTP.HTTP2.Disabled = true
    47  	protos = nextProtos(config)
    48  	if !reflect.DeepEqual(protos, []string{"http/1.1"}) {
    49  		t.Fatalf("expected protos to equal [http/1.1], got %s", protos)
    50  	}
    51  }
    52  
    53  type registryTLSConfig struct {
    54  	cipherSuites    []string
    55  	certificatePath string
    56  	privateKeyPath  string
    57  	certificate     *tls.Certificate
    58  }
    59  
    60  func setupRegistry(tlsCfg *registryTLSConfig, addr string) (*Registry, error) {
    61  	config := &configuration.Configuration{}
    62  	// TODO: this needs to change to something ephemeral as the test will fail if there is any server
    63  	// already listening on port 5000
    64  	config.HTTP.Addr = addr
    65  	config.HTTP.DrainTimeout = time.Duration(10) * time.Second
    66  	if tlsCfg != nil {
    67  		config.HTTP.TLS.CipherSuites = tlsCfg.cipherSuites
    68  		config.HTTP.TLS.Certificate = tlsCfg.certificatePath
    69  		config.HTTP.TLS.Key = tlsCfg.privateKeyPath
    70  	}
    71  	config.Storage = map[string]configuration.Parameters{"inmemory": map[string]interface{}{}}
    72  	return NewRegistry(context.Background(), config)
    73  }
    74  
    75  func TestGracefulShutdown(t *testing.T) {
    76  	registry, err := setupRegistry(nil, ":5000")
    77  	if err != nil {
    78  		t.Fatal(err)
    79  	}
    80  
    81  	// run registry server
    82  	var errchan chan error
    83  	go func() {
    84  		errchan <- registry.ListenAndServe()
    85  	}()
    86  	select {
    87  	case err = <-errchan:
    88  		t.Fatalf("Error listening: %v", err)
    89  	default:
    90  	}
    91  
    92  	// Wait for some unknown random time for server to start listening
    93  	time.Sleep(3 * time.Second)
    94  
    95  	// send incomplete request
    96  	conn, err := net.Dial("tcp", "localhost:5000")
    97  	if err != nil {
    98  		t.Fatal(err)
    99  	}
   100  	fmt.Fprintf(conn, "GET /v2/ ")
   101  
   102  	// send stop signal
   103  	quit <- os.Interrupt
   104  	time.Sleep(100 * time.Millisecond)
   105  
   106  	// try connecting again. it shouldn't
   107  	_, err = net.Dial("tcp", "localhost:5000")
   108  	if err == nil {
   109  		t.Fatal("Managed to connect after stopping.")
   110  	}
   111  
   112  	// make sure earlier request is not disconnected and response can be received
   113  	fmt.Fprintf(conn, "HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
   114  	resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
   115  	if err != nil {
   116  		t.Fatal(err)
   117  	}
   118  	if resp.Status != "200 OK" {
   119  		t.Error("response status is not 200 OK: ", resp.Status)
   120  	}
   121  	if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
   122  		t.Error("Body is not {}; ", string(body))
   123  	}
   124  }
   125  
   126  func TestGetCipherSuite(t *testing.T) {
   127  	resp, err := getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA"})
   128  	if err != nil || len(resp) != 1 || resp[0] != tls.TLS_RSA_WITH_AES_128_CBC_SHA {
   129  		t.Errorf("expected cipher suite %q, got %q",
   130  			"TLS_RSA_WITH_AES_128_CBC_SHA",
   131  			strings.Join(getCipherSuiteNames(resp), ","),
   132  		)
   133  	}
   134  
   135  	resp, err = getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_AES_128_GCM_SHA256"})
   136  	if err != nil || len(resp) != 2 ||
   137  		resp[0] != tls.TLS_RSA_WITH_AES_128_CBC_SHA || resp[1] != tls.TLS_AES_128_GCM_SHA256 {
   138  		t.Errorf("expected cipher suites %q, got %q",
   139  			"TLS_RSA_WITH_AES_128_CBC_SHA,TLS_AES_128_GCM_SHA256",
   140  			strings.Join(getCipherSuiteNames(resp), ","),
   141  		)
   142  	}
   143  
   144  	_, err = getCipherSuites([]string{"TLS_RSA_WITH_AES_128_CBC_SHA", "bad_input"})
   145  	if err == nil {
   146  		t.Error("did not return expected error about unknown cipher suite")
   147  	}
   148  }
   149  
   150  func buildRegistryTLSConfig(name, keyType string, cipherSuites []string) (*registryTLSConfig, error) {
   151  	var priv interface{}
   152  	var pub crypto.PublicKey
   153  	var err error
   154  	switch keyType {
   155  	case "rsa":
   156  		priv, err = rsa.GenerateKey(rand.Reader, 2048)
   157  		if err != nil {
   158  			return nil, fmt.Errorf("failed to create rsa private key: %v", err)
   159  		}
   160  		rsaKey := priv.(*rsa.PrivateKey)
   161  		pub = rsaKey.Public()
   162  	case "ecdsa":
   163  		priv, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
   164  		if err != nil {
   165  			return nil, fmt.Errorf("failed to create ecdsa private key: %v", err)
   166  		}
   167  		ecdsaKey := priv.(*ecdsa.PrivateKey)
   168  		pub = ecdsaKey.Public()
   169  	default:
   170  		return nil, fmt.Errorf("unsupported key type: %v", keyType)
   171  	}
   172  
   173  	notBefore := time.Now()
   174  	notAfter := notBefore.Add(time.Minute)
   175  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
   176  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
   177  	if err != nil {
   178  		return nil, fmt.Errorf("failed to create serial number: %v", err)
   179  	}
   180  	cert := x509.Certificate{
   181  		SerialNumber: serialNumber,
   182  		Subject: pkix.Name{
   183  			Organization: []string{"registry_test"},
   184  		},
   185  		NotBefore:             notBefore,
   186  		NotAfter:              notAfter,
   187  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   188  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   189  		BasicConstraintsValid: true,
   190  		IPAddresses:           []net.IP{net.ParseIP("127.0.0.1")},
   191  		DNSNames:              []string{"localhost"},
   192  		IsCA:                  true,
   193  	}
   194  	derBytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, pub, priv)
   195  	if err != nil {
   196  		return nil, fmt.Errorf("failed to create certificate: %v", err)
   197  	}
   198  	if _, err := os.Stat(os.TempDir()); os.IsNotExist(err) {
   199  		os.Mkdir(os.TempDir(), 1777)
   200  	}
   201  
   202  	certPath := path.Join(os.TempDir(), name+".pem")
   203  	certOut, err := os.Create(certPath)
   204  	if err != nil {
   205  		return nil, fmt.Errorf("failed to create pem: %v", err)
   206  	}
   207  	if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
   208  		return nil, fmt.Errorf("failed to write data to %s: %v", certPath, err)
   209  	}
   210  	if err := certOut.Close(); err != nil {
   211  		return nil, fmt.Errorf("error closing %s: %v", certPath, err)
   212  	}
   213  
   214  	keyPath := path.Join(os.TempDir(), name+".key")
   215  	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
   216  	if err != nil {
   217  		return nil, fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
   218  	}
   219  	privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
   220  	if err != nil {
   221  		return nil, fmt.Errorf("unable to marshal private key: %v", err)
   222  	}
   223  	if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
   224  		return nil, fmt.Errorf("failed to write data to key.pem: %v", err)
   225  	}
   226  	if err := keyOut.Close(); err != nil {
   227  		return nil, fmt.Errorf("error closing %s: %v", keyPath, err)
   228  	}
   229  
   230  	tlsCert := tls.Certificate{
   231  		Certificate: [][]byte{derBytes},
   232  		PrivateKey:  priv,
   233  	}
   234  
   235  	tlsTestCfg := registryTLSConfig{
   236  		cipherSuites:    cipherSuites,
   237  		certificatePath: certPath,
   238  		privateKeyPath:  keyPath,
   239  		certificate:     &tlsCert,
   240  	}
   241  
   242  	return &tlsTestCfg, nil
   243  }
   244  
   245  func TestRegistrySupportedCipherSuite(t *testing.T) {
   246  	name := "registry_test_server_supported_cipher"
   247  	cipherSuites := []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}
   248  	serverTLS, err := buildRegistryTLSConfig(name, "rsa", cipherSuites)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  
   253  	registry, err := setupRegistry(serverTLS, ":5001")
   254  	if err != nil {
   255  		t.Fatal(err)
   256  	}
   257  
   258  	// run registry server
   259  	var errchan chan error
   260  	go func() {
   261  		errchan <- registry.ListenAndServe()
   262  	}()
   263  	select {
   264  	case err = <-errchan:
   265  		t.Fatalf("Error listening: %v", err)
   266  	default:
   267  	}
   268  
   269  	// Wait for some unknown random time for server to start listening
   270  	time.Sleep(3 * time.Second)
   271  
   272  	// send tls request with server supported cipher suite
   273  	clientCipherSuites, err := getCipherSuites(cipherSuites)
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  	clientTLS := tls.Config{
   278  		InsecureSkipVerify: true,
   279  		CipherSuites:       clientCipherSuites,
   280  	}
   281  	dialer := net.Dialer{
   282  		Timeout: time.Second * 5,
   283  	}
   284  	conn, err := tls.DialWithDialer(&dialer, "tcp", "127.0.0.1:5001", &clientTLS)
   285  	if err != nil {
   286  		t.Fatal(err)
   287  	}
   288  	fmt.Fprintf(conn, "GET /v2/ HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n")
   289  
   290  	resp, err := http.ReadResponse(bufio.NewReader(conn), nil)
   291  	if err != nil {
   292  		t.Fatal(err)
   293  	}
   294  	if resp.Status != "200 OK" {
   295  		t.Error("response status is not 200 OK: ", resp.Status)
   296  	}
   297  	if body, err := ioutil.ReadAll(resp.Body); err != nil || string(body) != "{}" {
   298  		t.Error("Body is not {}; ", string(body))
   299  	}
   300  
   301  	// send stop signal
   302  	quit <- os.Interrupt
   303  	time.Sleep(100 * time.Millisecond)
   304  }
   305  
   306  func TestRegistryUnsupportedCipherSuite(t *testing.T) {
   307  	name := "registry_test_server_unsupported_cipher"
   308  	serverTLS, err := buildRegistryTLSConfig(name, "rsa", []string{"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA358"})
   309  	if err != nil {
   310  		t.Fatal(err)
   311  	}
   312  
   313  	registry, err := setupRegistry(serverTLS, ":5002")
   314  	if err != nil {
   315  		t.Fatal(err)
   316  	}
   317  
   318  	// run registry server
   319  	var errchan chan error
   320  	go func() {
   321  		errchan <- registry.ListenAndServe()
   322  	}()
   323  	select {
   324  	case err = <-errchan:
   325  		t.Fatalf("Error listening: %v", err)
   326  	default:
   327  	}
   328  
   329  	// Wait for some unknown random time for server to start listening
   330  	time.Sleep(3 * time.Second)
   331  
   332  	// send tls request with server unsupported cipher suite
   333  	clientTLS := tls.Config{
   334  		InsecureSkipVerify: true,
   335  		CipherSuites:       []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   336  	}
   337  	dialer := net.Dialer{
   338  		Timeout: time.Second * 5,
   339  	}
   340  	_, err = tls.DialWithDialer(&dialer, "tcp", "127.0.0.1:5002", &clientTLS)
   341  	if err == nil {
   342  		t.Error("expected TLS connection to timeout")
   343  	}
   344  
   345  	// send stop signal
   346  	quit <- os.Interrupt
   347  	time.Sleep(100 * time.Millisecond)
   348  }
   349  

View as plain text