1 package compiler
2
3 import (
4 "bytes"
5 "encoding/binary"
6 "fmt"
7 "io"
8 "runtime"
9
10 "github.com/tetratelabs/wazero/experimental"
11 "github.com/tetratelabs/wazero/internal/platform"
12 "github.com/tetratelabs/wazero/internal/u32"
13 "github.com/tetratelabs/wazero/internal/u64"
14 "github.com/tetratelabs/wazero/internal/wasm"
15 )
16
17 func (e *engine) deleteCompiledModule(module *wasm.Module) {
18 e.mux.Lock()
19 defer e.mux.Unlock()
20
21 delete(e.codes, module.ID)
22
23
24
25 }
26
27 func (e *engine) addCompiledModule(module *wasm.Module, cm *compiledModule, withGoFunc bool) (err error) {
28 e.addCompiledModuleToMemory(module, cm)
29 if !withGoFunc {
30 err = e.addCompiledModuleToCache(module, cm)
31 }
32 return
33 }
34
35 func (e *engine) getCompiledModule(module *wasm.Module, listeners []experimental.FunctionListener) (cm *compiledModule, ok bool, err error) {
36 cm, ok = e.getCompiledModuleFromMemory(module)
37 if ok {
38 return
39 }
40 cm, ok, err = e.getCompiledModuleFromCache(module)
41 if ok {
42 e.addCompiledModuleToMemory(module, cm)
43 if len(listeners) > 0 {
44
45 for i := range cm.functions {
46 cm.functions[i].listener = listeners[i]
47 }
48 }
49
50
51 e.setFinalizer(cm, releaseCompiledModule)
52 }
53 return
54 }
55
56 func (e *engine) addCompiledModuleToMemory(module *wasm.Module, cm *compiledModule) {
57 e.mux.Lock()
58 defer e.mux.Unlock()
59 e.codes[module.ID] = cm
60 }
61
62 func (e *engine) getCompiledModuleFromMemory(module *wasm.Module) (cm *compiledModule, ok bool) {
63 e.mux.RLock()
64 defer e.mux.RUnlock()
65 cm, ok = e.codes[module.ID]
66 return
67 }
68
69 func (e *engine) addCompiledModuleToCache(module *wasm.Module, cm *compiledModule) (err error) {
70 if e.fileCache == nil || module.IsHostModule {
71 return
72 }
73 err = e.fileCache.Add(module.ID, serializeCompiledModule(e.wazeroVersion, cm))
74 return
75 }
76
77 func (e *engine) getCompiledModuleFromCache(module *wasm.Module) (cm *compiledModule, hit bool, err error) {
78 if e.fileCache == nil || module.IsHostModule {
79 return
80 }
81
82
83 var cached io.ReadCloser
84 cached, hit, err = e.fileCache.Get(module.ID)
85 if !hit || err != nil {
86 return
87 }
88
89
90
91 var staleCache bool
92
93 cm, staleCache, err = deserializeCompiledModule(e.wazeroVersion, cached, module)
94 if err != nil {
95 hit = false
96 return
97 } else if staleCache {
98 return nil, false, e.fileCache.Delete(module.ID)
99 }
100
101 cm.source = module
102 return
103 }
104
105 var wazeroMagic = "WAZERO"
106
107 func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader {
108 buf := bytes.NewBuffer(nil)
109
110 buf.WriteString(wazeroMagic)
111
112 buf.WriteByte(byte(len(wazeroVersion)))
113
114 buf.WriteString(wazeroVersion)
115 if cm.ensureTermination {
116 buf.WriteByte(1)
117 } else {
118 buf.WriteByte(0)
119 }
120
121 buf.Write(u32.LeBytes(uint32(len(cm.functions))))
122 for i := 0; i < len(cm.functions); i++ {
123 f := &cm.functions[i]
124
125 buf.Write(u64.LeBytes(f.stackPointerCeil))
126
127 buf.Write(u64.LeBytes(uint64(f.executableOffset)))
128 }
129
130 buf.Write(u64.LeBytes(uint64(cm.executable.Len())))
131
132 buf.Write(cm.executable.Bytes())
133 return bytes.NewReader(buf.Bytes())
134 }
135
136 func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, module *wasm.Module) (cm *compiledModule, staleCache bool, err error) {
137 defer reader.Close()
138 cacheHeaderSize := len(wazeroMagic) + 1 + len(wazeroVersion) + 1 + 4
139
140
141 header := make([]byte, cacheHeaderSize)
142 n, err := reader.Read(header)
143 if err != nil {
144 return nil, false, fmt.Errorf("compilationcache: error reading header: %v", err)
145 }
146
147 if n != cacheHeaderSize {
148 return nil, false, fmt.Errorf("compilationcache: invalid header length: %d", n)
149 }
150
151
152 versionSize := int(header[len(wazeroMagic)])
153
154 cachedVersionBegin, cachedVersionEnd := len(wazeroMagic)+1, len(wazeroMagic)+1+versionSize
155 if cachedVersionEnd >= len(header) {
156 staleCache = true
157 return
158 } else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion {
159 staleCache = true
160 return
161 }
162
163 ensureTermination := header[cachedVersionEnd] != 0
164 functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
165 cm = &compiledModule{
166 compiledCode: new(compiledCode),
167 functions: make([]compiledFunction, functionsNum),
168 ensureTermination: ensureTermination,
169 }
170
171 imported := module.ImportFunctionCount
172
173 var eightBytes [8]byte
174 for i := uint32(0); i < functionsNum; i++ {
175 f := &cm.functions[i]
176 f.parent = cm.compiledCode
177
178
179 if f.stackPointerCeil, err = readUint64(reader, &eightBytes); err != nil {
180 err = fmt.Errorf("compilationcache: error reading func[%d] stack pointer ceil: %v", i, err)
181 return
182 }
183
184
185 var offset uint64
186 if offset, err = readUint64(reader, &eightBytes); err != nil {
187 err = fmt.Errorf("compilationcache: error reading func[%d] executable offset: %v", i, err)
188 return
189 }
190 f.executableOffset = uintptr(offset)
191 f.index = imported + i
192 }
193
194 executableLen, err := readUint64(reader, &eightBytes)
195 if err != nil {
196 err = fmt.Errorf("compilationcache: error reading executable size: %v", err)
197 return
198 }
199
200 if executableLen > 0 {
201 if err = cm.executable.Map(int(executableLen)); err != nil {
202 err = fmt.Errorf("compilationcache: error mmapping executable (len=%d): %v", executableLen, err)
203 return
204 }
205
206 _, err = io.ReadFull(reader, cm.executable.Bytes())
207 if err != nil {
208 err = fmt.Errorf("compilationcache: error reading executable (len=%d): %v", executableLen, err)
209 return
210 }
211
212 if runtime.GOARCH == "arm64" {
213
214 if err = platform.MprotectRX(cm.executable.Bytes()); err != nil {
215 return
216 }
217 }
218 }
219 return
220 }
221
222
223
224 func readUint64(reader io.Reader, b *[8]byte) (uint64, error) {
225 s := b[0:8]
226 n, err := reader.Read(s)
227 if err != nil {
228 return 0, err
229 } else if n < 8 {
230 return 0, io.EOF
231 }
232
233
234 ret := binary.LittleEndian.Uint64(s)
235
236
237 for i := 0; i < 8; i++ {
238 b[i] = 0
239 }
240 return ret, nil
241 }
242
View as plain text