...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/compression_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package driver
     8  
     9  import (
    10  	"bytes"
    11  	"compress/zlib"
    12  	"os"
    13  	"testing"
    14  
    15  	"github.com/golang/snappy"
    16  	"github.com/klauspost/compress/zstd"
    17  
    18  	"go.mongodb.org/mongo-driver/internal/assert"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    20  )
    21  
    22  func TestCompression(t *testing.T) {
    23  	compressors := []wiremessage.CompressorID{
    24  		wiremessage.CompressorNoOp,
    25  		wiremessage.CompressorSnappy,
    26  		wiremessage.CompressorZLib,
    27  		wiremessage.CompressorZstd,
    28  	}
    29  
    30  	for _, compressor := range compressors {
    31  		t.Run(compressor.String(), func(t *testing.T) {
    32  			payload := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt")
    33  			opts := CompressionOpts{
    34  				Compressor:       compressor,
    35  				ZlibLevel:        wiremessage.DefaultZlibLevel,
    36  				ZstdLevel:        wiremessage.DefaultZstdLevel,
    37  				UncompressedSize: int32(len(payload)),
    38  			}
    39  			compressed, err := CompressPayload(payload, opts)
    40  			assert.NoError(t, err)
    41  			assert.NotEqual(t, 0, len(compressed))
    42  			decompressed, err := DecompressPayload(compressed, opts)
    43  			assert.NoError(t, err)
    44  			assert.Equal(t, payload, decompressed)
    45  		})
    46  	}
    47  }
    48  
    49  func TestCompressionLevels(t *testing.T) {
    50  	in := []byte("abc")
    51  	wr := new(bytes.Buffer)
    52  
    53  	t.Run("ZLib", func(t *testing.T) {
    54  		opts := CompressionOpts{
    55  			Compressor: wiremessage.CompressorZLib,
    56  		}
    57  		for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ {
    58  			opts.ZlibLevel = lvl
    59  			_, err1 := CompressPayload(in, opts)
    60  			_, err2 := zlib.NewWriterLevel(wr, lvl)
    61  			if err2 != nil {
    62  				assert.Error(t, err1, "expected an error for ZLib level %d", lvl)
    63  			} else {
    64  				assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl)
    65  			}
    66  		}
    67  	})
    68  
    69  	t.Run("Zstd", func(t *testing.T) {
    70  		opts := CompressionOpts{
    71  			Compressor: wiremessage.CompressorZstd,
    72  		}
    73  		for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ {
    74  			opts.ZstdLevel = int(lvl)
    75  			_, err1 := CompressPayload(in, opts)
    76  			_, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel)))
    77  			if err2 != nil {
    78  				assert.Error(t, err1, "expected an error for Zstd level %d", lvl)
    79  			} else {
    80  				assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl)
    81  			}
    82  		}
    83  	})
    84  }
    85  
    86  func TestDecompressFailures(t *testing.T) {
    87  	t.Parallel()
    88  
    89  	t.Run("snappy decompress huge size", func(t *testing.T) {
    90  		t.Parallel()
    91  
    92  		opts := CompressionOpts{
    93  			Compressor:       wiremessage.CompressorSnappy,
    94  			UncompressedSize: 100, // reasonable size
    95  		}
    96  		// Compressed data is twice as large as declared above.
    97  		// In test we use actual compression so that the decompress action would pass without fix (thus failing test).
    98  		// When decompression starts it allocates a buffer of the defined size, regardless of a valid compressed body following.
    99  		compressedData, err := CompressPayload(make([]byte, opts.UncompressedSize*2), opts)
   100  		assert.NoError(t, err, "premature error making compressed example")
   101  
   102  		_, err = DecompressPayload(compressedData, opts)
   103  		assert.Error(t, err)
   104  	})
   105  }
   106  
   107  var (
   108  	compressionPayload      []byte
   109  	compressedSnappyPayload []byte
   110  	compressedZLibPayload   []byte
   111  	compressedZstdPayload   []byte
   112  )
   113  
   114  func initCompressionPayload(b *testing.B) {
   115  	if compressionPayload != nil {
   116  		return
   117  	}
   118  	data, err := os.ReadFile("testdata/compression.go")
   119  	if err != nil {
   120  		b.Fatal(err)
   121  	}
   122  	for i := 1; i < 10; i++ {
   123  		data = append(data, data...)
   124  	}
   125  	compressionPayload = data
   126  
   127  	compressedSnappyPayload = snappy.Encode(compressedSnappyPayload[:0], data)
   128  
   129  	{
   130  		var buf bytes.Buffer
   131  		enc, err := zstd.NewWriter(&buf, zstd.WithEncoderLevel(zstd.SpeedDefault))
   132  		if err != nil {
   133  			b.Fatal(err)
   134  		}
   135  		compressedZstdPayload = enc.EncodeAll(data, nil)
   136  	}
   137  
   138  	{
   139  		var buf bytes.Buffer
   140  		enc := zlib.NewWriter(&buf)
   141  		if _, err := enc.Write(data); err != nil {
   142  			b.Fatal(err)
   143  		}
   144  		if err := enc.Close(); err != nil {
   145  			b.Fatal(err)
   146  		}
   147  		if err := enc.Close(); err != nil {
   148  			b.Fatal(err)
   149  		}
   150  		compressedZLibPayload = append(compressedZLibPayload[:0], buf.Bytes()...)
   151  	}
   152  
   153  	b.ResetTimer()
   154  }
   155  
   156  func BenchmarkCompressPayload(b *testing.B) {
   157  	initCompressionPayload(b)
   158  
   159  	compressors := []wiremessage.CompressorID{
   160  		wiremessage.CompressorSnappy,
   161  		wiremessage.CompressorZLib,
   162  		wiremessage.CompressorZstd,
   163  	}
   164  
   165  	for _, compressor := range compressors {
   166  		b.Run(compressor.String(), func(b *testing.B) {
   167  			opts := CompressionOpts{
   168  				Compressor: compressor,
   169  				ZlibLevel:  wiremessage.DefaultZlibLevel,
   170  				ZstdLevel:  wiremessage.DefaultZstdLevel,
   171  			}
   172  			payload := compressionPayload
   173  			b.SetBytes(int64(len(payload)))
   174  			b.ReportAllocs()
   175  			b.RunParallel(func(pb *testing.PB) {
   176  				for pb.Next() {
   177  					_, err := CompressPayload(payload, opts)
   178  					if err != nil {
   179  						b.Error(err)
   180  					}
   181  				}
   182  			})
   183  		})
   184  	}
   185  }
   186  
   187  func BenchmarkDecompressPayload(b *testing.B) {
   188  	initCompressionPayload(b)
   189  
   190  	benchmarks := []struct {
   191  		compressor wiremessage.CompressorID
   192  		payload    []byte
   193  	}{
   194  		{wiremessage.CompressorSnappy, compressedSnappyPayload},
   195  		{wiremessage.CompressorZLib, compressedZLibPayload},
   196  		{wiremessage.CompressorZstd, compressedZstdPayload},
   197  	}
   198  
   199  	for _, bench := range benchmarks {
   200  		b.Run(bench.compressor.String(), func(b *testing.B) {
   201  			opts := CompressionOpts{
   202  				Compressor:       bench.compressor,
   203  				ZlibLevel:        wiremessage.DefaultZlibLevel,
   204  				ZstdLevel:        wiremessage.DefaultZstdLevel,
   205  				UncompressedSize: int32(len(compressionPayload)),
   206  			}
   207  			payload := bench.payload
   208  			b.SetBytes(int64(len(compressionPayload)))
   209  			b.ReportAllocs()
   210  			b.RunParallel(func(pb *testing.PB) {
   211  				for pb.Next() {
   212  					_, err := DecompressPayload(payload, opts)
   213  					if err != nil {
   214  						b.Fatal(err)
   215  					}
   216  				}
   217  			})
   218  		})
   219  	}
   220  }
   221  

View as plain text