1
16
17
18 package restapi
19
20 import (
21 "context"
22 "crypto/tls"
23 go_errors "errors"
24 "fmt"
25 "net/http"
26 "net/http/httputil"
27 "strconv"
28 "time"
29
30
31 _ "embed"
32
33 "github.com/go-chi/chi/middleware"
34 "github.com/go-openapi/errors"
35 "github.com/go-openapi/runtime"
36 "github.com/mitchellh/mapstructure"
37 "github.com/rs/cors"
38 "github.com/spf13/viper"
39 "go.uber.org/zap"
40 "go.uber.org/zap/zapcore"
41
42 pkgapi "github.com/sigstore/rekor/pkg/api"
43 "github.com/sigstore/rekor/pkg/generated/restapi/operations"
44 "github.com/sigstore/rekor/pkg/generated/restapi/operations/entries"
45 "github.com/sigstore/rekor/pkg/generated/restapi/operations/index"
46 "github.com/sigstore/rekor/pkg/generated/restapi/operations/pubkey"
47 "github.com/sigstore/rekor/pkg/generated/restapi/operations/tlog"
48 "github.com/sigstore/rekor/pkg/log"
49 "github.com/sigstore/rekor/pkg/util"
50
51 "golang.org/x/exp/slices"
52 )
53
54
55
56 type contextKey string
57
58 var (
59 ctxKeyAPIToRecord = contextKey("apiToRecord")
60 )
61
62
63 type apiToRecord struct {
64 method *string
65 path *string
66 }
67
68 func configureFlags(_ *operations.RekorServerAPI) {
69
70 }
71
72 func configureAPI(api *operations.RekorServerAPI) http.Handler {
73
74 api.ServeError = logAndServeError
75
76
77
78
79
80
81 api.Logger = log.Logger.Infof
82
83
84
85
86
87 api.JSONConsumer = runtime.JSONConsumer()
88 api.JSONProducer = runtime.JSONProducer()
89
90 api.ApplicationXPemFileProducer = runtime.TextProducer()
91
92
93 api.IndexSearchIndexHandler = index.SearchIndexHandlerFunc(pkgapi.SearchIndexNotImplementedHandler)
94 api.EntriesCreateLogEntryHandler = entries.CreateLogEntryHandlerFunc(pkgapi.CreateLogEntryNotImplementedHandler)
95 api.EntriesGetLogEntryByIndexHandler = entries.GetLogEntryByIndexHandlerFunc(pkgapi.GetLogEntryByIndexNotImplementedHandler)
96 api.EntriesGetLogEntryByUUIDHandler = entries.GetLogEntryByUUIDHandlerFunc(pkgapi.GetLogEntryByUUIDNotImplementedHandler)
97 api.EntriesSearchLogQueryHandler = entries.SearchLogQueryHandlerFunc(pkgapi.SearchLogQueryNotImplementedHandler)
98 api.PubkeyGetPublicKeyHandler = pubkey.GetPublicKeyHandlerFunc(pkgapi.GetPublicKeyNotImplementedHandler)
99 api.TlogGetLogProofHandler = tlog.GetLogProofHandlerFunc(pkgapi.GetLogProofNotImplementedHandler)
100
101 enabledAPIEndpoints := viper.GetStringSlice("enabled_api_endpoints")
102 if !slices.Contains(enabledAPIEndpoints, "searchIndex") && viper.GetBool("enable_retrieve_api") {
103 enabledAPIEndpoints = append(enabledAPIEndpoints, "searchIndex")
104 }
105
106 for _, enabledAPI := range enabledAPIEndpoints {
107 log.Logger.Infof("Enabling API endpoint: %s", enabledAPI)
108 switch enabledAPI {
109 case "searchIndex":
110 api.IndexSearchIndexHandler = index.SearchIndexHandlerFunc(pkgapi.SearchIndexHandler)
111 case "getLogInfo":
112 api.TlogGetLogInfoHandler = tlog.GetLogInfoHandlerFunc(pkgapi.GetLogInfoHandler)
113 case "getPublicKey":
114 api.PubkeyGetPublicKeyHandler = pubkey.GetPublicKeyHandlerFunc(pkgapi.GetPublicKeyHandler)
115 case "getLogProof":
116 api.TlogGetLogProofHandler = tlog.GetLogProofHandlerFunc(pkgapi.GetLogProofHandler)
117 case "createLogEntry":
118 api.EntriesCreateLogEntryHandler = entries.CreateLogEntryHandlerFunc(pkgapi.CreateLogEntryHandler)
119 case "getLogEntryByIndex":
120 api.EntriesGetLogEntryByIndexHandler = entries.GetLogEntryByIndexHandlerFunc(pkgapi.GetLogEntryByIndexHandler)
121 case "getLogEntryByUUID":
122 api.EntriesGetLogEntryByUUIDHandler = entries.GetLogEntryByUUIDHandlerFunc(pkgapi.GetLogEntryByUUIDHandler)
123 case "searchLogQuery":
124 api.EntriesSearchLogQueryHandler = entries.SearchLogQueryHandlerFunc(pkgapi.SearchLogQueryHandler)
125 default:
126 log.Logger.Panicf("Unknown API endpoint requested: %s", enabledAPI)
127 }
128 }
129
130
131 for _, enabledAPI := range enabledAPIEndpoints {
132 switch enabledAPI {
133 case "searchIndex":
134 recordMetricsForAPI(api, "POST", "/api/v1/index/retrieve")
135 case "getLogInfo":
136 api.AddMiddlewareFor("GET", "/api/v1/log", middleware.NoCache)
137 recordMetricsForAPI(api, "GET", "/api/v1/log")
138 case "getPublicKey":
139 api.AddMiddlewareFor("GET", "/api/v1/log/publicKey", middleware.NoCache)
140 recordMetricsForAPI(api, "GET", "/api/v1/log/publicKey")
141 case "getLogProof":
142 api.AddMiddlewareFor("GET", "/api/v1/log/proof", middleware.NoCache)
143 recordMetricsForAPI(api, "GET", "/api/v1/log/proof")
144 case "createLogEntry":
145 recordMetricsForAPI(api, "POST", "/api/v1/log/entries")
146 case "getLogEntryByIndex":
147 api.AddMiddlewareFor("GET", "/api/v1/log/entries", middleware.NoCache)
148 recordMetricsForAPI(api, "GET", "/api/v1/log/entries")
149 case "getLogEntryByUUID":
150 api.AddMiddlewareFor("GET", "/api/v1/log/entries/{entryUUID}", middleware.NoCache)
151 recordMetricsForAPI(api, "GET", "/api/v1/log/entries/{entryUUID}")
152 case "searchLogQuery":
153 recordMetricsForAPI(api, "POST", "/api/v1/log/entries/retrieve")
154 }
155 }
156 api.RegisterFormat("signedCheckpoint", &util.SignedNote{}, util.SignedCheckpointValidator)
157
158 api.PreServerShutdown = func() {}
159 api.ServerShutdown = func() {
160 pkgapi.StopAPI()
161 }
162
163 return setupGlobalMiddleware(api.Serve(setupMiddlewares))
164 }
165
166
167 func configureTLS(_ *tls.Config) {
168
169 }
170
171
172
173
174
175 func configureServer(_ *http.Server, _, _ string) {
176 }
177
178
179
180 func setupMiddlewares(handler http.Handler) http.Handler {
181 return handler
182 }
183
184 type httpRequestFields struct {
185 requestMethod string
186 requestURL string
187 requestSize int64
188 status int
189 responseSize int
190 userAgent string
191 remoteIp string
192 latency time.Duration
193 protocol string
194 }
195
196 func (h *httpRequestFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
197 enc.AddString("requestMethod", h.requestMethod)
198 enc.AddString("requestUrl", h.requestURL)
199 enc.AddString("requestSize", fmt.Sprintf("%d", h.requestSize))
200 enc.AddInt("status", h.status)
201 enc.AddString("responseSize", fmt.Sprintf("%d", h.responseSize))
202 enc.AddString("userAgent", h.userAgent)
203 enc.AddString("remoteIp", h.remoteIp)
204 enc.AddString("latency", fmt.Sprintf("%.9fs", h.latency.Seconds()))
205 enc.AddString("protocol", h.protocol)
206 return nil
207 }
208
209
210 type zapLogEntry struct {
211 r *http.Request
212 }
213
214 func (z *zapLogEntry) Write(status, bytes int, _ http.Header, elapsed time.Duration, extra interface{}) {
215 var fields []interface{}
216
217
218
219 scheme := "http"
220 if z.r.TLS != nil {
221 scheme = "https"
222 }
223 httpRequestObj := &httpRequestFields{
224 requestMethod: z.r.Method,
225 requestURL: fmt.Sprintf("%s://%s%s", scheme, z.r.Host, z.r.RequestURI),
226 requestSize: z.r.ContentLength,
227 status: status,
228 responseSize: bytes,
229 userAgent: z.r.Header.Get("User-Agent"),
230 remoteIp: z.r.RemoteAddr,
231 latency: elapsed,
232 protocol: z.r.Proto,
233 }
234 fields = append(fields, zap.Object("httpRequest", httpRequestObj))
235 if extra != nil {
236 fields = append(fields, zap.Any("extra", extra))
237 }
238
239 log.ContextLogger(z.r.Context()).With(fields...).Info("completed request")
240 }
241
242 func (z *zapLogEntry) Panic(v interface{}, stack []byte) {
243 fields := []interface{}{zap.String("message", fmt.Sprintf("%v\n%v", v, string(stack)))}
244 log.ContextLogger(z.r.Context()).With(fields...).Errorf("panic detected: %v", v)
245 }
246
247 type logFormatter struct{}
248
249 func (l *logFormatter) NewLogEntry(r *http.Request) middleware.LogEntry {
250 return &zapLogEntry{r}
251 }
252
253
254
255 func setupGlobalMiddleware(handler http.Handler) http.Handler {
256 returnHandler := recoverer(handler)
257 maxReqBodySize := viper.GetInt64("max_request_body_size")
258 if maxReqBodySize > 0 {
259 returnHandler = maxBodySize(maxReqBodySize, returnHandler)
260 }
261 middleware.DefaultLogger = middleware.RequestLogger(&logFormatter{})
262 returnHandler = middleware.Logger(returnHandler)
263 returnHandler = middleware.Heartbeat("/ping")(returnHandler)
264 returnHandler = serveStaticContent(returnHandler)
265
266 handleCORS := cors.Default().Handler
267 returnHandler = handleCORS(returnHandler)
268
269 returnHandler = wrapMetrics(returnHandler)
270
271 return middleware.RequestID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
272 ctx := r.Context()
273 r = r.WithContext(log.WithRequestID(ctx, middleware.GetReqID(ctx)))
274 defer func() {
275 _ = log.ContextLogger(ctx).Sync()
276 }()
277
278 returnHandler.ServeHTTP(w, r)
279 }))
280 }
281
282
283 func recordMetricsForAPI(api *operations.RekorServerAPI, method string, path string) {
284 metricsHandler := func(handler http.Handler) http.Handler {
285 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
286 ctx := r.Context()
287 if apiInfo, ok := ctx.Value(ctxKeyAPIToRecord).(*apiToRecord); ok {
288 apiInfo.method = &method
289 apiInfo.path = &path
290 } else {
291 log.ContextLogger(ctx).Warn("Could not attach api info - endpoint may not be monitored.")
292 }
293 handler.ServeHTTP(w, r)
294 })
295 }
296
297 api.AddMiddlewareFor(method, path, metricsHandler)
298 }
299
300 func wrapMetrics(handler http.Handler) http.Handler {
301 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
302 ctx := r.Context()
303 apiInfo := apiToRecord{}
304 ctx = context.WithValue(ctx, ctxKeyAPIToRecord, &apiInfo)
305 r = r.WithContext(ctx)
306
307 start := time.Now()
308 ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
309 defer func() {
310
311 if apiInfo.path != nil && apiInfo.method != nil {
312 code := strconv.Itoa(ww.Status())
313 labels := map[string]string{
314 "path": *apiInfo.path,
315 "code": code,
316 }
317
318
319 pkgapi.MetricLatency.With(labels).Observe(float64(time.Since(start)))
320 pkgapi.MetricLatencySummary.With(labels).Observe(float64(time.Since(start)))
321
322 pkgapi.MetricRequestLatency.With(
323 map[string]string{
324 "path": *apiInfo.path,
325 "method": *apiInfo.method,
326 }).Observe(float64(time.Since(start)))
327
328 pkgapi.MetricRequestCount.With(
329 map[string]string{
330 "path": *apiInfo.path,
331 "method": *apiInfo.method,
332 "code": code,
333 }).Inc()
334 }
335 }()
336
337 handler.ServeHTTP(ww, r)
338
339 })
340 }
341
342 func logAndServeError(w http.ResponseWriter, r *http.Request, err error) {
343 ctx := r.Context()
344 if apiErr, ok := err.(errors.Error); ok && apiErr.Code() == http.StatusNotFound {
345 log.ContextLogger(ctx).Warn(err)
346 } else {
347 log.ContextLogger(ctx).Error(err)
348 }
349 if compErr, ok := err.(*errors.CompositeError); ok {
350
351 for _, embeddedErr := range compErr.Errors {
352 var maxBytesError *http.MaxBytesError
353 if parseErr, ok := embeddedErr.(*errors.ParseError); ok && go_errors.As(parseErr.Reason, &maxBytesError) {
354 err = errors.New(http.StatusRequestEntityTooLarge, http.StatusText(http.StatusRequestEntityTooLarge))
355 break
356 }
357 }
358 }
359 requestFields := map[string]interface{}{}
360 if decodeErr := mapstructure.Decode(r, &requestFields); decodeErr == nil {
361 log.ContextLogger(ctx).Debug(requestFields)
362 }
363 errors.ServeError(w, r, err)
364 }
365
366
367 var homePageBytes []byte
368
369 func serveStaticContent(handler http.Handler) http.Handler {
370 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
371 if r.URL.Path == "/" {
372 w.Header().Add("Content-Type", "text/html")
373 w.WriteHeader(200)
374 _, _ = w.Write(homePageBytes)
375 return
376 }
377 handler.ServeHTTP(w, r)
378 })
379 }
380
381
382 func recoverer(next http.Handler) http.Handler {
383 fn := func(w http.ResponseWriter, r *http.Request) {
384 defer func() {
385 if rvr := recover(); rvr != nil && rvr != http.ErrAbortHandler {
386 var fields []interface{}
387
388
389 ctx := r.Context()
390 request, err := httputil.DumpRequest(r, false)
391 if err == nil {
392 fields = append(fields, zap.ByteString("request_headers", request))
393 }
394
395 log.ContextLogger(ctx).With(fields...).Errorf("panic detected: %v", rvr)
396
397 errors.ServeError(w, r, nil)
398 }
399 }()
400
401 next.ServeHTTP(w, r)
402 }
403
404 return http.HandlerFunc(fn)
405 }
406
407
408 func maxBodySize(maxLength int64, next http.Handler) http.Handler {
409 fn := func(w http.ResponseWriter, r *http.Request) {
410 r.Body = http.MaxBytesReader(w, r.Body, maxLength)
411 next.ServeHTTP(w, r)
412 }
413
414 return http.HandlerFunc(fn)
415 }
416
View as plain text