1 package wazerotest
2
3 import (
4 "context"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "math"
9 "reflect"
10 "strconv"
11 "sync"
12 "sync/atomic"
13
14 "github.com/tetratelabs/wazero/api"
15 "github.com/tetratelabs/wazero/internal/internalapi"
16 "github.com/tetratelabs/wazero/sys"
17 )
18
19 const (
20 exitStatusMarker = 1 << 63
21 )
22
23
24
25 type Module struct {
26 internalapi.WazeroOnlyType
27
28
29 ModuleName string
30
31
32
33 Functions []*Function
34
35
36 Globals []*Global
37
38
39
40 ExportMemory *Memory
41
42 exitStatus atomic.Uint64
43
44 once sync.Once
45 exportedFunctions map[string]api.Function
46 exportedFunctionDefinitions map[string]api.FunctionDefinition
47 exportedGlobals map[string]api.Global
48 exportedMemoryDefinitions map[string]api.MemoryDefinition
49 }
50
51
52 func NewModule(memory *Memory, functions ...*Function) *Module {
53 return &Module{Functions: functions, ExportMemory: memory}
54 }
55
56
57 func (m *Module) String() string {
58 return "module[" + m.ModuleName + "]"
59 }
60
61
62 func (m *Module) Name() string {
63 return m.ModuleName
64 }
65
66
67 func (m *Module) Memory() api.Memory {
68 if m.ExportMemory != nil {
69 m.once.Do(m.initialize)
70 return m.ExportMemory
71 }
72 return nil
73 }
74
75
76 func (m *Module) ExportedFunction(name string) api.Function {
77 m.once.Do(m.initialize)
78 return m.exportedFunctions[name]
79 }
80
81
82 func (m *Module) ExportedFunctionDefinitions() map[string]api.FunctionDefinition {
83 m.once.Do(m.initialize)
84 return m.exportedFunctionDefinitions
85 }
86
87
88 func (m *Module) ExportedMemory(name string) api.Memory {
89 if m.ExportMemory != nil && name == "memory" {
90 m.once.Do(m.initialize)
91 return m.ExportMemory
92 }
93 return nil
94 }
95
96
97 func (m *Module) ExportedMemoryDefinitions() map[string]api.MemoryDefinition {
98 m.once.Do(m.initialize)
99 return m.exportedMemoryDefinitions
100 }
101
102
103 func (m *Module) ExportedGlobal(name string) api.Global {
104 m.once.Do(m.initialize)
105 return m.exportedGlobals[name]
106 }
107
108
109 func (m *Module) Close(ctx context.Context) error {
110 return m.CloseWithExitCode(ctx, 0)
111 }
112
113
114 func (m *Module) CloseWithExitCode(ctx context.Context, exitCode uint32) error {
115 m.exitStatus.CompareAndSwap(0, exitStatusMarker|uint64(exitCode))
116 return nil
117 }
118
119
120 func (m *Module) IsClosed() bool {
121 _, exited := m.ExitStatus()
122 return exited
123 }
124
125
126 func (m *Module) NumGlobal() int {
127 return len(m.Globals)
128 }
129
130
131 func (m *Module) Global(i int) api.Global {
132 m.once.Do(m.initialize)
133 return m.Globals[i]
134 }
135
136
137
138 func (m *Module) NumFunction() int {
139 return len(m.Functions)
140 }
141
142 func (m *Module) Function(i int) api.Function {
143 m.once.Do(m.initialize)
144 return m.Functions[i]
145 }
146
147 func (m *Module) ExitStatus() (exitCode uint32, exited bool) {
148 exitStatus := m.exitStatus.Load()
149 return uint32(exitStatus), exitStatus != 0
150 }
151
152 func (m *Module) initialize() {
153 m.exportedFunctions = make(map[string]api.Function)
154 m.exportedFunctionDefinitions = make(map[string]api.FunctionDefinition)
155 m.exportedGlobals = make(map[string]api.Global)
156 m.exportedMemoryDefinitions = make(map[string]api.MemoryDefinition)
157
158 for index, function := range m.Functions {
159 for _, exportName := range function.ExportNames {
160 m.exportedFunctions[exportName] = function
161 m.exportedFunctionDefinitions[exportName] = function.Definition()
162 }
163 function.module = m
164 function.index = index
165 }
166
167 for _, global := range m.Globals {
168 for _, exportName := range global.ExportNames {
169 m.exportedGlobals[exportName] = global
170 }
171 }
172
173 if m.ExportMemory != nil {
174 m.ExportMemory.module = m
175 m.exportedMemoryDefinitions["memory"] = m.ExportMemory.Definition()
176 }
177 }
178
179
180
181 type Global struct {
182 internalapi.WazeroOnlyType
183
184
185 ValueType api.ValueType
186
187
188 Value uint64
189
190
191 ExportNames []string
192 }
193
194 func (g *Global) String() string {
195 switch g.ValueType {
196 case api.ValueTypeI32:
197 return strconv.FormatInt(int64(api.DecodeI32(g.Value)), 10)
198 case api.ValueTypeI64:
199 return strconv.FormatInt(int64(g.Value), 10)
200 case api.ValueTypeF32:
201 return strconv.FormatFloat(float64(api.DecodeF32(g.Value)), 'g', -1, 32)
202 case api.ValueTypeF64:
203 return strconv.FormatFloat(api.DecodeF64(g.Value), 'g', -1, 64)
204 default:
205 return "0x" + strconv.FormatUint(g.Value, 16)
206 }
207 }
208
209 func (g *Global) Type() api.ValueType {
210 return g.ValueType
211 }
212
213 func (g *Global) Get() uint64 {
214 return g.Value
215 }
216
217 func GlobalI32(value int32, export ...string) *Global {
218 return &Global{ValueType: api.ValueTypeI32, Value: api.EncodeI32(value), ExportNames: export}
219 }
220
221 func GlobalI64(value int64, export ...string) *Global {
222 return &Global{ValueType: api.ValueTypeI64, Value: api.EncodeI64(value), ExportNames: export}
223 }
224
225 func GlobalF32(value float32, export ...string) *Global {
226 return &Global{ValueType: api.ValueTypeF32, Value: api.EncodeF32(value), ExportNames: export}
227 }
228
229 func GlobalF64(value float64, export ...string) *Global {
230 return &Global{ValueType: api.ValueTypeF64, Value: api.EncodeF64(value), ExportNames: export}
231 }
232
233
234
235
236
237
238 type Function struct {
239 internalapi.WazeroOnlyType
240
241
242
243
244
245
246 GoModuleFunction api.GoModuleFunction
247
248
249
250
251 ParamTypes []api.ValueType
252 ResultTypes []api.ValueType
253
254
255
256 FunctionName string
257 DebugName string
258 ParamNames []string
259 ResultNames []string
260 ExportNames []string
261
262
263 module *Module
264 index int
265 }
266
267
268
269
270
271
272
273 func NewFunction(fn any) *Function {
274 functionType := reflect.TypeOf(fn)
275 functionValue := reflect.ValueOf(fn)
276
277 paramTypes := make([]api.ValueType, functionType.NumIn()-2)
278 paramFuncs := make([]func(uint64) reflect.Value, len(paramTypes))
279
280 resultTypes := make([]api.ValueType, functionType.NumOut())
281 resultFuncs := make([]func(reflect.Value) uint64, len(resultTypes))
282
283 for i := range paramTypes {
284 var paramType api.ValueType
285 var paramFunc func(uint64) reflect.Value
286
287 switch functionType.In(i + 2).Kind() {
288 case reflect.Uint32:
289 paramType = api.ValueTypeI32
290 paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(api.DecodeU32(v)) }
291 case reflect.Uint64:
292 paramType = api.ValueTypeI64
293 paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(v) }
294 case reflect.Int32:
295 paramType = api.ValueTypeI32
296 paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(api.DecodeI32(v)) }
297 case reflect.Int64:
298 paramType = api.ValueTypeI64
299 paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(int64(v)) }
300 case reflect.Float32:
301 paramType = api.ValueTypeF32
302 paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(api.DecodeF32(v)) }
303 case reflect.Float64:
304 paramType = api.ValueTypeF64
305 paramFunc = func(v uint64) reflect.Value { return reflect.ValueOf(api.DecodeF64(v)) }
306 default:
307 panic("cannot construct wasm function from go function of type " + functionType.String())
308 }
309
310 paramTypes[i] = paramType
311 paramFuncs[i] = paramFunc
312 }
313
314 for i := range resultTypes {
315 var resultType api.ValueType
316 var resultFunc func(reflect.Value) uint64
317
318 switch functionType.Out(i).Kind() {
319 case reflect.Uint32:
320 resultType = api.ValueTypeI32
321 resultFunc = func(v reflect.Value) uint64 { return v.Uint() }
322 case reflect.Uint64:
323 resultType = api.ValueTypeI64
324 resultFunc = func(v reflect.Value) uint64 { return v.Uint() }
325 case reflect.Int32:
326 resultType = api.ValueTypeI32
327 resultFunc = func(v reflect.Value) uint64 { return api.EncodeI32(int32(v.Int())) }
328 case reflect.Int64:
329 resultType = api.ValueTypeI64
330 resultFunc = func(v reflect.Value) uint64 { return api.EncodeI64(v.Int()) }
331 case reflect.Float32:
332 resultType = api.ValueTypeF32
333 resultFunc = func(v reflect.Value) uint64 { return api.EncodeF32(float32(v.Float())) }
334 case reflect.Float64:
335 resultType = api.ValueTypeF64
336 resultFunc = func(v reflect.Value) uint64 { return api.EncodeF64(v.Float()) }
337 default:
338 panic("cannot construct wasm function from go function of type " + functionType.String())
339 }
340
341 resultTypes[i] = resultType
342 resultFuncs[i] = resultFunc
343 }
344
345 return &Function{
346 GoModuleFunction: api.GoModuleFunc(func(ctx context.Context, mod api.Module, stack []uint64) {
347 in := make([]reflect.Value, 2+len(paramFuncs))
348 in[0] = reflect.ValueOf(ctx)
349 in[1] = reflect.ValueOf(mod)
350 for i, param := range paramFuncs {
351 in[i+2] = param(stack[i])
352 }
353 out := functionValue.Call(in)
354 for i, result := range resultFuncs {
355 stack[i] = result(out[i])
356 }
357 }),
358 ParamTypes: paramTypes,
359 ResultTypes: resultTypes,
360 }
361 }
362
363 var (
364 errMissingFunctionSignature = errors.New("missing function signature")
365 errMissingFunctionModule = errors.New("missing function module")
366 errMissingFunctionImplementation = errors.New("missing function implementation")
367 )
368
369 func (f *Function) Definition() api.FunctionDefinition {
370 return functionDefinition{function: f}
371 }
372
373 func (f *Function) Call(ctx context.Context, params ...uint64) ([]uint64, error) {
374 stackLen := len(f.ParamTypes)
375 if stackLen < len(f.ResultTypes) {
376 stackLen = len(f.ResultTypes)
377 }
378 stack := make([]uint64, stackLen)
379 copy(stack, params)
380 err := f.CallWithStack(ctx, stack)
381 if err != nil {
382 for i := range stack {
383 stack[i] = 0
384 }
385 }
386 return stack[:len(f.ResultTypes)], err
387 }
388
389 func (f *Function) CallWithStack(ctx context.Context, stack []uint64) error {
390 if f.ParamTypes == nil || f.ResultTypes == nil {
391 return errMissingFunctionSignature
392 }
393 if f.GoModuleFunction == nil {
394 return errMissingFunctionImplementation
395 }
396 if f.module == nil {
397 return errMissingFunctionModule
398 }
399 if exitCode, exited := f.module.ExitStatus(); exited {
400 return sys.NewExitError(exitCode)
401 }
402 f.GoModuleFunction.Call(ctx, f.module, stack)
403 return nil
404 }
405
406 type functionDefinition struct {
407 internalapi.WazeroOnlyType
408 function *Function
409 }
410
411 func (def functionDefinition) Name() string {
412 return def.function.FunctionName
413 }
414
415 func (def functionDefinition) DebugName() string {
416 if def.function.DebugName != "" {
417 return def.function.DebugName
418 }
419 return fmt.Sprintf("%s.$%d", def.ModuleName(), def.Index())
420 }
421
422 func (def functionDefinition) GoFunction() any {
423 return def.function.GoModuleFunction
424 }
425
426 func (def functionDefinition) ParamTypes() []api.ValueType {
427 return def.function.ParamTypes
428 }
429
430 func (def functionDefinition) ParamNames() []string {
431 return def.function.ParamNames
432 }
433
434 func (def functionDefinition) ResultTypes() []api.ValueType {
435 return def.function.ResultTypes
436 }
437
438 func (def functionDefinition) ResultNames() []string {
439 return def.function.ResultNames
440 }
441
442 func (def functionDefinition) ModuleName() string {
443 if def.function.module != nil {
444 return def.function.module.ModuleName
445 }
446 return ""
447 }
448
449 func (def functionDefinition) Index() uint32 {
450 return uint32(def.function.index)
451 }
452
453 func (def functionDefinition) Import() (moduleName, name string, isImport bool) {
454 return
455 }
456
457 func (def functionDefinition) ExportNames() []string {
458 return def.function.ExportNames
459 }
460
461
462
463 type Memory struct {
464 internalapi.WazeroOnlyType
465
466
467
468
469
470 Bytes []byte
471
472
473
474
475 Min uint32
476 Max uint32
477
478
479 module *Module
480 }
481
482
483
484 func NewMemory(size int) *Memory {
485 numPages := (size + (PageSize - 1)) / PageSize
486 return &Memory{
487 Bytes: make([]byte, numPages*PageSize),
488 Min: uint32(numPages),
489 }
490 }
491
492
493
494
495 func NewFixedMemory(size int) *Memory {
496 memory := NewMemory(size)
497 memory.Max = memory.Min
498 return memory
499 }
500
501
502
503
504 const PageSize = 65536
505
506 func (m *Memory) Definition() api.MemoryDefinition {
507 return memoryDefinition{memory: m}
508 }
509
510 func (m *Memory) Size() uint32 {
511 return uint32(len(m.Bytes))
512 }
513
514 func (m *Memory) Grow(deltaPages uint32) (previousPages uint32, ok bool) {
515 previousPages = uint32(len(m.Bytes) / PageSize)
516 numPages := previousPages + deltaPages
517 if m.Max != 0 && numPages > m.Max {
518 return previousPages, false
519 }
520 bytes := make([]byte, PageSize*numPages)
521 copy(bytes, m.Bytes)
522 m.Bytes = bytes
523 return previousPages, true
524 }
525
526 func (m *Memory) ReadByte(offset uint32) (byte, bool) {
527 if m.isOutOfRange(offset, 1) {
528 return 0, false
529 }
530 return m.Bytes[offset], true
531 }
532
533 func (m *Memory) ReadUint16Le(offset uint32) (uint16, bool) {
534 if m.isOutOfRange(offset, 2) {
535 return 0, false
536 }
537 return binary.LittleEndian.Uint16(m.Bytes[offset:]), true
538 }
539
540 func (m *Memory) ReadUint32Le(offset uint32) (uint32, bool) {
541 if m.isOutOfRange(offset, 4) {
542 return 0, false
543 }
544 return binary.LittleEndian.Uint32(m.Bytes[offset:]), true
545 }
546
547 func (m *Memory) ReadUint64Le(offset uint32) (uint64, bool) {
548 if m.isOutOfRange(offset, 8) {
549 return 0, false
550 }
551 return binary.LittleEndian.Uint64(m.Bytes[offset:]), true
552 }
553
554 func (m *Memory) ReadFloat32Le(offset uint32) (float32, bool) {
555 v, ok := m.ReadUint32Le(offset)
556 return math.Float32frombits(v), ok
557 }
558
559 func (m *Memory) ReadFloat64Le(offset uint32) (float64, bool) {
560 v, ok := m.ReadUint64Le(offset)
561 return math.Float64frombits(v), ok
562 }
563
564 func (m *Memory) Read(offset, length uint32) ([]byte, bool) {
565 if m.isOutOfRange(offset, length) {
566 return nil, false
567 }
568 return m.Bytes[offset : offset+length : offset+length], true
569 }
570
571 func (m *Memory) WriteByte(offset uint32, value byte) bool {
572 if m.isOutOfRange(offset, 1) {
573 return false
574 }
575 m.Bytes[offset] = value
576 return true
577 }
578
579 func (m *Memory) WriteUint16Le(offset uint32, value uint16) bool {
580 if m.isOutOfRange(offset, 2) {
581 return false
582 }
583 binary.LittleEndian.PutUint16(m.Bytes[offset:], value)
584 return true
585 }
586
587 func (m *Memory) WriteUint32Le(offset uint32, value uint32) bool {
588 if m.isOutOfRange(offset, 4) {
589 return false
590 }
591 binary.LittleEndian.PutUint32(m.Bytes[offset:], value)
592 return true
593 }
594
595 func (m *Memory) WriteUint64Le(offset uint32, value uint64) bool {
596 if m.isOutOfRange(offset, 4) {
597 return false
598 }
599 binary.LittleEndian.PutUint64(m.Bytes[offset:], value)
600 return true
601 }
602
603 func (m *Memory) WriteFloat32Le(offset uint32, value float32) bool {
604 return m.WriteUint32Le(offset, math.Float32bits(value))
605 }
606
607 func (m *Memory) WriteFloat64Le(offset uint32, value float64) bool {
608 return m.WriteUint64Le(offset, math.Float64bits(value))
609 }
610
611 func (m *Memory) Write(offset uint32, value []byte) bool {
612 if m.isOutOfRange(offset, uint32(len(value))) {
613 return false
614 }
615 copy(m.Bytes[offset:], value)
616 return true
617 }
618
619 func (m *Memory) WriteString(offset uint32, value string) bool {
620 if m.isOutOfRange(offset, uint32(len(value))) {
621 return false
622 }
623 copy(m.Bytes[offset:], value)
624 return true
625 }
626
627 func (m *Memory) isOutOfRange(offset, length uint32) bool {
628 size := m.Size()
629 return offset >= size || length > size || offset > (size-length)
630 }
631
632 type memoryDefinition struct {
633 internalapi.WazeroOnlyType
634 memory *Memory
635 }
636
637 func (def memoryDefinition) ModuleName() string {
638 if def.memory.module != nil {
639 return def.memory.module.ModuleName
640 }
641 return ""
642 }
643
644 func (def memoryDefinition) Index() uint32 {
645 return 0
646 }
647
648 func (def memoryDefinition) Import() (moduleName, name string, isImport bool) {
649 return
650 }
651
652 func (def memoryDefinition) ExportNames() []string {
653 if def.memory.module != nil {
654 return []string{"memory"}
655 }
656 return nil
657 }
658
659 func (def memoryDefinition) Min() uint32 {
660 return def.memory.Min
661 }
662
663 func (def memoryDefinition) Max() (uint32, bool) {
664 return def.memory.Max, def.memory.Max != 0
665 }
666
667 var (
668 _ api.Module = (*Module)(nil)
669 _ api.Function = (*Function)(nil)
670 _ api.Global = (*Global)(nil)
671 )
672
View as plain text