...

Source file src/github.com/klauspost/compress/gzhttp/transport_test.go

Documentation: github.com/klauspost/compress/gzhttp

     1  // Copyright (c) 2021 Klaus Post. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gzhttp
     6  
     7  import (
     8  	"bytes"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"os"
    13  	"runtime"
    14  	"testing"
    15  
    16  	"github.com/klauspost/compress/gzip"
    17  	"github.com/klauspost/compress/zstd"
    18  )
    19  
    20  func TestTransport(t *testing.T) {
    21  	bin, err := os.ReadFile("testdata/benchmark.json")
    22  	if err != nil {
    23  		t.Fatal(err)
    24  	}
    25  
    26  	server := httptest.NewServer(newTestHandler(bin))
    27  
    28  	c := http.Client{Transport: Transport(http.DefaultTransport)}
    29  	resp, err := c.Get(server.URL)
    30  	if err != nil {
    31  		t.Fatal(err)
    32  	}
    33  	got, err := io.ReadAll(resp.Body)
    34  	if err != nil {
    35  		t.Fatal(err)
    36  	}
    37  	if !bytes.Equal(got, bin) {
    38  		t.Errorf("data mismatch")
    39  	}
    40  }
    41  
    42  func TestTransportForced(t *testing.T) {
    43  	raw, err := os.ReadFile("testdata/benchmark.json")
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  
    48  	var buf bytes.Buffer
    49  	zw := gzip.NewWriter(&buf)
    50  	zw.Write(raw)
    51  	zw.Close()
    52  	bin := buf.Bytes()
    53  
    54  	server := httptest.NewServer(newTestHandler(bin))
    55  
    56  	c := http.Client{Transport: Transport(http.DefaultTransport)}
    57  	resp, err := c.Get(server.URL + "/gzipped")
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	got, err := io.ReadAll(resp.Body)
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  	if !bytes.Equal(got, raw) {
    66  		t.Errorf("data mismatch")
    67  	}
    68  }
    69  
    70  func TestTransportForcedDisabled(t *testing.T) {
    71  	raw, err := os.ReadFile("testdata/benchmark.json")
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  
    76  	var buf bytes.Buffer
    77  	zw := gzip.NewWriter(&buf)
    78  	zw.Write(raw)
    79  	zw.Close()
    80  	bin := buf.Bytes()
    81  
    82  	server := httptest.NewServer(newTestHandler(bin))
    83  	c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableGzip(false))}
    84  	resp, err := c.Get(server.URL + "/gzipped")
    85  	if err != nil {
    86  		t.Fatal(err)
    87  	}
    88  	got, err := io.ReadAll(resp.Body)
    89  	if err != nil {
    90  		t.Fatal(err)
    91  	}
    92  	if !bytes.Equal(bin, got) {
    93  		t.Errorf("data mismatch")
    94  	}
    95  }
    96  
    97  func TestTransportZstd(t *testing.T) {
    98  	bin, err := os.ReadFile("testdata/benchmark.json")
    99  	if err != nil {
   100  		t.Fatal(err)
   101  	}
   102  	enc, _ := zstd.NewWriter(nil)
   103  	defer enc.Close()
   104  	zsBin := enc.EncodeAll(bin, nil)
   105  	server := httptest.NewServer(newTestHandler(zsBin))
   106  
   107  	c := http.Client{Transport: Transport(http.DefaultTransport)}
   108  	resp, err := c.Get(server.URL + "/zstd")
   109  	if err != nil {
   110  		t.Fatal(err)
   111  	}
   112  	got, err := io.ReadAll(resp.Body)
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	if !bytes.Equal(got, bin) {
   117  		t.Errorf("data mismatch")
   118  	}
   119  }
   120  
   121  func TestTransportInvalid(t *testing.T) {
   122  	bin, err := os.ReadFile("testdata/benchmark.json")
   123  	if err != nil {
   124  		t.Fatal(err)
   125  	}
   126  
   127  	server := httptest.NewServer(newTestHandler(bin))
   128  
   129  	c := http.Client{Transport: Transport(http.DefaultTransport)}
   130  	// Serves json as gzippped...
   131  	resp, err := c.Get(server.URL + "/gzipped")
   132  	if err != nil {
   133  		t.Fatal(err)
   134  	}
   135  	_, err = io.ReadAll(resp.Body)
   136  	if err == nil {
   137  		t.Fatal("expected error, got nil")
   138  	}
   139  }
   140  
   141  func TestTransportZstdDisabled(t *testing.T) {
   142  	raw, err := os.ReadFile("testdata/benchmark.json")
   143  	if err != nil {
   144  		t.Fatal(err)
   145  	}
   146  
   147  	enc, _ := zstd.NewWriter(nil)
   148  	defer enc.Close()
   149  	zsBin := enc.EncodeAll(raw, nil)
   150  
   151  	server := httptest.NewServer(newTestHandler(zsBin))
   152  	c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableZstd(false))}
   153  	resp, err := c.Get(server.URL + "/zstd")
   154  	if err != nil {
   155  		t.Fatal(err)
   156  	}
   157  	got, err := io.ReadAll(resp.Body)
   158  	if err != nil {
   159  		t.Fatal(err)
   160  	}
   161  	if !bytes.Equal(zsBin, got) {
   162  		t.Errorf("data mismatch")
   163  	}
   164  }
   165  
   166  func TestTransportZstdInvalid(t *testing.T) {
   167  	bin, err := os.ReadFile("testdata/benchmark.json")
   168  	if err != nil {
   169  		t.Fatal(err)
   170  	}
   171  	// Do not encode...
   172  	server := httptest.NewServer(newTestHandler(bin))
   173  
   174  	c := http.Client{Transport: Transport(http.DefaultTransport)}
   175  	resp, err := c.Get(server.URL + "/zstd")
   176  	if err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	_, err = io.ReadAll(resp.Body)
   180  	if err == nil {
   181  		t.Fatal("expected error, got nil")
   182  	}
   183  	t.Log("expected error:", err)
   184  }
   185  
   186  func TestDefaultTransport(t *testing.T) {
   187  	bin, err := os.ReadFile("testdata/benchmark.json")
   188  	if err != nil {
   189  		t.Fatal(err)
   190  	}
   191  
   192  	server := httptest.NewServer(newTestHandler(bin))
   193  
   194  	// Not wrapped...
   195  	c := http.Client{Transport: http.DefaultTransport}
   196  	resp, err := c.Get(server.URL)
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  	got, err := io.ReadAll(resp.Body)
   201  	if err != nil {
   202  		t.Fatal(err)
   203  	}
   204  	if !bytes.Equal(got, bin) {
   205  		t.Errorf("data mismatch")
   206  	}
   207  }
   208  
   209  func TestTransportCustomEval(t *testing.T) {
   210  	bin, err := os.ReadFile("testdata/benchmark.json")
   211  	if err != nil {
   212  		t.Fatal(err)
   213  	}
   214  
   215  	server := httptest.NewServer(newTestHandler(bin))
   216  	calledWith := ""
   217  	c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableZstd(false), TransportCustomEval(func(h http.Header) bool {
   218  		calledWith = h.Get("Content-Encoding")
   219  		return true
   220  	}))}
   221  	resp, err := c.Get(server.URL)
   222  	if err != nil {
   223  		t.Fatal(err)
   224  	}
   225  	got, err := io.ReadAll(resp.Body)
   226  	if err != nil {
   227  		t.Fatal(err)
   228  	}
   229  	if !bytes.Equal(got, bin) {
   230  		t.Errorf("data mismatch")
   231  	}
   232  	if calledWith != "gzip" {
   233  		t.Fatalf("Expected encoding %q, got %q", "gzip", calledWith)
   234  	}
   235  	// Test returning false
   236  	c = http.Client{Transport: Transport(http.DefaultTransport, TransportCustomEval(func(h http.Header) bool {
   237  		calledWith = h.Get("Content-Encoding")
   238  		return false
   239  	}))}
   240  	resp, err = c.Get(server.URL)
   241  	if err != nil {
   242  		t.Fatal(err)
   243  	}
   244  	// Check we got the compressed data
   245  	gotCE := resp.Header.Get("Content-Encoding")
   246  	if gotCE != "gzip" {
   247  		t.Fatalf("Expected encoding %q, got %q", "gzip", gotCE)
   248  	}
   249  	if calledWith != "gzip" {
   250  		t.Fatalf("Expected encoding %q, got %q", "gzip", calledWith)
   251  	}
   252  }
   253  
   254  func BenchmarkTransport(b *testing.B) {
   255  	raw, err := os.ReadFile("testdata/benchmark.json")
   256  	if err != nil {
   257  		b.Fatal(err)
   258  	}
   259  	sz := int64(len(raw))
   260  	var buf bytes.Buffer
   261  	zw := gzip.NewWriter(&buf)
   262  	zw.Write(raw)
   263  	zw.Close()
   264  	bin := buf.Bytes()
   265  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   266  		r.Body.Close()
   267  		w.Header().Set("Content-Encoding", "gzip")
   268  		w.WriteHeader(http.StatusOK)
   269  		w.Write(bin)
   270  	}))
   271  	enc, _ := zstd.NewWriter(nil, zstd.WithWindowSize(128<<10), zstd.WithEncoderLevel(zstd.SpeedBestCompression))
   272  	defer enc.Close()
   273  	zsBin := enc.EncodeAll(raw, nil)
   274  	serverZstd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   275  		r.Body.Close()
   276  		w.Header().Set("Content-Encoding", "zstd")
   277  		w.WriteHeader(http.StatusOK)
   278  		w.Write(zsBin)
   279  	}))
   280  	b.Run("gzhttp", func(b *testing.B) {
   281  		c := http.Client{Transport: Transport(http.DefaultTransport)}
   282  
   283  		b.SetBytes(int64(sz))
   284  		b.ReportAllocs()
   285  		b.ResetTimer()
   286  		for i := 0; i < b.N; i++ {
   287  			resp, err := c.Get(server.URL + "/gzipped")
   288  			if err != nil {
   289  				b.Fatal(err)
   290  			}
   291  			n, err := io.Copy(io.Discard, resp.Body)
   292  			if err != nil {
   293  				b.Fatal(err)
   294  			}
   295  			if n != sz {
   296  				b.Fatalf("size mismatch: want %d, got %d", sz, n)
   297  			}
   298  			resp.Body.Close()
   299  		}
   300  		b.ReportMetric(100*float64(len(bin))/float64(len(raw)), "pct")
   301  	})
   302  	b.Run("stdlib", func(b *testing.B) {
   303  		c := http.Client{Transport: http.DefaultTransport}
   304  		b.SetBytes(int64(sz))
   305  		b.ReportAllocs()
   306  		b.ResetTimer()
   307  		for i := 0; i < b.N; i++ {
   308  			resp, err := c.Get(server.URL + "/gzipped")
   309  			if err != nil {
   310  				b.Fatal(err)
   311  			}
   312  			n, err := io.Copy(io.Discard, resp.Body)
   313  			if err != nil {
   314  				b.Fatal(err)
   315  			}
   316  			if n != sz {
   317  				b.Fatalf("size mismatch: want %d, got %d", sz, n)
   318  			}
   319  			resp.Body.Close()
   320  		}
   321  		b.ReportMetric(100*float64(len(bin))/float64(len(raw)), "pct")
   322  	})
   323  	b.Run("zstd", func(b *testing.B) {
   324  		c := http.Client{Transport: Transport(http.DefaultTransport)}
   325  
   326  		b.SetBytes(int64(sz))
   327  		b.ReportAllocs()
   328  		b.ResetTimer()
   329  		for i := 0; i < b.N; i++ {
   330  			resp, err := c.Get(serverZstd.URL + "/zstd")
   331  			if err != nil {
   332  				b.Fatal(err)
   333  			}
   334  			n, err := io.Copy(io.Discard, resp.Body)
   335  			if err != nil {
   336  				b.Fatal(err)
   337  			}
   338  			if n != sz {
   339  				b.Fatalf("size mismatch: want %d, got %d", sz, n)
   340  			}
   341  			resp.Body.Close()
   342  		}
   343  		b.ReportMetric(100*float64(len(zsBin))/float64(len(raw)), "pct")
   344  	})
   345  	b.Run("gzhttp-par", func(b *testing.B) {
   346  		c := http.Client{
   347  			Transport: Transport(&http.Transport{
   348  				MaxConnsPerHost:     runtime.GOMAXPROCS(0),
   349  				MaxIdleConnsPerHost: runtime.GOMAXPROCS(0),
   350  			}),
   351  		}
   352  
   353  		b.SetBytes(int64(sz))
   354  		b.ReportAllocs()
   355  		b.ResetTimer()
   356  		b.RunParallel(func(pb *testing.PB) {
   357  			for pb.Next() {
   358  				resp, err := c.Get(server.URL + "/gzipped")
   359  				if err != nil {
   360  					b.Fatal(err)
   361  				}
   362  				n, err := io.Copy(io.Discard, resp.Body)
   363  				if err != nil {
   364  					b.Fatal(err)
   365  				}
   366  				if n != sz {
   367  					b.Fatalf("size mismatch: want %d, got %d", sz, n)
   368  				}
   369  				resp.Body.Close()
   370  			}
   371  		})
   372  		b.ReportMetric(100*float64(len(bin))/float64(len(raw)), "pct")
   373  	})
   374  	b.Run("stdlib-par", func(b *testing.B) {
   375  		c := http.Client{Transport: &http.Transport{
   376  			MaxConnsPerHost:     runtime.GOMAXPROCS(0),
   377  			MaxIdleConnsPerHost: runtime.GOMAXPROCS(0),
   378  		}}
   379  		b.SetBytes(int64(sz))
   380  		b.ReportAllocs()
   381  		b.ResetTimer()
   382  		b.RunParallel(func(pb *testing.PB) {
   383  			for pb.Next() {
   384  				resp, err := c.Get(server.URL + "/gzipped")
   385  				if err != nil {
   386  					b.Fatal(err)
   387  				}
   388  				n, err := io.Copy(io.Discard, resp.Body)
   389  				if err != nil {
   390  					b.Fatal(err)
   391  				}
   392  				if n != sz {
   393  					b.Fatalf("size mismatch: want %d, got %d", sz, n)
   394  				}
   395  				resp.Body.Close()
   396  			}
   397  		})
   398  		b.ReportMetric(100*float64(len(bin))/float64(len(raw)), "pct")
   399  	})
   400  	b.Run("zstd-par", func(b *testing.B) {
   401  		c := http.Client{
   402  			Transport: Transport(&http.Transport{
   403  				MaxConnsPerHost:     runtime.GOMAXPROCS(0),
   404  				MaxIdleConnsPerHost: runtime.GOMAXPROCS(0),
   405  			}),
   406  		}
   407  
   408  		b.SetBytes(int64(sz))
   409  		b.ReportAllocs()
   410  		b.ResetTimer()
   411  		b.RunParallel(func(pb *testing.PB) {
   412  			for pb.Next() {
   413  				resp, err := c.Get(serverZstd.URL + "/zstd")
   414  				if err != nil {
   415  					b.Fatal(err)
   416  				}
   417  				n, err := io.Copy(io.Discard, resp.Body)
   418  				if err != nil {
   419  					b.Fatal(err)
   420  				}
   421  				if n != sz {
   422  					b.Fatalf("size mismatch: want %d, got %d", sz, n)
   423  				}
   424  				resp.Body.Close()
   425  			}
   426  		})
   427  		b.ReportMetric(100*float64(len(zsBin))/float64(len(raw)), "pct")
   428  	})
   429  }
   430  

View as plain text