1 package compiler
2
3 import (
4 "fmt"
5 "strings"
6
7 "github.com/tetratelabs/wazero/internal/asm"
8 "github.com/tetratelabs/wazero/internal/wasm"
9 )
10
11 var (
12
13 unreservedGeneralPurposeRegisters []asm.Register
14
15
16 unreservedVectorRegisters []asm.Register
17 )
18
19 func isGeneralPurposeRegister(r asm.Register) bool {
20 return unreservedGeneralPurposeRegisters[0] <= r && r <= unreservedGeneralPurposeRegisters[len(unreservedGeneralPurposeRegisters)-1]
21 }
22
23 func isVectorRegister(r asm.Register) bool {
24 return unreservedVectorRegisters[0] <= r && r <= unreservedVectorRegisters[len(unreservedVectorRegisters)-1]
25 }
26
27
28
29
30 type runtimeValueLocation struct {
31 valueType runtimeValueType
32
33 register asm.Register
34
35 conditionalRegister asm.ConditionalRegisterState
36
37 stackPointer uint64
38 }
39
40 func (v *runtimeValueLocation) getRegisterType() (ret registerType) {
41 switch v.valueType {
42 case runtimeValueTypeI32, runtimeValueTypeI64:
43 ret = registerTypeGeneralPurpose
44 case runtimeValueTypeF32, runtimeValueTypeF64,
45 runtimeValueTypeV128Lo, runtimeValueTypeV128Hi:
46 ret = registerTypeVector
47 default:
48 panic("BUG")
49 }
50 return
51 }
52
53 type runtimeValueType byte
54
55 const (
56 runtimeValueTypeNone runtimeValueType = iota
57 runtimeValueTypeI32
58 runtimeValueTypeI64
59 runtimeValueTypeF32
60 runtimeValueTypeF64
61 runtimeValueTypeV128Lo
62 runtimeValueTypeV128Hi
63 )
64
65 func (r runtimeValueType) String() (ret string) {
66 switch r {
67 case runtimeValueTypeI32:
68 ret = "i32"
69 case runtimeValueTypeI64:
70 ret = "i64"
71 case runtimeValueTypeF32:
72 ret = "f32"
73 case runtimeValueTypeF64:
74 ret = "f64"
75 case runtimeValueTypeV128Lo:
76 ret = "v128.lo"
77 case runtimeValueTypeV128Hi:
78 ret = "v128.hi"
79 }
80 return
81 }
82
83 func (v *runtimeValueLocation) setRegister(reg asm.Register) {
84 v.register = reg
85 v.conditionalRegister = asm.ConditionalRegisterStateUnset
86 }
87
88 func (v *runtimeValueLocation) onRegister() bool {
89 return v.register != asm.NilRegister && v.conditionalRegister == asm.ConditionalRegisterStateUnset
90 }
91
92 func (v *runtimeValueLocation) onStack() bool {
93 return v.register == asm.NilRegister && v.conditionalRegister == asm.ConditionalRegisterStateUnset
94 }
95
96 func (v *runtimeValueLocation) onConditionalRegister() bool {
97 return v.conditionalRegister != asm.ConditionalRegisterStateUnset
98 }
99
100 func (v *runtimeValueLocation) String() string {
101 var location string
102 if v.onStack() {
103 location = fmt.Sprintf("stack(%d)", v.stackPointer)
104 } else if v.onConditionalRegister() {
105 location = fmt.Sprintf("conditional(%d)", v.conditionalRegister)
106 } else if v.onRegister() {
107 location = fmt.Sprintf("register(%s)", registerNameFn(v.register))
108 }
109 return fmt.Sprintf("{type=%s,location=%s}", v.valueType, location)
110 }
111
112 func newRuntimeValueLocationStack() runtimeValueLocationStack {
113 return runtimeValueLocationStack{
114 unreservedVectorRegisters: unreservedVectorRegisters,
115 unreservedGeneralPurposeRegisters: unreservedGeneralPurposeRegisters,
116 }
117 }
118
119
120
121
122
123
124
125
126
127
128
129 type runtimeValueLocationStack struct {
130
131 stack []runtimeValueLocation
132
133 sp uint64
134
135 usedRegisters usedRegistersMask
136
137 stackPointerCeil uint64
138
139
140 unreservedGeneralPurposeRegisters, unreservedVectorRegisters []asm.Register
141 }
142
143 func (v *runtimeValueLocationStack) reset() {
144 stack := v.stack[:0]
145 *v = runtimeValueLocationStack{
146 unreservedVectorRegisters: unreservedVectorRegisters,
147 unreservedGeneralPurposeRegisters: unreservedGeneralPurposeRegisters,
148 stack: stack,
149 }
150 }
151
152 func (v *runtimeValueLocationStack) String() string {
153 var stackStr []string
154 for i := uint64(0); i < v.sp; i++ {
155 stackStr = append(stackStr, v.stack[i].String())
156 }
157 usedRegisters := v.usedRegisters.list()
158 return fmt.Sprintf("sp=%d, stack=[%s], used_registers=[%s]", v.sp, strings.Join(stackStr, ","), strings.Join(usedRegisters, ","))
159 }
160
161
162
163
164 func (v *runtimeValueLocationStack) cloneFrom(from runtimeValueLocationStack) {
165
166 prev := v.stack
167 *v = from
168 v.stack = prev[:cap(prev)]
169
170 if diff := int(from.sp) - len(v.stack); diff > 0 {
171 v.stack = append(v.stack, make([]runtimeValueLocation, diff)...)
172 }
173 copy(v.stack, from.stack[:from.sp])
174 }
175
176
177
178 func (v *runtimeValueLocationStack) pushRuntimeValueLocationOnRegister(reg asm.Register, vt runtimeValueType) (loc *runtimeValueLocation) {
179 loc = v.push(reg, asm.ConditionalRegisterStateUnset)
180 loc.valueType = vt
181 return
182 }
183
184
185 func (v *runtimeValueLocationStack) pushRuntimeValueLocationOnStack() (loc *runtimeValueLocation) {
186 loc = v.push(asm.NilRegister, asm.ConditionalRegisterStateUnset)
187 loc.valueType = runtimeValueTypeNone
188 return
189 }
190
191
192
193 func (v *runtimeValueLocationStack) pushRuntimeValueLocationOnConditionalRegister(state asm.ConditionalRegisterState) (loc *runtimeValueLocation) {
194 loc = v.push(asm.NilRegister, state)
195 loc.valueType = runtimeValueTypeI32
196 return
197 }
198
199
200 func (v *runtimeValueLocationStack) push(reg asm.Register, conditionalRegister asm.ConditionalRegisterState) (ret *runtimeValueLocation) {
201 if v.sp >= uint64(len(v.stack)) {
202
203
204 v.stack = append(v.stack, runtimeValueLocation{})
205 }
206 ret = &v.stack[v.sp]
207 ret.register, ret.conditionalRegister, ret.stackPointer = reg, conditionalRegister, v.sp
208 v.sp++
209
210
211 if v.sp > v.stackPointerCeil {
212 v.stackPointerCeil = v.sp
213 }
214 return
215 }
216
217 func (v *runtimeValueLocationStack) pop() (loc *runtimeValueLocation) {
218 v.sp--
219 loc = &v.stack[v.sp]
220 return
221 }
222
223 func (v *runtimeValueLocationStack) popV128() (loc *runtimeValueLocation) {
224 v.sp -= 2
225 loc = &v.stack[v.sp]
226 return
227 }
228
229 func (v *runtimeValueLocationStack) peek() (loc *runtimeValueLocation) {
230 loc = &v.stack[v.sp-1]
231 return
232 }
233
234 func (v *runtimeValueLocationStack) releaseRegister(loc *runtimeValueLocation) {
235 v.markRegisterUnused(loc.register)
236 loc.register = asm.NilRegister
237 loc.conditionalRegister = asm.ConditionalRegisterStateUnset
238 }
239
240 func (v *runtimeValueLocationStack) markRegisterUnused(regs ...asm.Register) {
241 for _, reg := range regs {
242 v.usedRegisters.remove(reg)
243 }
244 }
245
246 func (v *runtimeValueLocationStack) markRegisterUsed(regs ...asm.Register) {
247 for _, reg := range regs {
248 v.usedRegisters.add(reg)
249 }
250 }
251
252 type registerType byte
253
254 const (
255 registerTypeGeneralPurpose registerType = iota
256
257
258
259
260
261
262
263
264
265 registerTypeVector
266 )
267
268 func (tp registerType) String() (ret string) {
269 switch tp {
270 case registerTypeGeneralPurpose:
271 ret = "int"
272 case registerTypeVector:
273 ret = "vector"
274 }
275 return
276 }
277
278
279 func (v *runtimeValueLocationStack) takeFreeRegister(tp registerType) (reg asm.Register, found bool) {
280 var targetRegs []asm.Register
281 switch tp {
282 case registerTypeVector:
283 targetRegs = v.unreservedVectorRegisters
284 case registerTypeGeneralPurpose:
285 targetRegs = v.unreservedGeneralPurposeRegisters
286 }
287 for _, candidate := range targetRegs {
288 if v.usedRegisters.exist(candidate) {
289 continue
290 }
291 return candidate, true
292 }
293 return 0, false
294 }
295
296
297
298 func (v *runtimeValueLocationStack) takeStealTargetFromUsedRegister(tp registerType) (*runtimeValueLocation, bool) {
299 for i := uint64(0); i < v.sp; i++ {
300 loc := &v.stack[i]
301 if loc.onRegister() {
302 switch tp {
303 case registerTypeVector:
304 if loc.valueType == runtimeValueTypeV128Hi {
305 panic("BUG: V128Hi must be above the corresponding V128Lo")
306 }
307 if isVectorRegister(loc.register) {
308 return loc, true
309 }
310 case registerTypeGeneralPurpose:
311 if isGeneralPurposeRegister(loc.register) {
312 return loc, true
313 }
314 }
315 }
316 }
317 return nil, false
318 }
319
320
321
322
323
324 func (v *runtimeValueLocationStack) init(sig *wasm.FunctionType) {
325 for _, t := range sig.Params {
326 loc := v.pushRuntimeValueLocationOnStack()
327 switch t {
328 case wasm.ValueTypeI32:
329 loc.valueType = runtimeValueTypeI32
330 case wasm.ValueTypeI64, wasm.ValueTypeFuncref, wasm.ValueTypeExternref:
331 loc.valueType = runtimeValueTypeI64
332 case wasm.ValueTypeF32:
333 loc.valueType = runtimeValueTypeF32
334 case wasm.ValueTypeF64:
335 loc.valueType = runtimeValueTypeF64
336 case wasm.ValueTypeV128:
337 loc.valueType = runtimeValueTypeV128Lo
338 hi := v.pushRuntimeValueLocationOnStack()
339 hi.valueType = runtimeValueTypeV128Hi
340 default:
341 panic("BUG")
342 }
343 }
344
345
346
347 for i := 0; i < sig.ResultNumInUint64-sig.ParamNumInUint64; i++ {
348 _ = v.pushRuntimeValueLocationOnStack()
349 }
350
351
352 for i := 0; i < callFrameDataSizeInUint64; i++ {
353 loc := v.pushRuntimeValueLocationOnStack()
354 loc.valueType = runtimeValueTypeI64
355 }
356 }
357
358
359
360
361 func (v *runtimeValueLocationStack) getCallFrameLocations(sig *wasm.FunctionType) (
362 returnAddress, callerStackBasePointerInBytes, callerFunction *runtimeValueLocation,
363 ) {
364 offset := callFrameOffset(sig)
365 return &v.stack[offset], &v.stack[offset+1], &v.stack[offset+2]
366 }
367
368
369
370
371
372 func (v *runtimeValueLocationStack) pushCallFrame(callTargetFunctionType *wasm.FunctionType) (
373 returnAddress, callerStackBasePointerInBytes, callerFunction *runtimeValueLocation,
374 ) {
375
376 reservedSlotsBeforeCallFrame := callTargetFunctionType.ResultNumInUint64 - callTargetFunctionType.ParamNumInUint64
377 for i := 0; i < reservedSlotsBeforeCallFrame; i++ {
378 v.pushRuntimeValueLocationOnStack()
379 }
380
381
382
383
384
385 returnAddress = v.pushRuntimeValueLocationOnStack()
386 returnAddress.valueType = runtimeValueTypeI64
387
388 callerStackBasePointerInBytes = v.pushRuntimeValueLocationOnStack()
389 callerStackBasePointerInBytes.valueType = runtimeValueTypeI64
390
391 callerFunction = v.pushRuntimeValueLocationOnStack()
392 callerFunction.valueType = runtimeValueTypeI64
393 return
394 }
395
396
397 type usedRegistersMask uint64
398
399
400 func (u *usedRegistersMask) add(r asm.Register) {
401 *u = *u | (1 << registerMaskShift(r))
402 }
403
404
405 func (u *usedRegistersMask) remove(r asm.Register) {
406 *u = *u & ^(1 << registerMaskShift(r))
407 }
408
409
410 func (u *usedRegistersMask) exist(r asm.Register) bool {
411 shift := registerMaskShift(r)
412 return (*u & (1 << shift)) > 0
413 }
414
415
416
417 func (u *usedRegistersMask) list() (ret []string) {
418 mask := *u
419 for i := 0; i < 64; i++ {
420 if mask&(1<<i) > 0 {
421 ret = append(ret, registerNameFn(registerFromMaskShift(i)))
422 }
423 }
424 return
425 }
426
View as plain text