...

Source file src/github.com/klauspost/compress/s2/writer.go

Documentation: github.com/klauspost/compress/s2

     1  // Copyright 2011 The Snappy-Go Authors. All rights reserved.
     2  // Copyright (c) 2019+ Klaus Post. All rights reserved.
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file.
     5  
     6  package s2
     7  
     8  import (
     9  	"crypto/rand"
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"runtime"
    15  	"sync"
    16  
    17  	"github.com/klauspost/compress/internal/race"
    18  )
    19  
    20  const (
    21  	levelUncompressed = iota + 1
    22  	levelFast
    23  	levelBetter
    24  	levelBest
    25  )
    26  
    27  // NewWriter returns a new Writer that compresses to w, using the
    28  // framing format described at
    29  // https://github.com/google/snappy/blob/master/framing_format.txt
    30  //
    31  // Users must call Close to guarantee all data has been forwarded to
    32  // the underlying io.Writer and that resources are released.
    33  // They may also call Flush zero or more times before calling Close.
    34  func NewWriter(w io.Writer, opts ...WriterOption) *Writer {
    35  	w2 := Writer{
    36  		blockSize:   defaultBlockSize,
    37  		concurrency: runtime.GOMAXPROCS(0),
    38  		randSrc:     rand.Reader,
    39  		level:       levelFast,
    40  	}
    41  	for _, opt := range opts {
    42  		if err := opt(&w2); err != nil {
    43  			w2.errState = err
    44  			return &w2
    45  		}
    46  	}
    47  	w2.obufLen = obufHeaderLen + MaxEncodedLen(w2.blockSize)
    48  	w2.paramsOK = true
    49  	w2.ibuf = make([]byte, 0, w2.blockSize)
    50  	w2.buffers.New = func() interface{} {
    51  		return make([]byte, w2.obufLen)
    52  	}
    53  	w2.Reset(w)
    54  	return &w2
    55  }
    56  
    57  // Writer is an io.Writer that can write Snappy-compressed bytes.
    58  type Writer struct {
    59  	errMu    sync.Mutex
    60  	errState error
    61  
    62  	// ibuf is a buffer for the incoming (uncompressed) bytes.
    63  	ibuf []byte
    64  
    65  	blockSize     int
    66  	obufLen       int
    67  	concurrency   int
    68  	written       int64
    69  	uncompWritten int64 // Bytes sent to compression
    70  	output        chan chan result
    71  	buffers       sync.Pool
    72  	pad           int
    73  
    74  	writer    io.Writer
    75  	randSrc   io.Reader
    76  	writerWg  sync.WaitGroup
    77  	index     Index
    78  	customEnc func(dst, src []byte) int
    79  
    80  	// wroteStreamHeader is whether we have written the stream header.
    81  	wroteStreamHeader bool
    82  	paramsOK          bool
    83  	snappy            bool
    84  	flushOnWrite      bool
    85  	appendIndex       bool
    86  	level             uint8
    87  }
    88  
    89  type result struct {
    90  	b []byte
    91  	// Uncompressed start offset
    92  	startOffset int64
    93  }
    94  
    95  // err returns the previously set error.
    96  // If no error has been set it is set to err if not nil.
    97  func (w *Writer) err(err error) error {
    98  	w.errMu.Lock()
    99  	errSet := w.errState
   100  	if errSet == nil && err != nil {
   101  		w.errState = err
   102  		errSet = err
   103  	}
   104  	w.errMu.Unlock()
   105  	return errSet
   106  }
   107  
   108  // Reset discards the writer's state and switches the Snappy writer to write to w.
   109  // This permits reusing a Writer rather than allocating a new one.
   110  func (w *Writer) Reset(writer io.Writer) {
   111  	if !w.paramsOK {
   112  		return
   113  	}
   114  	// Close previous writer, if any.
   115  	if w.output != nil {
   116  		close(w.output)
   117  		w.writerWg.Wait()
   118  		w.output = nil
   119  	}
   120  	w.errState = nil
   121  	w.ibuf = w.ibuf[:0]
   122  	w.wroteStreamHeader = false
   123  	w.written = 0
   124  	w.writer = writer
   125  	w.uncompWritten = 0
   126  	w.index.reset(w.blockSize)
   127  
   128  	// If we didn't get a writer, stop here.
   129  	if writer == nil {
   130  		return
   131  	}
   132  	// If no concurrency requested, don't spin up writer goroutine.
   133  	if w.concurrency == 1 {
   134  		return
   135  	}
   136  
   137  	toWrite := make(chan chan result, w.concurrency)
   138  	w.output = toWrite
   139  	w.writerWg.Add(1)
   140  
   141  	// Start a writer goroutine that will write all output in order.
   142  	go func() {
   143  		defer w.writerWg.Done()
   144  
   145  		// Get a queued write.
   146  		for write := range toWrite {
   147  			// Wait for the data to be available.
   148  			input := <-write
   149  			in := input.b
   150  			if len(in) > 0 {
   151  				if w.err(nil) == nil {
   152  					// Don't expose data from previous buffers.
   153  					toWrite := in[:len(in):len(in)]
   154  					// Write to output.
   155  					n, err := writer.Write(toWrite)
   156  					if err == nil && n != len(toWrite) {
   157  						err = io.ErrShortBuffer
   158  					}
   159  					_ = w.err(err)
   160  					w.err(w.index.add(w.written, input.startOffset))
   161  					w.written += int64(n)
   162  				}
   163  			}
   164  			if cap(in) >= w.obufLen {
   165  				w.buffers.Put(in)
   166  			}
   167  			// close the incoming write request.
   168  			// This can be used for synchronizing flushes.
   169  			close(write)
   170  		}
   171  	}()
   172  }
   173  
   174  // Write satisfies the io.Writer interface.
   175  func (w *Writer) Write(p []byte) (nRet int, errRet error) {
   176  	if err := w.err(nil); err != nil {
   177  		return 0, err
   178  	}
   179  	if w.flushOnWrite {
   180  		return w.write(p)
   181  	}
   182  	// If we exceed the input buffer size, start writing
   183  	for len(p) > (cap(w.ibuf)-len(w.ibuf)) && w.err(nil) == nil {
   184  		var n int
   185  		if len(w.ibuf) == 0 {
   186  			// Large write, empty buffer.
   187  			// Write directly from p to avoid copy.
   188  			n, _ = w.write(p)
   189  		} else {
   190  			n = copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
   191  			w.ibuf = w.ibuf[:len(w.ibuf)+n]
   192  			w.write(w.ibuf)
   193  			w.ibuf = w.ibuf[:0]
   194  		}
   195  		nRet += n
   196  		p = p[n:]
   197  	}
   198  	if err := w.err(nil); err != nil {
   199  		return nRet, err
   200  	}
   201  	// p should always be able to fit into w.ibuf now.
   202  	n := copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
   203  	w.ibuf = w.ibuf[:len(w.ibuf)+n]
   204  	nRet += n
   205  	return nRet, nil
   206  }
   207  
   208  // ReadFrom implements the io.ReaderFrom interface.
   209  // Using this is typically more efficient since it avoids a memory copy.
   210  // ReadFrom reads data from r until EOF or error.
   211  // The return value n is the number of bytes read.
   212  // Any error except io.EOF encountered during the read is also returned.
   213  func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
   214  	if err := w.err(nil); err != nil {
   215  		return 0, err
   216  	}
   217  	if len(w.ibuf) > 0 {
   218  		err := w.AsyncFlush()
   219  		if err != nil {
   220  			return 0, err
   221  		}
   222  	}
   223  	if br, ok := r.(byter); ok {
   224  		buf := br.Bytes()
   225  		if err := w.EncodeBuffer(buf); err != nil {
   226  			return 0, err
   227  		}
   228  		return int64(len(buf)), w.AsyncFlush()
   229  	}
   230  	for {
   231  		inbuf := w.buffers.Get().([]byte)[:w.blockSize+obufHeaderLen]
   232  		n2, err := io.ReadFull(r, inbuf[obufHeaderLen:])
   233  		if err != nil {
   234  			if err == io.ErrUnexpectedEOF {
   235  				err = io.EOF
   236  			}
   237  			if err != io.EOF {
   238  				return n, w.err(err)
   239  			}
   240  		}
   241  		if n2 == 0 {
   242  			break
   243  		}
   244  		n += int64(n2)
   245  		err2 := w.writeFull(inbuf[:n2+obufHeaderLen])
   246  		if w.err(err2) != nil {
   247  			break
   248  		}
   249  
   250  		if err != nil {
   251  			// We got EOF and wrote everything
   252  			break
   253  		}
   254  	}
   255  
   256  	return n, w.err(nil)
   257  }
   258  
   259  // AddSkippableBlock will add a skippable block to the stream.
   260  // The ID must be 0x80-0xfe (inclusive).
   261  // Length of the skippable block must be <= 16777215 bytes.
   262  func (w *Writer) AddSkippableBlock(id uint8, data []byte) (err error) {
   263  	if err := w.err(nil); err != nil {
   264  		return err
   265  	}
   266  	if len(data) == 0 {
   267  		return nil
   268  	}
   269  	if id < 0x80 || id > chunkTypePadding {
   270  		return fmt.Errorf("invalid skippable block id %x", id)
   271  	}
   272  	if len(data) > maxChunkSize {
   273  		return fmt.Errorf("skippable block excessed maximum size")
   274  	}
   275  	var header [4]byte
   276  	chunkLen := len(data)
   277  	header[0] = id
   278  	header[1] = uint8(chunkLen >> 0)
   279  	header[2] = uint8(chunkLen >> 8)
   280  	header[3] = uint8(chunkLen >> 16)
   281  	if w.concurrency == 1 {
   282  		write := func(b []byte) error {
   283  			n, err := w.writer.Write(b)
   284  			if err = w.err(err); err != nil {
   285  				return err
   286  			}
   287  			if n != len(b) {
   288  				return w.err(io.ErrShortWrite)
   289  			}
   290  			w.written += int64(n)
   291  			return w.err(nil)
   292  		}
   293  		if !w.wroteStreamHeader {
   294  			w.wroteStreamHeader = true
   295  			if w.snappy {
   296  				if err := write([]byte(magicChunkSnappy)); err != nil {
   297  					return err
   298  				}
   299  			} else {
   300  				if err := write([]byte(magicChunk)); err != nil {
   301  					return err
   302  				}
   303  			}
   304  		}
   305  		if err := write(header[:]); err != nil {
   306  			return err
   307  		}
   308  		return write(data)
   309  	}
   310  
   311  	// Create output...
   312  	if !w.wroteStreamHeader {
   313  		w.wroteStreamHeader = true
   314  		hWriter := make(chan result)
   315  		w.output <- hWriter
   316  		if w.snappy {
   317  			hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunkSnappy)}
   318  		} else {
   319  			hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunk)}
   320  		}
   321  	}
   322  
   323  	// Copy input.
   324  	inbuf := w.buffers.Get().([]byte)[:4]
   325  	copy(inbuf, header[:])
   326  	inbuf = append(inbuf, data...)
   327  
   328  	output := make(chan result, 1)
   329  	// Queue output.
   330  	w.output <- output
   331  	output <- result{startOffset: w.uncompWritten, b: inbuf}
   332  
   333  	return nil
   334  }
   335  
   336  // EncodeBuffer will add a buffer to the stream.
   337  // This is the fastest way to encode a stream,
   338  // but the input buffer cannot be written to by the caller
   339  // until Flush or Close has been called when concurrency != 1.
   340  //
   341  // If you cannot control that, use the regular Write function.
   342  //
   343  // Note that input is not buffered.
   344  // This means that each write will result in discrete blocks being created.
   345  // For buffered writes, use the regular Write function.
   346  func (w *Writer) EncodeBuffer(buf []byte) (err error) {
   347  	if err := w.err(nil); err != nil {
   348  		return err
   349  	}
   350  
   351  	if w.flushOnWrite {
   352  		_, err := w.write(buf)
   353  		return err
   354  	}
   355  	// Flush queued data first.
   356  	if len(w.ibuf) > 0 {
   357  		err := w.AsyncFlush()
   358  		if err != nil {
   359  			return err
   360  		}
   361  	}
   362  	if w.concurrency == 1 {
   363  		_, err := w.writeSync(buf)
   364  		return err
   365  	}
   366  
   367  	// Spawn goroutine and write block to output channel.
   368  	if !w.wroteStreamHeader {
   369  		w.wroteStreamHeader = true
   370  		hWriter := make(chan result)
   371  		w.output <- hWriter
   372  		if w.snappy {
   373  			hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunkSnappy)}
   374  		} else {
   375  			hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunk)}
   376  		}
   377  	}
   378  
   379  	for len(buf) > 0 {
   380  		// Cut input.
   381  		uncompressed := buf
   382  		if len(uncompressed) > w.blockSize {
   383  			uncompressed = uncompressed[:w.blockSize]
   384  		}
   385  		buf = buf[len(uncompressed):]
   386  		// Get an output buffer.
   387  		obuf := w.buffers.Get().([]byte)[:len(uncompressed)+obufHeaderLen]
   388  		race.WriteSlice(obuf)
   389  
   390  		output := make(chan result)
   391  		// Queue output now, so we keep order.
   392  		w.output <- output
   393  		res := result{
   394  			startOffset: w.uncompWritten,
   395  		}
   396  		w.uncompWritten += int64(len(uncompressed))
   397  		go func() {
   398  			race.ReadSlice(uncompressed)
   399  
   400  			checksum := crc(uncompressed)
   401  
   402  			// Set to uncompressed.
   403  			chunkType := uint8(chunkTypeUncompressedData)
   404  			chunkLen := 4 + len(uncompressed)
   405  
   406  			// Attempt compressing.
   407  			n := binary.PutUvarint(obuf[obufHeaderLen:], uint64(len(uncompressed)))
   408  			n2 := w.encodeBlock(obuf[obufHeaderLen+n:], uncompressed)
   409  
   410  			// Check if we should use this, or store as uncompressed instead.
   411  			if n2 > 0 {
   412  				chunkType = uint8(chunkTypeCompressedData)
   413  				chunkLen = 4 + n + n2
   414  				obuf = obuf[:obufHeaderLen+n+n2]
   415  			} else {
   416  				// copy uncompressed
   417  				copy(obuf[obufHeaderLen:], uncompressed)
   418  			}
   419  
   420  			// Fill in the per-chunk header that comes before the body.
   421  			obuf[0] = chunkType
   422  			obuf[1] = uint8(chunkLen >> 0)
   423  			obuf[2] = uint8(chunkLen >> 8)
   424  			obuf[3] = uint8(chunkLen >> 16)
   425  			obuf[4] = uint8(checksum >> 0)
   426  			obuf[5] = uint8(checksum >> 8)
   427  			obuf[6] = uint8(checksum >> 16)
   428  			obuf[7] = uint8(checksum >> 24)
   429  
   430  			// Queue final output.
   431  			res.b = obuf
   432  			output <- res
   433  		}()
   434  	}
   435  	return nil
   436  }
   437  
   438  func (w *Writer) encodeBlock(obuf, uncompressed []byte) int {
   439  	if w.customEnc != nil {
   440  		if ret := w.customEnc(obuf, uncompressed); ret >= 0 {
   441  			return ret
   442  		}
   443  	}
   444  	if w.snappy {
   445  		switch w.level {
   446  		case levelFast:
   447  			return encodeBlockSnappy(obuf, uncompressed)
   448  		case levelBetter:
   449  			return encodeBlockBetterSnappy(obuf, uncompressed)
   450  		case levelBest:
   451  			return encodeBlockBestSnappy(obuf, uncompressed)
   452  		}
   453  		return 0
   454  	}
   455  	switch w.level {
   456  	case levelFast:
   457  		return encodeBlock(obuf, uncompressed)
   458  	case levelBetter:
   459  		return encodeBlockBetter(obuf, uncompressed)
   460  	case levelBest:
   461  		return encodeBlockBest(obuf, uncompressed, nil)
   462  	}
   463  	return 0
   464  }
   465  
   466  func (w *Writer) write(p []byte) (nRet int, errRet error) {
   467  	if err := w.err(nil); err != nil {
   468  		return 0, err
   469  	}
   470  	if w.concurrency == 1 {
   471  		return w.writeSync(p)
   472  	}
   473  
   474  	// Spawn goroutine and write block to output channel.
   475  	for len(p) > 0 {
   476  		if !w.wroteStreamHeader {
   477  			w.wroteStreamHeader = true
   478  			hWriter := make(chan result)
   479  			w.output <- hWriter
   480  			if w.snappy {
   481  				hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunkSnappy)}
   482  			} else {
   483  				hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunk)}
   484  			}
   485  		}
   486  
   487  		var uncompressed []byte
   488  		if len(p) > w.blockSize {
   489  			uncompressed, p = p[:w.blockSize], p[w.blockSize:]
   490  		} else {
   491  			uncompressed, p = p, nil
   492  		}
   493  
   494  		// Copy input.
   495  		// If the block is incompressible, this is used for the result.
   496  		inbuf := w.buffers.Get().([]byte)[:len(uncompressed)+obufHeaderLen]
   497  		obuf := w.buffers.Get().([]byte)[:w.obufLen]
   498  		copy(inbuf[obufHeaderLen:], uncompressed)
   499  		uncompressed = inbuf[obufHeaderLen:]
   500  
   501  		output := make(chan result)
   502  		// Queue output now, so we keep order.
   503  		w.output <- output
   504  		res := result{
   505  			startOffset: w.uncompWritten,
   506  		}
   507  		w.uncompWritten += int64(len(uncompressed))
   508  
   509  		go func() {
   510  			checksum := crc(uncompressed)
   511  
   512  			// Set to uncompressed.
   513  			chunkType := uint8(chunkTypeUncompressedData)
   514  			chunkLen := 4 + len(uncompressed)
   515  
   516  			// Attempt compressing.
   517  			n := binary.PutUvarint(obuf[obufHeaderLen:], uint64(len(uncompressed)))
   518  			n2 := w.encodeBlock(obuf[obufHeaderLen+n:], uncompressed)
   519  
   520  			// Check if we should use this, or store as uncompressed instead.
   521  			if n2 > 0 {
   522  				chunkType = uint8(chunkTypeCompressedData)
   523  				chunkLen = 4 + n + n2
   524  				obuf = obuf[:obufHeaderLen+n+n2]
   525  			} else {
   526  				// Use input as output.
   527  				obuf, inbuf = inbuf, obuf
   528  			}
   529  
   530  			// Fill in the per-chunk header that comes before the body.
   531  			obuf[0] = chunkType
   532  			obuf[1] = uint8(chunkLen >> 0)
   533  			obuf[2] = uint8(chunkLen >> 8)
   534  			obuf[3] = uint8(chunkLen >> 16)
   535  			obuf[4] = uint8(checksum >> 0)
   536  			obuf[5] = uint8(checksum >> 8)
   537  			obuf[6] = uint8(checksum >> 16)
   538  			obuf[7] = uint8(checksum >> 24)
   539  
   540  			// Queue final output.
   541  			res.b = obuf
   542  			output <- res
   543  
   544  			// Put unused buffer back in pool.
   545  			w.buffers.Put(inbuf)
   546  		}()
   547  		nRet += len(uncompressed)
   548  	}
   549  	return nRet, nil
   550  }
   551  
   552  // writeFull is a special version of write that will always write the full buffer.
   553  // Data to be compressed should start at offset obufHeaderLen and fill the remainder of the buffer.
   554  // The data will be written as a single block.
   555  // The caller is not allowed to use inbuf after this function has been called.
   556  func (w *Writer) writeFull(inbuf []byte) (errRet error) {
   557  	if err := w.err(nil); err != nil {
   558  		return err
   559  	}
   560  
   561  	if w.concurrency == 1 {
   562  		_, err := w.writeSync(inbuf[obufHeaderLen:])
   563  		return err
   564  	}
   565  
   566  	// Spawn goroutine and write block to output channel.
   567  	if !w.wroteStreamHeader {
   568  		w.wroteStreamHeader = true
   569  		hWriter := make(chan result)
   570  		w.output <- hWriter
   571  		if w.snappy {
   572  			hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunkSnappy)}
   573  		} else {
   574  			hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunk)}
   575  		}
   576  	}
   577  
   578  	// Get an output buffer.
   579  	obuf := w.buffers.Get().([]byte)[:w.obufLen]
   580  	uncompressed := inbuf[obufHeaderLen:]
   581  
   582  	output := make(chan result)
   583  	// Queue output now, so we keep order.
   584  	w.output <- output
   585  	res := result{
   586  		startOffset: w.uncompWritten,
   587  	}
   588  	w.uncompWritten += int64(len(uncompressed))
   589  
   590  	go func() {
   591  		checksum := crc(uncompressed)
   592  
   593  		// Set to uncompressed.
   594  		chunkType := uint8(chunkTypeUncompressedData)
   595  		chunkLen := 4 + len(uncompressed)
   596  
   597  		// Attempt compressing.
   598  		n := binary.PutUvarint(obuf[obufHeaderLen:], uint64(len(uncompressed)))
   599  		n2 := w.encodeBlock(obuf[obufHeaderLen+n:], uncompressed)
   600  
   601  		// Check if we should use this, or store as uncompressed instead.
   602  		if n2 > 0 {
   603  			chunkType = uint8(chunkTypeCompressedData)
   604  			chunkLen = 4 + n + n2
   605  			obuf = obuf[:obufHeaderLen+n+n2]
   606  		} else {
   607  			// Use input as output.
   608  			obuf, inbuf = inbuf, obuf
   609  		}
   610  
   611  		// Fill in the per-chunk header that comes before the body.
   612  		obuf[0] = chunkType
   613  		obuf[1] = uint8(chunkLen >> 0)
   614  		obuf[2] = uint8(chunkLen >> 8)
   615  		obuf[3] = uint8(chunkLen >> 16)
   616  		obuf[4] = uint8(checksum >> 0)
   617  		obuf[5] = uint8(checksum >> 8)
   618  		obuf[6] = uint8(checksum >> 16)
   619  		obuf[7] = uint8(checksum >> 24)
   620  
   621  		// Queue final output.
   622  		res.b = obuf
   623  		output <- res
   624  
   625  		// Put unused buffer back in pool.
   626  		w.buffers.Put(inbuf)
   627  	}()
   628  	return nil
   629  }
   630  
   631  func (w *Writer) writeSync(p []byte) (nRet int, errRet error) {
   632  	if err := w.err(nil); err != nil {
   633  		return 0, err
   634  	}
   635  	if !w.wroteStreamHeader {
   636  		w.wroteStreamHeader = true
   637  		var n int
   638  		var err error
   639  		if w.snappy {
   640  			n, err = w.writer.Write([]byte(magicChunkSnappy))
   641  		} else {
   642  			n, err = w.writer.Write([]byte(magicChunk))
   643  		}
   644  		if err != nil {
   645  			return 0, w.err(err)
   646  		}
   647  		if n != len(magicChunk) {
   648  			return 0, w.err(io.ErrShortWrite)
   649  		}
   650  		w.written += int64(n)
   651  	}
   652  
   653  	for len(p) > 0 {
   654  		var uncompressed []byte
   655  		if len(p) > w.blockSize {
   656  			uncompressed, p = p[:w.blockSize], p[w.blockSize:]
   657  		} else {
   658  			uncompressed, p = p, nil
   659  		}
   660  
   661  		obuf := w.buffers.Get().([]byte)[:w.obufLen]
   662  		checksum := crc(uncompressed)
   663  
   664  		// Set to uncompressed.
   665  		chunkType := uint8(chunkTypeUncompressedData)
   666  		chunkLen := 4 + len(uncompressed)
   667  
   668  		// Attempt compressing.
   669  		n := binary.PutUvarint(obuf[obufHeaderLen:], uint64(len(uncompressed)))
   670  		n2 := w.encodeBlock(obuf[obufHeaderLen+n:], uncompressed)
   671  
   672  		if n2 > 0 {
   673  			chunkType = uint8(chunkTypeCompressedData)
   674  			chunkLen = 4 + n + n2
   675  			obuf = obuf[:obufHeaderLen+n+n2]
   676  		} else {
   677  			obuf = obuf[:8]
   678  		}
   679  
   680  		// Fill in the per-chunk header that comes before the body.
   681  		obuf[0] = chunkType
   682  		obuf[1] = uint8(chunkLen >> 0)
   683  		obuf[2] = uint8(chunkLen >> 8)
   684  		obuf[3] = uint8(chunkLen >> 16)
   685  		obuf[4] = uint8(checksum >> 0)
   686  		obuf[5] = uint8(checksum >> 8)
   687  		obuf[6] = uint8(checksum >> 16)
   688  		obuf[7] = uint8(checksum >> 24)
   689  
   690  		n, err := w.writer.Write(obuf)
   691  		if err != nil {
   692  			return 0, w.err(err)
   693  		}
   694  		if n != len(obuf) {
   695  			return 0, w.err(io.ErrShortWrite)
   696  		}
   697  		w.err(w.index.add(w.written, w.uncompWritten))
   698  		w.written += int64(n)
   699  		w.uncompWritten += int64(len(uncompressed))
   700  
   701  		if chunkType == chunkTypeUncompressedData {
   702  			// Write uncompressed data.
   703  			n, err := w.writer.Write(uncompressed)
   704  			if err != nil {
   705  				return 0, w.err(err)
   706  			}
   707  			if n != len(uncompressed) {
   708  				return 0, w.err(io.ErrShortWrite)
   709  			}
   710  			w.written += int64(n)
   711  		}
   712  		w.buffers.Put(obuf)
   713  		// Queue final output.
   714  		nRet += len(uncompressed)
   715  	}
   716  	return nRet, nil
   717  }
   718  
   719  // AsyncFlush writes any buffered bytes to a block and starts compressing it.
   720  // It does not wait for the output has been written as Flush() does.
   721  func (w *Writer) AsyncFlush() error {
   722  	if err := w.err(nil); err != nil {
   723  		return err
   724  	}
   725  
   726  	// Queue any data still in input buffer.
   727  	if len(w.ibuf) != 0 {
   728  		if !w.wroteStreamHeader {
   729  			_, err := w.writeSync(w.ibuf)
   730  			w.ibuf = w.ibuf[:0]
   731  			return w.err(err)
   732  		} else {
   733  			_, err := w.write(w.ibuf)
   734  			w.ibuf = w.ibuf[:0]
   735  			err = w.err(err)
   736  			if err != nil {
   737  				return err
   738  			}
   739  		}
   740  	}
   741  	return w.err(nil)
   742  }
   743  
   744  // Flush flushes the Writer to its underlying io.Writer.
   745  // This does not apply padding.
   746  func (w *Writer) Flush() error {
   747  	if err := w.AsyncFlush(); err != nil {
   748  		return err
   749  	}
   750  	if w.output == nil {
   751  		return w.err(nil)
   752  	}
   753  
   754  	// Send empty buffer
   755  	res := make(chan result)
   756  	w.output <- res
   757  	// Block until this has been picked up.
   758  	res <- result{b: nil, startOffset: w.uncompWritten}
   759  	// When it is closed, we have flushed.
   760  	<-res
   761  	return w.err(nil)
   762  }
   763  
   764  // Close calls Flush and then closes the Writer.
   765  // Calling Close multiple times is ok,
   766  // but calling CloseIndex after this will make it not return the index.
   767  func (w *Writer) Close() error {
   768  	_, err := w.closeIndex(w.appendIndex)
   769  	return err
   770  }
   771  
   772  // CloseIndex calls Close and returns an index on first call.
   773  // This is not required if you are only adding index to a stream.
   774  func (w *Writer) CloseIndex() ([]byte, error) {
   775  	return w.closeIndex(true)
   776  }
   777  
   778  func (w *Writer) closeIndex(idx bool) ([]byte, error) {
   779  	err := w.Flush()
   780  	if w.output != nil {
   781  		close(w.output)
   782  		w.writerWg.Wait()
   783  		w.output = nil
   784  	}
   785  
   786  	var index []byte
   787  	if w.err(err) == nil && w.writer != nil {
   788  		// Create index.
   789  		if idx {
   790  			compSize := int64(-1)
   791  			if w.pad <= 1 {
   792  				compSize = w.written
   793  			}
   794  			index = w.index.appendTo(w.ibuf[:0], w.uncompWritten, compSize)
   795  			// Count as written for padding.
   796  			if w.appendIndex {
   797  				w.written += int64(len(index))
   798  			}
   799  		}
   800  
   801  		if w.pad > 1 {
   802  			tmp := w.ibuf[:0]
   803  			if len(index) > 0 {
   804  				// Allocate another buffer.
   805  				tmp = w.buffers.Get().([]byte)[:0]
   806  				defer w.buffers.Put(tmp)
   807  			}
   808  			add := calcSkippableFrame(w.written, int64(w.pad))
   809  			frame, err := skippableFrame(tmp, add, w.randSrc)
   810  			if err = w.err(err); err != nil {
   811  				return nil, err
   812  			}
   813  			n, err2 := w.writer.Write(frame)
   814  			if err2 == nil && n != len(frame) {
   815  				err2 = io.ErrShortWrite
   816  			}
   817  			_ = w.err(err2)
   818  		}
   819  		if len(index) > 0 && w.appendIndex {
   820  			n, err2 := w.writer.Write(index)
   821  			if err2 == nil && n != len(index) {
   822  				err2 = io.ErrShortWrite
   823  			}
   824  			_ = w.err(err2)
   825  		}
   826  	}
   827  	err = w.err(errClosed)
   828  	if err == errClosed {
   829  		return index, nil
   830  	}
   831  	return nil, err
   832  }
   833  
   834  // calcSkippableFrame will return a total size to be added for written
   835  // to be divisible by multiple.
   836  // The value will always be > skippableFrameHeader.
   837  // The function will panic if written < 0 or wantMultiple <= 0.
   838  func calcSkippableFrame(written, wantMultiple int64) int {
   839  	if wantMultiple <= 0 {
   840  		panic("wantMultiple <= 0")
   841  	}
   842  	if written < 0 {
   843  		panic("written < 0")
   844  	}
   845  	leftOver := written % wantMultiple
   846  	if leftOver == 0 {
   847  		return 0
   848  	}
   849  	toAdd := wantMultiple - leftOver
   850  	for toAdd < skippableFrameHeader {
   851  		toAdd += wantMultiple
   852  	}
   853  	return int(toAdd)
   854  }
   855  
   856  // skippableFrame will add a skippable frame with a total size of bytes.
   857  // total should be >= skippableFrameHeader and < maxBlockSize + skippableFrameHeader
   858  func skippableFrame(dst []byte, total int, r io.Reader) ([]byte, error) {
   859  	if total == 0 {
   860  		return dst, nil
   861  	}
   862  	if total < skippableFrameHeader {
   863  		return dst, fmt.Errorf("s2: requested skippable frame (%d) < 4", total)
   864  	}
   865  	if int64(total) >= maxBlockSize+skippableFrameHeader {
   866  		return dst, fmt.Errorf("s2: requested skippable frame (%d) >= max 1<<24", total)
   867  	}
   868  	// Chunk type 0xfe "Section 4.4 Padding (chunk type 0xfe)"
   869  	dst = append(dst, chunkTypePadding)
   870  	f := uint32(total - skippableFrameHeader)
   871  	// Add chunk length.
   872  	dst = append(dst, uint8(f), uint8(f>>8), uint8(f>>16))
   873  	// Add data
   874  	start := len(dst)
   875  	dst = append(dst, make([]byte, f)...)
   876  	_, err := io.ReadFull(r, dst[start:])
   877  	return dst, err
   878  }
   879  
   880  var errClosed = errors.New("s2: Writer is closed")
   881  
   882  // WriterOption is an option for creating a encoder.
   883  type WriterOption func(*Writer) error
   884  
   885  // WriterConcurrency will set the concurrency,
   886  // meaning the maximum number of decoders to run concurrently.
   887  // The value supplied must be at least 1.
   888  // By default this will be set to GOMAXPROCS.
   889  func WriterConcurrency(n int) WriterOption {
   890  	return func(w *Writer) error {
   891  		if n <= 0 {
   892  			return errors.New("concurrency must be at least 1")
   893  		}
   894  		w.concurrency = n
   895  		return nil
   896  	}
   897  }
   898  
   899  // WriterAddIndex will append an index to the end of a stream
   900  // when it is closed.
   901  func WriterAddIndex() WriterOption {
   902  	return func(w *Writer) error {
   903  		w.appendIndex = true
   904  		return nil
   905  	}
   906  }
   907  
   908  // WriterBetterCompression will enable better compression.
   909  // EncodeBetter compresses better than Encode but typically with a
   910  // 10-40% speed decrease on both compression and decompression.
   911  func WriterBetterCompression() WriterOption {
   912  	return func(w *Writer) error {
   913  		w.level = levelBetter
   914  		return nil
   915  	}
   916  }
   917  
   918  // WriterBestCompression will enable better compression.
   919  // EncodeBetter compresses better than Encode but typically with a
   920  // big speed decrease on compression.
   921  func WriterBestCompression() WriterOption {
   922  	return func(w *Writer) error {
   923  		w.level = levelBest
   924  		return nil
   925  	}
   926  }
   927  
   928  // WriterUncompressed will bypass compression.
   929  // The stream will be written as uncompressed blocks only.
   930  // If concurrency is > 1 CRC and output will still be done async.
   931  func WriterUncompressed() WriterOption {
   932  	return func(w *Writer) error {
   933  		w.level = levelUncompressed
   934  		return nil
   935  	}
   936  }
   937  
   938  // WriterBlockSize allows to override the default block size.
   939  // Blocks will be this size or smaller.
   940  // Minimum size is 4KB and maximum size is 4MB.
   941  //
   942  // Bigger blocks may give bigger throughput on systems with many cores,
   943  // and will increase compression slightly, but it will limit the possible
   944  // concurrency for smaller payloads for both encoding and decoding.
   945  // Default block size is 1MB.
   946  //
   947  // When writing Snappy compatible output using WriterSnappyCompat,
   948  // the maximum block size is 64KB.
   949  func WriterBlockSize(n int) WriterOption {
   950  	return func(w *Writer) error {
   951  		if w.snappy && n > maxSnappyBlockSize || n < minBlockSize {
   952  			return errors.New("s2: block size too large. Must be <= 64K and >=4KB on for snappy compatible output")
   953  		}
   954  		if n > maxBlockSize || n < minBlockSize {
   955  			return errors.New("s2: block size too large. Must be <= 4MB and >=4KB")
   956  		}
   957  		w.blockSize = n
   958  		return nil
   959  	}
   960  }
   961  
   962  // WriterPadding will add padding to all output so the size will be a multiple of n.
   963  // This can be used to obfuscate the exact output size or make blocks of a certain size.
   964  // The contents will be a skippable frame, so it will be invisible by the decoder.
   965  // n must be > 0 and <= 4MB.
   966  // The padded area will be filled with data from crypto/rand.Reader.
   967  // The padding will be applied whenever Close is called on the writer.
   968  func WriterPadding(n int) WriterOption {
   969  	return func(w *Writer) error {
   970  		if n <= 0 {
   971  			return fmt.Errorf("s2: padding must be at least 1")
   972  		}
   973  		// No need to waste our time.
   974  		if n == 1 {
   975  			w.pad = 0
   976  		}
   977  		if n > maxBlockSize {
   978  			return fmt.Errorf("s2: padding must less than 4MB")
   979  		}
   980  		w.pad = n
   981  		return nil
   982  	}
   983  }
   984  
   985  // WriterPaddingSrc will get random data for padding from the supplied source.
   986  // By default crypto/rand is used.
   987  func WriterPaddingSrc(reader io.Reader) WriterOption {
   988  	return func(w *Writer) error {
   989  		w.randSrc = reader
   990  		return nil
   991  	}
   992  }
   993  
   994  // WriterSnappyCompat will write snappy compatible output.
   995  // The output can be decompressed using either snappy or s2.
   996  // If block size is more than 64KB it is set to that.
   997  func WriterSnappyCompat() WriterOption {
   998  	return func(w *Writer) error {
   999  		w.snappy = true
  1000  		if w.blockSize > 64<<10 {
  1001  			// We choose 8 bytes less than 64K, since that will make literal emits slightly more effective.
  1002  			// And allows us to skip some size checks.
  1003  			w.blockSize = (64 << 10) - 8
  1004  		}
  1005  		return nil
  1006  	}
  1007  }
  1008  
  1009  // WriterFlushOnWrite will compress blocks on each call to the Write function.
  1010  //
  1011  // This is quite inefficient as blocks size will depend on the write size.
  1012  //
  1013  // Use WriterConcurrency(1) to also make sure that output is flushed.
  1014  // When Write calls return, otherwise they will be written when compression is done.
  1015  func WriterFlushOnWrite() WriterOption {
  1016  	return func(w *Writer) error {
  1017  		w.flushOnWrite = true
  1018  		return nil
  1019  	}
  1020  }
  1021  
  1022  // WriterCustomEncoder allows to override the encoder for blocks on the stream.
  1023  // The function must compress 'src' into 'dst' and return the bytes used in dst as an integer.
  1024  // Block size (initial varint) should not be added by the encoder.
  1025  // Returning value 0 indicates the block could not be compressed.
  1026  // Returning a negative value indicates that compression should be attempted.
  1027  // The function should expect to be called concurrently.
  1028  func WriterCustomEncoder(fn func(dst, src []byte) int) WriterOption {
  1029  	return func(w *Writer) error {
  1030  		w.customEnc = fn
  1031  		return nil
  1032  	}
  1033  }
  1034  

View as plain text