...

Source file src/github.com/emicklei/go-restful/v3/container.go

Documentation: github.com/emicklei/go-restful/v3

     1  package restful
     2  
     3  // Copyright 2013 Ernest Micklei. All rights reserved.
     4  // Use of this source code is governed by a license
     5  // that can be found in the LICENSE file.
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"net/http"
    12  	"os"
    13  	"runtime"
    14  	"strings"
    15  	"sync"
    16  
    17  	"github.com/emicklei/go-restful/v3/log"
    18  )
    19  
    20  // Container holds a collection of WebServices and a http.ServeMux to dispatch http requests.
    21  // The requests are further dispatched to routes of WebServices using a RouteSelector
    22  type Container struct {
    23  	webServicesLock        sync.RWMutex
    24  	webServices            []*WebService
    25  	ServeMux               *http.ServeMux
    26  	isRegisteredOnRoot     bool
    27  	containerFilters       []FilterFunction
    28  	doNotRecover           bool // default is true
    29  	recoverHandleFunc      RecoverHandleFunction
    30  	serviceErrorHandleFunc ServiceErrorHandleFunction
    31  	router                 RouteSelector // default is a CurlyRouter (RouterJSR311 is a slower alternative)
    32  	contentEncodingEnabled bool          // default is false
    33  }
    34  
    35  // NewContainer creates a new Container using a new ServeMux and default router (CurlyRouter)
    36  func NewContainer() *Container {
    37  	return &Container{
    38  		webServices:            []*WebService{},
    39  		ServeMux:               http.NewServeMux(),
    40  		isRegisteredOnRoot:     false,
    41  		containerFilters:       []FilterFunction{},
    42  		doNotRecover:           true,
    43  		recoverHandleFunc:      logStackOnRecover,
    44  		serviceErrorHandleFunc: writeServiceError,
    45  		router:                 CurlyRouter{},
    46  		contentEncodingEnabled: false}
    47  }
    48  
    49  // RecoverHandleFunction declares functions that can be used to handle a panic situation.
    50  // The first argument is what recover() returns. The second must be used to communicate an error response.
    51  type RecoverHandleFunction func(interface{}, http.ResponseWriter)
    52  
    53  // RecoverHandler changes the default function (logStackOnRecover) to be called
    54  // when a panic is detected. DoNotRecover must be have its default value (=false).
    55  func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
    56  	c.recoverHandleFunc = handler
    57  }
    58  
    59  // ServiceErrorHandleFunction declares functions that can be used to handle a service error situation.
    60  // The first argument is the service error, the second is the request that resulted in the error and
    61  // the third must be used to communicate an error response.
    62  type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)
    63  
    64  // ServiceErrorHandler changes the default function (writeServiceError) to be called
    65  // when a ServiceError is detected.
    66  func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
    67  	c.serviceErrorHandleFunc = handler
    68  }
    69  
    70  // DoNotRecover controls whether panics will be caught to return HTTP 500.
    71  // If set to true, Route functions are responsible for handling any error situation.
    72  // Default value is true.
    73  func (c *Container) DoNotRecover(doNot bool) {
    74  	c.doNotRecover = doNot
    75  }
    76  
    77  // Router changes the default Router (currently CurlyRouter)
    78  func (c *Container) Router(aRouter RouteSelector) {
    79  	c.router = aRouter
    80  }
    81  
    82  // EnableContentEncoding (default=false) allows for GZIP or DEFLATE encoding of responses.
    83  func (c *Container) EnableContentEncoding(enabled bool) {
    84  	c.contentEncodingEnabled = enabled
    85  }
    86  
    87  // Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
    88  func (c *Container) Add(service *WebService) *Container {
    89  	c.webServicesLock.Lock()
    90  	defer c.webServicesLock.Unlock()
    91  
    92  	// if rootPath was not set then lazy initialize it
    93  	if len(service.rootPath) == 0 {
    94  		service.Path("/")
    95  	}
    96  
    97  	// cannot have duplicate root paths
    98  	for _, each := range c.webServices {
    99  		if each.RootPath() == service.RootPath() {
   100  			log.Printf("WebService with duplicate root path detected:['%v']", each)
   101  			os.Exit(1)
   102  		}
   103  	}
   104  
   105  	// If not registered on root then add specific mapping
   106  	if !c.isRegisteredOnRoot {
   107  		c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
   108  	}
   109  	c.webServices = append(c.webServices, service)
   110  	return c
   111  }
   112  
   113  // addHandler may set a new HandleFunc for the serveMux
   114  // this function must run inside the critical region protected by the webServicesLock.
   115  // returns true if the function was registered on root ("/")
   116  func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
   117  	pattern := fixedPrefixPath(service.RootPath())
   118  	// check if root path registration is needed
   119  	if "/" == pattern || "" == pattern {
   120  		serveMux.HandleFunc("/", c.dispatch)
   121  		return true
   122  	}
   123  	// detect if registration already exists
   124  	alreadyMapped := false
   125  	for _, each := range c.webServices {
   126  		if each.RootPath() == service.RootPath() {
   127  			alreadyMapped = true
   128  			break
   129  		}
   130  	}
   131  	if !alreadyMapped {
   132  		serveMux.HandleFunc(pattern, c.dispatch)
   133  		if !strings.HasSuffix(pattern, "/") {
   134  			serveMux.HandleFunc(pattern+"/", c.dispatch)
   135  		}
   136  	}
   137  	return false
   138  }
   139  
   140  func (c *Container) Remove(ws *WebService) error {
   141  	if c.ServeMux == http.DefaultServeMux {
   142  		errMsg := fmt.Sprintf("cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
   143  		log.Print(errMsg)
   144  		return errors.New(errMsg)
   145  	}
   146  	c.webServicesLock.Lock()
   147  	defer c.webServicesLock.Unlock()
   148  	// build a new ServeMux and re-register all WebServices
   149  	newServeMux := http.NewServeMux()
   150  	newServices := []*WebService{}
   151  	newIsRegisteredOnRoot := false
   152  	for _, each := range c.webServices {
   153  		if each.rootPath != ws.rootPath {
   154  			// If not registered on root then add specific mapping
   155  			if !newIsRegisteredOnRoot {
   156  				newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
   157  			}
   158  			newServices = append(newServices, each)
   159  		}
   160  	}
   161  	c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
   162  	return nil
   163  }
   164  
   165  // logStackOnRecover is the default RecoverHandleFunction and is called
   166  // when DoNotRecover is false and the recoverHandleFunc is not set for the container.
   167  // Default implementation logs the stacktrace and writes the stacktrace on the response.
   168  // This may be a security issue as it exposes sourcecode information.
   169  func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
   170  	var buffer bytes.Buffer
   171  	buffer.WriteString(fmt.Sprintf("recover from panic situation: - %v\r\n", panicReason))
   172  	for i := 2; ; i += 1 {
   173  		_, file, line, ok := runtime.Caller(i)
   174  		if !ok {
   175  			break
   176  		}
   177  		buffer.WriteString(fmt.Sprintf("    %s:%d\r\n", file, line))
   178  	}
   179  	log.Print(buffer.String())
   180  	httpWriter.WriteHeader(http.StatusInternalServerError)
   181  	httpWriter.Write(buffer.Bytes())
   182  }
   183  
   184  // writeServiceError is the default ServiceErrorHandleFunction and is called
   185  // when a ServiceError is returned during route selection. Default implementation
   186  // calls resp.WriteErrorString(err.Code, err.Message)
   187  func writeServiceError(err ServiceError, req *Request, resp *Response) {
   188  	for header, values := range err.Header {
   189  		for _, value := range values {
   190  			resp.Header().Add(header, value)
   191  		}
   192  	}
   193  	resp.WriteErrorString(err.Code, err.Message)
   194  }
   195  
   196  // Dispatch the incoming Http Request to a matching WebService.
   197  func (c *Container) Dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
   198  	if httpWriter == nil {
   199  		panic("httpWriter cannot be nil")
   200  	}
   201  	if httpRequest == nil {
   202  		panic("httpRequest cannot be nil")
   203  	}
   204  	c.dispatch(httpWriter, httpRequest)
   205  }
   206  
   207  // Dispatch the incoming Http Request to a matching WebService.
   208  func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
   209  	// so we can assign a compressing one later
   210  	writer := httpWriter
   211  
   212  	// CompressingResponseWriter should be closed after all operations are done
   213  	defer func() {
   214  		if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
   215  			compressWriter.Close()
   216  		}
   217  	}()
   218  
   219  	// Instal panic recovery unless told otherwise
   220  	if !c.doNotRecover { // catch all for 500 response
   221  		defer func() {
   222  			if r := recover(); r != nil {
   223  				c.recoverHandleFunc(r, writer)
   224  				return
   225  			}
   226  		}()
   227  	}
   228  
   229  	// Find best match Route ; err is non nil if no match was found
   230  	var webService *WebService
   231  	var route *Route
   232  	var err error
   233  	func() {
   234  		c.webServicesLock.RLock()
   235  		defer c.webServicesLock.RUnlock()
   236  		webService, route, err = c.router.SelectRoute(
   237  			c.webServices,
   238  			httpRequest)
   239  	}()
   240  	if err != nil {
   241  		// a non-200 response (may be compressed) has already been written
   242  		// run container filters anyway ; they should not touch the response...
   243  		chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
   244  			switch err.(type) {
   245  			case ServiceError:
   246  				ser := err.(ServiceError)
   247  				c.serviceErrorHandleFunc(ser, req, resp)
   248  			}
   249  			// TODO
   250  		}}
   251  		chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
   252  		return
   253  	}
   254  
   255  	// Unless httpWriter is already an CompressingResponseWriter see if we need to install one
   256  	if _, isCompressing := httpWriter.(*CompressingResponseWriter); !isCompressing {
   257  		// Detect if compression is needed
   258  		// assume without compression, test for override
   259  		contentEncodingEnabled := c.contentEncodingEnabled
   260  		if route != nil && route.contentEncodingEnabled != nil {
   261  			contentEncodingEnabled = *route.contentEncodingEnabled
   262  		}
   263  		if contentEncodingEnabled {
   264  			doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
   265  			if doCompress {
   266  				var err error
   267  				writer, err = NewCompressingResponseWriter(httpWriter, encoding)
   268  				if err != nil {
   269  					log.Print("unable to install compressor: ", err)
   270  					httpWriter.WriteHeader(http.StatusInternalServerError)
   271  					return
   272  				}
   273  			}
   274  		}
   275  	}
   276  
   277  	pathProcessor, routerProcessesPath := c.router.(PathProcessor)
   278  	if !routerProcessesPath {
   279  		pathProcessor = defaultPathProcessor{}
   280  	}
   281  	pathParams := pathProcessor.ExtractParameters(route, webService, httpRequest.URL.Path)
   282  	wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest, pathParams)
   283  	// pass through filters (if any)
   284  	if size := len(c.containerFilters) + len(webService.filters) + len(route.Filters); size > 0 {
   285  		// compose filter chain
   286  		allFilters := make([]FilterFunction, 0, size)
   287  		allFilters = append(allFilters, c.containerFilters...)
   288  		allFilters = append(allFilters, webService.filters...)
   289  		allFilters = append(allFilters, route.Filters...)
   290  		chain := FilterChain{
   291  			Filters:       allFilters,
   292  			Target:        route.Function,
   293  			ParameterDocs: route.ParameterDocs,
   294  			Operation:     route.Operation,
   295  		}
   296  		chain.ProcessFilter(wrappedRequest, wrappedResponse)
   297  	} else {
   298  		// no filters, handle request by route
   299  		route.Function(wrappedRequest, wrappedResponse)
   300  	}
   301  }
   302  
   303  // fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
   304  func fixedPrefixPath(pathspec string) string {
   305  	varBegin := strings.Index(pathspec, "{")
   306  	if -1 == varBegin {
   307  		return pathspec
   308  	}
   309  	return pathspec[:varBegin]
   310  }
   311  
   312  // ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
   313  func (c *Container) ServeHTTP(httpWriter http.ResponseWriter, httpRequest *http.Request) {
   314  	// Skip, if content encoding is disabled
   315  	if !c.contentEncodingEnabled {
   316  		c.ServeMux.ServeHTTP(httpWriter, httpRequest)
   317  		return
   318  	}
   319  	// content encoding is enabled
   320  
   321  	// Skip, if httpWriter is already an CompressingResponseWriter
   322  	if _, ok := httpWriter.(*CompressingResponseWriter); ok {
   323  		c.ServeMux.ServeHTTP(httpWriter, httpRequest)
   324  		return
   325  	}
   326  
   327  	writer := httpWriter
   328  	// CompressingResponseWriter should be closed after all operations are done
   329  	defer func() {
   330  		if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
   331  			compressWriter.Close()
   332  		}
   333  	}()
   334  
   335  	doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
   336  	if doCompress {
   337  		var err error
   338  		writer, err = NewCompressingResponseWriter(httpWriter, encoding)
   339  		if err != nil {
   340  			log.Print("unable to install compressor: ", err)
   341  			httpWriter.WriteHeader(http.StatusInternalServerError)
   342  			return
   343  		}
   344  	}
   345  
   346  	c.ServeMux.ServeHTTP(writer, httpRequest)
   347  }
   348  
   349  // Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
   350  func (c *Container) Handle(pattern string, handler http.Handler) {
   351  	c.ServeMux.Handle(pattern, http.HandlerFunc(func(httpWriter http.ResponseWriter, httpRequest *http.Request) {
   352  		// Skip, if httpWriter is already an CompressingResponseWriter
   353  		if _, ok := httpWriter.(*CompressingResponseWriter); ok {
   354  			handler.ServeHTTP(httpWriter, httpRequest)
   355  			return
   356  		}
   357  
   358  		writer := httpWriter
   359  
   360  		// CompressingResponseWriter should be closed after all operations are done
   361  		defer func() {
   362  			if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
   363  				compressWriter.Close()
   364  			}
   365  		}()
   366  
   367  		if c.contentEncodingEnabled {
   368  			doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
   369  			if doCompress {
   370  				var err error
   371  				writer, err = NewCompressingResponseWriter(httpWriter, encoding)
   372  				if err != nil {
   373  					log.Print("unable to install compressor: ", err)
   374  					httpWriter.WriteHeader(http.StatusInternalServerError)
   375  					return
   376  				}
   377  			}
   378  		}
   379  
   380  		handler.ServeHTTP(writer, httpRequest)
   381  	}))
   382  }
   383  
   384  // HandleWithFilter registers the handler for the given pattern.
   385  // Container's filter chain is applied for handler.
   386  // If a handler already exists for pattern, HandleWithFilter panics.
   387  func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
   388  	f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
   389  		if len(c.containerFilters) == 0 {
   390  			handler.ServeHTTP(httpResponse, httpRequest)
   391  			return
   392  		}
   393  
   394  		chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
   395  			handler.ServeHTTP(resp, req.Request)
   396  		}}
   397  		chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
   398  	}
   399  
   400  	c.Handle(pattern, http.HandlerFunc(f))
   401  }
   402  
   403  // Filter appends a container FilterFunction. These are called before dispatching
   404  // a http.Request to a WebService from the container
   405  func (c *Container) Filter(filter FilterFunction) {
   406  	c.containerFilters = append(c.containerFilters, filter)
   407  }
   408  
   409  // RegisteredWebServices returns the collections of added WebServices
   410  func (c *Container) RegisteredWebServices() []*WebService {
   411  	c.webServicesLock.RLock()
   412  	defer c.webServicesLock.RUnlock()
   413  	result := make([]*WebService, len(c.webServices))
   414  	for ix := range c.webServices {
   415  		result[ix] = c.webServices[ix]
   416  	}
   417  	return result
   418  }
   419  
   420  // computeAllowedMethods returns a list of HTTP methods that are valid for a Request
   421  func (c *Container) computeAllowedMethods(req *Request) []string {
   422  	// Go through all RegisteredWebServices() and all its Routes to collect the options
   423  	methods := []string{}
   424  	requestPath := req.Request.URL.Path
   425  	for _, ws := range c.RegisteredWebServices() {
   426  		matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
   427  		if matches != nil {
   428  			finalMatch := matches[len(matches)-1]
   429  			for _, rt := range ws.Routes() {
   430  				matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
   431  				if matches != nil {
   432  					lastMatch := matches[len(matches)-1]
   433  					if lastMatch == "" || lastMatch == "/" { // do not include if value is neither empty nor ‘/’.
   434  						methods = append(methods, rt.Method)
   435  					}
   436  				}
   437  			}
   438  		}
   439  	}
   440  	// methods = append(methods, "OPTIONS")  not sure about this
   441  	return methods
   442  }
   443  
   444  // newBasicRequestResponse creates a pair of Request,Response from its http versions.
   445  // It is basic because no parameter or (produces) content-type information is given.
   446  func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
   447  	resp := NewResponse(httpWriter)
   448  	resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
   449  	return NewRequest(httpRequest), resp
   450  }
   451  

View as plain text