1
16
17 package zstdchunked
18
19 import (
20 "bytes"
21 "crypto/sha256"
22 "fmt"
23 "io"
24 "sort"
25 "testing"
26
27 "github.com/containerd/stargz-snapshotter/estargz"
28 "github.com/klauspost/compress/zstd"
29 )
30
31
32 func TestZstdChunked(t *testing.T) {
33 estargz.CompressionTestSuite(t,
34 zstdControllerWithLevel(zstd.SpeedFastest),
35 zstdControllerWithLevel(zstd.SpeedDefault),
36 zstdControllerWithLevel(zstd.SpeedBetterCompression),
37
38 )
39 }
40
41 func zstdControllerWithLevel(compressionLevel zstd.EncoderLevel) estargz.TestingControllerFactory {
42 return func() estargz.TestingController {
43 return &zstdController{&Compressor{CompressionLevel: compressionLevel}, &Decompressor{}}
44 }
45 }
46
47 type zstdController struct {
48 *Compressor
49 *Decompressor
50 }
51
52 func (zc *zstdController) String() string {
53 return fmt.Sprintf("zstd_compression_level=%v", zc.Compressor.CompressionLevel)
54 }
55
56
57
58 func (zc *zstdController) TestStreams(t *testing.T, b []byte, streams []int64) {
59 t.Logf("got zstd streams (compressed size: %d):", len(b))
60
61 if len(streams) == 0 {
62 return
63 }
64
65
66
67 sort.Slice(streams, func(i, j int) bool {
68 return streams[i] < streams[j]
69 })
70 streams[len(streams)-1] = streams[len(streams)-1] - 8
71 wants := map[int64]struct{}{}
72 for _, s := range streams {
73 wants[s] = struct{}{}
74 }
75
76 magicLen := 4
77 zoff := 0
78 numStreams := 0
79 for {
80 if len(b) <= zoff {
81 break
82 } else if len(b)-zoff <= magicLen {
83 t.Fatalf("invalid frame size %d is too small", len(b)-zoff)
84 }
85 delete(wants, int64(zoff))
86 remainingFrames := b[zoff:]
87
88
89 if !bytes.Equal(remainingFrames[:magicLen], zstdFrameMagic) {
90 if !bytes.Equal(remainingFrames[:magicLen], skippableFrameMagic) {
91 t.Fatalf("frame must start from magic bytes; but %x",
92 remainingFrames[:magicLen])
93 }
94 }
95 searchBase := magicLen
96 nextMagicIdx := nextIndex(remainingFrames[searchBase:], zstdFrameMagic)
97 nextSkippableIdx := nextIndex(remainingFrames[searchBase:], skippableFrameMagic)
98 nextFrame := len(remainingFrames)
99 for _, i := range []int{nextMagicIdx, nextSkippableIdx} {
100 if 0 < i && searchBase+i < nextFrame {
101 nextFrame = searchBase + i
102 }
103 }
104 t.Logf(" [%d] at %d in stargz (nextFrame: %d/%d): %v, %v",
105 numStreams, zoff, zoff+nextFrame, len(b), nextMagicIdx, nextSkippableIdx)
106 zoff += nextFrame
107 numStreams++
108 }
109 if len(wants) != 0 {
110 t.Fatalf("some stream offsets not found in the blob: %v", wants)
111 }
112 }
113
114 func nextIndex(s1, sub []byte) int {
115 for i := 0; i < len(s1); i++ {
116 if len(s1)-i < len(sub) {
117 return -1
118 } else if bytes.Equal(s1[i:i+len(sub)], sub) {
119 return i
120 }
121 }
122 return -1
123 }
124
125 func (zc *zstdController) DiffIDOf(t *testing.T, b []byte) string {
126 h := sha256.New()
127 zr, err := zstd.NewReader(bytes.NewReader(b))
128 if err != nil {
129 t.Fatalf("diffIDOf(zstd): %v", err)
130 }
131 defer zr.Close()
132 if _, err := io.Copy(h, zr); err != nil {
133 t.Fatalf("diffIDOf(zstd).Copy: %v", err)
134 }
135 return fmt.Sprintf("sha256:%x", h.Sum(nil))
136 }
137
138
139 func TestZstdChunkedFooter(t *testing.T) {
140 max := int64(200000)
141 for off := int64(0); off <= max; off += 1023 {
142 size := max - off
143 checkZstdChunkedFooter(t, off, size, size/2)
144 }
145 }
146
147 func checkZstdChunkedFooter(t *testing.T, off, size, cSize int64) {
148 footer := zstdFooterBytes(uint64(off), uint64(size), uint64(cSize))
149 if len(footer) != FooterSize {
150 t.Fatalf("for offset %v, footer length was %d, not expected %d. got bytes: %q", off, len(footer), FooterSize, footer)
151 }
152 gotBlobPayloadSize, gotOff, gotSize, err := (&Decompressor{}).ParseFooter(footer)
153 if err != nil {
154 t.Fatalf("failed to parse footer for offset %d, footer: %x: err: %v",
155 off, footer, err)
156 }
157 if gotBlobPayloadSize != off-8 {
158
159 t.Fatalf("ParseFooter(footerBytes(offset %d)) = blobPayloadSize %d; want %d", off, gotBlobPayloadSize, off-8)
160 }
161 if gotOff != off {
162 t.Fatalf("ParseFooter(footerBytes(offset %d)) = off %d; want %d", off, gotOff, off)
163 }
164 if gotSize != cSize {
165 t.Fatalf("ParseFooter(footerBytes(offset %d)) = size %d; want %d", off, gotSize, cSize)
166 }
167 }
168
View as plain text