sqlite.go

  1// Package sqlite3 wraps the C SQLite API.
  2package sqlite3
  3
  4import (
  5	"context"
  6	"math/bits"
  7	"os"
  8	"sync"
  9	"unsafe"
 10
 11	"github.com/tetratelabs/wazero"
 12	"github.com/tetratelabs/wazero/api"
 13
 14	"github.com/ncruces/go-sqlite3/internal/util"
 15	"github.com/ncruces/go-sqlite3/vfs"
 16)
 17
 18// Configure SQLite Wasm.
 19//
 20// Importing package embed initializes [Binary]
 21// with an appropriate build of SQLite:
 22//
 23//	import _ "github.com/ncruces/go-sqlite3/embed"
 24var (
 25	Binary []byte // Wasm binary to load.
 26	Path   string // Path to load the binary from.
 27
 28	RuntimeConfig wazero.RuntimeConfig
 29)
 30
 31// Initialize decodes and compiles the SQLite Wasm binary.
 32// This is called implicitly when the first connection is openned,
 33// but is potentially slow, so you may want to call it at a more convenient time.
 34func Initialize() error {
 35	instance.once.Do(compileSQLite)
 36	return instance.err
 37}
 38
 39var instance struct {
 40	runtime  wazero.Runtime
 41	compiled wazero.CompiledModule
 42	err      error
 43	once     sync.Once
 44}
 45
 46func compileSQLite() {
 47	ctx := context.Background()
 48	cfg := RuntimeConfig
 49	if cfg == nil {
 50		cfg = wazero.NewRuntimeConfig()
 51		if bits.UintSize < 64 {
 52			cfg = cfg.WithMemoryLimitPages(512) // 32MB
 53		} else {
 54			cfg = cfg.WithMemoryLimitPages(4096) // 256MB
 55		}
 56		cfg = cfg.WithCoreFeatures(api.CoreFeaturesV2)
 57	}
 58
 59	instance.runtime = wazero.NewRuntimeWithConfig(ctx, cfg)
 60
 61	env := instance.runtime.NewHostModuleBuilder("env")
 62	env = vfs.ExportHostFunctions(env)
 63	env = exportCallbacks(env)
 64	_, instance.err = env.Instantiate(ctx)
 65	if instance.err != nil {
 66		return
 67	}
 68
 69	bin := Binary
 70	if bin == nil && Path != "" {
 71		bin, instance.err = os.ReadFile(Path)
 72		if instance.err != nil {
 73			return
 74		}
 75	}
 76	if bin == nil {
 77		instance.err = util.NoBinaryErr
 78		return
 79	}
 80
 81	instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
 82}
 83
 84type sqlite struct {
 85	ctx   context.Context
 86	mod   api.Module
 87	funcs struct {
 88		fn   [32]api.Function
 89		id   [32]*byte
 90		mask uint32
 91	}
 92	stack [9]stk_t
 93}
 94
 95func instantiateSQLite() (sqlt *sqlite, err error) {
 96	if err := Initialize(); err != nil {
 97		return nil, err
 98	}
 99
100	sqlt = new(sqlite)
101	sqlt.ctx = util.NewContext(context.Background())
102
103	sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
104		instance.compiled, wazero.NewModuleConfig().WithName(""))
105	if err != nil {
106		return nil, err
107	}
108	if sqlt.getfn("sqlite3_progress_handler_go") == nil {
109		return nil, util.BadBinaryErr
110	}
111	return sqlt, nil
112}
113
114func (sqlt *sqlite) close() error {
115	return sqlt.mod.Close(sqlt.ctx)
116}
117
118func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
119	if rc == _OK {
120		return nil
121	}
122
123	if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM {
124		panic(util.OOMErr)
125	}
126
127	if handle != 0 {
128		var msg, query string
129		if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
130			msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
131			switch {
132			case msg == "not an error":
133				msg = ""
134			case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
135				msg = ""
136			}
137		}
138
139		if len(sql) != 0 {
140			if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
141				query = sql[0][i:]
142			}
143		}
144
145		if msg != "" || query != "" {
146			return &Error{code: rc, msg: msg, sql: query}
147		}
148	}
149	return xErrorCode(rc)
150}
151
152func (sqlt *sqlite) getfn(name string) api.Function {
153	c := &sqlt.funcs
154	p := unsafe.StringData(name)
155	for i := range c.id {
156		if c.id[i] == p {
157			c.id[i] = nil
158			c.mask &^= uint32(1) << i
159			return c.fn[i]
160		}
161	}
162	return sqlt.mod.ExportedFunction(name)
163}
164
165func (sqlt *sqlite) putfn(name string, fn api.Function) {
166	c := &sqlt.funcs
167	p := unsafe.StringData(name)
168	i := bits.TrailingZeros32(^c.mask)
169	if i < 32 {
170		c.id[i] = p
171		c.fn[i] = fn
172		c.mask |= uint32(1) << i
173	} else {
174		c.id[0] = p
175		c.fn[0] = fn
176		c.mask = uint32(1)
177	}
178}
179
180func (sqlt *sqlite) call(name string, params ...stk_t) stk_t {
181	copy(sqlt.stack[:], params)
182	fn := sqlt.getfn(name)
183	err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
184	if err != nil {
185		panic(err)
186	}
187	sqlt.putfn(name, fn)
188	return stk_t(sqlt.stack[0])
189}
190
191func (sqlt *sqlite) free(ptr ptr_t) {
192	if ptr == 0 {
193		return
194	}
195	sqlt.call("sqlite3_free", stk_t(ptr))
196}
197
198func (sqlt *sqlite) new(size int64) ptr_t {
199	ptr := ptr_t(sqlt.call("sqlite3_malloc64", stk_t(size)))
200	if ptr == 0 && size != 0 {
201		panic(util.OOMErr)
202	}
203	return ptr
204}
205
206func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t {
207	ptr = ptr_t(sqlt.call("sqlite3_realloc64", stk_t(ptr), stk_t(size)))
208	if ptr == 0 && size != 0 {
209		panic(util.OOMErr)
210	}
211	return ptr
212}
213
214func (sqlt *sqlite) newBytes(b []byte) ptr_t {
215	if len(b) == 0 {
216		return 0
217	}
218	ptr := sqlt.new(int64(len(b)))
219	util.WriteBytes(sqlt.mod, ptr, b)
220	return ptr
221}
222
223func (sqlt *sqlite) newString(s string) ptr_t {
224	ptr := sqlt.new(int64(len(s)) + 1)
225	util.WriteString(sqlt.mod, ptr, s)
226	return ptr
227}
228
229const arenaSize = 4096
230
231func (sqlt *sqlite) newArena() arena {
232	return arena{
233		sqlt: sqlt,
234		base: sqlt.new(arenaSize),
235	}
236}
237
238type arena struct {
239	sqlt *sqlite
240	ptrs []ptr_t
241	base ptr_t
242	next int32
243}
244
245func (a *arena) free() {
246	if a.sqlt == nil {
247		return
248	}
249	for _, ptr := range a.ptrs {
250		a.sqlt.free(ptr)
251	}
252	a.sqlt.free(a.base)
253	a.sqlt = nil
254}
255
256func (a *arena) mark() (reset func()) {
257	ptrs := len(a.ptrs)
258	next := a.next
259	return func() {
260		rest := a.ptrs[ptrs:]
261		for _, ptr := range a.ptrs[:ptrs] {
262			a.sqlt.free(ptr)
263		}
264		a.ptrs = rest
265		a.next = next
266	}
267}
268
269func (a *arena) new(size int64) ptr_t {
270	// Align the next address, to 4 or 8 bytes.
271	if size&7 != 0 {
272		a.next = (a.next + 3) &^ 3
273	} else {
274		a.next = (a.next + 7) &^ 7
275	}
276	if size <= arenaSize-int64(a.next) {
277		ptr := a.base + ptr_t(a.next)
278		a.next += int32(size)
279		return ptr_t(ptr)
280	}
281	ptr := a.sqlt.new(size)
282	a.ptrs = append(a.ptrs, ptr)
283	return ptr_t(ptr)
284}
285
286func (a *arena) bytes(b []byte) ptr_t {
287	if len(b) == 0 {
288		return 0
289	}
290	ptr := a.new(int64(len(b)))
291	util.WriteBytes(a.sqlt.mod, ptr, b)
292	return ptr
293}
294
295func (a *arena) string(s string) ptr_t {
296	ptr := a.new(int64(len(s)) + 1)
297	util.WriteString(a.sqlt.mod, ptr, s)
298	return ptr
299}
300
301func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
302	util.ExportFuncII(env, "go_progress_handler", progressCallback)
303	util.ExportFuncIII(env, "go_busy_timeout", timeoutCallback)
304	util.ExportFuncIII(env, "go_busy_handler", busyCallback)
305	util.ExportFuncII(env, "go_commit_hook", commitCallback)
306	util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)
307	util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback)
308	util.ExportFuncIIIII(env, "go_wal_hook", walCallback)
309	util.ExportFuncIIIII(env, "go_trace", traceCallback)
310	util.ExportFuncIIIIII(env, "go_autovacuum_pages", autoVacuumCallback)
311	util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback)
312	util.ExportFuncVIII(env, "go_log", logCallback)
313	util.ExportFuncVI(env, "go_destroy", destroyCallback)
314	util.ExportFuncVIIII(env, "go_func", funcCallback)
315	util.ExportFuncVIIIII(env, "go_step", stepCallback)
316	util.ExportFuncVIIII(env, "go_value", valueCallback)
317	util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
318	util.ExportFuncVIIII(env, "go_collation_needed", collationCallback)
319	util.ExportFuncIIIIII(env, "go_compare", compareCallback)
320	util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate))
321	util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect))
322	util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
323	util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback)
324	util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)
325	util.ExportFuncIIIII(env, "go_vtab_update", vtabUpdateCallback)
326	util.ExportFuncIII(env, "go_vtab_rename", vtabRenameCallback)
327	util.ExportFuncIIIII(env, "go_vtab_find_function", vtabFindFuncCallback)
328	util.ExportFuncII(env, "go_vtab_begin", vtabBeginCallback)
329	util.ExportFuncII(env, "go_vtab_sync", vtabSyncCallback)
330	util.ExportFuncII(env, "go_vtab_commit", vtabCommitCallback)
331	util.ExportFuncII(env, "go_vtab_rollback", vtabRollbackCallback)
332	util.ExportFuncIII(env, "go_vtab_savepoint", vtabSavepointCallback)
333	util.ExportFuncIII(env, "go_vtab_release", vtabReleaseCallback)
334	util.ExportFuncIII(env, "go_vtab_rollback_to", vtabRollbackToCallback)
335	util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback)
336	util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback)
337	util.ExportFuncII(env, "go_cur_close", cursorCloseCallback)
338	util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback)
339	util.ExportFuncII(env, "go_cur_next", cursorNextCallback)
340	util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback)
341	util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback)
342	util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback)
343	return env
344}