1 package compiler
2
3 import (
4 "bytes"
5 "crypto/sha256"
6 "encoding/binary"
7 "errors"
8 "io"
9 "math"
10 "testing"
11 "testing/iotest"
12
13 "github.com/tetratelabs/wazero/internal/asm"
14 "github.com/tetratelabs/wazero/internal/filecache"
15 "github.com/tetratelabs/wazero/internal/testing/require"
16 "github.com/tetratelabs/wazero/internal/u32"
17 "github.com/tetratelabs/wazero/internal/u64"
18 "github.com/tetratelabs/wazero/internal/wasm"
19 )
20
21 var testVersion = ""
22
23 func concat(ins ...[]byte) (ret []byte) {
24 for _, in := range ins {
25 ret = append(ret, in...)
26 }
27 return
28 }
29
30 func makeCodeSegment(bytes ...byte) asm.CodeSegment {
31 return *asm.NewCodeSegment(bytes)
32 }
33
34 func TestSerializeCompiledModule(t *testing.T) {
35 tests := []struct {
36 in *compiledModule
37 exp []byte
38 }{
39 {
40 in: &compiledModule{
41 compiledCode: &compiledCode{
42 executable: makeCodeSegment(1, 2, 3, 4, 5),
43 },
44 functions: []compiledFunction{
45 {executableOffset: 0, stackPointerCeil: 12345},
46 },
47 },
48 exp: concat(
49 []byte(wazeroMagic),
50 []byte{byte(len(testVersion))},
51 []byte(testVersion),
52 []byte{0},
53 u32.LeBytes(1),
54 u64.LeBytes(12345),
55 u64.LeBytes(0),
56 u64.LeBytes(5),
57 []byte{1, 2, 3, 4, 5},
58 ),
59 },
60 {
61 in: &compiledModule{
62 compiledCode: &compiledCode{
63 executable: makeCodeSegment(1, 2, 3, 4, 5),
64 },
65 functions: []compiledFunction{
66 {executableOffset: 0, stackPointerCeil: 12345},
67 },
68 ensureTermination: true,
69 },
70 exp: concat(
71 []byte(wazeroMagic),
72 []byte{byte(len(testVersion))},
73 []byte(testVersion),
74 []byte{1},
75 u32.LeBytes(1),
76 u64.LeBytes(12345),
77 u64.LeBytes(0),
78 u64.LeBytes(5),
79 []byte{1, 2, 3, 4, 5},
80 ),
81 },
82 {
83 in: &compiledModule{
84 compiledCode: &compiledCode{
85 executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3),
86 },
87 functions: []compiledFunction{
88 {executableOffset: 0, stackPointerCeil: 12345},
89 {executableOffset: 5, stackPointerCeil: 0xffffffff},
90 },
91 ensureTermination: true,
92 },
93 exp: concat(
94 []byte(wazeroMagic),
95 []byte{byte(len(testVersion))},
96 []byte(testVersion),
97 []byte{1},
98 u32.LeBytes(2),
99
100 u64.LeBytes(12345),
101 u64.LeBytes(0),
102
103 u64.LeBytes(0xffffffff),
104 u64.LeBytes(5),
105
106 u64.LeBytes(8),
107 []byte{1, 2, 3, 4, 5, 1, 2, 3},
108 ),
109 },
110 }
111
112 for i, tc := range tests {
113 actual, err := io.ReadAll(serializeCompiledModule(testVersion, tc.in))
114 require.NoError(t, err, i)
115 require.Equal(t, tc.exp, actual, i)
116 }
117 }
118
119 func TestDeserializeCompiledModule(t *testing.T) {
120 tests := []struct {
121 name string
122 in []byte
123 importedFunctionCount uint32
124 expCompiledModule *compiledModule
125 expStaleCache bool
126 expErr string
127 }{
128 {
129 name: "invalid header",
130 in: []byte{1},
131 expErr: "compilationcache: invalid header length: 1",
132 },
133 {
134 name: "version mismatch",
135 in: concat(
136 []byte(wazeroMagic),
137 []byte{byte(len("1233123.1.1"))},
138 []byte("1233123.1.1"),
139 u32.LeBytes(1),
140 ),
141 expStaleCache: true,
142 },
143 {
144 name: "version mismatch",
145 in: concat(
146 []byte(wazeroMagic),
147 []byte{byte(len("1"))},
148 []byte("1"),
149 u32.LeBytes(1),
150 ),
151 expStaleCache: true,
152 },
153 {
154 name: "one function",
155 in: concat(
156 []byte(wazeroMagic),
157 []byte{byte(len(testVersion))},
158 []byte(testVersion),
159 []byte{0},
160 u32.LeBytes(1),
161 u64.LeBytes(12345),
162 u64.LeBytes(0),
163
164 u64.LeBytes(5),
165 []byte{1, 2, 3, 4, 5},
166 ),
167 expCompiledModule: &compiledModule{
168 compiledCode: &compiledCode{
169 executable: makeCodeSegment(1, 2, 3, 4, 5),
170 },
171 functions: []compiledFunction{
172 {executableOffset: 0, stackPointerCeil: 12345, index: 0},
173 },
174 },
175 expStaleCache: false,
176 expErr: "",
177 },
178 {
179 name: "one function with ensure termination",
180 in: concat(
181 []byte(wazeroMagic),
182 []byte{byte(len(testVersion))},
183 []byte(testVersion),
184 []byte{1},
185 u32.LeBytes(1),
186 u64.LeBytes(12345),
187 u64.LeBytes(0),
188 u64.LeBytes(5),
189 []byte{1, 2, 3, 4, 5},
190 ),
191 expCompiledModule: &compiledModule{
192 compiledCode: &compiledCode{
193 executable: makeCodeSegment(1, 2, 3, 4, 5),
194 },
195 functions: []compiledFunction{{executableOffset: 0, stackPointerCeil: 12345, index: 0}},
196 ensureTermination: true,
197 },
198 expStaleCache: false,
199 expErr: "",
200 },
201 {
202 name: "two functions",
203 in: concat(
204 []byte(wazeroMagic),
205 []byte{byte(len(testVersion))},
206 []byte(testVersion),
207 []byte{0},
208 u32.LeBytes(2),
209
210 u64.LeBytes(12345),
211 u64.LeBytes(0),
212
213 u64.LeBytes(0xffffffff),
214 u64.LeBytes(7),
215
216 u64.LeBytes(10),
217 []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
218 ),
219 importedFunctionCount: 1,
220 expCompiledModule: &compiledModule{
221 compiledCode: &compiledCode{
222 executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
223 },
224 functions: []compiledFunction{
225 {executableOffset: 0, stackPointerCeil: 12345, index: 1},
226 {executableOffset: 7, stackPointerCeil: 0xffffffff, index: 2},
227 },
228 },
229 expStaleCache: false,
230 expErr: "",
231 },
232 {
233 name: "reading stack pointer",
234 in: concat(
235 []byte(wazeroMagic),
236 []byte{byte(len(testVersion))},
237 []byte(testVersion),
238 []byte{0},
239 u32.LeBytes(2),
240
241 u64.LeBytes(12345),
242 u64.LeBytes(5),
243
244 ),
245 expErr: "compilationcache: error reading func[1] stack pointer ceil: EOF",
246 },
247 {
248 name: "reading executable offset",
249 in: concat(
250 []byte(wazeroMagic),
251 []byte{byte(len(testVersion))},
252 []byte(testVersion),
253 []byte{0},
254 u32.LeBytes(2),
255
256 u64.LeBytes(12345),
257 u64.LeBytes(5),
258
259 u64.LeBytes(12345),
260 ),
261 expErr: "compilationcache: error reading func[1] executable offset: EOF",
262 },
263 {
264 name: "mmapping",
265 in: concat(
266 []byte(wazeroMagic),
267 []byte{byte(len(testVersion))},
268 []byte(testVersion),
269 []byte{0},
270 u32.LeBytes(2),
271
272 u64.LeBytes(12345),
273 u64.LeBytes(0),
274
275 u64.LeBytes(12345),
276 u64.LeBytes(5),
277
278 u64.LeBytes(5),
279
280 ),
281 expErr: "compilationcache: error reading executable (len=5): EOF",
282 },
283 }
284
285 for _, tc := range tests {
286 tc := tc
287 t.Run(tc.name, func(t *testing.T) {
288 cm, staleCache, err := deserializeCompiledModule(testVersion, io.NopCloser(bytes.NewReader(tc.in)),
289 &wasm.Module{ImportFunctionCount: tc.importedFunctionCount})
290
291 if tc.expCompiledModule != nil {
292 require.Equal(t, len(tc.expCompiledModule.functions), len(cm.functions))
293 for i := 0; i < len(cm.functions); i++ {
294 require.Equal(t, cm.compiledCode, cm.functions[i].parent)
295 tc.expCompiledModule.functions[i].parent = cm.compiledCode
296 }
297 }
298
299 if tc.expErr != "" {
300 require.EqualError(t, err, tc.expErr)
301 } else {
302 require.NoError(t, err)
303 require.Equal(t, tc.expCompiledModule, cm)
304 }
305
306 require.Equal(t, tc.expStaleCache, staleCache)
307 })
308 }
309 }
310
311 func TestEngine_getCompiledModuleFromCache(t *testing.T) {
312 valid := concat(
313 []byte(wazeroMagic),
314 []byte{byte(len(testVersion))},
315 []byte(testVersion),
316 []byte{0},
317 u32.LeBytes(2),
318
319 u64.LeBytes(12345),
320 u64.LeBytes(0),
321
322 u64.LeBytes(0xffffffff),
323 u64.LeBytes(5),
324
325 u64.LeBytes(10),
326 []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
327 )
328
329 tests := []struct {
330 name string
331 ext map[wasm.ModuleID][]byte
332 key wasm.ModuleID
333 isHostMod bool
334 expCompiledModule *compiledModule
335 expHit bool
336 expErr string
337 expDeleted bool
338 }{
339 {name: "extern cache not given"},
340 {
341 name: "not hit",
342 ext: map[wasm.ModuleID][]byte{},
343 },
344 {
345 name: "host module",
346 ext: map[wasm.ModuleID][]byte{{}: valid},
347 isHostMod: true,
348 },
349 {
350 name: "error in Cache.Get",
351 ext: map[wasm.ModuleID][]byte{{}: {}},
352 expErr: "compilationcache: error reading header: EOF",
353 },
354 {
355 name: "error in deserialization",
356 ext: map[wasm.ModuleID][]byte{{}: {1, 2, 3}},
357 expErr: "compilationcache: invalid header length: 3",
358 },
359 {
360 name: "stale cache",
361 ext: map[wasm.ModuleID][]byte{{}: concat(
362 []byte(wazeroMagic),
363 []byte{byte(len("1233123.1.1"))},
364 []byte("1233123.1.1"),
365 u32.LeBytes(1),
366 )},
367 expDeleted: true,
368 },
369 {
370 name: "hit",
371 ext: map[wasm.ModuleID][]byte{
372 {}: valid,
373 },
374 expHit: true,
375 expCompiledModule: &compiledModule{
376 compiledCode: &compiledCode{
377 executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
378 },
379 functions: []compiledFunction{
380 {stackPointerCeil: 12345, executableOffset: 0, index: 0},
381 {stackPointerCeil: 0xffffffff, executableOffset: 5, index: 1},
382 },
383 },
384 },
385 }
386
387 for _, tc := range tests {
388 tc := tc
389 t.Run(tc.name, func(t *testing.T) {
390 m := &wasm.Module{ID: tc.key, IsHostModule: tc.isHostMod}
391 if exp := tc.expCompiledModule; exp != nil {
392 exp.source = m
393 for i := range tc.expCompiledModule.functions {
394 tc.expCompiledModule.functions[i].parent = exp.compiledCode
395 }
396 }
397
398 e := engine{}
399 if tc.ext != nil {
400 tmp := t.TempDir()
401 e.fileCache = filecache.New(tmp)
402 for key, value := range tc.ext {
403 err := e.fileCache.Add(key, bytes.NewReader(value))
404 require.NoError(t, err)
405 }
406 }
407
408 codes, hit, err := e.getCompiledModuleFromCache(m)
409 if tc.expErr != "" {
410 require.EqualError(t, err, tc.expErr)
411 } else {
412 require.NoError(t, err)
413 }
414
415 require.Equal(t, tc.expHit, hit)
416 require.Equal(t, tc.expCompiledModule, codes)
417
418 if tc.ext != nil && tc.expDeleted {
419 _, hit, err := e.fileCache.Get(tc.key)
420 require.NoError(t, err)
421 require.False(t, hit)
422 }
423 })
424 }
425 }
426
427 func TestEngine_addCompiledModuleToCache(t *testing.T) {
428 t.Run("not defined", func(t *testing.T) {
429 e := engine{}
430 err := e.addCompiledModuleToCache(nil, nil)
431 require.NoError(t, err)
432 })
433 t.Run("host module", func(t *testing.T) {
434 tc := filecache.New(t.TempDir())
435 e := engine{fileCache: tc}
436 cm := &compiledModule{
437 compiledCode: &compiledCode{
438 executable: makeCodeSegment(1, 2, 3),
439 },
440 functions: []compiledFunction{{stackPointerCeil: 123}},
441 }
442 m := &wasm.Module{ID: sha256.Sum256(nil), IsHostModule: true}
443 err := e.addCompiledModuleToCache(m, cm)
444 require.NoError(t, err)
445
446 _, hit, err := tc.Get(m.ID)
447 require.NoError(t, err)
448 require.False(t, hit)
449 })
450 t.Run("add", func(t *testing.T) {
451 tc := filecache.New(t.TempDir())
452 e := engine{fileCache: tc}
453 m := &wasm.Module{}
454 cm := &compiledModule{
455 compiledCode: &compiledCode{
456 executable: makeCodeSegment(1, 2, 3),
457 },
458 functions: []compiledFunction{{stackPointerCeil: 123}},
459 }
460 err := e.addCompiledModuleToCache(m, cm)
461 require.NoError(t, err)
462
463 content, ok, err := tc.Get(m.ID)
464 require.NoError(t, err)
465 require.True(t, ok)
466 actual, err := io.ReadAll(content)
467 require.NoError(t, err)
468 require.Equal(t, concat(
469 []byte(wazeroMagic),
470 []byte{byte(len(testVersion))},
471 []byte(testVersion),
472 []byte{0},
473 u32.LeBytes(1),
474 u64.LeBytes(123),
475 u64.LeBytes(0),
476 u64.LeBytes(3),
477 []byte{1, 2, 3},
478 ), actual)
479 require.NoError(t, content.Close())
480 })
481 }
482
483 func Test_readUint64(t *testing.T) {
484 tests := []struct {
485 name string
486 input uint64
487 }{
488 {
489 name: "zero",
490 input: 0,
491 },
492 {
493 name: "half",
494 input: math.MaxUint32,
495 },
496 {
497 name: "max",
498 input: math.MaxUint64,
499 },
500 }
501
502 for _, tt := range tests {
503 tc := tt
504
505 t.Run(tc.name, func(t *testing.T) {
506 input := make([]byte, 8)
507 binary.LittleEndian.PutUint64(input, tc.input)
508
509 var b [8]byte
510 n, err := readUint64(bytes.NewReader(input), &b)
511 require.NoError(t, err)
512 require.Equal(t, tc.input, n)
513
514
515 var expectedB [8]byte
516 require.Equal(t, expectedB, b)
517 })
518 }
519 }
520
521 func Test_readUint64_errors(t *testing.T) {
522 tests := []struct {
523 name string
524 input io.Reader
525 expectedErr string
526 }{
527 {
528 name: "zero",
529 input: bytes.NewReader([]byte{}),
530 expectedErr: "EOF",
531 },
532 {
533 name: "not enough",
534 input: bytes.NewReader([]byte{1, 2}),
535 expectedErr: "EOF",
536 },
537 {
538 name: "error reading",
539 input: iotest.ErrReader(errors.New("ice cream")),
540 expectedErr: "ice cream",
541 },
542 }
543
544 for _, tt := range tests {
545 tc := tt
546
547 t.Run(tc.name, func(t *testing.T) {
548 var b [8]byte
549 _, err := readUint64(tc.input, &b)
550 require.EqualError(t, err, tc.expectedErr)
551 })
552 }
553 }
554
View as plain text