txn.go

  1package sqlite3
  2
  3import (
  4	"context"
  5	"math/rand"
  6	"runtime"
  7	"strconv"
  8	"strings"
  9
 10	"github.com/tetratelabs/wazero/api"
 11
 12	"github.com/ncruces/go-sqlite3/internal/util"
 13)
 14
 15// Txn is an in-progress database transaction.
 16//
 17// https://sqlite.org/lang_transaction.html
 18type Txn struct {
 19	c *Conn
 20}
 21
 22// Begin starts a deferred transaction.
 23// It panics if a transaction is in-progress.
 24// For nested transactions, use [Conn.Savepoint].
 25//
 26// https://sqlite.org/lang_transaction.html
 27func (c *Conn) Begin() Txn {
 28	// BEGIN even if interrupted.
 29	err := c.exec(`BEGIN DEFERRED`)
 30	if err != nil {
 31		panic(err)
 32	}
 33	return Txn{c}
 34}
 35
 36// BeginConcurrent starts a concurrent transaction.
 37//
 38// Experimental: requires a custom build of SQLite.
 39//
 40// https://sqlite.org/cgi/src/doc/begin-concurrent/doc/begin_concurrent.md
 41func (c *Conn) BeginConcurrent() (Txn, error) {
 42	err := c.Exec(`BEGIN CONCURRENT`)
 43	if err != nil {
 44		return Txn{}, err
 45	}
 46	return Txn{c}, nil
 47}
 48
 49// BeginImmediate starts an immediate transaction.
 50//
 51// https://sqlite.org/lang_transaction.html
 52func (c *Conn) BeginImmediate() (Txn, error) {
 53	err := c.Exec(`BEGIN IMMEDIATE`)
 54	if err != nil {
 55		return Txn{}, err
 56	}
 57	return Txn{c}, nil
 58}
 59
 60// BeginExclusive starts an exclusive transaction.
 61//
 62// https://sqlite.org/lang_transaction.html
 63func (c *Conn) BeginExclusive() (Txn, error) {
 64	err := c.Exec(`BEGIN EXCLUSIVE`)
 65	if err != nil {
 66		return Txn{}, err
 67	}
 68	return Txn{c}, nil
 69}
 70
 71// End calls either [Txn.Commit] or [Txn.Rollback]
 72// depending on whether *error points to a nil or non-nil error.
 73//
 74// This is meant to be deferred:
 75//
 76//	func doWork(db *sqlite3.Conn) (err error) {
 77//		tx := db.Begin()
 78//		defer tx.End(&err)
 79//
 80//		// ... do work in the transaction
 81//	}
 82//
 83// https://sqlite.org/lang_transaction.html
 84func (tx Txn) End(errp *error) {
 85	recovered := recover()
 86	if recovered != nil {
 87		defer panic(recovered)
 88	}
 89
 90	if *errp == nil && recovered == nil {
 91		// Success path.
 92		if tx.c.GetAutocommit() { // There is nothing to commit.
 93			return
 94		}
 95		*errp = tx.Commit()
 96		if *errp == nil {
 97			return
 98		}
 99		// Fall through to the error path.
100	}
101
102	// Error path.
103	if tx.c.GetAutocommit() { // There is nothing to rollback.
104		return
105	}
106	err := tx.Rollback()
107	if err != nil {
108		panic(err)
109	}
110}
111
112// Commit commits the transaction.
113//
114// https://sqlite.org/lang_transaction.html
115func (tx Txn) Commit() error {
116	return tx.c.Exec(`COMMIT`)
117}
118
119// Rollback rolls back the transaction,
120// even if the connection has been interrupted.
121//
122// https://sqlite.org/lang_transaction.html
123func (tx Txn) Rollback() error {
124	// ROLLBACK even if interrupted.
125	return tx.c.exec(`ROLLBACK`)
126}
127
128// Savepoint is a marker within a transaction
129// that allows for partial rollback.
130//
131// https://sqlite.org/lang_savepoint.html
132type Savepoint struct {
133	c    *Conn
134	name string
135}
136
137// Savepoint establishes a new transaction savepoint.
138//
139// https://sqlite.org/lang_savepoint.html
140func (c *Conn) Savepoint() Savepoint {
141	name := callerName()
142	if name == "" {
143		name = "sqlite3.Savepoint"
144	}
145	// Names can be reused, but this makes catching bugs more likely.
146	name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
147
148	err := c.exec(`SAVEPOINT ` + name)
149	if err != nil {
150		panic(err)
151	}
152	return Savepoint{c: c, name: name}
153}
154
155func callerName() (name string) {
156	var pc [8]uintptr
157	n := runtime.Callers(3, pc[:])
158	if n <= 0 {
159		return ""
160	}
161	frames := runtime.CallersFrames(pc[:n])
162	frame, more := frames.Next()
163	for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
164		strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
165		frame, more = frames.Next()
166	}
167	return frame.Function
168}
169
170// Release releases the savepoint rolling back any changes
171// if *error points to a non-nil error.
172//
173// This is meant to be deferred:
174//
175//	func doWork(db *sqlite3.Conn) (err error) {
176//		savept := db.Savepoint()
177//		defer savept.Release(&err)
178//
179//		// ... do work in the transaction
180//	}
181func (s Savepoint) Release(errp *error) {
182	recovered := recover()
183	if recovered != nil {
184		defer panic(recovered)
185	}
186
187	if *errp == nil && recovered == nil {
188		// Success path.
189		if s.c.GetAutocommit() { // There is nothing to commit.
190			return
191		}
192		*errp = s.c.Exec(`RELEASE ` + s.name)
193		if *errp == nil {
194			return
195		}
196		// Fall through to the error path.
197	}
198
199	// Error path.
200	if s.c.GetAutocommit() { // There is nothing to rollback.
201		return
202	}
203	// ROLLBACK and RELEASE even if interrupted.
204	err := s.c.exec(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
205	if err != nil {
206		panic(err)
207	}
208}
209
210// Rollback rolls the transaction back to the savepoint,
211// even if the connection has been interrupted.
212// Rollback does not release the savepoint.
213//
214// https://sqlite.org/lang_transaction.html
215func (s Savepoint) Rollback() error {
216	// ROLLBACK even if interrupted.
217	return s.c.exec(`ROLLBACK TO ` + s.name)
218}
219
220// TxnState determines the transaction state of a database.
221//
222// https://sqlite.org/c3ref/txn_state.html
223func (c *Conn) TxnState(schema string) TxnState {
224	var ptr ptr_t
225	if schema != "" {
226		defer c.arena.mark()()
227		ptr = c.arena.string(schema)
228	}
229	return TxnState(c.call("sqlite3_txn_state", stk_t(c.handle), stk_t(ptr)))
230}
231
232// CommitHook registers a callback function to be invoked
233// whenever a transaction is committed.
234// Return true to allow the commit operation to continue normally.
235//
236// https://sqlite.org/c3ref/commit_hook.html
237func (c *Conn) CommitHook(cb func() (ok bool)) {
238	var enable int32
239	if cb != nil {
240		enable = 1
241	}
242	c.call("sqlite3_commit_hook_go", stk_t(c.handle), stk_t(enable))
243	c.commit = cb
244}
245
246// RollbackHook registers a callback function to be invoked
247// whenever a transaction is rolled back.
248//
249// https://sqlite.org/c3ref/commit_hook.html
250func (c *Conn) RollbackHook(cb func()) {
251	var enable int32
252	if cb != nil {
253		enable = 1
254	}
255	c.call("sqlite3_rollback_hook_go", stk_t(c.handle), stk_t(enable))
256	c.rollback = cb
257}
258
259// UpdateHook registers a callback function to be invoked
260// whenever a row is updated, inserted or deleted in a rowid table.
261//
262// https://sqlite.org/c3ref/update_hook.html
263func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) {
264	var enable int32
265	if cb != nil {
266		enable = 1
267	}
268	c.call("sqlite3_update_hook_go", stk_t(c.handle), stk_t(enable))
269	c.update = cb
270}
271
272func commitCallback(ctx context.Context, mod api.Module, pDB ptr_t) (rollback int32) {
273	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
274		if !c.commit() {
275			rollback = 1
276		}
277	}
278	return rollback
279}
280
281func rollbackCallback(ctx context.Context, mod api.Module, pDB ptr_t) {
282	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil {
283		c.rollback()
284	}
285}
286
287func updateCallback(ctx context.Context, mod api.Module, pDB ptr_t, action AuthorizerActionCode, zSchema, zTabName ptr_t, rowid int64) {
288	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil {
289		schema := util.ReadString(mod, zSchema, _MAX_NAME)
290		table := util.ReadString(mod, zTabName, _MAX_NAME)
291		c.update(action, schema, table, rowid)
292	}
293}
294
295// CacheFlush flushes caches to disk mid-transaction.
296//
297// https://sqlite.org/c3ref/db_cacheflush.html
298func (c *Conn) CacheFlush() error {
299	rc := res_t(c.call("sqlite3_db_cacheflush", stk_t(c.handle)))
300	return c.error(rc)
301}