1// Package sqlite3 wraps the C SQLite API.
2package sqlite3
3
4import (
5 "context"
6 "math/bits"
7 "os"
8 "sync"
9 "unsafe"
10
11 "github.com/tetratelabs/wazero"
12 "github.com/tetratelabs/wazero/api"
13
14 "github.com/ncruces/go-sqlite3/internal/util"
15 "github.com/ncruces/go-sqlite3/vfs"
16)
17
18// Configure SQLite Wasm.
19//
20// Importing package embed initializes [Binary]
21// with an appropriate build of SQLite:
22//
23// import _ "github.com/ncruces/go-sqlite3/embed"
24var (
25 Binary []byte // Wasm binary to load.
26 Path string // Path to load the binary from.
27
28 RuntimeConfig wazero.RuntimeConfig
29)
30
31// Initialize decodes and compiles the SQLite Wasm binary.
32// This is called implicitly when the first connection is openned,
33// but is potentially slow, so you may want to call it at a more convenient time.
34func Initialize() error {
35 instance.once.Do(compileSQLite)
36 return instance.err
37}
38
39var instance struct {
40 runtime wazero.Runtime
41 compiled wazero.CompiledModule
42 err error
43 once sync.Once
44}
45
46func compileSQLite() {
47 ctx := context.Background()
48 cfg := RuntimeConfig
49 if cfg == nil {
50 cfg = wazero.NewRuntimeConfig()
51 if bits.UintSize < 64 {
52 cfg = cfg.WithMemoryLimitPages(512) // 32MB
53 } else {
54 cfg = cfg.WithMemoryLimitPages(4096) // 256MB
55 }
56 cfg = cfg.WithCoreFeatures(api.CoreFeaturesV2)
57 }
58
59 instance.runtime = wazero.NewRuntimeWithConfig(ctx, cfg)
60
61 env := instance.runtime.NewHostModuleBuilder("env")
62 env = vfs.ExportHostFunctions(env)
63 env = exportCallbacks(env)
64 _, instance.err = env.Instantiate(ctx)
65 if instance.err != nil {
66 return
67 }
68
69 bin := Binary
70 if bin == nil && Path != "" {
71 bin, instance.err = os.ReadFile(Path)
72 if instance.err != nil {
73 return
74 }
75 }
76 if bin == nil {
77 instance.err = util.NoBinaryErr
78 return
79 }
80
81 instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
82}
83
84type sqlite struct {
85 ctx context.Context
86 mod api.Module
87 funcs struct {
88 fn [32]api.Function
89 id [32]*byte
90 mask uint32
91 }
92 stack [9]stk_t
93}
94
95func instantiateSQLite() (sqlt *sqlite, err error) {
96 if err := Initialize(); err != nil {
97 return nil, err
98 }
99
100 sqlt = new(sqlite)
101 sqlt.ctx = util.NewContext(context.Background())
102
103 sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
104 instance.compiled, wazero.NewModuleConfig().WithName(""))
105 if err != nil {
106 return nil, err
107 }
108 if sqlt.getfn("sqlite3_progress_handler_go") == nil {
109 return nil, util.BadBinaryErr
110 }
111 return sqlt, nil
112}
113
114func (sqlt *sqlite) close() error {
115 return sqlt.mod.Close(sqlt.ctx)
116}
117
118func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
119 if rc == _OK {
120 return nil
121 }
122
123 if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM {
124 panic(util.OOMErr)
125 }
126
127 if handle != 0 {
128 var msg, query string
129 if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
130 msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
131 switch {
132 case msg == "not an error":
133 msg = ""
134 case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
135 msg = ""
136 }
137 }
138
139 if len(sql) != 0 {
140 if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
141 query = sql[0][i:]
142 }
143 }
144
145 if msg != "" || query != "" {
146 return &Error{code: rc, msg: msg, sql: query}
147 }
148 }
149 return xErrorCode(rc)
150}
151
152func (sqlt *sqlite) getfn(name string) api.Function {
153 c := &sqlt.funcs
154 p := unsafe.StringData(name)
155 for i := range c.id {
156 if c.id[i] == p {
157 c.id[i] = nil
158 c.mask &^= uint32(1) << i
159 return c.fn[i]
160 }
161 }
162 return sqlt.mod.ExportedFunction(name)
163}
164
165func (sqlt *sqlite) putfn(name string, fn api.Function) {
166 c := &sqlt.funcs
167 p := unsafe.StringData(name)
168 i := bits.TrailingZeros32(^c.mask)
169 if i < 32 {
170 c.id[i] = p
171 c.fn[i] = fn
172 c.mask |= uint32(1) << i
173 } else {
174 c.id[0] = p
175 c.fn[0] = fn
176 c.mask = uint32(1)
177 }
178}
179
180func (sqlt *sqlite) call(name string, params ...stk_t) stk_t {
181 copy(sqlt.stack[:], params)
182 fn := sqlt.getfn(name)
183 err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
184 if err != nil {
185 panic(err)
186 }
187 sqlt.putfn(name, fn)
188 return stk_t(sqlt.stack[0])
189}
190
191func (sqlt *sqlite) free(ptr ptr_t) {
192 if ptr == 0 {
193 return
194 }
195 sqlt.call("sqlite3_free", stk_t(ptr))
196}
197
198func (sqlt *sqlite) new(size int64) ptr_t {
199 ptr := ptr_t(sqlt.call("sqlite3_malloc64", stk_t(size)))
200 if ptr == 0 && size != 0 {
201 panic(util.OOMErr)
202 }
203 return ptr
204}
205
206func (sqlt *sqlite) realloc(ptr ptr_t, size int64) ptr_t {
207 ptr = ptr_t(sqlt.call("sqlite3_realloc64", stk_t(ptr), stk_t(size)))
208 if ptr == 0 && size != 0 {
209 panic(util.OOMErr)
210 }
211 return ptr
212}
213
214func (sqlt *sqlite) newBytes(b []byte) ptr_t {
215 if len(b) == 0 {
216 return 0
217 }
218 ptr := sqlt.new(int64(len(b)))
219 util.WriteBytes(sqlt.mod, ptr, b)
220 return ptr
221}
222
223func (sqlt *sqlite) newString(s string) ptr_t {
224 ptr := sqlt.new(int64(len(s)) + 1)
225 util.WriteString(sqlt.mod, ptr, s)
226 return ptr
227}
228
229const arenaSize = 4096
230
231func (sqlt *sqlite) newArena() arena {
232 return arena{
233 sqlt: sqlt,
234 base: sqlt.new(arenaSize),
235 }
236}
237
238type arena struct {
239 sqlt *sqlite
240 ptrs []ptr_t
241 base ptr_t
242 next int32
243}
244
245func (a *arena) free() {
246 if a.sqlt == nil {
247 return
248 }
249 for _, ptr := range a.ptrs {
250 a.sqlt.free(ptr)
251 }
252 a.sqlt.free(a.base)
253 a.sqlt = nil
254}
255
256func (a *arena) mark() (reset func()) {
257 ptrs := len(a.ptrs)
258 next := a.next
259 return func() {
260 rest := a.ptrs[ptrs:]
261 for _, ptr := range a.ptrs[:ptrs] {
262 a.sqlt.free(ptr)
263 }
264 a.ptrs = rest
265 a.next = next
266 }
267}
268
269func (a *arena) new(size int64) ptr_t {
270 // Align the next address, to 4 or 8 bytes.
271 if size&7 != 0 {
272 a.next = (a.next + 3) &^ 3
273 } else {
274 a.next = (a.next + 7) &^ 7
275 }
276 if size <= arenaSize-int64(a.next) {
277 ptr := a.base + ptr_t(a.next)
278 a.next += int32(size)
279 return ptr_t(ptr)
280 }
281 ptr := a.sqlt.new(size)
282 a.ptrs = append(a.ptrs, ptr)
283 return ptr_t(ptr)
284}
285
286func (a *arena) bytes(b []byte) ptr_t {
287 if len(b) == 0 {
288 return 0
289 }
290 ptr := a.new(int64(len(b)))
291 util.WriteBytes(a.sqlt.mod, ptr, b)
292 return ptr
293}
294
295func (a *arena) string(s string) ptr_t {
296 ptr := a.new(int64(len(s)) + 1)
297 util.WriteString(a.sqlt.mod, ptr, s)
298 return ptr
299}
300
301func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
302 util.ExportFuncII(env, "go_progress_handler", progressCallback)
303 util.ExportFuncIII(env, "go_busy_timeout", timeoutCallback)
304 util.ExportFuncIII(env, "go_busy_handler", busyCallback)
305 util.ExportFuncII(env, "go_commit_hook", commitCallback)
306 util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)
307 util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback)
308 util.ExportFuncIIIII(env, "go_wal_hook", walCallback)
309 util.ExportFuncIIIII(env, "go_trace", traceCallback)
310 util.ExportFuncIIIIII(env, "go_autovacuum_pages", autoVacuumCallback)
311 util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback)
312 util.ExportFuncVIII(env, "go_log", logCallback)
313 util.ExportFuncVI(env, "go_destroy", destroyCallback)
314 util.ExportFuncVIIII(env, "go_func", funcCallback)
315 util.ExportFuncVIIIII(env, "go_step", stepCallback)
316 util.ExportFuncVIIII(env, "go_value", valueCallback)
317 util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
318 util.ExportFuncVIIII(env, "go_collation_needed", collationCallback)
319 util.ExportFuncIIIIII(env, "go_compare", compareCallback)
320 util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate))
321 util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect))
322 util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
323 util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback)
324 util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)
325 util.ExportFuncIIIII(env, "go_vtab_update", vtabUpdateCallback)
326 util.ExportFuncIII(env, "go_vtab_rename", vtabRenameCallback)
327 util.ExportFuncIIIII(env, "go_vtab_find_function", vtabFindFuncCallback)
328 util.ExportFuncII(env, "go_vtab_begin", vtabBeginCallback)
329 util.ExportFuncII(env, "go_vtab_sync", vtabSyncCallback)
330 util.ExportFuncII(env, "go_vtab_commit", vtabCommitCallback)
331 util.ExportFuncII(env, "go_vtab_rollback", vtabRollbackCallback)
332 util.ExportFuncIII(env, "go_vtab_savepoint", vtabSavepointCallback)
333 util.ExportFuncIII(env, "go_vtab_release", vtabReleaseCallback)
334 util.ExportFuncIII(env, "go_vtab_rollback_to", vtabRollbackToCallback)
335 util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback)
336 util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback)
337 util.ExportFuncII(env, "go_cur_close", cursorCloseCallback)
338 util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback)
339 util.ExportFuncII(env, "go_cur_next", cursorNextCallback)
340 util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback)
341 util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback)
342 util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback)
343 return env
344}