1 package compiler
2
3 import (
4 "testing"
5
6 "github.com/tetratelabs/wazero/internal/asm"
7 "github.com/tetratelabs/wazero/internal/testing/require"
8 "github.com/tetratelabs/wazero/internal/wasm"
9 )
10
11 func Test_isIntRegister(t *testing.T) {
12 for _, r := range unreservedGeneralPurposeRegisters {
13 require.True(t, isGeneralPurposeRegister(r))
14 }
15 }
16
17 func Test_isVectorRegister(t *testing.T) {
18 for _, r := range unreservedVectorRegisters {
19 require.True(t, isVectorRegister(r))
20 }
21 }
22
23 func TestRuntimeValueLocationStack_basic(t *testing.T) {
24 s := newRuntimeValueLocationStack()
25
26 loc := s.pushRuntimeValueLocationOnStack()
27 require.Equal(t, uint64(1), s.sp)
28 require.Equal(t, uint64(0), loc.stackPointer)
29
30 tmpReg := unreservedGeneralPurposeRegisters[0]
31 loc = s.pushRuntimeValueLocationOnRegister(tmpReg, runtimeValueTypeI64)
32 require.Equal(t, uint64(2), s.sp)
33 require.Equal(t, uint64(1), loc.stackPointer)
34 require.Equal(t, tmpReg, loc.register)
35 require.Equal(t, loc.valueType, runtimeValueTypeI64)
36
37 tmpReg2 := unreservedGeneralPurposeRegisters[1]
38 s.markRegisterUsed(tmpReg2)
39 require.True(t, s.usedRegisters.exist(tmpReg2))
40
41 s.releaseRegister(loc)
42 require.False(t, s.usedRegisters.exist(loc.register))
43 require.Equal(t, asm.NilRegister, loc.register)
44
45 for i := 0; i < 1000; i++ {
46 s.pushRuntimeValueLocationOnStack()
47 }
48 for i := 0; i < 1000; i++ {
49 s.pop()
50 }
51 require.Equal(t, uint64(1002), s.stackPointerCeil)
52 }
53
54 func TestRuntimeValueLocationStack_takeFreeRegister(t *testing.T) {
55 s := newRuntimeValueLocationStack()
56
57 r, ok := s.takeFreeRegister(registerTypeGeneralPurpose)
58 require.True(t, ok)
59 require.True(t, isGeneralPurposeRegister(r))
60
61 for _, r := range unreservedGeneralPurposeRegisters {
62 s.markRegisterUsed(r)
63 }
64
65 _, ok = s.takeFreeRegister(registerTypeGeneralPurpose)
66 require.False(t, ok)
67
68 r, ok = s.takeFreeRegister(registerTypeVector)
69 require.True(t, ok)
70 require.True(t, isVectorRegister(r))
71
72 for _, r := range unreservedVectorRegisters {
73 s.markRegisterUsed(r)
74 }
75
76 _, ok = s.takeFreeRegister(registerTypeVector)
77 require.False(t, ok)
78 }
79
80 func TestRuntimeValueLocationStack_takeStealTargetFromUsedRegister(t *testing.T) {
81 s := newRuntimeValueLocationStack()
82 intReg := unreservedGeneralPurposeRegisters[0]
83 floatReg := unreservedVectorRegisters[0]
84 intLocation := s.push(intReg, asm.ConditionalRegisterStateUnset)
85 floatLocation := s.push(floatReg, asm.ConditionalRegisterStateUnset)
86
87 target, ok := s.takeStealTargetFromUsedRegister(registerTypeVector)
88 require.True(t, ok)
89 require.Equal(t, floatLocation, target)
90
91 target, ok = s.takeStealTargetFromUsedRegister(registerTypeGeneralPurpose)
92 require.True(t, ok)
93 require.Equal(t, intLocation, target)
94
95 popped := s.pop()
96 require.Equal(t, floatLocation, popped)
97
98 target, ok = s.takeStealTargetFromUsedRegister(registerTypeVector)
99 require.False(t, ok)
100 require.Nil(t, target)
101
102 popped = s.pop()
103 require.Equal(t, intLocation, popped)
104
105 target, ok = s.takeStealTargetFromUsedRegister(registerTypeGeneralPurpose)
106 require.False(t, ok)
107 require.Nil(t, target)
108 }
109
110 func TestRuntimeValueLocationStack_setupInitialStack(t *testing.T) {
111 const f32 = wasm.ValueTypeF32
112 tests := []struct {
113 name string
114 sig *wasm.FunctionType
115 expectedSP uint64
116 }{
117 {
118 name: "no params / no results",
119 sig: &wasm.FunctionType{},
120 expectedSP: callFrameDataSizeInUint64,
121 },
122 {
123 name: "no results",
124 sig: &wasm.FunctionType{
125 Params: []wasm.ValueType{f32, f32},
126 ParamNumInUint64: 2,
127 },
128 expectedSP: callFrameDataSizeInUint64 + 2,
129 },
130 {
131 name: "no params",
132 sig: &wasm.FunctionType{
133 Results: []wasm.ValueType{f32, f32},
134 ResultNumInUint64: 2,
135 },
136 expectedSP: callFrameDataSizeInUint64 + 2,
137 },
138 {
139 name: "params == results",
140 sig: &wasm.FunctionType{
141 Params: []wasm.ValueType{f32, f32},
142 ParamNumInUint64: 2,
143 Results: []wasm.ValueType{f32, f32},
144 ResultNumInUint64: 2,
145 },
146 expectedSP: callFrameDataSizeInUint64 + 2,
147 },
148 {
149 name: "params > results",
150 sig: &wasm.FunctionType{
151 Params: []wasm.ValueType{f32, f32, f32},
152 ParamNumInUint64: 3,
153 Results: []wasm.ValueType{f32, f32},
154 ResultNumInUint64: 2,
155 },
156 expectedSP: callFrameDataSizeInUint64 + 3,
157 },
158 {
159 name: "params < results",
160 sig: &wasm.FunctionType{
161 Params: []wasm.ValueType{f32},
162 ParamNumInUint64: 1,
163 Results: []wasm.ValueType{f32, f32, f32},
164 ResultNumInUint64: 3,
165 },
166 expectedSP: callFrameDataSizeInUint64 + 3,
167 },
168 }
169
170 for _, tc := range tests {
171 tc := tc
172 t.Run(tc.name, func(t *testing.T) {
173 s := newRuntimeValueLocationStack()
174 s.init(tc.sig)
175 require.Equal(t, tc.expectedSP, s.sp)
176
177 callFrameLocations := s.stack[s.sp-callFrameDataSizeInUint64 : s.sp]
178 for _, loc := range callFrameLocations {
179 require.Equal(t, runtimeValueTypeI64, loc.valueType)
180 }
181 })
182 }
183 }
184
185 func TestRuntimeValueLocation_pushCallFrame(t *testing.T) {
186 for _, sig := range []*wasm.FunctionType{
187 {ParamNumInUint64: 0, ResultNumInUint64: 1},
188 {ParamNumInUint64: 1, ResultNumInUint64: 0},
189 {ParamNumInUint64: 1, ResultNumInUint64: 1},
190 {ParamNumInUint64: 0, ResultNumInUint64: 2},
191 {ParamNumInUint64: 2, ResultNumInUint64: 0},
192 {ParamNumInUint64: 2, ResultNumInUint64: 3},
193 } {
194 sig := sig
195 t.Run(sig.String(), func(t *testing.T) {
196 s := newRuntimeValueLocationStack()
197
198 for i := 0; i < sig.ParamNumInUint64; i++ {
199 _ = s.pushRuntimeValueLocationOnStack()
200 }
201
202 retAddr, stackBasePointer, fn := s.pushCallFrame(sig)
203
204 expOffset := uint64(callFrameOffset(sig))
205 require.Equal(t, expOffset, retAddr.stackPointer)
206 require.Equal(t, expOffset+1, stackBasePointer.stackPointer)
207 require.Equal(t, expOffset+2, fn.stackPointer)
208 })
209 }
210 }
211
212 func Test_usedRegistersMask(t *testing.T) {
213 for _, r := range append(unreservedVectorRegisters, unreservedGeneralPurposeRegisters...) {
214 mask := usedRegistersMask(0)
215 mask.add(r)
216 require.False(t, mask == 0)
217 require.True(t, mask.exist(r))
218 mask.remove(r)
219 require.True(t, mask == 0)
220 require.False(t, mask.exist(r))
221 }
222 }
223
224 func TestRuntimeValueLocation_cloneFrom(t *testing.T) {
225 t.Run("sp<cap", func(t *testing.T) {
226 v := runtimeValueLocationStack{sp: 7, stack: make([]runtimeValueLocation, 5, 10)}
227 orig := v.stack
228 v.cloneFrom(runtimeValueLocationStack{sp: 3, usedRegisters: 0xffff, stack: []runtimeValueLocation{
229 {register: 3}, {register: 2}, {register: 1},
230 }})
231 require.Equal(t, uint64(3), v.sp)
232 require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters)
233
234 require.Equal(t, &orig[0], &v.stack[0])
235 require.Equal(t, v.stack[0].register, asm.Register(3))
236 require.Equal(t, v.stack[1].register, asm.Register(2))
237 require.Equal(t, v.stack[2].register, asm.Register(1))
238 })
239 t.Run("sp=cap", func(t *testing.T) {
240 v := runtimeValueLocationStack{stack: make([]runtimeValueLocation, 0, 3)}
241 orig := v.stack[:cap(v.stack)]
242 v.cloneFrom(runtimeValueLocationStack{sp: 3, usedRegisters: 0xffff, stack: []runtimeValueLocation{
243 {register: 3}, {register: 2}, {register: 1},
244 }})
245 require.Equal(t, uint64(3), v.sp)
246 require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters)
247
248 require.Equal(t, &orig[0], &v.stack[0])
249 require.Equal(t, v.stack[0].register, asm.Register(3))
250 require.Equal(t, v.stack[1].register, asm.Register(2))
251 require.Equal(t, v.stack[2].register, asm.Register(1))
252 })
253 t.Run("sp>cap", func(t *testing.T) {
254 v := runtimeValueLocationStack{stack: make([]runtimeValueLocation, 0, 3)}
255 orig := v.stack[:cap(v.stack)]
256 v.cloneFrom(runtimeValueLocationStack{sp: 5, usedRegisters: 0xffff, stack: []runtimeValueLocation{
257 {register: 5}, {register: 4}, {register: 3}, {register: 2}, {register: 1},
258 }})
259 require.Equal(t, uint64(5), v.sp)
260 require.Equal(t, usedRegistersMask(0xffff), v.usedRegisters)
261
262 require.NotEqual(t, &orig[0], &v.stack[0])
263 require.Equal(t, v.stack[0].register, asm.Register(5))
264 require.Equal(t, v.stack[1].register, asm.Register(4))
265 require.Equal(t, v.stack[2].register, asm.Register(3))
266 require.Equal(t, v.stack[3].register, asm.Register(2))
267 require.Equal(t, v.stack[4].register, asm.Register(1))
268 })
269 }
270
View as plain text