1 package wasm
2
3 import (
4 "context"
5 "math"
6 "testing"
7 "unsafe"
8
9 "github.com/tetratelabs/wazero/api"
10 "github.com/tetratelabs/wazero/internal/testing/require"
11 )
12
13
14 var testCtx = context.WithValue(context.Background(), struct{}{}, "arbitrary")
15
16 func Test_parseGoFunc(t *testing.T) {
17 tests := []struct {
18 name string
19 input interface{}
20 expectNeedsModule bool
21 expectedType *FunctionType
22 }{
23 {
24 name: "() -> ()",
25 input: func() {},
26 expectedType: &FunctionType{},
27 },
28 {
29 name: "(ctx) -> ()",
30 input: func(context.Context) {},
31 expectedType: &FunctionType{},
32 },
33 {
34 name: "(ctx, mod) -> ()",
35 input: func(context.Context, api.Module) {},
36 expectNeedsModule: true,
37 expectedType: &FunctionType{},
38 },
39 {
40 name: "all supported params and i32 result",
41 input: func(uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
42 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
43 },
44 {
45 name: "all supported params and i32 result - (ctx)",
46 input: func(context.Context, uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
47 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
48 },
49 {
50 name: "all supported params and i32 result - (ctx, mod)",
51 input: func(context.Context, api.Module, uint32, uint64, float32, float64, uintptr) uint32 { return 0 },
52 expectNeedsModule: true,
53 expectedType: &FunctionType{Params: []ValueType{i32, i64, f32, f64, externref}, Results: []ValueType{i32}},
54 },
55 }
56 for _, tt := range tests {
57 tc := tt
58
59 t.Run(tc.name, func(t *testing.T) {
60 paramTypes, resultTypes, code, err := parseGoReflectFunc(tc.input)
61 require.NoError(t, err)
62 _, isModuleFunc := code.GoFunc.(api.GoModuleFunction)
63 require.Equal(t, tc.expectNeedsModule, isModuleFunc)
64 require.Equal(t, tc.expectedType, &FunctionType{Params: paramTypes, Results: resultTypes})
65 })
66 }
67 }
68
69 func Test_parseGoFunc_Errors(t *testing.T) {
70 tests := []struct {
71 name string
72 input interface{}
73 expectedErr string
74 }{
75 {
76 name: "module no context",
77 input: func(api.Module) {},
78 expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context",
79 },
80 {
81 name: "not a func",
82 input: struct{}{},
83 expectedErr: "kind != func: struct",
84 },
85 {
86 name: "unsupported param",
87 input: func(context.Context, uint32, string) {},
88 expectedErr: "param[2] is unsupported: string",
89 },
90 {
91 name: "unsupported result",
92 input: func() string { return "" },
93 expectedErr: "result[0] is unsupported: string",
94 },
95 {
96 name: "error result",
97 input: func() error { return nil },
98 expectedErr: "result[0] is an error, which is unsupported",
99 },
100 {
101 name: "incorrect order",
102 input: func(api.Module, context.Context) error { return nil },
103 expectedErr: "invalid signature: api.Module parameter must be preceded by context.Context",
104 },
105 {
106 name: "multiple context.Context",
107 input: func(context.Context, uint64, context.Context) error { return nil },
108 expectedErr: "param[2] is a context.Context, which may be defined only once as param[0]",
109 },
110 {
111 name: "multiple wasm.Module",
112 input: func(context.Context, api.Module, uint64, api.Module) error { return nil },
113 expectedErr: "param[3] is a api.Module, which may be defined only once as param[0]",
114 },
115 }
116
117 for _, tt := range tests {
118 tc := tt
119
120 t.Run(tc.name, func(t *testing.T) {
121 _, _, _, err := parseGoReflectFunc(tc.input)
122 require.EqualError(t, err, tc.expectedErr)
123 })
124 }
125 }
126
127 func Test_callGoFunc(t *testing.T) {
128 tPtr := uintptr(unsafe.Pointer(t))
129 inst := &ModuleInstance{}
130
131 tests := []struct {
132 name string
133 input interface{}
134 inputParams, expectedResults []uint64
135 }{
136 {
137 name: "() -> ()",
138 input: func() {},
139 },
140 {
141 name: "(ctx) -> ()",
142 input: func(ctx context.Context) {
143 require.Equal(t, testCtx, ctx)
144 },
145 },
146 {
147 name: "(ctx, mod) -> ()",
148 input: func(ctx context.Context, m api.Module) {
149 require.Equal(t, testCtx, ctx)
150 require.Equal(t, inst, m)
151 },
152 },
153 {
154 name: "all supported params and i32 result",
155 input: func(v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
156 require.Equal(t, tPtr, v)
157 require.Equal(t, uint32(math.MaxUint32), w)
158 require.Equal(t, uint64(math.MaxUint64), x)
159 require.Equal(t, float32(math.MaxFloat32), y)
160 require.Equal(t, math.MaxFloat64, z)
161 return 100
162 },
163 inputParams: []uint64{
164 api.EncodeExternref(tPtr),
165 math.MaxUint32,
166 math.MaxUint64,
167 api.EncodeF32(math.MaxFloat32),
168 api.EncodeF64(math.MaxFloat64),
169 },
170 expectedResults: []uint64{100},
171 },
172 {
173 name: "all supported params and i32 result - (ctx)",
174 input: func(ctx context.Context, v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
175 require.Equal(t, testCtx, ctx)
176 require.Equal(t, tPtr, v)
177 require.Equal(t, uint32(math.MaxUint32), w)
178 require.Equal(t, uint64(math.MaxUint64), x)
179 require.Equal(t, float32(math.MaxFloat32), y)
180 require.Equal(t, math.MaxFloat64, z)
181 return 100
182 },
183 inputParams: []uint64{
184 api.EncodeExternref(tPtr),
185 math.MaxUint32,
186 math.MaxUint64,
187 api.EncodeF32(math.MaxFloat32),
188 api.EncodeF64(math.MaxFloat64),
189 },
190 expectedResults: []uint64{100},
191 },
192 {
193 name: "all supported params and i32 result - (ctx, mod)",
194 input: func(ctx context.Context, m api.Module, v uintptr, w uint32, x uint64, y float32, z float64) uint32 {
195 require.Equal(t, testCtx, ctx)
196 require.Equal(t, inst, m)
197 require.Equal(t, tPtr, v)
198 require.Equal(t, uint32(math.MaxUint32), w)
199 require.Equal(t, uint64(math.MaxUint64), x)
200 require.Equal(t, float32(math.MaxFloat32), y)
201 require.Equal(t, math.MaxFloat64, z)
202 return 100
203 },
204 inputParams: []uint64{
205 api.EncodeExternref(tPtr),
206 math.MaxUint32,
207 math.MaxUint64,
208 api.EncodeF32(math.MaxFloat32),
209 api.EncodeF64(math.MaxFloat64),
210 },
211 expectedResults: []uint64{100},
212 },
213 }
214 for _, tt := range tests {
215 tc := tt
216
217 t.Run(tc.name, func(t *testing.T) {
218 _, _, code, err := parseGoReflectFunc(tc.input)
219 require.NoError(t, err)
220
221 resultLen := len(tc.expectedResults)
222 stackLen := len(tc.inputParams)
223 if resultLen > stackLen {
224 stackLen = resultLen
225 }
226 stack := make([]uint64, stackLen)
227 copy(stack, tc.inputParams)
228
229 switch code.GoFunc.(type) {
230 case api.GoFunction:
231 code.GoFunc.(api.GoFunction).Call(testCtx, stack)
232 case api.GoModuleFunction:
233 code.GoFunc.(api.GoModuleFunction).Call(testCtx, inst, stack)
234 default:
235 t.Fatal("unexpected type.")
236 }
237
238 var results []uint64
239 if resultLen > 0 {
240 results = stack[:resultLen]
241 }
242 require.Equal(t, tc.expectedResults, results)
243 })
244 }
245 }
246
View as plain text