...

Source file src/sigs.k8s.io/controller-runtime/pkg/webhook/server_test.go

Documentation: sigs.k8s.io/controller-runtime/pkg/webhook

     1  /*
     2  Copyright 2019 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package webhook_test
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"net/http"
    26  	"path"
    27  	"reflect"
    28  
    29  	. "github.com/onsi/ginkgo/v2"
    30  	. "github.com/onsi/gomega"
    31  	"k8s.io/client-go/rest"
    32  
    33  	"sigs.k8s.io/controller-runtime/pkg/envtest"
    34  	"sigs.k8s.io/controller-runtime/pkg/webhook"
    35  )
    36  
    37  var _ = Describe("Webhook Server", func() {
    38  	var (
    39  		ctx          context.Context
    40  		ctxCancel    context.CancelFunc
    41  		testHostPort string
    42  		client       *http.Client
    43  		server       webhook.Server
    44  		servingOpts  envtest.WebhookInstallOptions
    45  	)
    46  
    47  	BeforeEach(func() {
    48  		ctx, ctxCancel = context.WithCancel(context.Background())
    49  		// closed in individual tests differently
    50  
    51  		servingOpts = envtest.WebhookInstallOptions{}
    52  		Expect(servingOpts.PrepWithoutInstalling()).To(Succeed())
    53  
    54  		testHostPort = net.JoinHostPort(servingOpts.LocalServingHost, fmt.Sprintf("%d", servingOpts.LocalServingPort))
    55  
    56  		// bypass needing to set up the x509 cert pool, etc ourselves
    57  		clientTransport, err := rest.TransportFor(&rest.Config{
    58  			TLSClientConfig: rest.TLSClientConfig{CAData: servingOpts.LocalServingCAData},
    59  		})
    60  		Expect(err).NotTo(HaveOccurred())
    61  		client = &http.Client{
    62  			Transport: clientTransport,
    63  		}
    64  
    65  		server = webhook.NewServer(webhook.Options{
    66  			Host:    servingOpts.LocalServingHost,
    67  			Port:    servingOpts.LocalServingPort,
    68  			CertDir: servingOpts.LocalServingCertDir,
    69  		})
    70  	})
    71  	AfterEach(func() {
    72  		Expect(servingOpts.Cleanup()).To(Succeed())
    73  	})
    74  
    75  	genericStartServer := func(f func(ctx context.Context)) (done <-chan struct{}) {
    76  		doneCh := make(chan struct{})
    77  		go func() {
    78  			defer GinkgoRecover()
    79  			defer close(doneCh)
    80  			f(ctx)
    81  		}()
    82  		// wait till we can ping the server to start the test
    83  		Eventually(func() error {
    84  			_, err := client.Get(fmt.Sprintf("https://%s/unservedpath", testHostPort))
    85  			return err
    86  		}).Should(Succeed())
    87  
    88  		return doneCh
    89  	}
    90  
    91  	startServer := func() (done <-chan struct{}) {
    92  		return genericStartServer(func(ctx context.Context) {
    93  			Expect(server.Start(ctx)).To(Succeed())
    94  		})
    95  	}
    96  
    97  	// TODO(directxman12): figure out a good way to test all the serving setup
    98  	// with httptest.Server to get all the niceness from that.
    99  
   100  	Context("when serving", func() {
   101  		PIt("should verify the client CA name when asked to", func() {
   102  
   103  		})
   104  		PIt("should support HTTP/2", func() {
   105  
   106  		})
   107  
   108  		// TODO(directxman12): figure out a good way to test the port default, etc
   109  	})
   110  
   111  	It("should panic if a duplicate path is registered", func() {
   112  		server.Register("/somepath", &testHandler{})
   113  		doneCh := startServer()
   114  
   115  		Expect(func() { server.Register("/somepath", &testHandler{}) }).To(Panic())
   116  
   117  		ctxCancel()
   118  		Eventually(doneCh, "4s").Should(BeClosed())
   119  	})
   120  
   121  	Context("when registering new webhooks before starting", func() {
   122  		It("should serve a webhook on the requested path", func() {
   123  			server.Register("/somepath", &testHandler{})
   124  
   125  			Expect(server.StartedChecker()(nil)).ToNot(Succeed())
   126  
   127  			doneCh := startServer()
   128  
   129  			Eventually(func() ([]byte, error) {
   130  				resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
   131  				Expect(err).NotTo(HaveOccurred())
   132  				defer resp.Body.Close()
   133  				return io.ReadAll(resp.Body)
   134  			}).Should(Equal([]byte("gadzooks!")))
   135  
   136  			Expect(server.StartedChecker()(nil)).To(Succeed())
   137  
   138  			ctxCancel()
   139  			Eventually(doneCh, "4s").Should(BeClosed())
   140  		})
   141  	})
   142  
   143  	Context("when registering webhooks after starting", func() {
   144  		var (
   145  			doneCh <-chan struct{}
   146  		)
   147  		BeforeEach(func() {
   148  			doneCh = startServer()
   149  		})
   150  		AfterEach(func() {
   151  			// wait for cleanup to happen
   152  			ctxCancel()
   153  			Eventually(doneCh, "4s").Should(BeClosed())
   154  		})
   155  
   156  		It("should serve a webhook on the requested path", func() {
   157  			server.Register("/somepath", &testHandler{})
   158  			resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
   159  			Expect(err).NotTo(HaveOccurred())
   160  			defer resp.Body.Close()
   161  
   162  			Expect(io.ReadAll(resp.Body)).To(Equal([]byte("gadzooks!")))
   163  		})
   164  	})
   165  
   166  	It("should respect passed in TLS configurations", func() {
   167  		var finalCfg *tls.Config
   168  		tlsCfgFunc := func(cfg *tls.Config) {
   169  			cfg.CipherSuites = []uint16{
   170  				tls.TLS_AES_128_GCM_SHA256,
   171  				tls.TLS_AES_256_GCM_SHA384,
   172  			}
   173  			cfg.MinVersion = tls.VersionTLS12
   174  			// save cfg after changes to test against
   175  			finalCfg = cfg
   176  		}
   177  		server = webhook.NewServer(webhook.Options{
   178  			Host:    servingOpts.LocalServingHost,
   179  			Port:    servingOpts.LocalServingPort,
   180  			CertDir: servingOpts.LocalServingCertDir,
   181  			TLSOpts: []func(*tls.Config){
   182  				tlsCfgFunc,
   183  			},
   184  		})
   185  		server.Register("/somepath", &testHandler{})
   186  		doneCh := genericStartServer(func(ctx context.Context) {
   187  			Expect(server.Start(ctx)).To(Succeed())
   188  		})
   189  
   190  		Eventually(func() ([]byte, error) {
   191  			resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
   192  			Expect(err).NotTo(HaveOccurred())
   193  			defer resp.Body.Close()
   194  			return io.ReadAll(resp.Body)
   195  		}).Should(Equal([]byte("gadzooks!")))
   196  		Expect(finalCfg.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
   197  		Expect(finalCfg.CipherSuites).To(ContainElements(
   198  			tls.TLS_AES_128_GCM_SHA256,
   199  			tls.TLS_AES_256_GCM_SHA384,
   200  		))
   201  
   202  		ctxCancel()
   203  		Eventually(doneCh, "4s").Should(BeClosed())
   204  	})
   205  
   206  	It("should prefer GetCertificate through TLSOpts", func() {
   207  		var finalCfg *tls.Config
   208  		finalCert, err := tls.LoadX509KeyPair(
   209  			path.Join(servingOpts.LocalServingCertDir, "tls.crt"),
   210  			path.Join(servingOpts.LocalServingCertDir, "tls.key"),
   211  		)
   212  		Expect(err).NotTo(HaveOccurred())
   213  		finalGetCertificate := func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
   214  			return &finalCert, nil
   215  		}
   216  		server = &webhook.DefaultServer{Options: webhook.Options{
   217  			Host:    servingOpts.LocalServingHost,
   218  			Port:    servingOpts.LocalServingPort,
   219  			CertDir: servingOpts.LocalServingCertDir,
   220  
   221  			TLSOpts: []func(*tls.Config){
   222  				func(cfg *tls.Config) {
   223  					cfg.GetCertificate = finalGetCertificate
   224  					cfg.MinVersion = tls.VersionTLS12
   225  					// save cfg after changes to test against
   226  					finalCfg = cfg
   227  				},
   228  			},
   229  		}}
   230  		server.Register("/somepath", &testHandler{})
   231  		doneCh := genericStartServer(func(ctx context.Context) {
   232  			Expect(server.Start(ctx)).To(Succeed())
   233  		})
   234  
   235  		Eventually(func() ([]byte, error) {
   236  			resp, err := client.Get(fmt.Sprintf("https://%s/somepath", testHostPort))
   237  			Expect(err).NotTo(HaveOccurred())
   238  			defer resp.Body.Close()
   239  			return io.ReadAll(resp.Body)
   240  		}).Should(Equal([]byte("gadzooks!")))
   241  		Expect(finalCfg.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
   242  		// We can't compare the functions directly, but we can compare their pointers
   243  		if reflect.ValueOf(finalCfg.GetCertificate).Pointer() != reflect.ValueOf(finalGetCertificate).Pointer() {
   244  			Fail("GetCertificate was not set properly, or overwritten")
   245  		}
   246  		cert, err := finalCfg.GetCertificate(nil)
   247  		Expect(err).NotTo(HaveOccurred())
   248  		Expect(cert).To(BeEquivalentTo(&finalCert))
   249  
   250  		ctxCancel()
   251  		Eventually(doneCh, "4s").Should(BeClosed())
   252  	})
   253  })
   254  
   255  type testHandler struct {
   256  }
   257  
   258  func (t *testHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
   259  	if _, err := resp.Write([]byte("gadzooks!")); err != nil {
   260  		panic("unable to write http response!")
   261  	}
   262  }
   263  

View as plain text