...
1 package cors
2
3 import (
4 "net/http"
5 "strconv"
6 "strings"
7 "time"
8 )
9
10 type settings struct {
11 allowAllOrigins bool
12 allowedOriginFunc func(string) bool
13 allowedOrigins []string
14 allowedMethods []string
15 allowedHeaders []string
16 exposedHeaders []string
17 normalHeaders http.Header
18 preflightHeaders http.Header
19 }
20
21 func newSettings(c Config) *settings {
22 if err := c.Validate(); err != nil {
23 panic(err.Error())
24 }
25 return &settings{
26 allowedOriginFunc: c.AllowOriginFunc,
27 allowAllOrigins: c.AllowAllOrigins,
28 allowedOrigins: c.AllowedOrigins,
29 allowedMethods: distinct(c.AllowedMethods),
30 allowedHeaders: distinct(c.AllowedHeaders),
31 normalHeaders: generateNormalHeaders(c),
32 preflightHeaders: generatePreflightHeaders(c),
33 }
34 }
35
36 func (c *settings) validateOrigin(origin string) (string, bool) {
37 if c.allowAllOrigins {
38 return "*", true
39 }
40 if c.allowedOriginFunc != nil {
41 return origin, c.allowedOriginFunc(origin)
42 }
43 for _, value := range c.allowedOrigins {
44 if value == origin {
45 return origin, true
46 }
47 }
48 return "", false
49 }
50
51 func (c *settings) validateMethod(method string) bool {
52
53 return true
54 }
55
56 func (c *settings) validateHeader(header string) bool {
57
58 return true
59 }
60
61 func generateNormalHeaders(c Config) http.Header {
62 headers := make(http.Header)
63 if c.AllowCredentials {
64 headers.Set("Access-Control-Allow-Credentials", "true")
65 }
66 if len(c.ExposedHeaders) > 0 {
67 headers.Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ", "))
68 }
69 return headers
70 }
71
72 func generatePreflightHeaders(c Config) http.Header {
73 headers := make(http.Header)
74 if c.AllowCredentials {
75 headers.Set("Access-Control-Allow-Credentials", "true")
76 }
77 if len(c.AllowedMethods) > 0 {
78 headers.Set("Access-Control-Allow-Methods", strings.Join(c.AllowedMethods, ", "))
79 }
80 if len(c.AllowedHeaders) > 0 {
81 headers.Set("Access-Control-Allow-Headers", strings.Join(c.AllowedHeaders, ", "))
82 }
83 if c.MaxAge > time.Duration(0) {
84 headers.Set("Access-Control-Max-Age", strconv.FormatInt(int64(c.MaxAge/time.Second), 10))
85 }
86 return headers
87 }
88
89 func distinct(s []string) []string {
90 m := map[string]bool{}
91 for _, v := range s {
92 if _, seen := m[v]; !seen {
93 s[len(m)] = v
94 m[v] = true
95 }
96 }
97 return s[:len(m)]
98 }
99
100 func parse(content string) []string {
101 if len(content) == 0 {
102 return nil
103 }
104 parts := strings.Split(content, ",")
105 for i := 0; i < len(parts); i++ {
106 parts[i] = strings.TrimSpace(parts[i])
107 }
108 return parts
109 }
110
View as plain text