1 package wasm
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "math"
9 "reflect"
10
11 "github.com/tetratelabs/wazero/api"
12 )
13
14 type paramsKind byte
15
16 const (
17 paramsKindNoContext paramsKind = iota
18 paramsKindContext
19 paramsKindContextModule
20 )
21
22
23
24 var (
25 moduleType = reflect.TypeOf((*api.Module)(nil)).Elem()
26 goContextType = reflect.TypeOf((*context.Context)(nil)).Elem()
27 errorType = reflect.TypeOf((*error)(nil)).Elem()
28 )
29
30
31
32 var _ api.GoModuleFunction = (*reflectGoModuleFunction)(nil)
33
34 type reflectGoModuleFunction struct {
35 fn *reflect.Value
36 params, results []ValueType
37 }
38
39
40 func (f *reflectGoModuleFunction) Call(ctx context.Context, mod api.Module, stack []uint64) {
41 callGoFunc(ctx, mod, f.fn, stack)
42 }
43
44
45 func (f *reflectGoModuleFunction) EqualTo(that interface{}) bool {
46 if f2, ok := that.(*reflectGoModuleFunction); !ok {
47 return false
48 } else {
49
50 return bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
51 }
52 }
53
54
55 var _ api.GoFunction = (*reflectGoFunction)(nil)
56
57 type reflectGoFunction struct {
58 fn *reflect.Value
59 pk paramsKind
60 params, results []ValueType
61 }
62
63
64 func (f *reflectGoFunction) EqualTo(that interface{}) bool {
65 if f2, ok := that.(*reflectGoFunction); !ok {
66 return false
67 } else {
68
69 return f.pk == f2.pk &&
70 bytes.Equal(f.params, f2.params) && bytes.Equal(f.results, f2.results)
71 }
72 }
73
74
75 func (f *reflectGoFunction) Call(ctx context.Context, stack []uint64) {
76 if f.pk == paramsKindNoContext {
77 ctx = nil
78 }
79 callGoFunc(ctx, nil, f.fn, stack)
80 }
81
82
83
84 func callGoFunc(ctx context.Context, mod api.Module, fn *reflect.Value, stack []uint64) {
85 tp := fn.Type()
86
87 var in []reflect.Value
88 pLen := tp.NumIn()
89 if pLen != 0 {
90 in = make([]reflect.Value, pLen)
91
92 i := 0
93 if ctx != nil {
94 in[0] = newContextVal(ctx)
95 i++
96 }
97 if mod != nil {
98 in[1] = newModuleVal(mod)
99 i++
100 }
101
102 for j := 0; i < pLen; i++ {
103 next := tp.In(i)
104 val := reflect.New(next).Elem()
105 k := next.Kind()
106 raw := stack[j]
107 j++
108
109 switch k {
110 case reflect.Float32:
111 val.SetFloat(float64(math.Float32frombits(uint32(raw))))
112 case reflect.Float64:
113 val.SetFloat(math.Float64frombits(raw))
114 case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
115 val.SetUint(raw)
116 case reflect.Int32, reflect.Int64:
117 val.SetInt(int64(raw))
118 default:
119 panic(fmt.Errorf("BUG: param[%d] has an invalid type: %v", i, k))
120 }
121 in[i] = val
122 }
123 }
124
125
126 for i, ret := range fn.Call(in) {
127 switch ret.Kind() {
128 case reflect.Float32:
129 stack[i] = uint64(math.Float32bits(float32(ret.Float())))
130 case reflect.Float64:
131 stack[i] = math.Float64bits(ret.Float())
132 case reflect.Uint32, reflect.Uint64, reflect.Uintptr:
133 stack[i] = ret.Uint()
134 case reflect.Int32, reflect.Int64:
135 stack[i] = uint64(ret.Int())
136 default:
137 panic(fmt.Errorf("BUG: result[%d] has an invalid type: %v", i, ret.Kind()))
138 }
139 }
140 }
141
142 func newContextVal(ctx context.Context) reflect.Value {
143 val := reflect.New(goContextType).Elem()
144 val.Set(reflect.ValueOf(ctx))
145 return val
146 }
147
148 func newModuleVal(m api.Module) reflect.Value {
149 val := reflect.New(moduleType).Elem()
150 val.Set(reflect.ValueOf(m))
151 return val
152 }
153
154
155
156
157
158 func MustParseGoReflectFuncCode(fn interface{}) Code {
159 _, _, code, err := parseGoReflectFunc(fn)
160 if err != nil {
161 panic(err)
162 }
163 return code
164 }
165
166 func parseGoReflectFunc(fn interface{}) (params, results []ValueType, code Code, err error) {
167 fnV := reflect.ValueOf(fn)
168 p := fnV.Type()
169
170 if fnV.Kind() != reflect.Func {
171 err = fmt.Errorf("kind != func: %s", fnV.Kind().String())
172 return
173 }
174
175 pk, kindErr := kind(p)
176 if kindErr != nil {
177 err = kindErr
178 return
179 }
180
181 pOffset := 0
182 switch pk {
183 case paramsKindNoContext:
184 case paramsKindContext:
185 pOffset = 1
186 case paramsKindContextModule:
187 pOffset = 2
188 }
189
190 pCount := p.NumIn() - pOffset
191 if pCount > 0 {
192 params = make([]ValueType, pCount)
193 }
194 for i := 0; i < len(params); i++ {
195 pI := p.In(i + pOffset)
196 if t, ok := getTypeOf(pI.Kind()); ok {
197 params[i] = t
198 continue
199 }
200
201
202 var arg0Type reflect.Type
203 if hc := pI.Implements(moduleType); hc {
204 arg0Type = moduleType
205 } else if gc := pI.Implements(goContextType); gc {
206 arg0Type = goContextType
207 }
208
209 if arg0Type != nil {
210 err = fmt.Errorf("param[%d] is a %s, which may be defined only once as param[0]", i+pOffset, arg0Type)
211 } else {
212 err = fmt.Errorf("param[%d] is unsupported: %s", i+pOffset, pI.Kind())
213 }
214 return
215 }
216
217 rCount := p.NumOut()
218 if rCount > 0 {
219 results = make([]ValueType, rCount)
220 }
221 for i := 0; i < len(results); i++ {
222 rI := p.Out(i)
223 if t, ok := getTypeOf(rI.Kind()); ok {
224 results[i] = t
225 continue
226 }
227
228
229 if rI.Implements(errorType) {
230 err = fmt.Errorf("result[%d] is an error, which is unsupported", i)
231 } else {
232 err = fmt.Errorf("result[%d] is unsupported: %s", i, rI.Kind())
233 }
234 return
235 }
236
237 code = Code{}
238 if pk == paramsKindContextModule {
239 code.GoFunc = &reflectGoModuleFunction{fn: &fnV, params: params, results: results}
240 } else {
241 code.GoFunc = &reflectGoFunction{pk: pk, fn: &fnV, params: params, results: results}
242 }
243 return
244 }
245
246 func kind(p reflect.Type) (paramsKind, error) {
247 pCount := p.NumIn()
248 if pCount > 0 && p.In(0).Kind() == reflect.Interface {
249 p0 := p.In(0)
250 if p0.Implements(moduleType) {
251 return 0, errors.New("invalid signature: api.Module parameter must be preceded by context.Context")
252 } else if p0.Implements(goContextType) {
253 if pCount >= 2 && p.In(1).Implements(moduleType) {
254 return paramsKindContextModule, nil
255 }
256 return paramsKindContext, nil
257 }
258 }
259
260
261 return paramsKindNoContext, nil
262 }
263
264 func getTypeOf(kind reflect.Kind) (ValueType, bool) {
265 switch kind {
266 case reflect.Float64:
267 return ValueTypeF64, true
268 case reflect.Float32:
269 return ValueTypeF32, true
270 case reflect.Int32, reflect.Uint32:
271 return ValueTypeI32, true
272 case reflect.Int64, reflect.Uint64:
273 return ValueTypeI64, true
274 case reflect.Uintptr:
275 return ValueTypeExternref, true
276 default:
277 return 0x00, false
278 }
279 }
280
View as plain text