...

Source file src/github.com/go-chi/chi/middleware/compress_test.go

Documentation: github.com/go-chi/chi/middleware

     1  package middleware
     2  
     3  import (
     4  	"compress/flate"
     5  	"compress/gzip"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/go-chi/chi"
    15  )
    16  
    17  func TestCompressor(t *testing.T) {
    18  	r := chi.NewRouter()
    19  
    20  	compressor := NewCompressor(5, "text/html", "text/css")
    21  	if len(compressor.encoders) != 0 || len(compressor.pooledEncoders) != 2 {
    22  		t.Errorf("gzip and deflate should be pooled")
    23  	}
    24  
    25  	compressor.SetEncoder("nop", func(w io.Writer, _ int) io.Writer {
    26  		return w
    27  	})
    28  
    29  	if len(compressor.encoders) != 1 {
    30  		t.Errorf("nop encoder should be stored in the encoders map")
    31  	}
    32  
    33  	r.Use(compressor.Handler)
    34  
    35  	r.Get("/gethtml", func(w http.ResponseWriter, r *http.Request) {
    36  		w.Header().Set("Content-Type", "text/html")
    37  		w.Write([]byte("textstring"))
    38  	})
    39  
    40  	r.Get("/getcss", func(w http.ResponseWriter, r *http.Request) {
    41  		w.Header().Set("Content-Type", "text/html")
    42  		w.Write([]byte("textstring"))
    43  	})
    44  
    45  	r.Get("/getplain", func(w http.ResponseWriter, r *http.Request) {
    46  		w.Header().Set("Content-Type", "text/html")
    47  		w.Write([]byte("textstring"))
    48  	})
    49  
    50  	ts := httptest.NewServer(r)
    51  	defer ts.Close()
    52  
    53  	tests := []struct {
    54  		name              string
    55  		path              string
    56  		acceptedEncodings []string
    57  		expectedEncoding  string
    58  	}{
    59  		{
    60  			name:              "no expected encodings due to no accepted encodings",
    61  			path:              "/gethtml",
    62  			acceptedEncodings: nil,
    63  			expectedEncoding:  "",
    64  		},
    65  		{
    66  			name:              "no expected encodings due to content type",
    67  			path:              "/getplain",
    68  			acceptedEncodings: nil,
    69  			expectedEncoding:  "",
    70  		},
    71  		{
    72  			name:              "gzip is only encoding",
    73  			path:              "/gethtml",
    74  			acceptedEncodings: []string{"gzip"},
    75  			expectedEncoding:  "gzip",
    76  		},
    77  		{
    78  			name:              "gzip is preferred over deflate",
    79  			path:              "/getcss",
    80  			acceptedEncodings: []string{"gzip", "deflate"},
    81  			expectedEncoding:  "gzip",
    82  		},
    83  		{
    84  			name:              "deflate is used",
    85  			path:              "/getcss",
    86  			acceptedEncodings: []string{"deflate"},
    87  			expectedEncoding:  "deflate",
    88  		},
    89  		{
    90  
    91  			name:              "nop is preferred",
    92  			path:              "/getcss",
    93  			acceptedEncodings: []string{"nop, gzip, deflate"},
    94  			expectedEncoding:  "nop",
    95  		},
    96  	}
    97  
    98  	for _, tc := range tests {
    99  		tc := tc
   100  		t.Run(tc.name, func(t *testing.T) {
   101  			resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...)
   102  			if respString != "textstring" {
   103  				t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString)
   104  			}
   105  			if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding {
   106  				t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got)
   107  			}
   108  
   109  		})
   110  
   111  	}
   112  }
   113  
   114  func TestCompressorWildcards(t *testing.T) {
   115  	tests := []struct {
   116  		name       string
   117  		types      []string
   118  		typesCount int
   119  		wcCount    int
   120  		recover    string
   121  	}{
   122  		{
   123  			name:       "defaults",
   124  			typesCount: 10,
   125  		},
   126  		{
   127  			name:       "no wildcard",
   128  			types:      []string{"text/plain", "text/html"},
   129  			typesCount: 2,
   130  		},
   131  		{
   132  			name:    "invalid wildcard #1",
   133  			types:   []string{"audio/*wav"},
   134  			recover: "middleware/compress: Unsupported content-type wildcard pattern 'audio/*wav'. Only '/*' supported",
   135  		},
   136  		{
   137  			name:    "invalid wildcard #2",
   138  			types:   []string{"application*/*"},
   139  			recover: "middleware/compress: Unsupported content-type wildcard pattern 'application*/*'. Only '/*' supported",
   140  		},
   141  		{
   142  			name:    "valid wildcard",
   143  			types:   []string{"text/*"},
   144  			wcCount: 1,
   145  		},
   146  		{
   147  			name:       "mixed",
   148  			types:      []string{"audio/wav", "text/*"},
   149  			typesCount: 1,
   150  			wcCount:    1,
   151  		},
   152  	}
   153  	for _, tt := range tests {
   154  		t.Run(tt.name, func(t *testing.T) {
   155  			defer func() {
   156  				if tt.recover == "" {
   157  					tt.recover = "<nil>"
   158  				}
   159  				if r := recover(); tt.recover != fmt.Sprintf("%v", r) {
   160  					t.Errorf("Unexpected value recovered: %v", r)
   161  				}
   162  			}()
   163  			compressor := NewCompressor(5, tt.types...)
   164  			if len(compressor.allowedTypes) != tt.typesCount {
   165  				t.Errorf("expected %d allowedTypes, got %d", tt.typesCount, len(compressor.allowedTypes))
   166  			}
   167  			if len(compressor.allowedWildcards) != tt.wcCount {
   168  				t.Errorf("expected %d allowedWildcards, got %d", tt.wcCount, len(compressor.allowedWildcards))
   169  			}
   170  		})
   171  	}
   172  }
   173  
   174  func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) {
   175  	req, err := http.NewRequest(method, ts.URL+path, nil)
   176  	if err != nil {
   177  		t.Fatal(err)
   178  		return nil, ""
   179  	}
   180  	if len(encodings) > 0 {
   181  		encodingsString := strings.Join(encodings, ",")
   182  		req.Header.Set("Accept-Encoding", encodingsString)
   183  	}
   184  
   185  	resp, err := http.DefaultClient.Do(req)
   186  	if err != nil {
   187  		t.Fatal(err)
   188  		return nil, ""
   189  	}
   190  
   191  	respBody := decodeResponseBody(t, resp)
   192  	defer resp.Body.Close()
   193  
   194  	return resp, respBody
   195  }
   196  
   197  func decodeResponseBody(t *testing.T, resp *http.Response) string {
   198  	var reader io.ReadCloser
   199  	switch resp.Header.Get("Content-Encoding") {
   200  	case "gzip":
   201  		var err error
   202  		reader, err = gzip.NewReader(resp.Body)
   203  		if err != nil {
   204  			t.Fatal(err)
   205  		}
   206  	case "deflate":
   207  		reader = flate.NewReader(resp.Body)
   208  	default:
   209  		reader = resp.Body
   210  	}
   211  	respBody, err := ioutil.ReadAll(reader)
   212  	if err != nil {
   213  		t.Fatal(err)
   214  		return ""
   215  	}
   216  	reader.Close()
   217  
   218  	return string(respBody)
   219  }
   220  

View as plain text