1 package handlers
2
3 import (
4 "net/http"
5 "strconv"
6 "strings"
7 )
8
9
10 type CORSOption func(*cors) error
11
12 type cors struct {
13 h http.Handler
14 allowedHeaders []string
15 allowedMethods []string
16 allowedOrigins []string
17 allowedOriginValidator OriginValidator
18 exposedHeaders []string
19 maxAge int
20 ignoreOptions bool
21 allowCredentials bool
22 optionStatusCode int
23 }
24
25
26 type OriginValidator func(string) bool
27
28 var (
29 defaultCorsOptionStatusCode = 200
30 defaultCorsMethods = []string{"GET", "HEAD", "POST"}
31 defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
32
33 )
34
35 const (
36 corsOptionMethod string = "OPTIONS"
37 corsAllowOriginHeader string = "Access-Control-Allow-Origin"
38 corsExposeHeadersHeader string = "Access-Control-Expose-Headers"
39 corsMaxAgeHeader string = "Access-Control-Max-Age"
40 corsAllowMethodsHeader string = "Access-Control-Allow-Methods"
41 corsAllowHeadersHeader string = "Access-Control-Allow-Headers"
42 corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
43 corsRequestMethodHeader string = "Access-Control-Request-Method"
44 corsRequestHeadersHeader string = "Access-Control-Request-Headers"
45 corsOriginHeader string = "Origin"
46 corsVaryHeader string = "Vary"
47 corsOriginMatchAll string = "*"
48 )
49
50 func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
51 origin := r.Header.Get(corsOriginHeader)
52 if !ch.isOriginAllowed(origin) {
53 if r.Method != corsOptionMethod || ch.ignoreOptions {
54 ch.h.ServeHTTP(w, r)
55 }
56
57 return
58 }
59
60 if r.Method == corsOptionMethod {
61 if ch.ignoreOptions {
62 ch.h.ServeHTTP(w, r)
63 return
64 }
65
66 if _, ok := r.Header[corsRequestMethodHeader]; !ok {
67 w.WriteHeader(http.StatusBadRequest)
68 return
69 }
70
71 method := r.Header.Get(corsRequestMethodHeader)
72 if !ch.isMatch(method, ch.allowedMethods) {
73 w.WriteHeader(http.StatusMethodNotAllowed)
74 return
75 }
76
77 requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
78 allowedHeaders := []string{}
79 for _, v := range requestHeaders {
80 canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
81 if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
82 continue
83 }
84
85 if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
86 w.WriteHeader(http.StatusForbidden)
87 return
88 }
89
90 allowedHeaders = append(allowedHeaders, canonicalHeader)
91 }
92
93 if len(allowedHeaders) > 0 {
94 w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
95 }
96
97 if ch.maxAge > 0 {
98 w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
99 }
100
101 if !ch.isMatch(method, defaultCorsMethods) {
102 w.Header().Set(corsAllowMethodsHeader, method)
103 }
104 } else {
105 if len(ch.exposedHeaders) > 0 {
106 w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
107 }
108 }
109
110 if ch.allowCredentials {
111 w.Header().Set(corsAllowCredentialsHeader, "true")
112 }
113
114 if len(ch.allowedOrigins) > 1 {
115 w.Header().Set(corsVaryHeader, corsOriginHeader)
116 }
117
118 returnOrigin := origin
119 if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
120 returnOrigin = "*"
121 } else {
122 for _, o := range ch.allowedOrigins {
123
124
125
126 if o == corsOriginMatchAll {
127 returnOrigin = "*"
128 break
129 }
130 }
131 }
132 w.Header().Set(corsAllowOriginHeader, returnOrigin)
133
134 if r.Method == corsOptionMethod {
135 w.WriteHeader(ch.optionStatusCode)
136 return
137 }
138 ch.h.ServeHTTP(w, r)
139 }
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160 func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
161 return func(h http.Handler) http.Handler {
162 ch := parseCORSOptions(opts...)
163 ch.h = h
164 return ch
165 }
166 }
167
168 func parseCORSOptions(opts ...CORSOption) *cors {
169 ch := &cors{
170 allowedMethods: defaultCorsMethods,
171 allowedHeaders: defaultCorsHeaders,
172 allowedOrigins: []string{},
173 optionStatusCode: defaultCorsOptionStatusCode,
174 }
175
176 for _, option := range opts {
177 option(ch)
178 }
179
180 return ch
181 }
182
183
184
185
186
187
188
189
190
191
192
193 func AllowedHeaders(headers []string) CORSOption {
194 return func(ch *cors) error {
195 for _, v := range headers {
196 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
197 if normalizedHeader == "" {
198 continue
199 }
200
201 if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
202 ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
203 }
204 }
205
206 return nil
207 }
208 }
209
210
211
212
213
214 func AllowedMethods(methods []string) CORSOption {
215 return func(ch *cors) error {
216 ch.allowedMethods = []string{}
217 for _, v := range methods {
218 normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
219 if normalizedMethod == "" {
220 continue
221 }
222
223 if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
224 ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
225 }
226 }
227
228 return nil
229 }
230 }
231
232
233
234
235 func AllowedOrigins(origins []string) CORSOption {
236 return func(ch *cors) error {
237 for _, v := range origins {
238 if v == corsOriginMatchAll {
239 ch.allowedOrigins = []string{corsOriginMatchAll}
240 return nil
241 }
242 }
243
244 ch.allowedOrigins = origins
245 return nil
246 }
247 }
248
249
250
251 func AllowedOriginValidator(fn OriginValidator) CORSOption {
252 return func(ch *cors) error {
253 ch.allowedOriginValidator = fn
254 return nil
255 }
256 }
257
258
259
260
261
262
263
264 func OptionStatusCode(code int) CORSOption {
265 return func(ch *cors) error {
266 ch.optionStatusCode = code
267 return nil
268 }
269 }
270
271
272
273 func ExposedHeaders(headers []string) CORSOption {
274 return func(ch *cors) error {
275 ch.exposedHeaders = []string{}
276 for _, v := range headers {
277 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
278 if normalizedHeader == "" {
279 continue
280 }
281
282 if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
283 ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
284 }
285 }
286
287 return nil
288 }
289 }
290
291
292
293
294 func MaxAge(age int) CORSOption {
295 return func(ch *cors) error {
296
297 if age > 600 {
298 age = 600
299 }
300
301 ch.maxAge = age
302 return nil
303 }
304 }
305
306
307
308
309 func IgnoreOptions() CORSOption {
310 return func(ch *cors) error {
311 ch.ignoreOptions = true
312 return nil
313 }
314 }
315
316
317
318 func AllowCredentials() CORSOption {
319 return func(ch *cors) error {
320 ch.allowCredentials = true
321 return nil
322 }
323 }
324
325 func (ch *cors) isOriginAllowed(origin string) bool {
326 if origin == "" {
327 return false
328 }
329
330 if ch.allowedOriginValidator != nil {
331 return ch.allowedOriginValidator(origin)
332 }
333
334 if len(ch.allowedOrigins) == 0 {
335 return true
336 }
337
338 for _, allowedOrigin := range ch.allowedOrigins {
339 if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
340 return true
341 }
342 }
343
344 return false
345 }
346
347 func (ch *cors) isMatch(needle string, haystack []string) bool {
348 for _, v := range haystack {
349 if v == needle {
350 return true
351 }
352 }
353
354 return false
355 }
356
View as plain text