...

Source file src/github.com/gin-contrib/cors/cors_test.go

Documentation: github.com/gin-contrib/cors

     1  package cors
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"reflect"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/gin-gonic/gin"
    13  	"github.com/stretchr/testify/assert"
    14  )
    15  
    16  func newTestRouter(config Config) *gin.Engine {
    17  	router := gin.New()
    18  	router.Use(New(config))
    19  	router.GET("/", func(c *gin.Context) {
    20  		c.String(http.StatusOK, "get")
    21  	})
    22  	router.POST("/", func(c *gin.Context) {
    23  		c.String(http.StatusOK, "post")
    24  	})
    25  	router.PATCH("/", func(c *gin.Context) {
    26  		c.String(http.StatusOK, "patch")
    27  	})
    28  	return router
    29  }
    30  
    31  func multiGroupRouter(config Config) *gin.Engine {
    32  	router := gin.New()
    33  	router.Use(New(config))
    34  
    35  	app1 := router.Group("/app1")
    36  	app1.GET("", func(c *gin.Context) {
    37  		c.String(http.StatusOK, "app1")
    38  	})
    39  
    40  	app2 := router.Group("/app2")
    41  	app2.GET("", func(c *gin.Context) {
    42  		c.String(http.StatusOK, "app2")
    43  	})
    44  
    45  	app3 := router.Group("/app3")
    46  	app3.GET("", func(c *gin.Context) {
    47  		c.String(http.StatusOK, "app3")
    48  	})
    49  
    50  	return router
    51  }
    52  
    53  func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder {
    54  	return performRequestWithHeaders(r, method, "/", origin, http.Header{})
    55  }
    56  
    57  func performRequestWithHeaders(
    58  	r http.Handler,
    59  	method, path, origin string,
    60  	header http.Header,
    61  ) *httptest.ResponseRecorder {
    62  	req, _ := http.NewRequestWithContext(context.Background(), method, path, nil)
    63  	// From go/net/http/request.go:
    64  	// For incoming requests, the Host header is promoted to the
    65  	// Request.Host field and removed from the Header map.
    66  	req.Host = header.Get("Host")
    67  	header.Del("Host")
    68  	if len(origin) > 0 {
    69  		header.Set("Origin", origin)
    70  	}
    71  	req.Header = header
    72  	w := httptest.NewRecorder()
    73  	r.ServeHTTP(w, req)
    74  	return w
    75  }
    76  
    77  func TestConfigAddAllow(t *testing.T) {
    78  	config := Config{}
    79  	config.AddAllowMethods("POST")
    80  	config.AddAllowMethods("GET", "PUT")
    81  	config.AddExposeHeaders()
    82  
    83  	config.AddAllowHeaders("Some", " cool")
    84  	config.AddAllowHeaders("header")
    85  	config.AddExposeHeaders()
    86  
    87  	config.AddExposeHeaders()
    88  	config.AddExposeHeaders("exposed", "header")
    89  	config.AddExposeHeaders("hey")
    90  
    91  	assert.Equal(t, config.AllowMethods, []string{"POST", "GET", "PUT"})
    92  	assert.Equal(t, config.AllowHeaders, []string{"Some", " cool", "header"})
    93  	assert.Equal(t, config.ExposeHeaders, []string{"exposed", "header", "hey"})
    94  }
    95  
    96  func TestBadConfig(t *testing.T) {
    97  	assert.Panics(t, func() { New(Config{}) })
    98  	assert.Panics(t, func() {
    99  		New(Config{
   100  			AllowAllOrigins: true,
   101  			AllowOrigins:    []string{"http://google.com"},
   102  		})
   103  	})
   104  	assert.Panics(t, func() {
   105  		New(Config{
   106  			AllowAllOrigins: true,
   107  			AllowOriginFunc: func(origin string) bool { return false },
   108  		})
   109  	})
   110  	assert.Panics(t, func() {
   111  		New(Config{
   112  			AllowOrigins: []string{"google.com"},
   113  		})
   114  	})
   115  }
   116  
   117  func TestNormalize(t *testing.T) {
   118  	values := normalize([]string{
   119  		"http-Access ", "Post", "POST", " poSt  ",
   120  		"HTTP-Access", "",
   121  	})
   122  	assert.Equal(t, values, []string{"http-access", "post", ""})
   123  
   124  	values = normalize(nil)
   125  	assert.Nil(t, values)
   126  
   127  	values = normalize([]string{})
   128  	assert.Equal(t, values, []string{})
   129  }
   130  
   131  func TestConvert(t *testing.T) {
   132  	methods := []string{"Get", "GET", "get"}
   133  	headers := []string{"X-CSRF-TOKEN", "X-CSRF-Token", "x-csrf-token"}
   134  
   135  	assert.Equal(t, []string{"GET", "GET", "GET"}, convert(methods, strings.ToUpper))
   136  	assert.Equal(t, []string{"X-Csrf-Token", "X-Csrf-Token", "X-Csrf-Token"}, convert(headers, http.CanonicalHeaderKey))
   137  }
   138  
   139  func TestGenerateNormalHeaders_AllowAllOrigins(t *testing.T) {
   140  	header := generateNormalHeaders(Config{
   141  		AllowAllOrigins: false,
   142  	})
   143  	assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
   144  	assert.Equal(t, header.Get("Vary"), "Origin")
   145  	assert.Len(t, header, 1)
   146  
   147  	header = generateNormalHeaders(Config{
   148  		AllowAllOrigins: true,
   149  	})
   150  	assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
   151  	assert.Equal(t, header.Get("Vary"), "")
   152  	assert.Len(t, header, 1)
   153  }
   154  
   155  func TestGenerateNormalHeaders_AllowCredentials(t *testing.T) {
   156  	header := generateNormalHeaders(Config{
   157  		AllowCredentials: true,
   158  	})
   159  	assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
   160  	assert.Equal(t, header.Get("Vary"), "Origin")
   161  	assert.Len(t, header, 2)
   162  }
   163  
   164  func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) {
   165  	header := generateNormalHeaders(Config{
   166  		ExposeHeaders: []string{"X-user", "xPassword"},
   167  	})
   168  	assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "X-User,Xpassword")
   169  	assert.Equal(t, header.Get("Vary"), "Origin")
   170  	assert.Len(t, header, 2)
   171  }
   172  
   173  func TestGeneratePreflightHeaders(t *testing.T) {
   174  	header := generatePreflightHeaders(Config{
   175  		AllowAllOrigins: false,
   176  	})
   177  	assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "")
   178  	assert.Equal(t, header.Get("Vary"), "Origin")
   179  	assert.Len(t, header, 1)
   180  
   181  	header = generateNormalHeaders(Config{
   182  		AllowAllOrigins: true,
   183  	})
   184  	assert.Equal(t, header.Get("Access-Control-Allow-Origin"), "*")
   185  	assert.Equal(t, header.Get("Vary"), "")
   186  	assert.Len(t, header, 1)
   187  }
   188  
   189  func TestGeneratePreflightHeaders_AllowCredentials(t *testing.T) {
   190  	header := generatePreflightHeaders(Config{
   191  		AllowCredentials: true,
   192  	})
   193  	assert.Equal(t, header.Get("Access-Control-Allow-Credentials"), "true")
   194  	assert.Equal(t, header.Get("Vary"), "Origin")
   195  	assert.Len(t, header, 2)
   196  }
   197  
   198  func TestGeneratePreflightHeaders_AllowPrivateNetwork(t *testing.T) {
   199  	header := generatePreflightHeaders(Config{
   200  		AllowPrivateNetwork: true,
   201  	})
   202  	assert.Equal(t, header.Get("Access-Control-Allow-Private-Network"), "true")
   203  	assert.Equal(t, header.Get("Vary"), "Origin")
   204  	assert.Len(t, header, 2)
   205  }
   206  
   207  func TestGeneratePreflightHeaders_AllowMethods(t *testing.T) {
   208  	header := generatePreflightHeaders(Config{
   209  		AllowMethods: []string{"GET ", "post", "PUT", " put  "},
   210  	})
   211  	assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "GET,POST,PUT")
   212  	assert.Equal(t, header.Get("Vary"), "Origin")
   213  	assert.Len(t, header, 2)
   214  }
   215  
   216  func TestGeneratePreflightHeaders_AllowHeaders(t *testing.T) {
   217  	header := generatePreflightHeaders(Config{
   218  		AllowHeaders: []string{"X-user", "Content-Type"},
   219  	})
   220  	assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "X-User,Content-Type")
   221  	assert.Equal(t, header.Get("Vary"), "Origin")
   222  	assert.Len(t, header, 2)
   223  }
   224  
   225  func TestGeneratePreflightHeaders_MaxAge(t *testing.T) {
   226  	header := generatePreflightHeaders(Config{
   227  		MaxAge: 12 * time.Hour,
   228  	})
   229  	assert.Equal(t, header.Get("Access-Control-Max-Age"), "43200") // 12*60*60
   230  	assert.Equal(t, header.Get("Vary"), "Origin")
   231  	assert.Len(t, header, 2)
   232  }
   233  
   234  func TestValidateOrigin(t *testing.T) {
   235  	cors := newCors(Config{
   236  		AllowAllOrigins: true,
   237  	})
   238  	assert.True(t, cors.validateOrigin("http://google.com"))
   239  	assert.True(t, cors.validateOrigin("https://google.com"))
   240  	assert.True(t, cors.validateOrigin("example.com"))
   241  	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
   242  
   243  	cors = newCors(Config{
   244  		AllowOrigins: []string{"https://google.com", "https://github.com"},
   245  		AllowOriginFunc: func(origin string) bool {
   246  			return (origin == "http://news.ycombinator.com")
   247  		},
   248  		AllowBrowserExtensions: true,
   249  	})
   250  	assert.False(t, cors.validateOrigin("http://google.com"))
   251  	assert.True(t, cors.validateOrigin("https://google.com"))
   252  	assert.True(t, cors.validateOrigin("https://github.com"))
   253  	assert.True(t, cors.validateOrigin("http://news.ycombinator.com"))
   254  	assert.False(t, cors.validateOrigin("http://example.com"))
   255  	assert.False(t, cors.validateOrigin("google.com"))
   256  	assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
   257  
   258  	cors = newCors(Config{
   259  		AllowOrigins: []string{"https://google.com", "https://github.com"},
   260  	})
   261  	assert.False(t, cors.validateOrigin("chrome-extension://random-extension-id"))
   262  	assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
   263  	assert.False(t, cors.validateOrigin("wss://socket-connection"))
   264  
   265  	cors = newCors(Config{
   266  		AllowOrigins: []string{
   267  			"chrome-extension://*",
   268  			"safari-extension://my-extension-*-app",
   269  			"*.some-domain.com",
   270  		},
   271  		AllowBrowserExtensions: true,
   272  		AllowWildcard:          true,
   273  	})
   274  	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
   275  	assert.True(t, cors.validateOrigin("chrome-extension://another-one"))
   276  	assert.True(t, cors.validateOrigin("safari-extension://my-extension-one-app"))
   277  	assert.True(t, cors.validateOrigin("safari-extension://my-extension-two-app"))
   278  	assert.False(t, cors.validateOrigin("moz-extension://ext-id-we-not-allow"))
   279  	assert.True(t, cors.validateOrigin("http://api.some-domain.com"))
   280  	assert.False(t, cors.validateOrigin("http://api.another-domain.com"))
   281  
   282  	cors = newCors(Config{
   283  		AllowOrigins:    []string{"file://safe-file.js", "wss://some-session-layer-connection"},
   284  		AllowFiles:      true,
   285  		AllowWebSockets: true,
   286  	})
   287  	assert.True(t, cors.validateOrigin("file://safe-file.js"))
   288  	assert.False(t, cors.validateOrigin("file://some-dangerous-file.js"))
   289  	assert.True(t, cors.validateOrigin("wss://some-session-layer-connection"))
   290  	assert.False(t, cors.validateOrigin("ws://not-what-we-expected"))
   291  
   292  	cors = newCors(Config{
   293  		AllowOrigins: []string{"*"},
   294  	})
   295  	assert.True(t, cors.validateOrigin("http://google.com"))
   296  	assert.True(t, cors.validateOrigin("https://google.com"))
   297  	assert.True(t, cors.validateOrigin("example.com"))
   298  	assert.True(t, cors.validateOrigin("chrome-extension://random-extension-id"))
   299  }
   300  
   301  func TestValidateTauri(t *testing.T) {
   302  	c := Config{
   303  		AllowOrigins:           []string{"tauri://localhost:1234"},
   304  		AllowBrowserExtensions: true,
   305  	}
   306  	err := c.Validate()
   307  	assert.Error(t, err)
   308  
   309  	c = Config{
   310  		AllowOrigins:           []string{"tauri://localhost:1234"},
   311  		AllowBrowserExtensions: true,
   312  		CustomSchemas:          []string{"tauri"},
   313  	}
   314  	assert.Nil(t, c.Validate())
   315  }
   316  
   317  func TestDefaultConfig(t *testing.T) {
   318  	config := DefaultConfig()
   319  	config.AllowAllOrigins = true
   320  	router := newTestRouter(config)
   321  	w := performRequest(router, "GET", "http://google.com")
   322  	assert.Equal(t, "get", w.Body.String())
   323  	assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
   324  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   325  	assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
   326  }
   327  
   328  func TestPassesAllowOrigins(t *testing.T) {
   329  	router := newTestRouter(Config{
   330  		AllowOrigins:     []string{"http://google.com"},
   331  		AllowMethods:     []string{" GeT ", "get", "post", "PUT  ", "Head", "POST"},
   332  		AllowHeaders:     []string{"Content-type", "timeStamp "},
   333  		ExposeHeaders:    []string{"Data", "x-User"},
   334  		AllowCredentials: false,
   335  		MaxAge:           12 * time.Hour,
   336  		AllowOriginFunc: func(origin string) bool {
   337  			return origin == "http://github.com"
   338  		},
   339  		AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {
   340  			return origin == "http://sample.com"
   341  		},
   342  	})
   343  
   344  	// no CORS request, origin == ""
   345  	w := performRequest(router, "GET", "")
   346  	assert.Equal(t, "get", w.Body.String())
   347  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
   348  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   349  	assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
   350  
   351  	// no CORS request, origin == host
   352  	h := http.Header{}
   353  	h.Set("Host", "facebook.com")
   354  	w = performRequestWithHeaders(router, "GET", "/", "http://facebook.com", h)
   355  	assert.Equal(t, "get", w.Body.String())
   356  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
   357  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   358  	assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
   359  
   360  	// allowed CORS request
   361  	w = performRequest(router, "GET", "http://google.com")
   362  	assert.Equal(t, "get", w.Body.String())
   363  	assert.Equal(t, "http://google.com", w.Header().Get("Access-Control-Allow-Origin"))
   364  	assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
   365  	assert.Equal(t, "Data,X-User", w.Header().Get("Access-Control-Expose-Headers"))
   366  
   367  	w = performRequest(router, "GET", "http://github.com")
   368  	assert.Equal(t, "get", w.Body.String())
   369  	assert.Equal(t, "http://github.com", w.Header().Get("Access-Control-Allow-Origin"))
   370  	assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
   371  	assert.Equal(t, "Data,X-User", w.Header().Get("Access-Control-Expose-Headers"))
   372  
   373  	// deny CORS request
   374  	w = performRequest(router, "GET", "https://google.com")
   375  	assert.Equal(t, http.StatusForbidden, w.Code)
   376  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
   377  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   378  	assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
   379  
   380  	// allowed CORS prefligh request
   381  	w = performRequest(router, "OPTIONS", "http://github.com")
   382  	assert.Equal(t, http.StatusNoContent, w.Code)
   383  	assert.Equal(t, "http://github.com", w.Header().Get("Access-Control-Allow-Origin"))
   384  	assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
   385  	assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods"))
   386  	assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
   387  	assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))
   388  
   389  	// allowed CORS prefligh request: allowed via AllowOriginWithContextFunc
   390  	w = performRequest(router, "OPTIONS", "http://sample.com")
   391  	assert.Equal(t, http.StatusNoContent, w.Code)
   392  	assert.Equal(t, "http://sample.com", w.Header().Get("Access-Control-Allow-Origin"))
   393  	assert.Equal(t, "", w.Header().Get("Access-Control-Allow-Credentials"))
   394  	assert.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods"))
   395  	assert.Equal(t, "Content-Type,Timestamp", w.Header().Get("Access-Control-Allow-Headers"))
   396  	assert.Equal(t, "43200", w.Header().Get("Access-Control-Max-Age"))
   397  
   398  	// deny CORS prefligh request
   399  	w = performRequest(router, "OPTIONS", "http://example.com")
   400  	assert.Equal(t, http.StatusForbidden, w.Code)
   401  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
   402  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   403  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"))
   404  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
   405  	assert.Empty(t, w.Header().Get("Access-Control-Max-Age"))
   406  }
   407  
   408  func TestPassesAllowAllOrigins(t *testing.T) {
   409  	router := newTestRouter(Config{
   410  		AllowAllOrigins:  true,
   411  		AllowMethods:     []string{" Patch ", "get", "post", "POST"},
   412  		AllowHeaders:     []string{"Content-type", "  testheader "},
   413  		ExposeHeaders:    []string{"Data2", "x-User2"},
   414  		AllowCredentials: false,
   415  		MaxAge:           10 * time.Hour,
   416  	})
   417  
   418  	// no CORS request, origin == ""
   419  	w := performRequest(router, "GET", "")
   420  	assert.Equal(t, "get", w.Body.String())
   421  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
   422  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   423  	assert.Empty(t, w.Header().Get("Access-Control-Expose-Headers"))
   424  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   425  
   426  	// allowed CORS request
   427  	w = performRequest(router, "POST", "example.com")
   428  	assert.Equal(t, "post", w.Body.String())
   429  	assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
   430  	assert.Equal(t, "Data2,X-User2", w.Header().Get("Access-Control-Expose-Headers"))
   431  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   432  	assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
   433  
   434  	// allowed CORS prefligh request
   435  	w = performRequest(router, "OPTIONS", "https://facebook.com")
   436  	assert.Equal(t, http.StatusNoContent, w.Code)
   437  	assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
   438  	assert.Equal(t, "PATCH,GET,POST", w.Header().Get("Access-Control-Allow-Methods"))
   439  	assert.Equal(t, "Content-Type,Testheader", w.Header().Get("Access-Control-Allow-Headers"))
   440  	assert.Equal(t, "36000", w.Header().Get("Access-Control-Max-Age"))
   441  	assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"))
   442  }
   443  
   444  func TestWildcard(t *testing.T) {
   445  	router := newTestRouter(Config{
   446  		AllowOrigins:  []string{"https://*.github.com", "https://api.*", "http://*", "https://facebook.com", "*.golang.org"},
   447  		AllowMethods:  []string{"GET"},
   448  		AllowWildcard: true,
   449  	})
   450  
   451  	w := performRequest(router, "GET", "https://gist.github.com")
   452  	assert.Equal(t, 200, w.Code)
   453  
   454  	w = performRequest(router, "GET", "https://api.github.com/v1/users")
   455  	assert.Equal(t, 200, w.Code)
   456  
   457  	w = performRequest(router, "GET", "https://giphy.com/")
   458  	assert.Equal(t, 403, w.Code)
   459  
   460  	w = performRequest(router, "GET", "http://hard-to-find-http-example.com")
   461  	assert.Equal(t, 200, w.Code)
   462  
   463  	w = performRequest(router, "GET", "https://facebook.com")
   464  	assert.Equal(t, 200, w.Code)
   465  
   466  	w = performRequest(router, "GET", "https://something.golang.org")
   467  	assert.Equal(t, 200, w.Code)
   468  
   469  	w = performRequest(router, "GET", "https://something.go.org")
   470  	assert.Equal(t, 403, w.Code)
   471  
   472  	router = newTestRouter(Config{
   473  		AllowOrigins: []string{"https://github.com", "https://facebook.com"},
   474  		AllowMethods: []string{"GET"},
   475  	})
   476  
   477  	w = performRequest(router, "GET", "https://gist.github.com")
   478  	assert.Equal(t, 403, w.Code)
   479  
   480  	w = performRequest(router, "GET", "https://github.com")
   481  	assert.Equal(t, 200, w.Code)
   482  }
   483  
   484  func TestMultiGroupRouter(t *testing.T) {
   485  	router := multiGroupRouter(Config{
   486  		AllowMethods: []string{"GET"},
   487  		AllowOriginWithContextFunc: func(c *gin.Context, origin string) bool {
   488  			path := c.Request.URL.Path
   489  			if strings.HasPrefix(path, "/app1") {
   490  				return origin == "http://app1.example.com"
   491  			}
   492  
   493  			if strings.HasPrefix(path, "/app2") {
   494  				return origin == "http://app2.example.com"
   495  			}
   496  
   497  			// app 3 allows all origins
   498  			return true
   499  		},
   500  	})
   501  
   502  	// allowed CORS prefligh request
   503  	emptyHeaders := http.Header{}
   504  	app1Origin := "http://app1.example.com"
   505  	app2Origin := "http://app2.example.com"
   506  	randomOrgin := "http://random.com"
   507  
   508  	// allowed CORS preflight
   509  	w := performRequestWithHeaders(router, "OPTIONS", "/app1", app1Origin, emptyHeaders)
   510  	assert.Equal(t, http.StatusNoContent, w.Code)
   511  
   512  	w = performRequestWithHeaders(router, "OPTIONS", "/app2", app2Origin, emptyHeaders)
   513  	assert.Equal(t, http.StatusNoContent, w.Code)
   514  
   515  	w = performRequestWithHeaders(router, "OPTIONS", "/app3", randomOrgin, emptyHeaders)
   516  	assert.Equal(t, http.StatusNoContent, w.Code)
   517  
   518  	// disallowed CORS preflight
   519  	w = performRequestWithHeaders(router, "OPTIONS", "/app1", randomOrgin, emptyHeaders)
   520  	assert.Equal(t, http.StatusForbidden, w.Code)
   521  
   522  	w = performRequestWithHeaders(router, "OPTIONS", "/app2", randomOrgin, emptyHeaders)
   523  	assert.Equal(t, http.StatusForbidden, w.Code)
   524  }
   525  
   526  func TestParseWildcardRules_NoWildcard(t *testing.T) {
   527  	config := Config{
   528  		AllowOrigins: []string{
   529  			"http://example.com",
   530  			"https://google.com",
   531  			"github.com",
   532  		},
   533  		AllowWildcard: false,
   534  	}
   535  
   536  	var expected [][]string
   537  
   538  	result := config.parseWildcardRules()
   539  
   540  	assert.Equal(t, expected, result)
   541  }
   542  
   543  func TestParseWildcardRules_InvalidWildcard(t *testing.T) {
   544  	config := Config{
   545  		AllowOrigins: []string{
   546  			"http://example.com",
   547  			"https://*.google.com*",
   548  			"*.github.com*",
   549  		},
   550  		AllowWildcard: true,
   551  	}
   552  
   553  	assert.Panics(t, func() {
   554  		config.parseWildcardRules()
   555  	})
   556  }
   557  
   558  func TestParseWildcardRules(t *testing.T) {
   559  	tests := []struct {
   560  		name           string
   561  		config         Config
   562  		expectedResult [][]string
   563  		expectPanic    bool
   564  	}{
   565  		{
   566  			name: "Wildcard not allowed",
   567  			config: Config{
   568  				AllowWildcard: false,
   569  				AllowOrigins:  []string{"http://example.com", "https://*.domain.com"},
   570  			},
   571  			expectedResult: nil,
   572  			expectPanic:    false,
   573  		},
   574  		{
   575  			name: "No wildcards",
   576  			config: Config{
   577  				AllowWildcard: true,
   578  				AllowOrigins:  []string{"http://example.com", "https://example.com"},
   579  			},
   580  			expectedResult: nil,
   581  			expectPanic:    false,
   582  		},
   583  		{
   584  			name: "Single wildcard at the end",
   585  			config: Config{
   586  				AllowWildcard: true,
   587  				AllowOrigins:  []string{"http://*.example.com"},
   588  			},
   589  			expectedResult: [][]string{{"http://", ".example.com"}},
   590  			expectPanic:    false,
   591  		},
   592  		{
   593  			name: "Single wildcard at the beginning",
   594  			config: Config{
   595  				AllowWildcard: true,
   596  				AllowOrigins:  []string{"*.example.com"},
   597  			},
   598  			expectedResult: [][]string{{"*", ".example.com"}},
   599  			expectPanic:    false,
   600  		},
   601  		{
   602  			name: "Single wildcard in the middle",
   603  			config: Config{
   604  				AllowWildcard: true,
   605  				AllowOrigins:  []string{"http://example.*.com"},
   606  			},
   607  			expectedResult: [][]string{{"http://example.", ".com"}},
   608  			expectPanic:    false,
   609  		},
   610  		{
   611  			name: "Multiple wildcards should panic",
   612  			config: Config{
   613  				AllowWildcard: true,
   614  				AllowOrigins:  []string{"http://*.*.com"},
   615  			},
   616  			expectedResult: nil,
   617  			expectPanic:    true,
   618  		},
   619  		{
   620  			name: "Single wildcard in the end",
   621  			config: Config{
   622  				AllowWildcard: true,
   623  				AllowOrigins:  []string{"http://example.com/*"},
   624  			},
   625  			expectedResult: [][]string{{"http://example.com/", "*"}},
   626  			expectPanic:    false,
   627  		},
   628  	}
   629  
   630  	for _, tt := range tests {
   631  		t.Run(tt.name, func(t *testing.T) {
   632  			if tt.expectPanic {
   633  				defer func() {
   634  					if r := recover(); r == nil {
   635  						t.Errorf("The code did not panic")
   636  					}
   637  				}()
   638  			}
   639  
   640  			result := tt.config.parseWildcardRules()
   641  			if !tt.expectPanic && !reflect.DeepEqual(result, tt.expectedResult) {
   642  				t.Errorf("Name: %v, Expected %v, got %v", tt.name, tt.expectedResult, result)
   643  			}
   644  		})
   645  	}
   646  }
   647  

View as plain text