1
2
3
4
5
6
7 package driver
8
9 import (
10 "bytes"
11 "compress/zlib"
12 "fmt"
13 "io"
14 "sync"
15
16 "github.com/golang/snappy"
17 "github.com/klauspost/compress/zstd"
18 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
19 )
20
21
22 type CompressionOpts struct {
23 Compressor wiremessage.CompressorID
24 ZlibLevel int
25 ZstdLevel int
26 UncompressedSize int32
27 }
28
29 var zstdEncoders sync.Map
30
31 func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
32 if v, ok := zstdEncoders.Load(level); ok {
33 return v.(*zstd.Encoder), nil
34 }
35 encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
36 if err != nil {
37 return nil, err
38 }
39 zstdEncoders.Store(level, encoder)
40 return encoder, nil
41 }
42
43 var zlibEncoders sync.Map
44
45 func getZlibEncoder(level int) (*zlibEncoder, error) {
46 if v, ok := zlibEncoders.Load(level); ok {
47 return v.(*zlibEncoder), nil
48 }
49 writer, err := zlib.NewWriterLevel(nil, level)
50 if err != nil {
51 return nil, err
52 }
53 encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
54 zlibEncoders.Store(level, encoder)
55
56 return encoder, nil
57 }
58
59 type zlibEncoder struct {
60 mu sync.Mutex
61 writer *zlib.Writer
62 buf *bytes.Buffer
63 }
64
65 func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
66 e.mu.Lock()
67 defer e.mu.Unlock()
68
69 e.buf.Reset()
70 e.writer.Reset(e.buf)
71
72 _, err := e.writer.Write(src)
73 if err != nil {
74 return nil, err
75 }
76 err = e.writer.Close()
77 if err != nil {
78 return nil, err
79 }
80 dst = append(dst[:0], e.buf.Bytes()...)
81 return dst, nil
82 }
83
84
85 func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
86 switch opts.Compressor {
87 case wiremessage.CompressorNoOp:
88 return in, nil
89 case wiremessage.CompressorSnappy:
90 return snappy.Encode(nil, in), nil
91 case wiremessage.CompressorZLib:
92 encoder, err := getZlibEncoder(opts.ZlibLevel)
93 if err != nil {
94 return nil, err
95 }
96 return encoder.Encode(nil, in)
97 case wiremessage.CompressorZstd:
98 encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
99 if err != nil {
100 return nil, err
101 }
102 return encoder.EncodeAll(in, nil), nil
103 default:
104 return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
105 }
106 }
107
108
109 func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
110 switch opts.Compressor {
111 case wiremessage.CompressorNoOp:
112 return in, nil
113 case wiremessage.CompressorSnappy:
114 l, err := snappy.DecodedLen(in)
115 if err != nil {
116 return nil, fmt.Errorf("decoding compressed length %w", err)
117 } else if int32(l) != opts.UncompressedSize {
118 return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
119 }
120 uncompressed = make([]byte, opts.UncompressedSize)
121 return snappy.Decode(uncompressed, in)
122 case wiremessage.CompressorZLib:
123 r, err := zlib.NewReader(bytes.NewReader(in))
124 if err != nil {
125 return nil, err
126 }
127 defer func() {
128 err = r.Close()
129 }()
130 uncompressed = make([]byte, opts.UncompressedSize)
131 _, err = io.ReadFull(r, uncompressed)
132 if err != nil {
133 return nil, err
134 }
135 return uncompressed, nil
136 case wiremessage.CompressorZstd:
137 r, err := zstd.NewReader(bytes.NewBuffer(in))
138 if err != nil {
139 return nil, err
140 }
141 defer r.Close()
142 uncompressed = make([]byte, opts.UncompressedSize)
143 _, err = io.ReadFull(r, uncompressed)
144 if err != nil {
145 return nil, err
146 }
147 return uncompressed, nil
148 default:
149 return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
150 }
151 }
152
View as plain text