...

Source file src/github.com/sigstore/rekor/pkg/generated/restapi/configure_rekor_server.go

Documentation: github.com/sigstore/rekor/pkg/generated/restapi

     1  /*
     2  Copyright © 2020 The Sigstore Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  // This file is safe to edit. Once it exists it will not be overwritten
    17  
    18  package restapi
    19  
    20  import (
    21  	"context"
    22  	"crypto/tls"
    23  	go_errors "errors"
    24  	"fmt"
    25  	"net/http"
    26  	"net/http/httputil"
    27  	"strconv"
    28  	"time"
    29  
    30  	// using embed to add the static html page duing build time
    31  	_ "embed"
    32  
    33  	"github.com/go-chi/chi/middleware"
    34  	"github.com/go-openapi/errors"
    35  	"github.com/go-openapi/runtime"
    36  	"github.com/mitchellh/mapstructure"
    37  	"github.com/rs/cors"
    38  	"github.com/spf13/viper"
    39  	"go.uber.org/zap"
    40  	"go.uber.org/zap/zapcore"
    41  
    42  	pkgapi "github.com/sigstore/rekor/pkg/api"
    43  	"github.com/sigstore/rekor/pkg/generated/restapi/operations"
    44  	"github.com/sigstore/rekor/pkg/generated/restapi/operations/entries"
    45  	"github.com/sigstore/rekor/pkg/generated/restapi/operations/index"
    46  	"github.com/sigstore/rekor/pkg/generated/restapi/operations/pubkey"
    47  	"github.com/sigstore/rekor/pkg/generated/restapi/operations/tlog"
    48  	"github.com/sigstore/rekor/pkg/log"
    49  	"github.com/sigstore/rekor/pkg/util"
    50  
    51  	"golang.org/x/exp/slices"
    52  )
    53  
    54  //go:generate swagger generate server --target ../../generated --name RekorServer --spec ../../../openapi.yaml --principal interface{} --exclude-main
    55  
    56  type contextKey string
    57  
    58  var (
    59  	ctxKeyAPIToRecord = contextKey("apiToRecord")
    60  )
    61  
    62  // Context payload for recording metrics.
    63  type apiToRecord struct {
    64  	method *string // Method to record in metrics, if any.
    65  	path   *string // Path to record in metrics, if any.
    66  }
    67  
    68  func configureFlags(_ *operations.RekorServerAPI) {
    69  	// api.CommandLineOptionsGroups = []swag.CommandLineOptionsGroup{ ... }
    70  }
    71  
    72  func configureAPI(api *operations.RekorServerAPI) http.Handler {
    73  	// configure the api here
    74  	api.ServeError = logAndServeError
    75  
    76  	// Set your custom logger if needed. Default one is log.Printf
    77  	// Expected interface func(string, ...interface{})
    78  	//
    79  	// Example:
    80  	// api.Logger = log.Printf
    81  	api.Logger = log.Logger.Infof
    82  
    83  	// api.UseSwaggerUI()
    84  	// To continue using redoc as your UI, uncomment the following line
    85  	// api.UseRedoc()
    86  
    87  	api.JSONConsumer = runtime.JSONConsumer()
    88  	api.JSONProducer = runtime.JSONProducer()
    89  
    90  	api.ApplicationXPemFileProducer = runtime.TextProducer()
    91  
    92  	// disable all endpoints to start
    93  	api.IndexSearchIndexHandler = index.SearchIndexHandlerFunc(pkgapi.SearchIndexNotImplementedHandler)
    94  	api.EntriesCreateLogEntryHandler = entries.CreateLogEntryHandlerFunc(pkgapi.CreateLogEntryNotImplementedHandler)
    95  	api.EntriesGetLogEntryByIndexHandler = entries.GetLogEntryByIndexHandlerFunc(pkgapi.GetLogEntryByIndexNotImplementedHandler)
    96  	api.EntriesGetLogEntryByUUIDHandler = entries.GetLogEntryByUUIDHandlerFunc(pkgapi.GetLogEntryByUUIDNotImplementedHandler)
    97  	api.EntriesSearchLogQueryHandler = entries.SearchLogQueryHandlerFunc(pkgapi.SearchLogQueryNotImplementedHandler)
    98  	api.PubkeyGetPublicKeyHandler = pubkey.GetPublicKeyHandlerFunc(pkgapi.GetPublicKeyNotImplementedHandler)
    99  	api.TlogGetLogProofHandler = tlog.GetLogProofHandlerFunc(pkgapi.GetLogProofNotImplementedHandler)
   100  
   101  	enabledAPIEndpoints := viper.GetStringSlice("enabled_api_endpoints")
   102  	if !slices.Contains(enabledAPIEndpoints, "searchIndex") && viper.GetBool("enable_retrieve_api") {
   103  		enabledAPIEndpoints = append(enabledAPIEndpoints, "searchIndex")
   104  	}
   105  
   106  	for _, enabledAPI := range enabledAPIEndpoints {
   107  		log.Logger.Infof("Enabling API endpoint: %s", enabledAPI)
   108  		switch enabledAPI {
   109  		case "searchIndex":
   110  			api.IndexSearchIndexHandler = index.SearchIndexHandlerFunc(pkgapi.SearchIndexHandler)
   111  		case "getLogInfo":
   112  			api.TlogGetLogInfoHandler = tlog.GetLogInfoHandlerFunc(pkgapi.GetLogInfoHandler)
   113  		case "getPublicKey":
   114  			api.PubkeyGetPublicKeyHandler = pubkey.GetPublicKeyHandlerFunc(pkgapi.GetPublicKeyHandler)
   115  		case "getLogProof":
   116  			api.TlogGetLogProofHandler = tlog.GetLogProofHandlerFunc(pkgapi.GetLogProofHandler)
   117  		case "createLogEntry":
   118  			api.EntriesCreateLogEntryHandler = entries.CreateLogEntryHandlerFunc(pkgapi.CreateLogEntryHandler)
   119  		case "getLogEntryByIndex":
   120  			api.EntriesGetLogEntryByIndexHandler = entries.GetLogEntryByIndexHandlerFunc(pkgapi.GetLogEntryByIndexHandler)
   121  		case "getLogEntryByUUID":
   122  			api.EntriesGetLogEntryByUUIDHandler = entries.GetLogEntryByUUIDHandlerFunc(pkgapi.GetLogEntryByUUIDHandler)
   123  		case "searchLogQuery":
   124  			api.EntriesSearchLogQueryHandler = entries.SearchLogQueryHandlerFunc(pkgapi.SearchLogQueryHandler)
   125  		default:
   126  			log.Logger.Panicf("Unknown API endpoint requested: %s", enabledAPI)
   127  		}
   128  	}
   129  
   130  	// all handlers need to be set before a call to api.AddMiddlewareFor
   131  	for _, enabledAPI := range enabledAPIEndpoints {
   132  		switch enabledAPI {
   133  		case "searchIndex":
   134  			recordMetricsForAPI(api, "POST", "/api/v1/index/retrieve") // add metrics
   135  		case "getLogInfo":
   136  			api.AddMiddlewareFor("GET", "/api/v1/log", middleware.NoCache) // not cacheable
   137  			recordMetricsForAPI(api, "GET", "/api/v1/log")                 // add metrics
   138  		case "getPublicKey":
   139  			api.AddMiddlewareFor("GET", "/api/v1/log/publicKey", middleware.NoCache) // not cacheable
   140  			recordMetricsForAPI(api, "GET", "/api/v1/log/publicKey")                 // add metrics
   141  		case "getLogProof":
   142  			api.AddMiddlewareFor("GET", "/api/v1/log/proof", middleware.NoCache) // not cacheable
   143  			recordMetricsForAPI(api, "GET", "/api/v1/log/proof")                 // add metrics
   144  		case "createLogEntry":
   145  			recordMetricsForAPI(api, "POST", "/api/v1/log/entries") // add metrics
   146  		case "getLogEntryByIndex":
   147  			api.AddMiddlewareFor("GET", "/api/v1/log/entries", middleware.NoCache) // not cacheable
   148  			recordMetricsForAPI(api, "GET", "/api/v1/log/entries")                 // add metrics
   149  		case "getLogEntryByUUID":
   150  			api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}", middleware.NoCache) // not cacheable
   151  			recordMetricsForAPI(api, "GET", "/api/v1/log/entries/{entryUUID}")                 // add metrics
   152  		case "searchLogQuery":
   153  			recordMetricsForAPI(api, "POST", "/api/v1/log/entries/retrieve") // add metrics
   154  		}
   155  	}
   156  	api.RegisterFormat("signedCheckpoint", &util.SignedNote{}, util.SignedCheckpointValidator)
   157  
   158  	api.PreServerShutdown = func() {}
   159  	api.ServerShutdown = func() {
   160  		pkgapi.StopAPI()
   161  	}
   162  
   163  	return setupGlobalMiddleware(api.Serve(setupMiddlewares))
   164  }
   165  
   166  // The TLS configuration before HTTPS server starts.
   167  func configureTLS(_ *tls.Config) {
   168  	// Make all necessary changes to the TLS configuration here.
   169  }
   170  
   171  // As soon as server is initialized but not run yet, this function will be called.
   172  // If you need to modify a config, store server instance to stop it individually later, this is the place.
   173  // This function can be called multiple times, depending on the number of serving schemes.
   174  // scheme value will be set accordingly: "http", "https" or "unix"
   175  func configureServer(_ *http.Server, _, _ string) {
   176  }
   177  
   178  // The middleware configuration is for the handler executors. These do not apply to the swagger.json document.
   179  // The middleware executes after routing but before authentication, binding and validation
   180  func setupMiddlewares(handler http.Handler) http.Handler {
   181  	return handler
   182  }
   183  
   184  type httpRequestFields struct {
   185  	requestMethod string
   186  	requestURL    string
   187  	requestSize   int64
   188  	status        int
   189  	responseSize  int
   190  	userAgent     string
   191  	remoteIp      string //revive:disable:var-naming
   192  	latency       time.Duration
   193  	protocol      string
   194  }
   195  
   196  func (h *httpRequestFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
   197  	enc.AddString("requestMethod", h.requestMethod)
   198  	enc.AddString("requestUrl", h.requestURL)
   199  	enc.AddString("requestSize", fmt.Sprintf("%d", h.requestSize))
   200  	enc.AddInt("status", h.status)
   201  	enc.AddString("responseSize", fmt.Sprintf("%d", h.responseSize))
   202  	enc.AddString("userAgent", h.userAgent)
   203  	enc.AddString("remoteIp", h.remoteIp)
   204  	enc.AddString("latency", fmt.Sprintf("%.9fs", h.latency.Seconds())) // formatted per GCP expectations
   205  	enc.AddString("protocol", h.protocol)
   206  	return nil
   207  }
   208  
   209  // We need this type to act as an adapter between zap and the middleware request logger.
   210  type zapLogEntry struct {
   211  	r *http.Request
   212  }
   213  
   214  func (z *zapLogEntry) Write(status, bytes int, _ http.Header, elapsed time.Duration, extra interface{}) {
   215  	var fields []interface{}
   216  
   217  	// follows https://cloud.google.com/logging/docs/reference/v2/rest/v2/LogEntry as a convention
   218  	// append HTTP Request / Response Information
   219  	scheme := "http"
   220  	if z.r.TLS != nil {
   221  		scheme = "https"
   222  	}
   223  	httpRequestObj := &httpRequestFields{
   224  		requestMethod: z.r.Method,
   225  		requestURL:    fmt.Sprintf("%s://%s%s", scheme, z.r.Host, z.r.RequestURI),
   226  		requestSize:   z.r.ContentLength,
   227  		status:        status,
   228  		responseSize:  bytes,
   229  		userAgent:     z.r.Header.Get("User-Agent"),
   230  		remoteIp:      z.r.RemoteAddr,
   231  		latency:       elapsed,
   232  		protocol:      z.r.Proto,
   233  	}
   234  	fields = append(fields, zap.Object("httpRequest", httpRequestObj))
   235  	if extra != nil {
   236  		fields = append(fields, zap.Any("extra", extra))
   237  	}
   238  
   239  	log.ContextLogger(z.r.Context()).With(fields...).Info("completed request")
   240  }
   241  
   242  func (z *zapLogEntry) Panic(v interface{}, stack []byte) {
   243  	fields := []interface{}{zap.String("message", fmt.Sprintf("%v\n%v", v, string(stack)))}
   244  	log.ContextLogger(z.r.Context()).With(fields...).Errorf("panic detected: %v", v)
   245  }
   246  
   247  type logFormatter struct{}
   248  
   249  func (l *logFormatter) NewLogEntry(r *http.Request) middleware.LogEntry {
   250  	return &zapLogEntry{r}
   251  }
   252  
   253  // The middleware configuration happens before anything, this middleware also applies to serving the swagger.json document.
   254  // So this is a good place to plug in a panic handling middleware, logging and metrics
   255  func setupGlobalMiddleware(handler http.Handler) http.Handler {
   256  	returnHandler := recoverer(handler)
   257  	maxReqBodySize := viper.GetInt64("max_request_body_size")
   258  	if maxReqBodySize > 0 {
   259  		returnHandler = maxBodySize(maxReqBodySize, returnHandler)
   260  	}
   261  	middleware.DefaultLogger = middleware.RequestLogger(&logFormatter{})
   262  	returnHandler = middleware.Logger(returnHandler)
   263  	returnHandler = middleware.Heartbeat("/ping")(returnHandler)
   264  	returnHandler = serveStaticContent(returnHandler)
   265  
   266  	handleCORS := cors.Default().Handler
   267  	returnHandler = handleCORS(returnHandler)
   268  
   269  	returnHandler = wrapMetrics(returnHandler)
   270  
   271  	return middleware.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   272  		ctx := r.Context()
   273  		r = r.WithContext(log.WithRequestID(ctx, middleware.GetReqID(ctx)))
   274  		defer func() {
   275  			_ = log.ContextLogger(ctx).Sync()
   276  		}()
   277  
   278  		returnHandler.ServeHTTP(w, r)
   279  	}))
   280  }
   281  
   282  // Populates the the apiToRecord for this method/path so metrics are emitted.
   283  func recordMetricsForAPI(api *operations.RekorServerAPI, method string, path string) {
   284  	metricsHandler := func(handler http.Handler) http.Handler {
   285  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   286  			ctx := r.Context()
   287  			if apiInfo, ok := ctx.Value(ctxKeyAPIToRecord).(*apiToRecord); ok {
   288  				apiInfo.method = &method
   289  				apiInfo.path = &path
   290  			} else {
   291  				log.ContextLogger(ctx).Warn("Could not attach api info - endpoint may not be monitored.")
   292  			}
   293  			handler.ServeHTTP(w, r)
   294  		})
   295  	}
   296  
   297  	api.AddMiddlewareFor(method, path, metricsHandler)
   298  }
   299  
   300  func wrapMetrics(handler http.Handler) http.Handler {
   301  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   302  		ctx := r.Context()
   303  		apiInfo := apiToRecord{}
   304  		ctx = context.WithValue(ctx, ctxKeyAPIToRecord, &apiInfo)
   305  		r = r.WithContext(ctx)
   306  
   307  		start := time.Now()
   308  		ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
   309  		defer func() {
   310  			// Only record metrics for APIs that need instrumentation.
   311  			if apiInfo.path != nil && apiInfo.method != nil {
   312  				code := strconv.Itoa(ww.Status())
   313  				labels := map[string]string{
   314  					"path": *apiInfo.path,
   315  					"code": code,
   316  				}
   317  				// This logs latency broken down by URL path and response code
   318  				// TODO(var-sdk): delete these metrics once the new metrics are safely rolled out.
   319  				pkgapi.MetricLatency.With(labels).Observe(float64(time.Since(start)))
   320  				pkgapi.MetricLatencySummary.With(labels).Observe(float64(time.Since(start)))
   321  
   322  				pkgapi.MetricRequestLatency.With(
   323  					map[string]string{
   324  						"path":   *apiInfo.path,
   325  						"method": *apiInfo.method,
   326  					}).Observe(float64(time.Since(start)))
   327  
   328  				pkgapi.MetricRequestCount.With(
   329  					map[string]string{
   330  						"path":   *apiInfo.path,
   331  						"method": *apiInfo.method,
   332  						"code":   code,
   333  					}).Inc()
   334  			}
   335  		}()
   336  
   337  		handler.ServeHTTP(ww, r)
   338  
   339  	})
   340  }
   341  
   342  func logAndServeError(w http.ResponseWriter, r *http.Request, err error) {
   343  	ctx := r.Context()
   344  	if apiErr, ok := err.(errors.Error); ok && apiErr.Code() == http.StatusNotFound {
   345  		log.ContextLogger(ctx).Warn(err)
   346  	} else {
   347  		log.ContextLogger(ctx).Error(err)
   348  	}
   349  	if compErr, ok := err.(*errors.CompositeError); ok {
   350  		// iterate over composite error looking for something more specific
   351  		for _, embeddedErr := range compErr.Errors {
   352  			var maxBytesError *http.MaxBytesError
   353  			if parseErr, ok := embeddedErr.(*errors.ParseError); ok && go_errors.As(parseErr.Reason, &maxBytesError) {
   354  				err = errors.New(http.StatusRequestEntityTooLarge, http.StatusText(http.StatusRequestEntityTooLarge))
   355  				break
   356  			}
   357  		}
   358  	}
   359  	requestFields := map[string]interface{}{}
   360  	if decodeErr := mapstructure.Decode(r, &requestFields); decodeErr == nil {
   361  		log.ContextLogger(ctx).Debug(requestFields)
   362  	}
   363  	errors.ServeError(w, r, err)
   364  }
   365  
   366  //go:embed rekorHomePage.html
   367  var homePageBytes []byte
   368  
   369  func serveStaticContent(handler http.Handler) http.Handler {
   370  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   371  		if r.URL.Path == "/" {
   372  			w.Header().Add("Content-Type", "text/html")
   373  			w.WriteHeader(200)
   374  			_, _ = w.Write(homePageBytes)
   375  			return
   376  		}
   377  		handler.ServeHTTP(w, r)
   378  	})
   379  }
   380  
   381  // recoverer
   382  func recoverer(next http.Handler) http.Handler {
   383  	fn := func(w http.ResponseWriter, r *http.Request) {
   384  		defer func() {
   385  			if rvr := recover(); rvr != nil && rvr != http.ErrAbortHandler {
   386  				var fields []interface{}
   387  
   388  				// get context before dump request in case there is an error
   389  				ctx := r.Context()
   390  				request, err := httputil.DumpRequest(r, false)
   391  				if err == nil {
   392  					fields = append(fields, zap.ByteString("request_headers", request))
   393  				}
   394  
   395  				log.ContextLogger(ctx).With(fields...).Errorf("panic detected: %v", rvr)
   396  
   397  				errors.ServeError(w, r, nil)
   398  			}
   399  		}()
   400  
   401  		next.ServeHTTP(w, r)
   402  	}
   403  
   404  	return http.HandlerFunc(fn)
   405  }
   406  
   407  // maxBodySize limits the request body
   408  func maxBodySize(maxLength int64, next http.Handler) http.Handler {
   409  	fn := func(w http.ResponseWriter, r *http.Request) {
   410  		r.Body = http.MaxBytesReader(w, r.Body, maxLength)
   411  		next.ServeHTTP(w, r)
   412  	}
   413  
   414  	return http.HandlerFunc(fn)
   415  }
   416  

View as plain text