engine_cache.go

  1package wazevo
  2
  3import (
  4	"bytes"
  5	"context"
  6	"crypto/sha256"
  7	"encoding/binary"
  8	"fmt"
  9	"hash/crc32"
 10	"io"
 11	"runtime"
 12	"unsafe"
 13
 14	"github.com/tetratelabs/wazero/experimental"
 15	"github.com/tetratelabs/wazero/internal/engine/wazevo/backend"
 16	"github.com/tetratelabs/wazero/internal/engine/wazevo/ssa"
 17	"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
 18	"github.com/tetratelabs/wazero/internal/filecache"
 19	"github.com/tetratelabs/wazero/internal/platform"
 20	"github.com/tetratelabs/wazero/internal/u32"
 21	"github.com/tetratelabs/wazero/internal/u64"
 22	"github.com/tetratelabs/wazero/internal/wasm"
 23)
 24
 25var crc = crc32.MakeTable(crc32.Castagnoli)
 26
 27// fileCacheKey returns a key for the file cache.
 28// In order to avoid collisions with the existing compiler, we do not use m.ID directly,
 29// but instead we rehash it with magic.
 30func fileCacheKey(m *wasm.Module) (ret filecache.Key) {
 31	s := sha256.New()
 32	s.Write(m.ID[:])
 33	s.Write(magic)
 34	// Write the CPU features so that we can cache the compiled module for the same CPU.
 35	// This prevents the incompatible CPU features from being used.
 36	cpu := platform.CpuFeatures.Raw()
 37	// Reuse the `ret` buffer to write the first 8 bytes of the CPU features so that we can avoid the allocation.
 38	binary.LittleEndian.PutUint64(ret[:8], cpu)
 39	s.Write(ret[:8])
 40	// Finally, write the hash to the ret buffer.
 41	s.Sum(ret[:0])
 42	return
 43}
 44
 45func (e *engine) addCompiledModule(module *wasm.Module, cm *compiledModule) (err error) {
 46	e.addCompiledModuleToMemory(module, cm)
 47	if !module.IsHostModule && e.fileCache != nil {
 48		err = e.addCompiledModuleToCache(module, cm)
 49	}
 50	return
 51}
 52
 53func (e *engine) getCompiledModule(module *wasm.Module, listeners []experimental.FunctionListener, ensureTermination bool) (cm *compiledModule, ok bool, err error) {
 54	cm, ok = e.getCompiledModuleFromMemory(module)
 55	if ok {
 56		return
 57	}
 58	cm, ok, err = e.getCompiledModuleFromCache(module)
 59	if ok {
 60		cm.parent = e
 61		cm.module = module
 62		cm.sharedFunctions = e.sharedFunctions
 63		cm.ensureTermination = ensureTermination
 64		cm.offsets = wazevoapi.NewModuleContextOffsetData(module, len(listeners) > 0)
 65		if len(listeners) > 0 {
 66			cm.listeners = listeners
 67			cm.listenerBeforeTrampolines = make([]*byte, len(module.TypeSection))
 68			cm.listenerAfterTrampolines = make([]*byte, len(module.TypeSection))
 69			for i := range module.TypeSection {
 70				typ := &module.TypeSection[i]
 71				before, after := e.getListenerTrampolineForType(typ)
 72				cm.listenerBeforeTrampolines[i] = before
 73				cm.listenerAfterTrampolines[i] = after
 74			}
 75		}
 76		e.addCompiledModuleToMemory(module, cm)
 77		ssaBuilder := ssa.NewBuilder()
 78		machine := newMachine()
 79		be := backend.NewCompiler(context.Background(), machine, ssaBuilder)
 80		cm.executables.compileEntryPreambles(module, machine, be)
 81
 82		// Set the finalizer.
 83		e.setFinalizer(cm.executables, executablesFinalizer)
 84	}
 85	return
 86}
 87
 88func (e *engine) addCompiledModuleToMemory(m *wasm.Module, cm *compiledModule) {
 89	e.mux.Lock()
 90	defer e.mux.Unlock()
 91	e.compiledModules[m.ID] = cm
 92	if len(cm.executable) > 0 {
 93		e.addCompiledModuleToSortedList(cm)
 94	}
 95}
 96
 97func (e *engine) getCompiledModuleFromMemory(module *wasm.Module) (cm *compiledModule, ok bool) {
 98	e.mux.RLock()
 99	defer e.mux.RUnlock()
100	cm, ok = e.compiledModules[module.ID]
101	return
102}
103
104func (e *engine) addCompiledModuleToCache(module *wasm.Module, cm *compiledModule) (err error) {
105	if e.fileCache == nil || module.IsHostModule {
106		return
107	}
108	err = e.fileCache.Add(fileCacheKey(module), serializeCompiledModule(e.wazeroVersion, cm))
109	return
110}
111
112func (e *engine) getCompiledModuleFromCache(module *wasm.Module) (cm *compiledModule, hit bool, err error) {
113	if e.fileCache == nil || module.IsHostModule {
114		return
115	}
116
117	// Check if the entries exist in the external cache.
118	var cached io.ReadCloser
119	cached, hit, err = e.fileCache.Get(fileCacheKey(module))
120	if !hit || err != nil {
121		return
122	}
123
124	// Otherwise, we hit the cache on external cache.
125	// We retrieve *code structures from `cached`.
126	var staleCache bool
127	// Note: cached.Close is ensured to be called in deserializeCodes.
128	cm, staleCache, err = deserializeCompiledModule(e.wazeroVersion, cached)
129	if err != nil {
130		hit = false
131		return
132	} else if staleCache {
133		return nil, false, e.fileCache.Delete(fileCacheKey(module))
134	}
135	return
136}
137
138var magic = []byte{'W', 'A', 'Z', 'E', 'V', 'O'}
139
140func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader {
141	buf := bytes.NewBuffer(nil)
142	// First 6 byte: WAZEVO header.
143	buf.Write(magic)
144	// Next 1 byte: length of version:
145	buf.WriteByte(byte(len(wazeroVersion)))
146	// Version of wazero.
147	buf.WriteString(wazeroVersion)
148	// Number of *code (== locally defined functions in the module): 4 bytes.
149	buf.Write(u32.LeBytes(uint32(len(cm.functionOffsets))))
150	for _, offset := range cm.functionOffsets {
151		// The offset of this function in the executable (8 bytes).
152		buf.Write(u64.LeBytes(uint64(offset)))
153	}
154	// The length of code segment (8 bytes).
155	buf.Write(u64.LeBytes(uint64(len(cm.executable))))
156	// Append the native code.
157	buf.Write(cm.executable)
158	// Append checksum.
159	checksum := crc32.Checksum(cm.executable, crc)
160	buf.Write(u32.LeBytes(checksum))
161	if sm := cm.sourceMap; len(sm.executableOffsets) > 0 {
162		buf.WriteByte(1) // indicates that source map is present.
163		l := len(sm.wasmBinaryOffsets)
164		buf.Write(u64.LeBytes(uint64(l)))
165		executableAddr := uintptr(unsafe.Pointer(&cm.executable[0]))
166		for i := 0; i < l; i++ {
167			buf.Write(u64.LeBytes(sm.wasmBinaryOffsets[i]))
168			// executableOffsets is absolute address, so we need to subtract executableAddr.
169			buf.Write(u64.LeBytes(uint64(sm.executableOffsets[i] - executableAddr)))
170		}
171	} else {
172		buf.WriteByte(0) // indicates that source map is not present.
173	}
174	return bytes.NewReader(buf.Bytes())
175}
176
177func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser) (cm *compiledModule, staleCache bool, err error) {
178	defer reader.Close()
179	cacheHeaderSize := len(magic) + 1 /* version size */ + len(wazeroVersion) + 4 /* number of functions */
180
181	// Read the header before the native code.
182	header := make([]byte, cacheHeaderSize)
183	n, err := reader.Read(header)
184	if err != nil {
185		return nil, false, fmt.Errorf("compilationcache: error reading header: %v", err)
186	}
187
188	if n != cacheHeaderSize {
189		return nil, false, fmt.Errorf("compilationcache: invalid header length: %d", n)
190	}
191
192	if !bytes.Equal(header[:len(magic)], magic) {
193		return nil, false, fmt.Errorf(
194			"compilationcache: invalid magic number: got %s but want %s", magic, header[:len(magic)])
195	}
196
197	// Check the version compatibility.
198	versionSize := int(header[len(magic)])
199
200	cachedVersionBegin, cachedVersionEnd := len(magic)+1, len(magic)+1+versionSize
201	if cachedVersionEnd >= len(header) {
202		staleCache = true
203		return
204	} else if cachedVersion := string(header[cachedVersionBegin:cachedVersionEnd]); cachedVersion != wazeroVersion {
205		staleCache = true
206		return
207	}
208
209	functionsNum := binary.LittleEndian.Uint32(header[len(header)-4:])
210	cm = &compiledModule{functionOffsets: make([]int, functionsNum), executables: &executables{}}
211
212	var eightBytes [8]byte
213	for i := uint32(0); i < functionsNum; i++ {
214		// Read the offset of each function in the executable.
215		var offset uint64
216		if offset, err = readUint64(reader, &eightBytes); err != nil {
217			err = fmt.Errorf("compilationcache: error reading func[%d] executable offset: %v", i, err)
218			return
219		}
220		cm.functionOffsets[i] = int(offset)
221	}
222
223	executableLen, err := readUint64(reader, &eightBytes)
224	if err != nil {
225		err = fmt.Errorf("compilationcache: error reading executable size: %v", err)
226		return
227	}
228
229	if executableLen > 0 {
230		executable, err := platform.MmapCodeSegment(int(executableLen))
231		if err != nil {
232			err = fmt.Errorf("compilationcache: error mmapping executable (len=%d): %v", executableLen, err)
233			return nil, false, err
234		}
235
236		_, err = io.ReadFull(reader, executable)
237		if err != nil {
238			err = fmt.Errorf("compilationcache: error reading executable (len=%d): %v", executableLen, err)
239			return nil, false, err
240		}
241
242		expected := crc32.Checksum(executable, crc)
243		if _, err = io.ReadFull(reader, eightBytes[:4]); err != nil {
244			return nil, false, fmt.Errorf("compilationcache: could not read checksum: %v", err)
245		} else if checksum := binary.LittleEndian.Uint32(eightBytes[:4]); expected != checksum {
246			return nil, false, fmt.Errorf("compilationcache: checksum mismatch (expected %d, got %d)", expected, checksum)
247		}
248
249		if runtime.GOARCH == "arm64" {
250			// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
251			if err = platform.MprotectRX(executable); err != nil {
252				return nil, false, err
253			}
254		}
255		cm.executable = executable
256	}
257
258	if _, err := io.ReadFull(reader, eightBytes[:1]); err != nil {
259		return nil, false, fmt.Errorf("compilationcache: error reading source map presence: %v", err)
260	}
261
262	if eightBytes[0] == 1 {
263		sm := &cm.sourceMap
264		sourceMapLen, err := readUint64(reader, &eightBytes)
265		if err != nil {
266			err = fmt.Errorf("compilationcache: error reading source map length: %v", err)
267			return nil, false, err
268		}
269		executableOffset := uintptr(unsafe.Pointer(&cm.executable[0]))
270		for i := uint64(0); i < sourceMapLen; i++ {
271			wasmBinaryOffset, err := readUint64(reader, &eightBytes)
272			if err != nil {
273				err = fmt.Errorf("compilationcache: error reading source map[%d] wasm binary offset: %v", i, err)
274				return nil, false, err
275			}
276			executableRelativeOffset, err := readUint64(reader, &eightBytes)
277			if err != nil {
278				err = fmt.Errorf("compilationcache: error reading source map[%d] executable offset: %v", i, err)
279				return nil, false, err
280			}
281			sm.wasmBinaryOffsets = append(sm.wasmBinaryOffsets, wasmBinaryOffset)
282			// executableOffsets is absolute address, so we need to add executableOffset.
283			sm.executableOffsets = append(sm.executableOffsets, uintptr(executableRelativeOffset)+executableOffset)
284		}
285	}
286	return
287}
288
289// readUint64 strictly reads an uint64 in little-endian byte order, using the
290// given array as a buffer. This returns io.EOF if less than 8 bytes were read.
291func readUint64(reader io.Reader, b *[8]byte) (uint64, error) {
292	s := b[0:8]
293	n, err := reader.Read(s)
294	if err != nil {
295		return 0, err
296	} else if n < 8 { // more strict than reader.Read
297		return 0, io.EOF
298	}
299
300	// Read the u64 from the underlying buffer.
301	ret := binary.LittleEndian.Uint64(s)
302	return ret, nil
303}