1 package compiler
2
3 import (
4 "fmt"
5 "math"
6 "os"
7 "runtime"
8 "testing"
9 "unsafe"
10
11 "github.com/tetratelabs/wazero/internal/platform"
12 "github.com/tetratelabs/wazero/internal/testing/require"
13 "github.com/tetratelabs/wazero/internal/wasm"
14 "github.com/tetratelabs/wazero/internal/wazeroir"
15 )
16
17 func TestMain(m *testing.M) {
18 if !platform.CompilerSupported() {
19 os.Exit(0)
20 }
21 os.Exit(m.Run())
22 }
23
24
25
26
27
28
29 func init() {
30 var me moduleEngine
31 requireEqual := func(expected, actual int, name string) {
32 if expected != actual {
33 panic(fmt.Sprintf("%s: expected %d, but was %d", name, expected, actual))
34 }
35 }
36 requireEqual(int(unsafe.Offsetof(me.functions)), moduleEngineFunctionsOffset, "moduleEngineFunctionsOffset")
37
38 var ce callEngine
39
40 requireEqual(int(unsafe.Offsetof(ce.fn)), callEngineModuleContextFnOffset, "callEngineModuleContextFnOffset")
41 requireEqual(int(unsafe.Offsetof(ce.moduleInstance)), callEngineModuleContextModuleInstanceOffset, "callEngineModuleContextModuleInstanceOffset")
42 requireEqual(int(unsafe.Offsetof(ce.globalElement0Address)), callEngineModuleContextGlobalElement0AddressOffset, "callEngineModuleContextGlobalElement0AddressOffset")
43 requireEqual(int(unsafe.Offsetof(ce.memoryElement0Address)), callEngineModuleContextMemoryElement0AddressOffset, "callEngineModuleContextMemoryElement0AddressOffset")
44 requireEqual(int(unsafe.Offsetof(ce.memorySliceLen)), callEngineModuleContextMemorySliceLenOffset, "callEngineModuleContextMemorySliceLenOffset")
45 requireEqual(int(unsafe.Offsetof(ce.memoryInstance)), callEngineModuleContextMemoryInstanceOffset, "callEngineModuleContextMemoryInstanceOffset")
46 requireEqual(int(unsafe.Offsetof(ce.tablesElement0Address)), callEngineModuleContextTablesElement0AddressOffset, "callEngineModuleContextTablesElement0AddressOffset")
47 requireEqual(int(unsafe.Offsetof(ce.functionsElement0Address)), callEngineModuleContextFunctionsElement0AddressOffset, "callEngineModuleContextFunctionsElement0AddressOffset")
48 requireEqual(int(unsafe.Offsetof(ce.typeIDsElement0Address)), callEngineModuleContextTypeIDsElement0AddressOffset, "callEngineModuleContextTypeIDsElement0AddressOffset")
49 requireEqual(int(unsafe.Offsetof(ce.dataInstancesElement0Address)), callEngineModuleContextDataInstancesElement0AddressOffset, "callEngineModuleContextDataInstancesElement0AddressOffset")
50 requireEqual(int(unsafe.Offsetof(ce.elementInstancesElement0Address)), callEngineModuleContextElementInstancesElement0AddressOffset, "callEngineModuleContextElementInstancesElement0AddressOffset")
51
52
53 requireEqual(int(unsafe.Offsetof(ce.stackPointer)), callEngineStackContextStackPointerOffset, "callEngineStackContextStackPointerOffset")
54 requireEqual(int(unsafe.Offsetof(ce.stackBasePointerInBytes)), callEngineStackContextStackBasePointerInBytesOffset, "callEngineStackContextStackBasePointerInBytesOffset")
55 requireEqual(int(unsafe.Offsetof(ce.stackElement0Address)), callEngineStackContextStackElement0AddressOffset, "callEngineStackContextStackElement0AddressOffset")
56 requireEqual(int(unsafe.Offsetof(ce.stackLenInBytes)), callEngineStackContextStackLenInBytesOffset, "callEngineStackContextStackLenInBytesOffset")
57
58
59 requireEqual(int(unsafe.Offsetof(ce.statusCode)), callEngineExitContextNativeCallStatusCodeOffset, "callEngineExitContextNativeCallStatusCodeOffset")
60 requireEqual(int(unsafe.Offsetof(ce.builtinFunctionCallIndex)), callEngineExitContextBuiltinFunctionCallIndexOffset, "callEngineExitContextBuiltinFunctionCallIndexOffset")
61 requireEqual(int(unsafe.Offsetof(ce.returnAddress)), callEngineExitContextReturnAddressOffset, "callEngineExitContextReturnAddressOffset")
62 requireEqual(int(unsafe.Offsetof(ce.callerModuleInstance)), callEngineExitContextCallerModuleInstanceOffset, "callEngineExitContextCallerModuleInstanceOffset")
63
64
65 var frame callFrame
66 requireEqual(int(unsafe.Sizeof(frame))/8, callFrameDataSizeInUint64, "callFrameDataSize")
67
68
69 var f function
70 requireEqual(int(unsafe.Offsetof(f.codeInitialAddress)), functionCodeInitialAddressOffset, "functionCodeInitialAddressOffset")
71 requireEqual(int(unsafe.Offsetof(f.moduleInstance)), functionModuleInstanceOffset, "functionModuleInstanceOffset")
72 requireEqual(int(unsafe.Offsetof(f.typeID)), functionTypeIDOffset, "functionTypeIDOffset")
73 requireEqual(int(unsafe.Sizeof(f)), functionSize, "functionModuleInstanceOffset")
74
75
76 var moduleInstance wasm.ModuleInstance
77 requireEqual(int(unsafe.Offsetof(moduleInstance.Globals)), moduleInstanceGlobalsOffset, "moduleInstanceGlobalsOffset")
78 requireEqual(int(unsafe.Offsetof(moduleInstance.MemoryInstance)), moduleInstanceMemoryOffset, "moduleInstanceMemoryOffset")
79 requireEqual(int(unsafe.Offsetof(moduleInstance.Tables)), moduleInstanceTablesOffset, "moduleInstanceTablesOffset")
80 requireEqual(int(unsafe.Offsetof(moduleInstance.Engine)), moduleInstanceEngineOffset, "moduleInstanceEngineOffset")
81 requireEqual(int(unsafe.Offsetof(moduleInstance.TypeIDs)), moduleInstanceTypeIDsOffset, "moduleInstanceTypeIDsOffset")
82 requireEqual(int(unsafe.Offsetof(moduleInstance.DataInstances)), moduleInstanceDataInstancesOffset, "moduleInstanceDataInstancesOffset")
83 requireEqual(int(unsafe.Offsetof(moduleInstance.ElementInstances)), moduleInstanceElementInstancesOffset, "moduleInstanceElementInstancesOffset")
84
85
86 var tableInstance wasm.TableInstance
87 requireEqual(int(unsafe.Offsetof(tableInstance.References)), tableInstanceTableOffset, "tableInstanceTableOffset")
88
89
90 requireEqual(int(unsafe.Offsetof(tableInstance.References)+8), tableInstanceTableLenOffset, "tableInstanceTableLenOffset")
91
92
93 var memoryInstance wasm.MemoryInstance
94 requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)), memoryInstanceBufferOffset, "memoryInstanceBufferOffset")
95
96 requireEqual(int(unsafe.Offsetof(memoryInstance.Buffer)+8), memoryInstanceBufferLenOffset, "memoryInstanceBufferLenOffset")
97
98
99 var globalInstance wasm.GlobalInstance
100 requireEqual(int(unsafe.Offsetof(globalInstance.Val)), globalInstanceValueOffset, "globalInstanceValueOffset")
101
102 var dataInstance wasm.DataInstance
103 requireEqual(int(unsafe.Sizeof(dataInstance)), dataInstanceStructSize, "dataInstanceStructSize")
104
105 var elementInstance wasm.ElementInstance
106 requireEqual(int(unsafe.Sizeof(elementInstance)), elementInstanceStructSize, "elementInstanceStructSize")
107
108 var pointer uintptr
109 requireEqual(int(unsafe.Sizeof(pointer)), 1<<pointerSizeLog2, "pointerSizeLog2")
110 }
111
112 type compilerEnv struct {
113 me *moduleEngine
114 ce *callEngine
115 moduleInstance *wasm.ModuleInstance
116 }
117
118 func (j *compilerEnv) stackTopAsUint32() uint32 {
119 return uint32(j.stack()[j.ce.stackContext.stackPointer-1])
120 }
121
122 func (j *compilerEnv) stackTopAsInt32() int32 {
123 return int32(j.stack()[j.ce.stackContext.stackPointer-1])
124 }
125
126 func (j *compilerEnv) stackTopAsUint64() uint64 {
127 return j.stack()[j.ce.stackContext.stackPointer-1]
128 }
129
130 func (j *compilerEnv) stackTopAsInt64() int64 {
131 return int64(j.stack()[j.ce.stackContext.stackPointer-1])
132 }
133
134 func (j *compilerEnv) stackTopAsFloat32() float32 {
135 return math.Float32frombits(uint32(j.stack()[j.ce.stackContext.stackPointer-1]))
136 }
137
138 func (j *compilerEnv) stackTopAsFloat64() float64 {
139 return math.Float64frombits(j.stack()[j.ce.stackContext.stackPointer-1])
140 }
141
142 func (j *compilerEnv) stackTopAsV128() (lo uint64, hi uint64) {
143 st := j.stack()
144 return st[j.ce.stackContext.stackPointer-2], st[j.ce.stackContext.stackPointer-1]
145 }
146
147 func (j *compilerEnv) memory() []byte {
148 return j.moduleInstance.MemoryInstance.Buffer
149 }
150
151 func (j *compilerEnv) stack() []uint64 {
152 return j.ce.stack
153 }
154
155 func (j *compilerEnv) compilerStatus() nativeCallStatusCode {
156 return j.ce.exitContext.statusCode
157 }
158
159 func (j *compilerEnv) builtinFunctionCallAddress() wasm.Index {
160 return j.ce.exitContext.builtinFunctionCallIndex
161 }
162
163
164 func (j *compilerEnv) stackPointer() uint64 {
165 return j.ce.stackContext.stackPointer - callFrameDataSizeInUint64
166 }
167
168 func (j *compilerEnv) stackBasePointer() uint64 {
169 return j.ce.stackContext.stackBasePointerInBytes >> 3
170 }
171
172 func (j *compilerEnv) setStackPointer(sp uint64) {
173 j.ce.stackContext.stackPointer = sp
174 }
175
176 func (j *compilerEnv) addGlobals(g ...*wasm.GlobalInstance) {
177 j.moduleInstance.Globals = append(j.moduleInstance.Globals, g...)
178 }
179
180 func (j *compilerEnv) globals() []*wasm.GlobalInstance {
181 return j.moduleInstance.Globals
182 }
183
184 func (j *compilerEnv) addTable(table *wasm.TableInstance) {
185 j.moduleInstance.Tables = append(j.moduleInstance.Tables, table)
186 }
187
188 func (j *compilerEnv) setStackBasePointer(sp uint64) {
189 j.ce.stackContext.stackBasePointerInBytes = sp << 3
190 }
191
192 func (j *compilerEnv) module() *wasm.ModuleInstance {
193 return j.moduleInstance
194 }
195
196 func (j *compilerEnv) moduleEngine() *moduleEngine {
197 return j.me
198 }
199
200 func (j *compilerEnv) callEngine() *callEngine {
201 return j.ce
202 }
203
204 func (j *compilerEnv) exec(machineCode []byte) {
205 cm := &compiledModule{compiledCode: &compiledCode{}}
206 if err := cm.executable.Map(len(machineCode)); err != nil {
207 panic(err)
208 }
209 executable := cm.executable.Bytes()
210 copy(executable, machineCode)
211 makeExecutable(executable)
212
213 f := &function{
214 parent: &compiledFunction{parent: cm.compiledCode},
215 codeInitialAddress: uintptr(unsafe.Pointer(&executable[0])),
216 moduleInstance: j.moduleInstance,
217 }
218 j.ce.initialFn = f
219 j.ce.fn = f
220
221 nativecall(
222 uintptr(unsafe.Pointer(&executable[0])),
223 j.ce, j.moduleInstance,
224 )
225 }
226
227 func (j *compilerEnv) requireNewCompiler(t *testing.T, functionType *wasm.FunctionType, fn func() compiler, ir *wazeroir.CompilationResult) compilerImpl {
228 requireSupportedOSArch(t)
229
230 if ir == nil {
231 ir = &wazeroir.CompilationResult{
232 LabelCallers: map[wazeroir.Label]uint32{},
233 }
234 }
235
236 c := fn()
237 c.Init(functionType, ir, false)
238
239 ret, ok := c.(compilerImpl)
240 require.True(t, ok)
241 return ret
242 }
243
244
245
246 type compilerImpl interface {
247 compiler
248 compileExitFromNativeCode(nativeCallStatusCode)
249 compileMaybeGrowStack() error
250 compileReturnFunction() error
251 assignStackPointerCeil(uint64)
252 setStackPointerCeil(uint64)
253 compileReleaseRegisterToStack(loc *runtimeValueLocation)
254 setRuntimeValueLocationStack(*runtimeValueLocationStack)
255 compileEnsureOnRegister(loc *runtimeValueLocation) error
256 compileModuleContextInitialization() error
257 }
258
259 const defaultMemoryPageNumInTest = 1
260
261 func newCompilerEnvironment() *compilerEnv {
262 me := &moduleEngine{}
263 return &compilerEnv{
264 me: me,
265 moduleInstance: &wasm.ModuleInstance{
266 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, wasm.MemoryPageSize*defaultMemoryPageNumInTest)},
267 Tables: []*wasm.TableInstance{},
268 Globals: []*wasm.GlobalInstance{},
269 Engine: me,
270 },
271 ce: me.newCallEngine(initialStackSize, &function{parent: &compiledFunction{parent: &compiledCode{}}}),
272 }
273 }
274
275
276
277 func requireRuntimeLocationStackPointerEqual(t *testing.T, expSP uint64, c compiler) {
278 require.Equal(t, expSP, c.runtimeValueLocationStack().sp-callFrameDataSizeInUint64)
279 }
280
281
282 func TestCompileI32WrapFromI64(t *testing.T) {
283 c := newCompiler()
284 c.Init(&wasm.FunctionType{}, nil, false)
285
286
287 loc := c.runtimeValueLocationStack().pushRuntimeValueLocationOnStack()
288 loc.valueType = runtimeValueTypeI64
289
290 err := c.compileI32WrapFromI64()
291 require.NoError(t, err)
292 require.Equal(t, runtimeValueTypeI32, loc.valueType)
293 }
294
295 func operationPtr(operation wazeroir.UnionOperation) *wazeroir.UnionOperation {
296 return &operation
297 }
298
299 func requireExecutable(original []byte) (executable []byte) {
300 executable, err := platform.MmapCodeSegment(len(original))
301 if err != nil {
302 panic(err)
303 }
304 copy(executable, original)
305 makeExecutable(executable)
306 return executable
307 }
308
309 func makeExecutable(executable []byte) {
310 if runtime.GOARCH == "arm64" {
311 if err := platform.MprotectRX(executable); err != nil {
312 panic(err)
313 }
314 }
315 }
316
View as plain text