...
1 package cors
2
3 import (
4 "net/http"
5 "strings"
6
7 "github.com/gin-gonic/gin"
8 )
9
10 type cors struct {
11 allowAllOrigins bool
12 allowCredentials bool
13 allowOriginFunc func(string) bool
14 allowOriginWithContextFunc func(*gin.Context, string) bool
15 allowOrigins []string
16 normalHeaders http.Header
17 preflightHeaders http.Header
18 wildcardOrigins [][]string
19 optionsResponseStatusCode int
20 }
21
22 var (
23 DefaultSchemas = []string{
24 "http://",
25 "https://",
26 }
27 ExtensionSchemas = []string{
28 "chrome-extension://",
29 "safari-extension://",
30 "moz-extension://",
31 "ms-browser-extension://",
32 }
33 FileSchemas = []string{
34 "file://",
35 }
36 WebSocketSchemas = []string{
37 "ws://",
38 "wss://",
39 }
40 )
41
42 func newCors(config Config) *cors {
43 if err := config.Validate(); err != nil {
44 panic(err.Error())
45 }
46
47 for _, origin := range config.AllowOrigins {
48 if origin == "*" {
49 config.AllowAllOrigins = true
50 }
51 }
52
53 if config.OptionsResponseStatusCode == 0 {
54 config.OptionsResponseStatusCode = http.StatusNoContent
55 }
56
57 return &cors{
58 allowOriginFunc: config.AllowOriginFunc,
59 allowOriginWithContextFunc: config.AllowOriginWithContextFunc,
60 allowAllOrigins: config.AllowAllOrigins,
61 allowCredentials: config.AllowCredentials,
62 allowOrigins: normalize(config.AllowOrigins),
63 normalHeaders: generateNormalHeaders(config),
64 preflightHeaders: generatePreflightHeaders(config),
65 wildcardOrigins: config.parseWildcardRules(),
66 optionsResponseStatusCode: config.OptionsResponseStatusCode,
67 }
68 }
69
70 func (cors *cors) applyCors(c *gin.Context) {
71 origin := c.Request.Header.Get("Origin")
72 if len(origin) == 0 {
73
74 return
75 }
76 host := c.Request.Host
77
78 if origin == "http://"+host || origin == "https://"+host {
79
80
81 return
82 }
83
84 if !cors.isOriginValid(c, origin) {
85 c.AbortWithStatus(http.StatusForbidden)
86 return
87 }
88
89 if c.Request.Method == "OPTIONS" {
90 cors.handlePreflight(c)
91 defer c.AbortWithStatus(cors.optionsResponseStatusCode)
92 } else {
93 cors.handleNormal(c)
94 }
95
96 if !cors.allowAllOrigins {
97 c.Header("Access-Control-Allow-Origin", origin)
98 }
99 }
100
101 func (cors *cors) validateWildcardOrigin(origin string) bool {
102 for _, w := range cors.wildcardOrigins {
103 if w[0] == "*" && strings.HasSuffix(origin, w[1]) {
104 return true
105 }
106 if w[1] == "*" && strings.HasPrefix(origin, w[0]) {
107 return true
108 }
109 if strings.HasPrefix(origin, w[0]) && strings.HasSuffix(origin, w[1]) {
110 return true
111 }
112 }
113
114 return false
115 }
116
117 func (cors *cors) isOriginValid(c *gin.Context, origin string) bool {
118 valid := cors.validateOrigin(origin)
119 if !valid && cors.allowOriginWithContextFunc != nil {
120 valid = cors.allowOriginWithContextFunc(c, origin)
121 }
122 return valid
123 }
124
125 func (cors *cors) validateOrigin(origin string) bool {
126 if cors.allowAllOrigins {
127 return true
128 }
129 for _, value := range cors.allowOrigins {
130 if value == origin {
131 return true
132 }
133 }
134 if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) {
135 return true
136 }
137 if cors.allowOriginFunc != nil {
138 return cors.allowOriginFunc(origin)
139 }
140 return false
141 }
142
143 func (cors *cors) handlePreflight(c *gin.Context) {
144 header := c.Writer.Header()
145 for key, value := range cors.preflightHeaders {
146 header[key] = value
147 }
148 }
149
150 func (cors *cors) handleNormal(c *gin.Context) {
151 header := c.Writer.Header()
152 for key, value := range cors.normalHeaders {
153 header[key] = value
154 }
155 }
156
View as plain text