...

Source file src/github.com/gin-contrib/cors/config.go

Documentation: github.com/gin-contrib/cors

     1  package cors
     2  
     3  import (
     4  	"net/http"
     5  	"strings"
     6  
     7  	"github.com/gin-gonic/gin"
     8  )
     9  
    10  type cors struct {
    11  	allowAllOrigins            bool
    12  	allowCredentials           bool
    13  	allowOriginFunc            func(string) bool
    14  	allowOriginWithContextFunc func(*gin.Context, string) bool
    15  	allowOrigins               []string
    16  	normalHeaders              http.Header
    17  	preflightHeaders           http.Header
    18  	wildcardOrigins            [][]string
    19  	optionsResponseStatusCode  int
    20  }
    21  
    22  var (
    23  	DefaultSchemas = []string{
    24  		"http://",
    25  		"https://",
    26  	}
    27  	ExtensionSchemas = []string{
    28  		"chrome-extension://",
    29  		"safari-extension://",
    30  		"moz-extension://",
    31  		"ms-browser-extension://",
    32  	}
    33  	FileSchemas = []string{
    34  		"file://",
    35  	}
    36  	WebSocketSchemas = []string{
    37  		"ws://",
    38  		"wss://",
    39  	}
    40  )
    41  
    42  func newCors(config Config) *cors {
    43  	if err := config.Validate(); err != nil {
    44  		panic(err.Error())
    45  	}
    46  
    47  	for _, origin := range config.AllowOrigins {
    48  		if origin == "*" {
    49  			config.AllowAllOrigins = true
    50  		}
    51  	}
    52  
    53  	if config.OptionsResponseStatusCode == 0 {
    54  		config.OptionsResponseStatusCode = http.StatusNoContent
    55  	}
    56  
    57  	return &cors{
    58  		allowOriginFunc:            config.AllowOriginFunc,
    59  		allowOriginWithContextFunc: config.AllowOriginWithContextFunc,
    60  		allowAllOrigins:            config.AllowAllOrigins,
    61  		allowCredentials:           config.AllowCredentials,
    62  		allowOrigins:               normalize(config.AllowOrigins),
    63  		normalHeaders:              generateNormalHeaders(config),
    64  		preflightHeaders:           generatePreflightHeaders(config),
    65  		wildcardOrigins:            config.parseWildcardRules(),
    66  		optionsResponseStatusCode:  config.OptionsResponseStatusCode,
    67  	}
    68  }
    69  
    70  func (cors *cors) applyCors(c *gin.Context) {
    71  	origin := c.Request.Header.Get("Origin")
    72  	if len(origin) == 0 {
    73  		// request is not a CORS request
    74  		return
    75  	}
    76  	host := c.Request.Host
    77  
    78  	if origin == "http://"+host || origin == "https://"+host {
    79  		// request is not a CORS request but have origin header.
    80  		// for example, use fetch api
    81  		return
    82  	}
    83  
    84  	if !cors.isOriginValid(c, origin) {
    85  		c.AbortWithStatus(http.StatusForbidden)
    86  		return
    87  	}
    88  
    89  	if c.Request.Method == "OPTIONS" {
    90  		cors.handlePreflight(c)
    91  		defer c.AbortWithStatus(cors.optionsResponseStatusCode)
    92  	} else {
    93  		cors.handleNormal(c)
    94  	}
    95  
    96  	if !cors.allowAllOrigins {
    97  		c.Header("Access-Control-Allow-Origin", origin)
    98  	}
    99  }
   100  
   101  func (cors *cors) validateWildcardOrigin(origin string) bool {
   102  	for _, w := range cors.wildcardOrigins {
   103  		if w[0] == "*" && strings.HasSuffix(origin, w[1]) {
   104  			return true
   105  		}
   106  		if w[1] == "*" && strings.HasPrefix(origin, w[0]) {
   107  			return true
   108  		}
   109  		if strings.HasPrefix(origin, w[0]) && strings.HasSuffix(origin, w[1]) {
   110  			return true
   111  		}
   112  	}
   113  
   114  	return false
   115  }
   116  
   117  func (cors *cors) isOriginValid(c *gin.Context, origin string) bool {
   118  	valid := cors.validateOrigin(origin)
   119  	if !valid && cors.allowOriginWithContextFunc != nil {
   120  		valid = cors.allowOriginWithContextFunc(c, origin)
   121  	}
   122  	return valid
   123  }
   124  
   125  func (cors *cors) validateOrigin(origin string) bool {
   126  	if cors.allowAllOrigins {
   127  		return true
   128  	}
   129  	for _, value := range cors.allowOrigins {
   130  		if value == origin {
   131  			return true
   132  		}
   133  	}
   134  	if len(cors.wildcardOrigins) > 0 && cors.validateWildcardOrigin(origin) {
   135  		return true
   136  	}
   137  	if cors.allowOriginFunc != nil {
   138  		return cors.allowOriginFunc(origin)
   139  	}
   140  	return false
   141  }
   142  
   143  func (cors *cors) handlePreflight(c *gin.Context) {
   144  	header := c.Writer.Header()
   145  	for key, value := range cors.preflightHeaders {
   146  		header[key] = value
   147  	}
   148  }
   149  
   150  func (cors *cors) handleNormal(c *gin.Context) {
   151  	header := c.Writer.Header()
   152  	for key, value := range cors.normalHeaders {
   153  		header[key] = value
   154  	}
   155  }
   156  

View as plain text