1 package compiler
2
3 import (
4 "fmt"
5 "reflect"
6 "testing"
7 "unsafe"
8
9 "github.com/tetratelabs/wazero/internal/asm"
10 "github.com/tetratelabs/wazero/internal/testing/require"
11 "github.com/tetratelabs/wazero/internal/wasm"
12 "github.com/tetratelabs/wazero/internal/wazeroir"
13 )
14
15 func TestCompiler_compileModuleContextInitialization(t *testing.T) {
16 tests := []struct {
17 name string
18 moduleInstance *wasm.ModuleInstance
19 }{
20 {
21 name: "no nil",
22 moduleInstance: &wasm.ModuleInstance{
23 Globals: []*wasm.GlobalInstance{{Val: 100}},
24 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)},
25 Tables: []*wasm.TableInstance{
26 {References: make([]wasm.Reference, 20)},
27 {References: make([]wasm.Reference, 10)},
28 },
29 TypeIDs: make([]wasm.FunctionTypeID, 10),
30 DataInstances: make([][]byte, 10),
31 ElementInstances: make([]wasm.ElementInstance, 10),
32 },
33 },
34 {
35 name: "element instances nil",
36 moduleInstance: &wasm.ModuleInstance{
37 Globals: []*wasm.GlobalInstance{{Val: 100}},
38 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)},
39 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
40 TypeIDs: make([]wasm.FunctionTypeID, 10),
41 DataInstances: make([][]byte, 10),
42 ElementInstances: nil,
43 },
44 },
45 {
46 name: "data instances nil",
47 moduleInstance: &wasm.ModuleInstance{
48 Globals: []*wasm.GlobalInstance{{Val: 100}},
49 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)},
50 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
51 TypeIDs: make([]wasm.FunctionTypeID, 10),
52 DataInstances: nil,
53 ElementInstances: make([]wasm.ElementInstance, 10),
54 },
55 },
56 {
57 name: "globals nil",
58 moduleInstance: &wasm.ModuleInstance{
59 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)},
60 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
61 TypeIDs: make([]wasm.FunctionTypeID, 10),
62 DataInstances: make([][]byte, 10),
63 ElementInstances: make([]wasm.ElementInstance, 10),
64 },
65 },
66 {
67 name: "memory nil",
68 moduleInstance: &wasm.ModuleInstance{
69 Globals: []*wasm.GlobalInstance{{Val: 100}},
70 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
71 TypeIDs: make([]wasm.FunctionTypeID, 10),
72 DataInstances: make([][]byte, 10),
73 ElementInstances: make([]wasm.ElementInstance, 10),
74 },
75 },
76 {
77 name: "table nil",
78 moduleInstance: &wasm.ModuleInstance{
79 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 10)},
80 Tables: []*wasm.TableInstance{{References: nil}},
81 TypeIDs: make([]wasm.FunctionTypeID, 10),
82 DataInstances: make([][]byte, 10),
83 ElementInstances: make([]wasm.ElementInstance, 10),
84 },
85 },
86 {
87 name: "table empty",
88 moduleInstance: &wasm.ModuleInstance{
89 Tables: []*wasm.TableInstance{{References: make([]wasm.Reference, 20)}},
90 TypeIDs: make([]wasm.FunctionTypeID, 10),
91 DataInstances: make([][]byte, 10),
92 ElementInstances: make([]wasm.ElementInstance, 10),
93 },
94 },
95 {
96 name: "memory zero length",
97 moduleInstance: &wasm.ModuleInstance{
98 MemoryInstance: &wasm.MemoryInstance{Buffer: make([]byte, 0)},
99 },
100 },
101 {
102 name: "all nil except mod engine",
103 moduleInstance: &wasm.ModuleInstance{},
104 },
105 }
106
107 for _, tt := range tests {
108 tc := tt
109 t.Run(tc.name, func(t *testing.T) {
110 env := newCompilerEnvironment()
111 env.moduleInstance = tc.moduleInstance
112 ce := env.callEngine()
113
114 ir := &wazeroir.CompilationResult{
115 HasMemory: tc.moduleInstance.MemoryInstance != nil,
116 HasTable: len(tc.moduleInstance.Tables) > 0,
117 HasDataInstances: len(tc.moduleInstance.DataInstances) > 0,
118 HasElementInstances: len(tc.moduleInstance.ElementInstances) > 0,
119 }
120 for _, g := range tc.moduleInstance.Globals {
121 ir.Globals = append(ir.Globals, g.Type)
122 }
123 compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, ir)
124 me := &moduleEngine{functions: make([]function, 10)}
125 tc.moduleInstance.Engine = me
126
127 err := compiler.compileModuleContextInitialization()
128 require.NoError(t, err)
129 require.Zero(t, len(compiler.runtimeValueLocationStack().usedRegisters.list()), "expected no usedRegisters")
130
131 compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned)
132
133 code := asm.CodeSegment{}
134 defer func() { require.NoError(t, code.Unmap()) }()
135
136
137 _, err = compiler.compile(code.NextCodeSection())
138 require.NoError(t, err)
139
140 env.exec(code.Bytes())
141
142
143 require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
144
145
146 bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Globals))
147 require.Equal(t, bufSliceHeader.Data, ce.moduleContext.globalElement0Address)
148
149 if tc.moduleInstance.MemoryInstance != nil {
150 bufSliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.MemoryInstance.Buffer))
151 require.Equal(t, uint64(bufSliceHeader.Len), ce.moduleContext.memorySliceLen)
152 require.Equal(t, bufSliceHeader.Data, ce.moduleContext.memoryElement0Address)
153 require.Equal(t, tc.moduleInstance.MemoryInstance, ce.moduleContext.memoryInstance)
154 }
155
156 if len(tc.moduleInstance.Tables) > 0 {
157 tableHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.Tables))
158 require.Equal(t, tableHeader.Data, ce.moduleContext.tablesElement0Address)
159 require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.TypeIDs[0])), ce.moduleContext.typeIDsElement0Address)
160 require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.Tables[0])), ce.moduleContext.tablesElement0Address)
161 }
162
163 if len(tc.moduleInstance.DataInstances) > 0 {
164 dataInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.DataInstances))
165 require.Equal(t, dataInstancesHeader.Data, ce.moduleContext.dataInstancesElement0Address)
166 require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.DataInstances[0])), ce.moduleContext.dataInstancesElement0Address)
167 }
168
169 if len(tc.moduleInstance.ElementInstances) > 0 {
170 elementInstancesHeader := (*reflect.SliceHeader)(unsafe.Pointer(&tc.moduleInstance.ElementInstances))
171 require.Equal(t, elementInstancesHeader.Data, ce.moduleContext.elementInstancesElement0Address)
172 require.Equal(t, uintptr(unsafe.Pointer(&tc.moduleInstance.ElementInstances[0])), ce.moduleContext.elementInstancesElement0Address)
173 }
174
175 require.Equal(t, uintptr(unsafe.Pointer(&me.functions[0])), ce.moduleContext.functionsElement0Address)
176 })
177 }
178 }
179
180 func TestCompiler_compileMaybeGrowStack(t *testing.T) {
181 t.Run("not grow", func(t *testing.T) {
182 const stackPointerCeil = 5
183 for _, baseOffset := range []uint64{5, 10, 20} {
184 t.Run(fmt.Sprintf("%d", baseOffset), func(t *testing.T) {
185 env := newCompilerEnvironment()
186 compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, nil)
187
188 err := compiler.compilePreamble()
189 require.NoError(t, err)
190
191 stackLen := uint64(len(env.stack()))
192 stackBasePointer := stackLen - baseOffset
193 compiler.assignStackPointerCeil(stackPointerCeil)
194 env.setStackBasePointer(stackBasePointer)
195
196 compiler.compileExitFromNativeCode(nativeCallStatusCodeReturned)
197
198 code := asm.CodeSegment{}
199 defer func() { require.NoError(t, code.Unmap()) }()
200
201
202 _, err = compiler.compile(code.NextCodeSection())
203 require.NoError(t, err)
204 env.exec(code.Bytes())
205
206
207 require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
208 })
209 }
210 })
211
212 defaultStackLen := uint64(initialStackSize)
213 t.Run("grow", func(t *testing.T) {
214 tests := []struct {
215 name string
216 stackPointerCeil uint64
217 stackBasePointer uint64
218 }{
219 {
220 name: "ceil=6/sbp=len-5",
221 stackPointerCeil: 6,
222 stackBasePointer: defaultStackLen - 5,
223 },
224 {
225 name: "ceil=10000/sbp=0",
226 stackPointerCeil: 10000,
227 stackBasePointer: 0,
228 },
229 }
230
231 for _, tc := range tests {
232 tc := tc
233 t.Run(tc.name, func(t *testing.T) {
234 env := newCompilerEnvironment()
235 compiler := env.requireNewCompiler(t, &wasm.FunctionType{}, newCompiler, nil)
236
237 err := compiler.compilePreamble()
238 require.NoError(t, err)
239
240
241 err = compiler.compileReturnFunction()
242 require.NoError(t, err)
243
244 code := asm.CodeSegment{}
245 defer func() { require.NoError(t, code.Unmap()) }()
246
247
248 compiler.setStackPointerCeil(tc.stackPointerCeil)
249 _, err = compiler.compile(code.NextCodeSection())
250 require.NoError(t, err)
251
252
253 env.setStackBasePointer(tc.stackBasePointer)
254 env.exec(code.Bytes())
255
256
257 require.Equal(t, nativeCallStatusCodeCallBuiltInFunction, env.compilerStatus())
258
259
260 returnAddress := env.ce.returnAddress
261 require.True(t, returnAddress != 0, "returnAddress was zero %d", returnAddress)
262 nativecall(returnAddress, env.callEngine(), env.module())
263
264
265 require.Equal(t, nativeCallStatusCodeReturned, env.compilerStatus())
266 })
267 }
268 })
269 }
270
View as plain text