...
1 package cors
2
3 import (
4 "errors"
5 "fmt"
6 "strings"
7 "time"
8
9 "github.com/gin-gonic/gin"
10 )
11
12
13 type Config struct {
14 AllowAllOrigins bool
15
16
17
18
19 AllowOrigins []string
20
21
22
23
24 AllowOriginFunc func(origin string) bool
25
26
27
28
29
30 AllowOriginWithContextFunc func(c *gin.Context, origin string) bool
31
32
33
34 AllowMethods []string
35
36
37 AllowPrivateNetwork bool
38
39
40
41 AllowHeaders []string
42
43
44
45 AllowCredentials bool
46
47
48
49 ExposeHeaders []string
50
51
52
53 MaxAge time.Duration
54
55
56 AllowWildcard bool
57
58
59 AllowBrowserExtensions bool
60
61
62 CustomSchemas []string
63
64
65 AllowWebSockets bool
66
67
68 AllowFiles bool
69
70
71 OptionsResponseStatusCode int
72 }
73
74
75 func (c *Config) AddAllowMethods(methods ...string) {
76 c.AllowMethods = append(c.AllowMethods, methods...)
77 }
78
79
80 func (c *Config) AddAllowHeaders(headers ...string) {
81 c.AllowHeaders = append(c.AllowHeaders, headers...)
82 }
83
84
85 func (c *Config) AddExposeHeaders(headers ...string) {
86 c.ExposeHeaders = append(c.ExposeHeaders, headers...)
87 }
88
89 func (c Config) getAllowedSchemas() []string {
90 allowedSchemas := DefaultSchemas
91 if c.AllowBrowserExtensions {
92 allowedSchemas = append(allowedSchemas, ExtensionSchemas...)
93 }
94 if c.AllowWebSockets {
95 allowedSchemas = append(allowedSchemas, WebSocketSchemas...)
96 }
97 if c.AllowFiles {
98 allowedSchemas = append(allowedSchemas, FileSchemas...)
99 }
100 if c.CustomSchemas != nil {
101 allowedSchemas = append(allowedSchemas, c.CustomSchemas...)
102 }
103 return allowedSchemas
104 }
105
106 func (c Config) validateAllowedSchemas(origin string) bool {
107 allowedSchemas := c.getAllowedSchemas()
108 for _, schema := range allowedSchemas {
109 if strings.HasPrefix(origin, schema) {
110 return true
111 }
112 }
113 return false
114 }
115
116
117 func (c Config) Validate() error {
118 hasOriginFn := c.AllowOriginFunc != nil
119 hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil
120
121 if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) {
122 originFields := strings.Join([]string{
123 "AllowOriginFunc",
124 "AllowOriginFuncWithContext",
125 "AllowOrigins",
126 }, " or ")
127 return fmt.Errorf(
128 "conflict settings: all origins enabled. %s is not needed",
129 originFields,
130 )
131 }
132 if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 {
133 return errors.New("conflict settings: all origins disabled")
134 }
135 for _, origin := range c.AllowOrigins {
136 if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) {
137 return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
138 }
139 }
140 return nil
141 }
142
143 func (c Config) parseWildcardRules() [][]string {
144 var wRules [][]string
145
146 if !c.AllowWildcard {
147 return wRules
148 }
149
150 for _, o := range c.AllowOrigins {
151 if !strings.Contains(o, "*") {
152 continue
153 }
154
155 if c := strings.Count(o, "*"); c > 1 {
156 panic(errors.New("only one * is allowed").Error())
157 }
158
159 i := strings.Index(o, "*")
160 if i == 0 {
161 wRules = append(wRules, []string{"*", o[1:]})
162 continue
163 }
164 if i == (len(o) - 1) {
165 wRules = append(wRules, []string{o[:i], "*"})
166 continue
167 }
168
169 wRules = append(wRules, []string{o[:i], o[i+1:]})
170 }
171
172 return wRules
173 }
174
175
176 func DefaultConfig() Config {
177 return Config{
178 AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
179 AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"},
180 AllowCredentials: false,
181 MaxAge: 12 * time.Hour,
182 }
183 }
184
185
186 func Default() gin.HandlerFunc {
187 config := DefaultConfig()
188 config.AllowAllOrigins = true
189 return New(config)
190 }
191
192
193 func New(config Config) gin.HandlerFunc {
194 cors := newCors(config)
195 return func(c *gin.Context) {
196 cors.applyCors(c)
197 }
198 }
199
View as plain text