1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package restapi
20
21 import (
22 "context"
23 "crypto/tls"
24 "crypto/x509"
25 "errors"
26 "fmt"
27 "log"
28 "net"
29 "net/http"
30 "os"
31 "os/signal"
32 "strconv"
33 "sync"
34 "sync/atomic"
35 "syscall"
36 "time"
37
38 "github.com/go-openapi/runtime/flagext"
39 "github.com/go-openapi/swag"
40 flag "github.com/spf13/pflag"
41 "golang.org/x/net/netutil"
42
43 "github.com/sigstore/rekor/pkg/generated/restapi/operations"
44 )
45
46 const (
47 schemeHTTP = "http"
48 schemeHTTPS = "https"
49 schemeUnix = "unix"
50 )
51
52 var defaultSchemes []string
53
54 func init() {
55 defaultSchemes = []string{
56 schemeHTTP,
57 }
58 }
59
60 var (
61 enabledListeners []string
62 cleanupTimeout time.Duration
63 gracefulTimeout time.Duration
64 maxHeaderSize flagext.ByteSize
65
66 socketPath string
67
68 host string
69 port int
70 listenLimit int
71 keepAlive time.Duration
72 readTimeout time.Duration
73 writeTimeout time.Duration
74
75 tlsHost string
76 tlsPort int
77 tlsListenLimit int
78 tlsKeepAlive time.Duration
79 tlsReadTimeout time.Duration
80 tlsWriteTimeout time.Duration
81 tlsCertificate string
82 tlsCertificateKey string
83 tlsCACertificate string
84 )
85
86 func init() {
87 maxHeaderSize = flagext.ByteSize(1000000)
88
89 flag.StringSliceVar(&enabledListeners, "scheme", defaultSchemes, "the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec")
90
91 flag.DurationVar(&cleanupTimeout, "cleanup-timeout", 10*time.Second, "grace period for which to wait before killing idle connections")
92 flag.DurationVar(&gracefulTimeout, "graceful-timeout", 15*time.Second, "grace period for which to wait before shutting down the server")
93 flag.Var(&maxHeaderSize, "max-header-size", "controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body")
94
95 flag.StringVar(&socketPath, "socket-path", "/var/run/todo-list.sock", "the unix socket to listen on")
96
97 flag.StringVar(&host, "host", "localhost", "the IP to listen on")
98 flag.IntVar(&port, "port", 0, "the port to listen on for insecure connections, defaults to a random value")
99 flag.IntVar(&listenLimit, "listen-limit", 0, "limit the number of outstanding requests")
100 flag.DurationVar(&keepAlive, "keep-alive", 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)")
101 flag.DurationVar(&readTimeout, "read-timeout", 30*time.Second, "maximum duration before timing out read of the request")
102 flag.DurationVar(&writeTimeout, "write-timeout", 30*time.Second, "maximum duration before timing out write of the response")
103
104 flag.StringVar(&tlsHost, "tls-host", "localhost", "the IP to listen on")
105 flag.IntVar(&tlsPort, "tls-port", 0, "the port to listen on for secure connections, defaults to a random value")
106 flag.StringVar(&tlsCertificate, "tls-certificate", "", "the certificate file to use for secure connections")
107 flag.StringVar(&tlsCertificateKey, "tls-key", "", "the private key file to use for secure connections (without passphrase)")
108 flag.StringVar(&tlsCACertificate, "tls-ca", "", "the certificate authority certificate file to be used with mutual tls auth")
109 flag.IntVar(&tlsListenLimit, "tls-listen-limit", 0, "limit the number of outstanding requests")
110 flag.DurationVar(&tlsKeepAlive, "tls-keep-alive", 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)")
111 flag.DurationVar(&tlsReadTimeout, "tls-read-timeout", 30*time.Second, "maximum duration before timing out read of the request")
112 flag.DurationVar(&tlsWriteTimeout, "tls-write-timeout", 30*time.Second, "maximum duration before timing out write of the response")
113 }
114
115 func stringEnvOverride(orig string, def string, keys ...string) string {
116 for _, k := range keys {
117 if os.Getenv(k) != "" {
118 return os.Getenv(k)
119 }
120 }
121 if def != "" && orig == "" {
122 return def
123 }
124 return orig
125 }
126
127 func intEnvOverride(orig int, def int, keys ...string) int {
128 for _, k := range keys {
129 if os.Getenv(k) != "" {
130 v, err := strconv.Atoi(os.Getenv(k))
131 if err != nil {
132 fmt.Fprintln(os.Stderr, k, "is not a valid number")
133 os.Exit(1)
134 }
135 return v
136 }
137 }
138 if def != 0 && orig == 0 {
139 return def
140 }
141 return orig
142 }
143
144
145 func NewServer(api *operations.RekorServerAPI) *Server {
146 s := new(Server)
147
148 s.EnabledListeners = enabledListeners
149 s.CleanupTimeout = cleanupTimeout
150 s.GracefulTimeout = gracefulTimeout
151 s.MaxHeaderSize = maxHeaderSize
152 s.SocketPath = socketPath
153 s.Host = stringEnvOverride(host, "", "HOST")
154 s.Port = intEnvOverride(port, 0, "PORT")
155 s.ListenLimit = listenLimit
156 s.KeepAlive = keepAlive
157 s.ReadTimeout = readTimeout
158 s.WriteTimeout = writeTimeout
159 s.TLSHost = stringEnvOverride(tlsHost, s.Host, "TLS_HOST", "HOST")
160 s.TLSPort = intEnvOverride(tlsPort, 0, "TLS_PORT")
161 s.TLSCertificate = stringEnvOverride(tlsCertificate, "", "TLS_CERTIFICATE")
162 s.TLSCertificateKey = stringEnvOverride(tlsCertificateKey, "", "TLS_PRIVATE_KEY")
163 s.TLSCACertificate = stringEnvOverride(tlsCACertificate, "", "TLS_CA_CERTIFICATE")
164 s.TLSListenLimit = tlsListenLimit
165 s.TLSKeepAlive = tlsKeepAlive
166 s.TLSReadTimeout = tlsReadTimeout
167 s.TLSWriteTimeout = tlsWriteTimeout
168 s.shutdown = make(chan struct{})
169 s.api = api
170 s.interrupt = make(chan os.Signal, 1)
171 return s
172 }
173
174
175 func (s *Server) ConfigureAPI() {
176 if s.api != nil {
177 s.handler = configureAPI(s.api)
178 }
179 }
180
181
182 func (s *Server) ConfigureFlags() {
183 if s.api != nil {
184 configureFlags(s.api)
185 }
186 }
187
188
189 type Server struct {
190 EnabledListeners []string
191 CleanupTimeout time.Duration
192 GracefulTimeout time.Duration
193 MaxHeaderSize flagext.ByteSize
194
195 SocketPath string
196 domainSocketL net.Listener
197
198 Host string
199 Port int
200 ListenLimit int
201 KeepAlive time.Duration
202 ReadTimeout time.Duration
203 WriteTimeout time.Duration
204 httpServerL net.Listener
205
206 TLSHost string
207 TLSPort int
208 TLSCertificate string
209 TLSCertificateKey string
210 TLSCACertificate string
211 TLSListenLimit int
212 TLSKeepAlive time.Duration
213 TLSReadTimeout time.Duration
214 TLSWriteTimeout time.Duration
215 httpsServerL net.Listener
216
217 api *operations.RekorServerAPI
218 handler http.Handler
219 hasListeners bool
220 shutdown chan struct{}
221 shuttingDown int32
222 interrupted bool
223 interrupt chan os.Signal
224 }
225
226
227 func (s *Server) Logf(f string, args ...interface{}) {
228 if s.api != nil && s.api.Logger != nil {
229 s.api.Logger(f, args...)
230 } else {
231 log.Printf(f, args...)
232 }
233 }
234
235
236
237 func (s *Server) Fatalf(f string, args ...interface{}) {
238 if s.api != nil && s.api.Logger != nil {
239 s.api.Logger(f, args...)
240 os.Exit(1)
241 } else {
242 log.Fatalf(f, args...)
243 }
244 }
245
246
247 func (s *Server) SetAPI(api *operations.RekorServerAPI) {
248 if api == nil {
249 s.api = nil
250 s.handler = nil
251 return
252 }
253
254 s.api = api
255 s.handler = configureAPI(api)
256 }
257
258 func (s *Server) hasScheme(scheme string) bool {
259 schemes := s.EnabledListeners
260 if len(schemes) == 0 {
261 schemes = defaultSchemes
262 }
263
264 for _, v := range schemes {
265 if v == scheme {
266 return true
267 }
268 }
269 return false
270 }
271
272
273 func (s *Server) Serve() (err error) {
274 if !s.hasListeners {
275 if err = s.Listen(); err != nil {
276 return err
277 }
278 }
279
280
281 if s.handler == nil {
282 if s.api == nil {
283 return errors.New("can't create the default handler, as no api is set")
284 }
285
286 s.SetHandler(s.api.Serve(nil))
287 }
288
289 wg := new(sync.WaitGroup)
290 once := new(sync.Once)
291 signalNotify(s.interrupt)
292 go handleInterrupt(once, s)
293
294 servers := []*http.Server{}
295
296 if s.hasScheme(schemeUnix) {
297 domainSocket := new(http.Server)
298 domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
299 domainSocket.Handler = s.handler
300 if int64(s.CleanupTimeout) > 0 {
301 domainSocket.IdleTimeout = s.CleanupTimeout
302 }
303
304 configureServer(domainSocket, "unix", string(s.SocketPath))
305
306 servers = append(servers, domainSocket)
307 wg.Add(1)
308 s.Logf("Serving rekor server at unix://%s", s.SocketPath)
309 go func(l net.Listener) {
310 defer wg.Done()
311 if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
312 s.Fatalf("%v", err)
313 }
314 s.Logf("Stopped serving rekor server at unix://%s", s.SocketPath)
315 }(s.domainSocketL)
316 }
317
318 if s.hasScheme(schemeHTTP) {
319 httpServer := new(http.Server)
320 httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
321 httpServer.ReadTimeout = s.ReadTimeout
322 httpServer.WriteTimeout = s.WriteTimeout
323 httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
324 if s.ListenLimit > 0 {
325 s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
326 }
327
328 if int64(s.CleanupTimeout) > 0 {
329 httpServer.IdleTimeout = s.CleanupTimeout
330 }
331
332 httpServer.Handler = s.handler
333
334 configureServer(httpServer, "http", s.httpServerL.Addr().String())
335
336 servers = append(servers, httpServer)
337 wg.Add(1)
338 s.Logf("Serving rekor server at http://%s", s.httpServerL.Addr())
339 go func(l net.Listener) {
340 defer wg.Done()
341 if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
342 s.Fatalf("%v", err)
343 }
344 s.Logf("Stopped serving rekor server at http://%s", l.Addr())
345 }(s.httpServerL)
346 }
347
348 if s.hasScheme(schemeHTTPS) {
349 httpsServer := new(http.Server)
350 httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
351 httpsServer.ReadTimeout = s.TLSReadTimeout
352 httpsServer.WriteTimeout = s.TLSWriteTimeout
353 httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
354 if s.TLSListenLimit > 0 {
355 s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
356 }
357 if int64(s.CleanupTimeout) > 0 {
358 httpsServer.IdleTimeout = s.CleanupTimeout
359 }
360 httpsServer.Handler = s.handler
361
362
363 httpsServer.TLSConfig = &tls.Config{
364
365
366 PreferServerCipherSuites: true,
367
368
369 CurvePreferences: []tls.CurveID{tls.CurveP256},
370
371 NextProtos: []string{"h2", "http/1.1"},
372
373 MinVersion: tls.VersionTLS12,
374
375 CipherSuites: []uint16{
376 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
377 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
378 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
379 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
380 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
381 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
382 },
383 }
384
385
386 if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
387 httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
388 httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(s.TLSCertificate, s.TLSCertificateKey)
389 if err != nil {
390 return err
391 }
392 }
393
394 if s.TLSCACertificate != "" {
395
396 caCert, caCertErr := os.ReadFile(s.TLSCACertificate)
397 if caCertErr != nil {
398 return caCertErr
399 }
400 caCertPool := x509.NewCertPool()
401 ok := caCertPool.AppendCertsFromPEM(caCert)
402 if !ok {
403 return fmt.Errorf("cannot parse CA certificate")
404 }
405 httpsServer.TLSConfig.ClientCAs = caCertPool
406 httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
407 }
408
409
410 configureTLS(httpsServer.TLSConfig)
411
412 if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil {
413
414 if s.TLSCertificate == "" {
415 if s.TLSCertificateKey == "" {
416 s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
417 }
418 s.Fatalf("the required flag `--tls-certificate` was not specified")
419 }
420 if s.TLSCertificateKey == "" {
421 s.Fatalf("the required flag `--tls-key` was not specified")
422 }
423
424 s.Fatalf("no certificate was configured for TLS")
425 }
426
427 configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
428
429 servers = append(servers, httpsServer)
430 wg.Add(1)
431 s.Logf("Serving rekor server at https://%s", s.httpsServerL.Addr())
432 go func(l net.Listener) {
433 defer wg.Done()
434 if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
435 s.Fatalf("%v", err)
436 }
437 s.Logf("Stopped serving rekor server at https://%s", l.Addr())
438 }(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
439 }
440
441 wg.Add(1)
442 go s.handleShutdown(wg, &servers)
443
444 wg.Wait()
445 return nil
446 }
447
448
449 func (s *Server) Listen() error {
450 if s.hasListeners {
451 return nil
452 }
453
454 if s.hasScheme(schemeHTTPS) {
455
456 if s.TLSHost == "" {
457 s.TLSHost = s.Host
458 }
459
460 if s.TLSListenLimit == 0 {
461 s.TLSListenLimit = s.ListenLimit
462 }
463
464 if int64(s.TLSKeepAlive) == 0 {
465 s.TLSKeepAlive = s.KeepAlive
466 }
467
468 if int64(s.TLSReadTimeout) == 0 {
469 s.TLSReadTimeout = s.ReadTimeout
470 }
471
472 if int64(s.TLSWriteTimeout) == 0 {
473 s.TLSWriteTimeout = s.WriteTimeout
474 }
475 }
476
477 if s.hasScheme(schemeUnix) {
478 domSockListener, err := net.Listen("unix", string(s.SocketPath))
479 if err != nil {
480 return err
481 }
482 s.domainSocketL = domSockListener
483 }
484
485 if s.hasScheme(schemeHTTP) {
486 listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
487 if err != nil {
488 return err
489 }
490
491 h, p, err := swag.SplitHostPort(listener.Addr().String())
492 if err != nil {
493 return err
494 }
495 s.Host = h
496 s.Port = p
497 s.httpServerL = listener
498 }
499
500 if s.hasScheme(schemeHTTPS) {
501 tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
502 if err != nil {
503 return err
504 }
505
506 sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
507 if err != nil {
508 return err
509 }
510 s.TLSHost = sh
511 s.TLSPort = sp
512 s.httpsServerL = tlsListener
513 }
514
515 s.hasListeners = true
516 return nil
517 }
518
519
520 func (s *Server) Shutdown() error {
521 if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
522 close(s.shutdown)
523 }
524 return nil
525 }
526
527 func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
528
529
530 defer wg.Done()
531
532 <-s.shutdown
533
534 servers := *serversPtr
535
536 ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
537 defer cancel()
538
539
540 s.api.PreServerShutdown()
541
542 shutdownChan := make(chan bool)
543 for i := range servers {
544 server := servers[i]
545 go func() {
546 var success bool
547 defer func() {
548 shutdownChan <- success
549 }()
550 if err := server.Shutdown(ctx); err != nil {
551
552 s.Logf("HTTP server Shutdown: %v", err)
553 } else {
554 success = true
555 }
556 }()
557 }
558
559
560 success := true
561 for range servers {
562 success = success && <-shutdownChan
563 }
564 if success {
565 s.api.ServerShutdown()
566 }
567 }
568
569
570 func (s *Server) GetHandler() http.Handler {
571 return s.handler
572 }
573
574
575 func (s *Server) SetHandler(handler http.Handler) {
576 s.handler = handler
577 }
578
579
580 func (s *Server) UnixListener() (net.Listener, error) {
581 if !s.hasListeners {
582 if err := s.Listen(); err != nil {
583 return nil, err
584 }
585 }
586 return s.domainSocketL, nil
587 }
588
589
590 func (s *Server) HTTPListener() (net.Listener, error) {
591 if !s.hasListeners {
592 if err := s.Listen(); err != nil {
593 return nil, err
594 }
595 }
596 return s.httpServerL, nil
597 }
598
599
600 func (s *Server) TLSListener() (net.Listener, error) {
601 if !s.hasListeners {
602 if err := s.Listen(); err != nil {
603 return nil, err
604 }
605 }
606 return s.httpsServerL, nil
607 }
608
609 func handleInterrupt(once *sync.Once, s *Server) {
610 once.Do(func() {
611 for range s.interrupt {
612 if s.interrupted {
613 s.Logf("Server already shutting down")
614 continue
615 }
616 s.interrupted = true
617 s.Logf("Shutting down... ")
618 if err := s.Shutdown(); err != nil {
619 s.Logf("HTTP server Shutdown: %v", err)
620 }
621 }
622 })
623 }
624
625 func signalNotify(interrupt chan<- os.Signal) {
626 signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
627 }
628
View as plain text