...
1 package cors
2
3 import (
4 "errors"
5 "net/http"
6 "strings"
7 "time"
8
9 "github.com/gin-gonic/gin"
10 )
11
12 type Config struct {
13 AbortOnError bool
14 AllowAllOrigins bool
15
16
17
18
19 AllowedOrigins []string
20
21
22
23
24 AllowOriginFunc func(origin string) bool
25
26
27
28 AllowedMethods []string
29
30
31
32
33
34 AllowedHeaders []string
35
36
37
38 ExposedHeaders []string
39
40
41
42 AllowCredentials bool
43
44
45
46 MaxAge time.Duration
47 }
48
49 func (c *Config) AddAllowedMethods(methods ...string) {
50 c.AllowedMethods = append(c.AllowedMethods, methods...)
51 }
52
53 func (c *Config) AddAllowedHeaders(headers ...string) {
54 c.AllowedHeaders = append(c.AllowedHeaders, headers...)
55 }
56
57 func (c *Config) AddExposedHeaders(headers ...string) {
58 c.ExposedHeaders = append(c.ExposedHeaders, headers...)
59 }
60
61 func (c Config) Validate() error {
62 if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowedOrigins) > 0) {
63 return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
64 }
65 if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowedOrigins) == 0 {
66 return errors.New("conflict settings: all origins disabled")
67 }
68 if c.AllowOriginFunc != nil && len(c.AllowedOrigins) > 0 {
69 return errors.New("conflict settings: if a allow origin func is provided, AllowedOrigins is not needed")
70 }
71 for _, origin := range c.AllowedOrigins {
72 if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
73 return errors.New("bad origin: origins must include http:// or https://")
74 }
75 }
76 return nil
77 }
78
79 var defaultConfig = Config{
80 AbortOnError: false,
81 AllowAllOrigins: true,
82 AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "HEAD"},
83 AllowedHeaders: []string{"Content-Type"},
84
85 AllowCredentials: false,
86 MaxAge: 12 * time.Hour,
87 }
88
89 func DefaultConfig() Config {
90 cp := defaultConfig
91 return cp
92 }
93
94 func Default() gin.HandlerFunc {
95 return New(defaultConfig)
96 }
97
98 func New(config Config) gin.HandlerFunc {
99 s := newSettings(config)
100
101
102 return func(c *gin.Context) {
103 origin := c.Request.Header.Get("Origin")
104 if len(origin) == 0 {
105 return
106 }
107 origin, valid := s.validateOrigin(origin)
108 if valid {
109 if c.Request.Method == "OPTIONS" {
110 valid = handlePreflight(c, s)
111 } else {
112 valid = handleNormal(c, s)
113 }
114 }
115
116 if !valid {
117 if config.AbortOnError {
118 c.AbortWithStatus(http.StatusForbidden)
119 }
120 return
121 }
122 c.Header("Access-Control-Allow-Origin", origin)
123 }
124 }
125
126 func handlePreflight(c *gin.Context, s *settings) bool {
127 c.AbortWithStatus(200)
128 if !s.validateMethod(c.Request.Header.Get("Access-Control-Request-Method")) {
129 return false
130 }
131 if !s.validateHeader(c.Request.Header.Get("Access-Control-Request-Header")) {
132 return false
133 }
134 for key, value := range s.preflightHeaders {
135 c.Writer.Header()[key] = value
136 }
137 return true
138 }
139
140 func handleNormal(c *gin.Context, s *settings) bool {
141 for key, value := range s.normalHeaders {
142 c.Writer.Header()[key] = value
143 }
144 return true
145 }
146
View as plain text