1 package restful
2
3
4
5
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "net/http"
12 "os"
13 "runtime"
14 "strings"
15 "sync"
16
17 "github.com/emicklei/go-restful/v3/log"
18 )
19
20
21
22 type Container struct {
23 webServicesLock sync.RWMutex
24 webServices []*WebService
25 ServeMux *http.ServeMux
26 isRegisteredOnRoot bool
27 containerFilters []FilterFunction
28 doNotRecover bool
29 recoverHandleFunc RecoverHandleFunction
30 serviceErrorHandleFunc ServiceErrorHandleFunction
31 router RouteSelector
32 contentEncodingEnabled bool
33 }
34
35
36 func NewContainer() *Container {
37 return &Container{
38 webServices: []*WebService{},
39 ServeMux: http.NewServeMux(),
40 isRegisteredOnRoot: false,
41 containerFilters: []FilterFunction{},
42 doNotRecover: true,
43 recoverHandleFunc: logStackOnRecover,
44 serviceErrorHandleFunc: writeServiceError,
45 router: CurlyRouter{},
46 contentEncodingEnabled: false}
47 }
48
49
50
51 type RecoverHandleFunction func(interface{}, http.ResponseWriter)
52
53
54
55 func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
56 c.recoverHandleFunc = handler
57 }
58
59
60
61
62 type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)
63
64
65
66 func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
67 c.serviceErrorHandleFunc = handler
68 }
69
70
71
72
73 func (c *Container) DoNotRecover(doNot bool) {
74 c.doNotRecover = doNot
75 }
76
77
78 func (c *Container) Router(aRouter RouteSelector) {
79 c.router = aRouter
80 }
81
82
83 func (c *Container) EnableContentEncoding(enabled bool) {
84 c.contentEncodingEnabled = enabled
85 }
86
87
88 func (c *Container) Add(service *WebService) *Container {
89 c.webServicesLock.Lock()
90 defer c.webServicesLock.Unlock()
91
92
93 if len(service.rootPath) == 0 {
94 service.Path("/")
95 }
96
97
98 for _, each := range c.webServices {
99 if each.RootPath() == service.RootPath() {
100 log.Printf("WebService with duplicate root path detected:['%v']", each)
101 os.Exit(1)
102 }
103 }
104
105
106 if !c.isRegisteredOnRoot {
107 c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
108 }
109 c.webServices = append(c.webServices, service)
110 return c
111 }
112
113
114
115
116 func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
117 pattern := fixedPrefixPath(service.RootPath())
118
119 if "/" == pattern || "" == pattern {
120 serveMux.HandleFunc("/", c.dispatch)
121 return true
122 }
123
124 alreadyMapped := false
125 for _, each := range c.webServices {
126 if each.RootPath() == service.RootPath() {
127 alreadyMapped = true
128 break
129 }
130 }
131 if !alreadyMapped {
132 serveMux.HandleFunc(pattern, c.dispatch)
133 if !strings.HasSuffix(pattern, "/") {
134 serveMux.HandleFunc(pattern+"/", c.dispatch)
135 }
136 }
137 return false
138 }
139
140 func (c *Container) Remove(ws *WebService) error {
141 if c.ServeMux == http.DefaultServeMux {
142 errMsg := fmt.Sprintf("cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
143 log.Print(errMsg)
144 return errors.New(errMsg)
145 }
146 c.webServicesLock.Lock()
147 defer c.webServicesLock.Unlock()
148
149 newServeMux := http.NewServeMux()
150 newServices := []*WebService{}
151 newIsRegisteredOnRoot := false
152 for _, each := range c.webServices {
153 if each.rootPath != ws.rootPath {
154
155 if !newIsRegisteredOnRoot {
156 newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
157 }
158 newServices = append(newServices, each)
159 }
160 }
161 c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
162 return nil
163 }
164
165
166
167
168
169 func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
170 var buffer bytes.Buffer
171 buffer.WriteString(fmt.Sprintf("recover from panic situation: - %v\r\n", panicReason))
172 for i := 2; ; i += 1 {
173 _, file, line, ok := runtime.Caller(i)
174 if !ok {
175 break
176 }
177 buffer.WriteString(fmt.Sprintf(" %s:%d\r\n", file, line))
178 }
179 log.Print(buffer.String())
180 httpWriter.WriteHeader(http.StatusInternalServerError)
181 httpWriter.Write(buffer.Bytes())
182 }
183
184
185
186
187 func writeServiceError(err ServiceError, req *Request, resp *Response) {
188 for header, values := range err.Header {
189 for _, value := range values {
190 resp.Header().Add(header, value)
191 }
192 }
193 resp.WriteErrorString(err.Code, err.Message)
194 }
195
196
197 func (c *Container) Dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
198 if httpWriter == nil {
199 panic("httpWriter cannot be nil")
200 }
201 if httpRequest == nil {
202 panic("httpRequest cannot be nil")
203 }
204 c.dispatch(httpWriter, httpRequest)
205 }
206
207
208 func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
209
210 writer := httpWriter
211
212
213 defer func() {
214 if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
215 compressWriter.Close()
216 }
217 }()
218
219
220 if !c.doNotRecover {
221 defer func() {
222 if r := recover(); r != nil {
223 c.recoverHandleFunc(r, writer)
224 return
225 }
226 }()
227 }
228
229
230 var webService *WebService
231 var route *Route
232 var err error
233 func() {
234 c.webServicesLock.RLock()
235 defer c.webServicesLock.RUnlock()
236 webService, route, err = c.router.SelectRoute(
237 c.webServices,
238 httpRequest)
239 }()
240 if err != nil {
241
242
243 chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
244 switch err.(type) {
245 case ServiceError:
246 ser := err.(ServiceError)
247 c.serviceErrorHandleFunc(ser, req, resp)
248 }
249
250 }}
251 chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
252 return
253 }
254
255
256 if _, isCompressing := httpWriter.(*CompressingResponseWriter); !isCompressing {
257
258
259 contentEncodingEnabled := c.contentEncodingEnabled
260 if route != nil && route.contentEncodingEnabled != nil {
261 contentEncodingEnabled = *route.contentEncodingEnabled
262 }
263 if contentEncodingEnabled {
264 doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
265 if doCompress {
266 var err error
267 writer, err = NewCompressingResponseWriter(httpWriter, encoding)
268 if err != nil {
269 log.Print("unable to install compressor: ", err)
270 httpWriter.WriteHeader(http.StatusInternalServerError)
271 return
272 }
273 }
274 }
275 }
276
277 pathProcessor, routerProcessesPath := c.router.(PathProcessor)
278 if !routerProcessesPath {
279 pathProcessor = defaultPathProcessor{}
280 }
281 pathParams := pathProcessor.ExtractParameters(route, webService, httpRequest.URL.Path)
282 wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest, pathParams)
283
284 if size := len(c.containerFilters) + len(webService.filters) + len(route.Filters); size > 0 {
285
286 allFilters := make([]FilterFunction, 0, size)
287 allFilters = append(allFilters, c.containerFilters...)
288 allFilters = append(allFilters, webService.filters...)
289 allFilters = append(allFilters, route.Filters...)
290 chain := FilterChain{
291 Filters: allFilters,
292 Target: route.Function,
293 ParameterDocs: route.ParameterDocs,
294 Operation: route.Operation,
295 }
296 chain.ProcessFilter(wrappedRequest, wrappedResponse)
297 } else {
298
299 route.Function(wrappedRequest, wrappedResponse)
300 }
301 }
302
303
304 func fixedPrefixPath(pathspec string) string {
305 varBegin := strings.Index(pathspec, "{")
306 if -1 == varBegin {
307 return pathspec
308 }
309 return pathspec[:varBegin]
310 }
311
312
313 func (c *Container) ServeHTTP(httpWriter http.ResponseWriter, httpRequest *http.Request) {
314
315 if !c.contentEncodingEnabled {
316 c.ServeMux.ServeHTTP(httpWriter, httpRequest)
317 return
318 }
319
320
321
322 if _, ok := httpWriter.(*CompressingResponseWriter); ok {
323 c.ServeMux.ServeHTTP(httpWriter, httpRequest)
324 return
325 }
326
327 writer := httpWriter
328
329 defer func() {
330 if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
331 compressWriter.Close()
332 }
333 }()
334
335 doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
336 if doCompress {
337 var err error
338 writer, err = NewCompressingResponseWriter(httpWriter, encoding)
339 if err != nil {
340 log.Print("unable to install compressor: ", err)
341 httpWriter.WriteHeader(http.StatusInternalServerError)
342 return
343 }
344 }
345
346 c.ServeMux.ServeHTTP(writer, httpRequest)
347 }
348
349
350 func (c *Container) Handle(pattern string, handler http.Handler) {
351 c.ServeMux.Handle(pattern, http.HandlerFunc(func(httpWriter http.ResponseWriter, httpRequest *http.Request) {
352
353 if _, ok := httpWriter.(*CompressingResponseWriter); ok {
354 handler.ServeHTTP(httpWriter, httpRequest)
355 return
356 }
357
358 writer := httpWriter
359
360
361 defer func() {
362 if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
363 compressWriter.Close()
364 }
365 }()
366
367 if c.contentEncodingEnabled {
368 doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
369 if doCompress {
370 var err error
371 writer, err = NewCompressingResponseWriter(httpWriter, encoding)
372 if err != nil {
373 log.Print("unable to install compressor: ", err)
374 httpWriter.WriteHeader(http.StatusInternalServerError)
375 return
376 }
377 }
378 }
379
380 handler.ServeHTTP(writer, httpRequest)
381 }))
382 }
383
384
385
386
387 func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
388 f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
389 if len(c.containerFilters) == 0 {
390 handler.ServeHTTP(httpResponse, httpRequest)
391 return
392 }
393
394 chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
395 handler.ServeHTTP(resp, req.Request)
396 }}
397 chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
398 }
399
400 c.Handle(pattern, http.HandlerFunc(f))
401 }
402
403
404
405 func (c *Container) Filter(filter FilterFunction) {
406 c.containerFilters = append(c.containerFilters, filter)
407 }
408
409
410 func (c *Container) RegisteredWebServices() []*WebService {
411 c.webServicesLock.RLock()
412 defer c.webServicesLock.RUnlock()
413 result := make([]*WebService, len(c.webServices))
414 for ix := range c.webServices {
415 result[ix] = c.webServices[ix]
416 }
417 return result
418 }
419
420
421 func (c *Container) computeAllowedMethods(req *Request) []string {
422
423 methods := []string{}
424 requestPath := req.Request.URL.Path
425 for _, ws := range c.RegisteredWebServices() {
426 matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
427 if matches != nil {
428 finalMatch := matches[len(matches)-1]
429 for _, rt := range ws.Routes() {
430 matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
431 if matches != nil {
432 lastMatch := matches[len(matches)-1]
433 if lastMatch == "" || lastMatch == "/" {
434 methods = append(methods, rt.Method)
435 }
436 }
437 }
438 }
439 }
440
441 return methods
442 }
443
444
445
446 func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
447 resp := NewResponse(httpWriter)
448 resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
449 return NewRequest(httpRequest), resp
450 }
451
View as plain text