1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package restapi
19
20 import (
21 "context"
22 "crypto/tls"
23 "crypto/x509"
24 "errors"
25 "fmt"
26 "log"
27 "net"
28 "net/http"
29 "os"
30 "os/signal"
31 "strconv"
32 "sync"
33 "sync/atomic"
34 "syscall"
35 "time"
36
37 "github.com/go-openapi/runtime/flagext"
38 "github.com/go-openapi/swag"
39 flag "github.com/spf13/pflag"
40 "golang.org/x/net/netutil"
41
42 "github.com/sigstore/timestamp-authority/pkg/generated/restapi/operations"
43 )
44
45 const (
46 schemeHTTP = "http"
47 schemeHTTPS = "https"
48 schemeUnix = "unix"
49 )
50
51 var defaultSchemes []string
52
53 func init() {
54 defaultSchemes = []string{
55 schemeHTTP,
56 }
57 }
58
59 var (
60 enabledListeners []string
61 cleanupTimeout time.Duration
62 gracefulTimeout time.Duration
63 maxHeaderSize flagext.ByteSize
64
65 socketPath string
66
67 host string
68 port int
69 listenLimit int
70 keepAlive time.Duration
71 readTimeout time.Duration
72 writeTimeout time.Duration
73
74 tlsHost string
75 tlsPort int
76 tlsListenLimit int
77 tlsKeepAlive time.Duration
78 tlsReadTimeout time.Duration
79 tlsWriteTimeout time.Duration
80 tlsCertificate string
81 tlsCertificateKey string
82 tlsCACertificate string
83 )
84
85 func init() {
86 maxHeaderSize = flagext.ByteSize(1000000)
87
88 flag.StringSliceVar(&enabledListeners, "scheme", defaultSchemes, "the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec")
89
90 flag.DurationVar(&cleanupTimeout, "cleanup-timeout", 10*time.Second, "grace period for which to wait before killing idle connections")
91 flag.DurationVar(&gracefulTimeout, "graceful-timeout", 15*time.Second, "grace period for which to wait before shutting down the server")
92 flag.Var(&maxHeaderSize, "max-header-size", "controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body")
93
94 flag.StringVar(&socketPath, "socket-path", "/var/run/todo-list.sock", "the unix socket to listen on")
95
96 flag.StringVar(&host, "host", "localhost", "the IP to listen on")
97 flag.IntVar(&port, "port", 0, "the port to listen on for insecure connections, defaults to a random value")
98 flag.IntVar(&listenLimit, "listen-limit", 0, "limit the number of outstanding requests")
99 flag.DurationVar(&keepAlive, "keep-alive", 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)")
100 flag.DurationVar(&readTimeout, "read-timeout", 30*time.Second, "maximum duration before timing out read of the request")
101 flag.DurationVar(&writeTimeout, "write-timeout", 30*time.Second, "maximum duration before timing out write of the response")
102
103 flag.StringVar(&tlsHost, "tls-host", "localhost", "the IP to listen on")
104 flag.IntVar(&tlsPort, "tls-port", 0, "the port to listen on for secure connections, defaults to a random value")
105 flag.StringVar(&tlsCertificate, "tls-certificate", "", "the certificate file to use for secure connections")
106 flag.StringVar(&tlsCertificateKey, "tls-key", "", "the private key file to use for secure connections (without passphrase)")
107 flag.StringVar(&tlsCACertificate, "tls-ca", "", "the certificate authority certificate file to be used with mutual tls auth")
108 flag.IntVar(&tlsListenLimit, "tls-listen-limit", 0, "limit the number of outstanding requests")
109 flag.DurationVar(&tlsKeepAlive, "tls-keep-alive", 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)")
110 flag.DurationVar(&tlsReadTimeout, "tls-read-timeout", 30*time.Second, "maximum duration before timing out read of the request")
111 flag.DurationVar(&tlsWriteTimeout, "tls-write-timeout", 30*time.Second, "maximum duration before timing out write of the response")
112 }
113
114 func stringEnvOverride(orig string, def string, keys ...string) string {
115 for _, k := range keys {
116 if os.Getenv(k) != "" {
117 return os.Getenv(k)
118 }
119 }
120 if def != "" && orig == "" {
121 return def
122 }
123 return orig
124 }
125
126 func intEnvOverride(orig int, def int, keys ...string) int {
127 for _, k := range keys {
128 if os.Getenv(k) != "" {
129 v, err := strconv.Atoi(os.Getenv(k))
130 if err != nil {
131 fmt.Fprintln(os.Stderr, k, "is not a valid number")
132 os.Exit(1)
133 }
134 return v
135 }
136 }
137 if def != 0 && orig == 0 {
138 return def
139 }
140 return orig
141 }
142
143
144 func NewServer(api *operations.TimestampServerAPI) *Server {
145 s := new(Server)
146
147 s.EnabledListeners = enabledListeners
148 s.CleanupTimeout = cleanupTimeout
149 s.GracefulTimeout = gracefulTimeout
150 s.MaxHeaderSize = maxHeaderSize
151 s.SocketPath = socketPath
152 s.Host = stringEnvOverride(host, "", "HOST")
153 s.Port = intEnvOverride(port, 0, "PORT")
154 s.ListenLimit = listenLimit
155 s.KeepAlive = keepAlive
156 s.ReadTimeout = readTimeout
157 s.WriteTimeout = writeTimeout
158 s.TLSHost = stringEnvOverride(tlsHost, s.Host, "TLS_HOST", "HOST")
159 s.TLSPort = intEnvOverride(tlsPort, 0, "TLS_PORT")
160 s.TLSCertificate = stringEnvOverride(tlsCertificate, "", "TLS_CERTIFICATE")
161 s.TLSCertificateKey = stringEnvOverride(tlsCertificateKey, "", "TLS_PRIVATE_KEY")
162 s.TLSCACertificate = stringEnvOverride(tlsCACertificate, "", "TLS_CA_CERTIFICATE")
163 s.TLSListenLimit = tlsListenLimit
164 s.TLSKeepAlive = tlsKeepAlive
165 s.TLSReadTimeout = tlsReadTimeout
166 s.TLSWriteTimeout = tlsWriteTimeout
167 s.shutdown = make(chan struct{})
168 s.api = api
169 s.interrupt = make(chan os.Signal, 1)
170 return s
171 }
172
173
174 func (s *Server) ConfigureAPI() {
175 if s.api != nil {
176 s.handler = configureAPI(s.api)
177 }
178 }
179
180
181 func (s *Server) ConfigureFlags() {
182 if s.api != nil {
183 configureFlags(s.api)
184 }
185 }
186
187
188 type Server struct {
189 EnabledListeners []string
190 CleanupTimeout time.Duration
191 GracefulTimeout time.Duration
192 MaxHeaderSize flagext.ByteSize
193
194 SocketPath string
195 domainSocketL net.Listener
196
197 Host string
198 Port int
199 ListenLimit int
200 KeepAlive time.Duration
201 ReadTimeout time.Duration
202 WriteTimeout time.Duration
203 httpServerL net.Listener
204
205 TLSHost string
206 TLSPort int
207 TLSCertificate string
208 TLSCertificateKey string
209 TLSCACertificate string
210 TLSListenLimit int
211 TLSKeepAlive time.Duration
212 TLSReadTimeout time.Duration
213 TLSWriteTimeout time.Duration
214 httpsServerL net.Listener
215
216 api *operations.TimestampServerAPI
217 handler http.Handler
218 hasListeners bool
219 shutdown chan struct{}
220 shuttingDown int32
221 interrupted bool
222 interrupt chan os.Signal
223 }
224
225
226 func (s *Server) Logf(f string, args ...interface{}) {
227 if s.api != nil && s.api.Logger != nil {
228 s.api.Logger(f, args...)
229 } else {
230 log.Printf(f, args...)
231 }
232 }
233
234
235
236 func (s *Server) Fatalf(f string, args ...interface{}) {
237 if s.api != nil && s.api.Logger != nil {
238 s.api.Logger(f, args...)
239 os.Exit(1)
240 } else {
241 log.Fatalf(f, args...)
242 }
243 }
244
245
246 func (s *Server) SetAPI(api *operations.TimestampServerAPI) {
247 if api == nil {
248 s.api = nil
249 s.handler = nil
250 return
251 }
252
253 s.api = api
254 s.handler = configureAPI(api)
255 }
256
257 func (s *Server) hasScheme(scheme string) bool {
258 schemes := s.EnabledListeners
259 if len(schemes) == 0 {
260 schemes = defaultSchemes
261 }
262
263 for _, v := range schemes {
264 if v == scheme {
265 return true
266 }
267 }
268 return false
269 }
270
271
272 func (s *Server) Serve() (err error) {
273 if !s.hasListeners {
274 if err = s.Listen(); err != nil {
275 return err
276 }
277 }
278
279
280 if s.handler == nil {
281 if s.api == nil {
282 return errors.New("can't create the default handler, as no api is set")
283 }
284
285 s.SetHandler(s.api.Serve(nil))
286 }
287
288 wg := new(sync.WaitGroup)
289 once := new(sync.Once)
290 signalNotify(s.interrupt)
291 go handleInterrupt(once, s)
292
293 servers := []*http.Server{}
294
295 if s.hasScheme(schemeUnix) {
296 domainSocket := new(http.Server)
297 domainSocket.MaxHeaderBytes = int(s.MaxHeaderSize)
298 domainSocket.Handler = s.handler
299 if int64(s.CleanupTimeout) > 0 {
300 domainSocket.IdleTimeout = s.CleanupTimeout
301 }
302
303 configureServer(domainSocket, "unix", string(s.SocketPath))
304
305 servers = append(servers, domainSocket)
306 wg.Add(1)
307 s.Logf("Serving timestamp server at unix://%s", s.SocketPath)
308 go func(l net.Listener) {
309 defer wg.Done()
310 if err := domainSocket.Serve(l); err != nil && err != http.ErrServerClosed {
311 s.Fatalf("%v", err)
312 }
313 s.Logf("Stopped serving timestamp server at unix://%s", s.SocketPath)
314 }(s.domainSocketL)
315 }
316
317 if s.hasScheme(schemeHTTP) {
318 httpServer := new(http.Server)
319 httpServer.MaxHeaderBytes = int(s.MaxHeaderSize)
320 httpServer.ReadTimeout = s.ReadTimeout
321 httpServer.WriteTimeout = s.WriteTimeout
322 httpServer.SetKeepAlivesEnabled(int64(s.KeepAlive) > 0)
323 if s.ListenLimit > 0 {
324 s.httpServerL = netutil.LimitListener(s.httpServerL, s.ListenLimit)
325 }
326
327 if int64(s.CleanupTimeout) > 0 {
328 httpServer.IdleTimeout = s.CleanupTimeout
329 }
330
331 httpServer.Handler = s.handler
332
333 configureServer(httpServer, "http", s.httpServerL.Addr().String())
334
335 servers = append(servers, httpServer)
336 wg.Add(1)
337 s.Logf("Serving timestamp server at http://%s", s.httpServerL.Addr())
338 go func(l net.Listener) {
339 defer wg.Done()
340 if err := httpServer.Serve(l); err != nil && err != http.ErrServerClosed {
341 s.Fatalf("%v", err)
342 }
343 s.Logf("Stopped serving timestamp server at http://%s", l.Addr())
344 }(s.httpServerL)
345 }
346
347 if s.hasScheme(schemeHTTPS) {
348 httpsServer := new(http.Server)
349 httpsServer.MaxHeaderBytes = int(s.MaxHeaderSize)
350 httpsServer.ReadTimeout = s.TLSReadTimeout
351 httpsServer.WriteTimeout = s.TLSWriteTimeout
352 httpsServer.SetKeepAlivesEnabled(int64(s.TLSKeepAlive) > 0)
353 if s.TLSListenLimit > 0 {
354 s.httpsServerL = netutil.LimitListener(s.httpsServerL, s.TLSListenLimit)
355 }
356 if int64(s.CleanupTimeout) > 0 {
357 httpsServer.IdleTimeout = s.CleanupTimeout
358 }
359 httpsServer.Handler = s.handler
360
361
362 httpsServer.TLSConfig = &tls.Config{
363
364
365 PreferServerCipherSuites: true,
366
367
368 CurvePreferences: []tls.CurveID{tls.CurveP256},
369
370 NextProtos: []string{"h2", "http/1.1"},
371
372 MinVersion: tls.VersionTLS12,
373
374 CipherSuites: []uint16{
375 tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
376 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
377 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
378 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
379 tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
380 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
381 },
382 }
383
384
385 if s.TLSCertificate != "" && s.TLSCertificateKey != "" {
386 httpsServer.TLSConfig.Certificates = make([]tls.Certificate, 1)
387 httpsServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(s.TLSCertificate, s.TLSCertificateKey)
388 if err != nil {
389 return err
390 }
391 }
392
393 if s.TLSCACertificate != "" {
394
395 caCert, caCertErr := os.ReadFile(s.TLSCACertificate)
396 if caCertErr != nil {
397 return caCertErr
398 }
399 caCertPool := x509.NewCertPool()
400 ok := caCertPool.AppendCertsFromPEM(caCert)
401 if !ok {
402 return fmt.Errorf("cannot parse CA certificate")
403 }
404 httpsServer.TLSConfig.ClientCAs = caCertPool
405 httpsServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
406 }
407
408
409 configureTLS(httpsServer.TLSConfig)
410
411 if len(httpsServer.TLSConfig.Certificates) == 0 && httpsServer.TLSConfig.GetCertificate == nil {
412
413 if s.TLSCertificate == "" {
414 if s.TLSCertificateKey == "" {
415 s.Fatalf("the required flags `--tls-certificate` and `--tls-key` were not specified")
416 }
417 s.Fatalf("the required flag `--tls-certificate` was not specified")
418 }
419 if s.TLSCertificateKey == "" {
420 s.Fatalf("the required flag `--tls-key` was not specified")
421 }
422
423 s.Fatalf("no certificate was configured for TLS")
424 }
425
426 configureServer(httpsServer, "https", s.httpsServerL.Addr().String())
427
428 servers = append(servers, httpsServer)
429 wg.Add(1)
430 s.Logf("Serving timestamp server at https://%s", s.httpsServerL.Addr())
431 go func(l net.Listener) {
432 defer wg.Done()
433 if err := httpsServer.Serve(l); err != nil && err != http.ErrServerClosed {
434 s.Fatalf("%v", err)
435 }
436 s.Logf("Stopped serving timestamp server at https://%s", l.Addr())
437 }(tls.NewListener(s.httpsServerL, httpsServer.TLSConfig))
438 }
439
440 wg.Add(1)
441 go s.handleShutdown(wg, &servers)
442
443 wg.Wait()
444 return nil
445 }
446
447
448 func (s *Server) Listen() error {
449 if s.hasListeners {
450 return nil
451 }
452
453 if s.hasScheme(schemeHTTPS) {
454
455 if s.TLSHost == "" {
456 s.TLSHost = s.Host
457 }
458
459 if s.TLSListenLimit == 0 {
460 s.TLSListenLimit = s.ListenLimit
461 }
462
463 if int64(s.TLSKeepAlive) == 0 {
464 s.TLSKeepAlive = s.KeepAlive
465 }
466
467 if int64(s.TLSReadTimeout) == 0 {
468 s.TLSReadTimeout = s.ReadTimeout
469 }
470
471 if int64(s.TLSWriteTimeout) == 0 {
472 s.TLSWriteTimeout = s.WriteTimeout
473 }
474 }
475
476 if s.hasScheme(schemeUnix) {
477 domSockListener, err := net.Listen("unix", string(s.SocketPath))
478 if err != nil {
479 return err
480 }
481 s.domainSocketL = domSockListener
482 }
483
484 if s.hasScheme(schemeHTTP) {
485 listener, err := net.Listen("tcp", net.JoinHostPort(s.Host, strconv.Itoa(s.Port)))
486 if err != nil {
487 return err
488 }
489
490 h, p, err := swag.SplitHostPort(listener.Addr().String())
491 if err != nil {
492 return err
493 }
494 s.Host = h
495 s.Port = p
496 s.httpServerL = listener
497 }
498
499 if s.hasScheme(schemeHTTPS) {
500 tlsListener, err := net.Listen("tcp", net.JoinHostPort(s.TLSHost, strconv.Itoa(s.TLSPort)))
501 if err != nil {
502 return err
503 }
504
505 sh, sp, err := swag.SplitHostPort(tlsListener.Addr().String())
506 if err != nil {
507 return err
508 }
509 s.TLSHost = sh
510 s.TLSPort = sp
511 s.httpsServerL = tlsListener
512 }
513
514 s.hasListeners = true
515 return nil
516 }
517
518
519 func (s *Server) Shutdown() error {
520 if atomic.CompareAndSwapInt32(&s.shuttingDown, 0, 1) {
521 close(s.shutdown)
522 }
523 return nil
524 }
525
526 func (s *Server) handleShutdown(wg *sync.WaitGroup, serversPtr *[]*http.Server) {
527
528
529 defer wg.Done()
530
531 <-s.shutdown
532
533 servers := *serversPtr
534
535 ctx, cancel := context.WithTimeout(context.TODO(), s.GracefulTimeout)
536 defer cancel()
537
538
539 s.api.PreServerShutdown()
540
541 shutdownChan := make(chan bool)
542 for i := range servers {
543 server := servers[i]
544 go func() {
545 var success bool
546 defer func() {
547 shutdownChan <- success
548 }()
549 if err := server.Shutdown(ctx); err != nil {
550
551 s.Logf("HTTP server Shutdown: %v", err)
552 } else {
553 success = true
554 }
555 }()
556 }
557
558
559 success := true
560 for range servers {
561 success = success && <-shutdownChan
562 }
563 if success {
564 s.api.ServerShutdown()
565 }
566 }
567
568
569 func (s *Server) GetHandler() http.Handler {
570 return s.handler
571 }
572
573
574 func (s *Server) SetHandler(handler http.Handler) {
575 s.handler = handler
576 }
577
578
579 func (s *Server) UnixListener() (net.Listener, error) {
580 if !s.hasListeners {
581 if err := s.Listen(); err != nil {
582 return nil, err
583 }
584 }
585 return s.domainSocketL, nil
586 }
587
588
589 func (s *Server) HTTPListener() (net.Listener, error) {
590 if !s.hasListeners {
591 if err := s.Listen(); err != nil {
592 return nil, err
593 }
594 }
595 return s.httpServerL, nil
596 }
597
598
599 func (s *Server) TLSListener() (net.Listener, error) {
600 if !s.hasListeners {
601 if err := s.Listen(); err != nil {
602 return nil, err
603 }
604 }
605 return s.httpsServerL, nil
606 }
607
608 func handleInterrupt(once *sync.Once, s *Server) {
609 once.Do(func() {
610 for range s.interrupt {
611 if s.interrupted {
612 s.Logf("Server already shutting down")
613 continue
614 }
615 s.interrupted = true
616 s.Logf("Shutting down... ")
617 if err := s.Shutdown(); err != nil {
618 s.Logf("HTTP server Shutdown: %v", err)
619 }
620 }
621 })
622 }
623
624 func signalNotify(interrupt chan<- os.Signal) {
625 signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
626 }
627
View as plain text