...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/compression.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package driver
     8  
     9  import (
    10  	"bytes"
    11  	"compress/zlib"
    12  	"fmt"
    13  	"io"
    14  	"sync"
    15  
    16  	"github.com/golang/snappy"
    17  	"github.com/klauspost/compress/zstd"
    18  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    19  )
    20  
    21  // CompressionOpts holds settings for how to compress a payload
    22  type CompressionOpts struct {
    23  	Compressor       wiremessage.CompressorID
    24  	ZlibLevel        int
    25  	ZstdLevel        int
    26  	UncompressedSize int32
    27  }
    28  
    29  // mustZstdNewWriter creates a zstd.Encoder with the given level and a nil
    30  // destination writer. It panics on any errors and should only be used at
    31  // package initialization time.
    32  func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder {
    33  	enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl))
    34  	if err != nil {
    35  		panic(err)
    36  	}
    37  	return enc
    38  }
    39  
    40  var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{
    41  	0:                           nil, // zstd.speedNotSet
    42  	zstd.SpeedFastest:           mustZstdNewWriter(zstd.SpeedFastest),
    43  	zstd.SpeedDefault:           mustZstdNewWriter(zstd.SpeedDefault),
    44  	zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression),
    45  	zstd.SpeedBestCompression:   mustZstdNewWriter(zstd.SpeedBestCompression),
    46  }
    47  
    48  func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
    49  	if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression {
    50  		return zstdEncoders[level], nil
    51  	}
    52  	// The level is outside the expected range, return an error.
    53  	return nil, fmt.Errorf("invalid zstd compression level: %d", level)
    54  }
    55  
    56  // zlibEncodersOffset is the offset into the zlibEncoders array for a given
    57  // compression level.
    58  const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2
    59  
    60  var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool
    61  
    62  func getZlibEncoder(level int) (*zlibEncoder, error) {
    63  	if zlib.HuffmanOnly <= level && level <= zlib.BestCompression {
    64  		if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil {
    65  			return enc, nil
    66  		}
    67  		writer, err := zlib.NewWriterLevel(nil, level)
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  		enc := &zlibEncoder{writer: writer, level: level}
    72  		return enc, nil
    73  	}
    74  	// The level is outside the expected range, return an error.
    75  	return nil, fmt.Errorf("invalid zlib compression level: %d", level)
    76  }
    77  
    78  func putZlibEncoder(enc *zlibEncoder) {
    79  	if enc != nil {
    80  		zlibEncoders[enc.level+zlibEncodersOffset].Put(enc)
    81  	}
    82  }
    83  
    84  type zlibEncoder struct {
    85  	writer *zlib.Writer
    86  	buf    bytes.Buffer
    87  	level  int
    88  }
    89  
    90  func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
    91  	defer putZlibEncoder(e)
    92  
    93  	e.buf.Reset()
    94  	e.writer.Reset(&e.buf)
    95  
    96  	_, err := e.writer.Write(src)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	err = e.writer.Close()
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  	dst = append(dst[:0], e.buf.Bytes()...)
   105  	return dst, nil
   106  }
   107  
   108  // CompressPayload takes a byte slice and compresses it according to the options passed
   109  func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
   110  	switch opts.Compressor {
   111  	case wiremessage.CompressorNoOp:
   112  		return in, nil
   113  	case wiremessage.CompressorSnappy:
   114  		return snappy.Encode(nil, in), nil
   115  	case wiremessage.CompressorZLib:
   116  		encoder, err := getZlibEncoder(opts.ZlibLevel)
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  		return encoder.Encode(nil, in)
   121  	case wiremessage.CompressorZstd:
   122  		encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
   123  		if err != nil {
   124  			return nil, err
   125  		}
   126  		return encoder.EncodeAll(in, nil), nil
   127  	default:
   128  		return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
   129  	}
   130  }
   131  
   132  var zstdReaderPool = sync.Pool{
   133  	New: func() interface{} {
   134  		r, _ := zstd.NewReader(nil)
   135  		return r
   136  	},
   137  }
   138  
   139  // DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
   140  func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
   141  	switch opts.Compressor {
   142  	case wiremessage.CompressorNoOp:
   143  		return in, nil
   144  	case wiremessage.CompressorSnappy:
   145  		l, err := snappy.DecodedLen(in)
   146  		if err != nil {
   147  			return nil, fmt.Errorf("decoding compressed length %w", err)
   148  		} else if int32(l) != opts.UncompressedSize {
   149  			return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
   150  		}
   151  		out := make([]byte, opts.UncompressedSize)
   152  		return snappy.Decode(out, in)
   153  	case wiremessage.CompressorZLib:
   154  		r, err := zlib.NewReader(bytes.NewReader(in))
   155  		if err != nil {
   156  			return nil, err
   157  		}
   158  		out := make([]byte, opts.UncompressedSize)
   159  		if _, err := io.ReadFull(r, out); err != nil {
   160  			return nil, err
   161  		}
   162  		if err := r.Close(); err != nil {
   163  			return nil, err
   164  		}
   165  		return out, nil
   166  	case wiremessage.CompressorZstd:
   167  		buf := make([]byte, 0, opts.UncompressedSize)
   168  		// Using a pool here is about ~20% faster
   169  		// than using a single global zstd.Reader
   170  		r := zstdReaderPool.Get().(*zstd.Decoder)
   171  		out, err := r.DecodeAll(in, buf)
   172  		zstdReaderPool.Put(r)
   173  		return out, err
   174  	default:
   175  		return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
   176  	}
   177  }
   178  

View as plain text