1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package middleware
16
17 import (
18 stdContext "context"
19 "fmt"
20 "net/http"
21 "net/url"
22 "path"
23 "strings"
24 "sync"
25
26 "github.com/go-openapi/analysis"
27 "github.com/go-openapi/errors"
28 "github.com/go-openapi/loads"
29 "github.com/go-openapi/spec"
30 "github.com/go-openapi/strfmt"
31
32 "github.com/go-openapi/runtime"
33 "github.com/go-openapi/runtime/logger"
34 "github.com/go-openapi/runtime/middleware/untyped"
35 "github.com/go-openapi/runtime/security"
36 )
37
38
39 var Debug = logger.DebugEnabled()
40
41
42 var Logger logger.Logger = logger.StandardLogger{}
43
44 func debugLogfFunc(lg logger.Logger) func(string, ...any) {
45 if logger.DebugEnabled() {
46 if lg == nil {
47 return Logger.Debugf
48 }
49
50 return lg.Debugf
51 }
52
53
54 return func(_ string, _ ...any) {}
55 }
56
57
58 type Builder func(http.Handler) http.Handler
59
60
61 func PassthroughBuilder(handler http.Handler) http.Handler { return handler }
62
63
64
65 type RequestBinder interface {
66 BindRequest(*http.Request, *MatchedRoute) error
67 }
68
69
70
71 type Responder interface {
72 WriteResponse(http.ResponseWriter, runtime.Producer)
73 }
74
75
76 type ResponderFunc func(http.ResponseWriter, runtime.Producer)
77
78
79 func (fn ResponderFunc) WriteResponse(rw http.ResponseWriter, pr runtime.Producer) {
80 fn(rw, pr)
81 }
82
83
84
85
86 type Context struct {
87 spec *loads.Document
88 analyzer *analysis.Spec
89 api RoutableAPI
90 router Router
91 debugLogf func(string, ...any)
92 }
93
94 type routableUntypedAPI struct {
95 api *untyped.API
96 hlock *sync.Mutex
97 handlers map[string]map[string]http.Handler
98 defaultConsumes string
99 defaultProduces string
100 }
101
102 func newRoutableUntypedAPI(spec *loads.Document, api *untyped.API, context *Context) *routableUntypedAPI {
103 var handlers map[string]map[string]http.Handler
104 if spec == nil || api == nil {
105 return nil
106 }
107 analyzer := analysis.New(spec.Spec())
108 for method, hls := range analyzer.Operations() {
109 um := strings.ToUpper(method)
110 for path, op := range hls {
111 schemes := analyzer.SecurityRequirementsFor(op)
112
113 if oh, ok := api.OperationHandlerFor(method, path); ok {
114 if handlers == nil {
115 handlers = make(map[string]map[string]http.Handler)
116 }
117 if b, ok := handlers[um]; !ok || b == nil {
118 handlers[um] = make(map[string]http.Handler)
119 }
120
121 var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
122
123 route, rCtx, _ := context.RouteInfo(r)
124 if rCtx != nil {
125 r = rCtx
126 }
127
128
129 var bound interface{}
130 var validation error
131 bound, r, validation = context.BindAndValidate(r, route)
132 if validation != nil {
133 context.Respond(w, r, route.Produces, route, validation)
134 return
135 }
136
137
138 result, err := oh.Handle(bound)
139 if err != nil {
140
141 context.Respond(w, r, route.Produces, route, err)
142 return
143 }
144
145
146 context.Respond(w, r, route.Produces, route, result)
147 })
148
149 if len(schemes) > 0 {
150 handler = newSecureAPI(context, handler)
151 }
152 handlers[um][path] = handler
153 }
154 }
155 }
156
157 return &routableUntypedAPI{
158 api: api,
159 hlock: new(sync.Mutex),
160 handlers: handlers,
161 defaultProduces: api.DefaultProduces,
162 defaultConsumes: api.DefaultConsumes,
163 }
164 }
165
166 func (r *routableUntypedAPI) HandlerFor(method, path string) (http.Handler, bool) {
167 r.hlock.Lock()
168 paths, ok := r.handlers[strings.ToUpper(method)]
169 if !ok {
170 r.hlock.Unlock()
171 return nil, false
172 }
173 handler, ok := paths[path]
174 r.hlock.Unlock()
175 return handler, ok
176 }
177 func (r *routableUntypedAPI) ServeErrorFor(_ string) func(http.ResponseWriter, *http.Request, error) {
178 return r.api.ServeError
179 }
180 func (r *routableUntypedAPI) ConsumersFor(mediaTypes []string) map[string]runtime.Consumer {
181 return r.api.ConsumersFor(mediaTypes)
182 }
183 func (r *routableUntypedAPI) ProducersFor(mediaTypes []string) map[string]runtime.Producer {
184 return r.api.ProducersFor(mediaTypes)
185 }
186 func (r *routableUntypedAPI) AuthenticatorsFor(schemes map[string]spec.SecurityScheme) map[string]runtime.Authenticator {
187 return r.api.AuthenticatorsFor(schemes)
188 }
189 func (r *routableUntypedAPI) Authorizer() runtime.Authorizer {
190 return r.api.Authorizer()
191 }
192 func (r *routableUntypedAPI) Formats() strfmt.Registry {
193 return r.api.Formats()
194 }
195
196 func (r *routableUntypedAPI) DefaultProduces() string {
197 return r.defaultProduces
198 }
199
200 func (r *routableUntypedAPI) DefaultConsumes() string {
201 return r.defaultConsumes
202 }
203
204
205
206
207 func NewRoutableContext(spec *loads.Document, routableAPI RoutableAPI, routes Router) *Context {
208 var an *analysis.Spec
209 if spec != nil {
210 an = analysis.New(spec.Spec())
211 }
212
213 return NewRoutableContextWithAnalyzedSpec(spec, an, routableAPI, routes)
214 }
215
216
217
218
219 func NewRoutableContextWithAnalyzedSpec(spec *loads.Document, an *analysis.Spec, routableAPI RoutableAPI, routes Router) *Context {
220
221 if !((spec == nil && an == nil) || (spec != nil && an != nil)) {
222 panic(errors.New(http.StatusInternalServerError, "routable context requires either both spec doc and analysis, or none of them"))
223 }
224
225 return &Context{
226 spec: spec,
227 api: routableAPI,
228 analyzer: an,
229 router: routes,
230 debugLogf: debugLogfFunc(nil),
231 }
232 }
233
234
235
236
237 func NewContext(spec *loads.Document, api *untyped.API, routes Router) *Context {
238 var an *analysis.Spec
239 if spec != nil {
240 an = analysis.New(spec.Spec())
241 }
242 ctx := &Context{
243 spec: spec,
244 analyzer: an,
245 router: routes,
246 debugLogf: debugLogfFunc(nil),
247 }
248 ctx.api = newRoutableUntypedAPI(spec, api, ctx)
249
250 return ctx
251 }
252
253
254 func Serve(spec *loads.Document, api *untyped.API) http.Handler {
255 return ServeWithBuilder(spec, api, PassthroughBuilder)
256 }
257
258
259
260 func ServeWithBuilder(spec *loads.Document, api *untyped.API, builder Builder) http.Handler {
261 context := NewContext(spec, api, nil)
262 return context.APIHandler(builder)
263 }
264
265 type contextKey int8
266
267 const (
268 _ contextKey = iota
269 ctxContentType
270 ctxResponseFormat
271 ctxMatchedRoute
272 ctxBoundParams
273 ctxSecurityPrincipal
274 ctxSecurityScopes
275 )
276
277
278 func MatchedRouteFrom(req *http.Request) *MatchedRoute {
279 mr := req.Context().Value(ctxMatchedRoute)
280 if mr == nil {
281 return nil
282 }
283 if res, ok := mr.(*MatchedRoute); ok {
284 return res
285 }
286 return nil
287 }
288
289
290 func SecurityPrincipalFrom(req *http.Request) interface{} {
291 return req.Context().Value(ctxSecurityPrincipal)
292 }
293
294
295 func SecurityScopesFrom(req *http.Request) []string {
296 rs := req.Context().Value(ctxSecurityScopes)
297 if res, ok := rs.([]string); ok {
298 return res
299 }
300 return nil
301 }
302
303 type contentTypeValue struct {
304 MediaType string
305 Charset string
306 }
307
308
309 func (c *Context) BasePath() string {
310 return c.spec.BasePath()
311 }
312
313
314
315
316 func (c *Context) SetLogger(lg logger.Logger) {
317 c.debugLogf = debugLogfFunc(lg)
318 }
319
320
321 func (c *Context) RequiredProduces() []string {
322 return c.analyzer.RequiredProduces()
323 }
324
325
326
327 func (c *Context) BindValidRequest(request *http.Request, route *MatchedRoute, binder RequestBinder) error {
328 var res []error
329 var requestContentType string
330
331
332 if runtime.HasBody(request) {
333 ct, _, err := runtime.ContentType(request.Header)
334 if err != nil {
335 res = append(res, err)
336 } else {
337 c.debugLogf("validating content type for %q against [%s]", ct, strings.Join(route.Consumes, ", "))
338 if err := validateContentType(route.Consumes, ct); err != nil {
339 res = append(res, err)
340 }
341 if len(res) == 0 {
342 cons, ok := route.Consumers[ct]
343 if !ok {
344 res = append(res, errors.New(500, "no consumer registered for %s", ct))
345 } else {
346 route.Consumer = cons
347 requestContentType = ct
348 }
349 }
350 }
351 }
352
353
354 if len(res) == 0 {
355
356
357 if len(route.Produces) == 0 && requestContentType == "" {
358 requestContentType = "*/*"
359 }
360
361 if str := NegotiateContentType(request, route.Produces, requestContentType); str == "" {
362 res = append(res, errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces))
363 }
364 }
365
366
367
368
369 if binder != nil && len(res) == 0 {
370 if err := binder.BindRequest(request, route); err != nil {
371 return err
372 }
373 }
374
375 if len(res) > 0 {
376 return errors.CompositeValidationError(res...)
377 }
378 return nil
379 }
380
381
382
383
384
385
386 func (c *Context) ContentType(request *http.Request) (string, string, *http.Request, error) {
387 var rCtx = request.Context()
388
389 if v, ok := rCtx.Value(ctxContentType).(*contentTypeValue); ok {
390 return v.MediaType, v.Charset, request, nil
391 }
392
393 mt, cs, err := runtime.ContentType(request.Header)
394 if err != nil {
395 return "", "", nil, err
396 }
397 rCtx = stdContext.WithValue(rCtx, ctxContentType, &contentTypeValue{mt, cs})
398 return mt, cs, request.WithContext(rCtx), nil
399 }
400
401
402 func (c *Context) LookupRoute(request *http.Request) (*MatchedRoute, bool) {
403 if route, ok := c.router.Lookup(request.Method, request.URL.EscapedPath()); ok {
404 return route, ok
405 }
406 return nil, false
407 }
408
409
410
411
412
413
414 func (c *Context) RouteInfo(request *http.Request) (*MatchedRoute, *http.Request, bool) {
415 var rCtx = request.Context()
416
417 if v, ok := rCtx.Value(ctxMatchedRoute).(*MatchedRoute); ok {
418 return v, request, ok
419 }
420
421 if route, ok := c.LookupRoute(request); ok {
422 rCtx = stdContext.WithValue(rCtx, ctxMatchedRoute, route)
423 return route, request.WithContext(rCtx), ok
424 }
425
426 return nil, nil, false
427 }
428
429
430
431
432 func (c *Context) ResponseFormat(r *http.Request, offers []string) (string, *http.Request) {
433 var rCtx = r.Context()
434
435 if v, ok := rCtx.Value(ctxResponseFormat).(string); ok {
436 c.debugLogf("[%s %s] found response format %q in context", r.Method, r.URL.Path, v)
437 return v, r
438 }
439
440 format := NegotiateContentType(r, offers, "")
441 if format != "" {
442 c.debugLogf("[%s %s] set response format %q in context", r.Method, r.URL.Path, format)
443 r = r.WithContext(stdContext.WithValue(rCtx, ctxResponseFormat, format))
444 }
445 c.debugLogf("[%s %s] negotiated response format %q", r.Method, r.URL.Path, format)
446 return format, r
447 }
448
449
450 func (c *Context) AllowedMethods(request *http.Request) []string {
451 return c.router.OtherMethods(request.Method, request.URL.EscapedPath())
452 }
453
454
455 func (c *Context) ResetAuth(request *http.Request) *http.Request {
456 rctx := request.Context()
457 rctx = stdContext.WithValue(rctx, ctxSecurityPrincipal, nil)
458 rctx = stdContext.WithValue(rctx, ctxSecurityScopes, nil)
459 return request.WithContext(rctx)
460 }
461
462
463
464
465
466 func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (interface{}, *http.Request, error) {
467 if route == nil || !route.HasAuth() {
468 return nil, nil, nil
469 }
470
471 var rCtx = request.Context()
472 if v := rCtx.Value(ctxSecurityPrincipal); v != nil {
473 return v, request, nil
474 }
475
476 applies, usr, err := route.Authenticators.Authenticate(request, route)
477 if !applies || err != nil || !route.Authenticators.AllowsAnonymous() && usr == nil {
478 if err != nil {
479 return nil, nil, err
480 }
481 return nil, nil, errors.Unauthenticated("invalid credentials")
482 }
483 if route.Authorizer != nil {
484 if err := route.Authorizer.Authorize(request, usr); err != nil {
485 if _, ok := err.(errors.Error); ok {
486 return nil, nil, err
487 }
488
489 return nil, nil, errors.New(http.StatusForbidden, err.Error())
490 }
491 }
492
493 rCtx = request.Context()
494
495 rCtx = stdContext.WithValue(rCtx, ctxSecurityPrincipal, usr)
496 rCtx = stdContext.WithValue(rCtx, ctxSecurityScopes, route.Authenticator.AllScopes())
497 return usr, request.WithContext(rCtx), nil
498 }
499
500
501
502
503
504 func (c *Context) BindAndValidate(request *http.Request, matched *MatchedRoute) (interface{}, *http.Request, error) {
505 var rCtx = request.Context()
506
507 if v, ok := rCtx.Value(ctxBoundParams).(*validation); ok {
508 c.debugLogf("got cached validation (valid: %t)", len(v.result) == 0)
509 if len(v.result) > 0 {
510 return v.bound, request, errors.CompositeValidationError(v.result...)
511 }
512 return v.bound, request, nil
513 }
514 result := validateRequest(c, request, matched)
515 rCtx = stdContext.WithValue(rCtx, ctxBoundParams, result)
516 request = request.WithContext(rCtx)
517 if len(result.result) > 0 {
518 return result.bound, request, errors.CompositeValidationError(result.result...)
519 }
520 c.debugLogf("no validation errors found")
521 return result.bound, request, nil
522 }
523
524
525 func (c *Context) NotFound(rw http.ResponseWriter, r *http.Request) {
526 c.Respond(rw, r, []string{c.api.DefaultProduces()}, nil, errors.NotFound("not found"))
527 }
528
529
530 func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, data interface{}) {
531 c.debugLogf("responding to %s %s with produces: %v", r.Method, r.URL.Path, produces)
532 offers := []string{}
533 for _, mt := range produces {
534 if mt != c.api.DefaultProduces() {
535 offers = append(offers, mt)
536 }
537 }
538
539 offers = append(offers, c.api.DefaultProduces())
540 c.debugLogf("offers: %v", offers)
541
542 var format string
543 format, r = c.ResponseFormat(r, offers)
544 rw.Header().Set(runtime.HeaderContentType, format)
545
546 if resp, ok := data.(Responder); ok {
547 producers := route.Producers
548
549
550 prod, ok := producers[normalizeOffer(format)]
551 if !ok {
552 prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
553 pr, ok := prods[c.api.DefaultProduces()]
554 if !ok {
555 panic(errors.New(http.StatusInternalServerError, cantFindProducer(format)))
556 }
557 prod = pr
558 }
559 resp.WriteResponse(rw, prod)
560 return
561 }
562
563 if err, ok := data.(error); ok {
564 if format == "" {
565 rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime)
566 }
567
568 if realm := security.FailedBasicAuth(r); realm != "" {
569 rw.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", realm))
570 }
571
572 if route == nil || route.Operation == nil {
573 c.api.ServeErrorFor("")(rw, r, err)
574 return
575 }
576 c.api.ServeErrorFor(route.Operation.ID)(rw, r, err)
577 return
578 }
579
580 if route == nil || route.Operation == nil {
581 rw.WriteHeader(http.StatusOK)
582 if r.Method == http.MethodHead {
583 return
584 }
585 producers := c.api.ProducersFor(normalizeOffers(offers))
586 prod, ok := producers[format]
587 if !ok {
588 panic(errors.New(http.StatusInternalServerError, cantFindProducer(format)))
589 }
590 if err := prod.Produce(rw, data); err != nil {
591 panic(err)
592 }
593 return
594 }
595
596 if _, code, ok := route.Operation.SuccessResponse(); ok {
597 rw.WriteHeader(code)
598 if code == http.StatusNoContent || r.Method == http.MethodHead {
599 return
600 }
601
602 producers := route.Producers
603 prod, ok := producers[format]
604 if !ok {
605 if !ok {
606 prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
607 pr, ok := prods[c.api.DefaultProduces()]
608 if !ok {
609 panic(errors.New(http.StatusInternalServerError, cantFindProducer(format)))
610 }
611 prod = pr
612 }
613 }
614 if err := prod.Produce(rw, data); err != nil {
615 panic(err)
616 }
617 return
618 }
619
620 c.api.ServeErrorFor(route.Operation.ID)(rw, r, errors.New(http.StatusInternalServerError, "can't produce response"))
621 }
622
623
624
625
626
627
628
629 func (c *Context) APIHandlerSwaggerUI(builder Builder, opts ...UIOption) http.Handler {
630 b := builder
631 if b == nil {
632 b = PassthroughBuilder
633 }
634
635 specPath, uiOpts, specOpts := c.uiOptionsForHandler(opts)
636 var swaggerUIOpts SwaggerUIOpts
637 fromCommonToAnyOptions(uiOpts, &swaggerUIOpts)
638
639 return Spec(specPath, c.spec.Raw(), SwaggerUI(swaggerUIOpts, c.RoutesHandler(b)), specOpts...)
640 }
641
642
643
644
645
646
647
648 func (c *Context) APIHandlerRapiDoc(builder Builder, opts ...UIOption) http.Handler {
649 b := builder
650 if b == nil {
651 b = PassthroughBuilder
652 }
653
654 specPath, uiOpts, specOpts := c.uiOptionsForHandler(opts)
655 var rapidocUIOpts RapiDocOpts
656 fromCommonToAnyOptions(uiOpts, &rapidocUIOpts)
657
658 return Spec(specPath, c.spec.Raw(), RapiDoc(rapidocUIOpts, c.RoutesHandler(b)), specOpts...)
659 }
660
661
662
663
664
665
666
667 func (c *Context) APIHandler(builder Builder, opts ...UIOption) http.Handler {
668 b := builder
669 if b == nil {
670 b = PassthroughBuilder
671 }
672
673 specPath, uiOpts, specOpts := c.uiOptionsForHandler(opts)
674 var redocOpts RedocOpts
675 fromCommonToAnyOptions(uiOpts, &redocOpts)
676
677 return Spec(specPath, c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b)), specOpts...)
678 }
679
680 func (c Context) uiOptionsForHandler(opts []UIOption) (string, uiOptions, []SpecOption) {
681 var title string
682 sp := c.spec.Spec()
683 if sp != nil && sp.Info != nil && sp.Info.Title != "" {
684 title = sp.Info.Title
685 }
686
687
688 optsForContext := []UIOption{
689 WithUIBasePath(c.BasePath()),
690 WithUITitle(title),
691 }
692 optsForContext = append(optsForContext, opts...)
693 uiOpts := uiOptionsWithDefaults(optsForContext)
694
695
696
697 u, _ := url.Parse(uiOpts.SpecURL)
698 var specPath string
699 if u != nil {
700 specPath = u.Path
701 }
702
703 pth, doc := path.Split(specPath)
704 if pth == "." {
705 pth = ""
706 }
707
708 return pth, uiOpts, []SpecOption{WithSpecDocument(doc)}
709 }
710
711
712 func (c *Context) RoutesHandler(builder Builder) http.Handler {
713 b := builder
714 if b == nil {
715 b = PassthroughBuilder
716 }
717 return NewRouter(c, b(NewOperationExecutor(c)))
718 }
719
720 func cantFindProducer(format string) string {
721 return "can't find a producer for " + format
722 }
723
View as plain text