pool.go

  1// Package db provides database operations for the Shelley AI coding agent.
  2package db
  3
  4import (
  5	"context"
  6	"database/sql"
  7	"fmt"
  8	"runtime"
  9	"strings"
 10	"time"
 11)
 12
 13// Pool is an SQLite connection pool.
 14//
 15// We deliberately minimize our use of database/sql machinery because
 16// the semantics do not match SQLite well.
 17//
 18// Instead, we choose a single connection to use for writing (because
 19// SQLite is single-writer) and use the rest as readers.
 20type Pool struct {
 21	db      *sql.DB
 22	writer  chan *sql.Conn
 23	readers chan *sql.Conn
 24}
 25
 26func NewPool(dataSourceName string, readerCount int) (*Pool, error) {
 27	if dataSourceName == ":memory:" {
 28		return nil, fmt.Errorf(":memory: is not supported (because multiple conns are needed); use a temp file")
 29	}
 30	// TODO: a caller could override PRAGMA query_only.
 31	// Consider opening two *sql.DBs, one configured as read-only,
 32	// to ensure read-only transactions are always such.
 33	db, err := sql.Open("sqlite", dataSourceName)
 34	if err != nil {
 35		return nil, fmt.Errorf("NewPool: %w", err)
 36	}
 37	numConns := readerCount + 1
 38	if err := InitPoolDB(db, numConns); err != nil {
 39		return nil, fmt.Errorf("NewPool: %w", err)
 40	}
 41
 42	var conns []*sql.Conn
 43	for i := 0; i < numConns; i++ {
 44		conn, err := db.Conn(context.Background())
 45		if err != nil {
 46			db.Close()
 47			return nil, fmt.Errorf("NewPool: %w", err)
 48		}
 49		conns = append(conns, conn)
 50	}
 51
 52	p := &Pool{
 53		db:      db,
 54		writer:  make(chan *sql.Conn, 1),
 55		readers: make(chan *sql.Conn, readerCount),
 56	}
 57	p.writer <- conns[0]
 58	for _, conn := range conns[1:] {
 59		if _, err := conn.ExecContext(context.Background(), "PRAGMA query_only=1;"); err != nil {
 60			db.Close()
 61			return nil, fmt.Errorf("NewPool query_only: %w", err)
 62		}
 63		p.readers <- conn
 64	}
 65
 66	return p, nil
 67}
 68
 69// InitPoolDB fixes the database/sql pool to a set of fixed connections.
 70func InitPoolDB(db *sql.DB, numConns int) error {
 71	db.SetMaxIdleConns(numConns)
 72	db.SetMaxOpenConns(numConns)
 73	db.SetConnMaxLifetime(-1)
 74	db.SetConnMaxIdleTime(-1)
 75
 76	initQueries := []string{
 77		"PRAGMA journal_mode=wal;",
 78		"PRAGMA busy_timeout=1000;",
 79		"PRAGMA foreign_keys=ON;",
 80	}
 81
 82	var conns []*sql.Conn
 83	for i := 0; i < numConns; i++ {
 84		conn, err := db.Conn(context.Background())
 85		if err != nil {
 86			db.Close()
 87			return fmt.Errorf("InitPoolDB: %w", err)
 88		}
 89		for _, q := range initQueries {
 90			if _, err := conn.ExecContext(context.Background(), q); err != nil {
 91				db.Close()
 92				return fmt.Errorf("InitPoolDB %d: %w", i, err)
 93			}
 94		}
 95		conns = append(conns, conn)
 96	}
 97	for _, conn := range conns {
 98		if err := conn.Close(); err != nil {
 99			db.Close()
100			return fmt.Errorf("InitPoolDB: %w", err)
101		}
102	}
103	return nil
104}
105
106func (p *Pool) Close() error {
107	return p.db.Close()
108}
109
110type ctxKeyType int
111
112// CtxKey is the context value key used to store the current *Tx or *Rx.
113// In general this should not be used, plumb the tx directly.
114// This code is here is used for an exception: the slog package.
115var CtxKey any = ctxKeyType(0)
116
117func checkNoTx(ctx context.Context, typ string) {
118	x := ctx.Value(CtxKey)
119	if x == nil {
120		return
121	}
122	orig := "unexpected"
123	switch x := x.(type) {
124	case *Tx:
125		orig = "Tx (" + x.caller + ")"
126	case *Rx:
127		orig = "Rx (" + x.caller + ")"
128	}
129	panic(typ + " inside " + orig)
130}
131
132// Exec executes a single statement outside of a transaction.
133// Useful in the rare case of PRAGMAs that cannot execute inside a tx,
134// such as PRAGMA wal_checkpoint.
135func (p *Pool) Exec(ctx context.Context, query string, args ...interface{}) error {
136	checkNoTx(ctx, "Tx")
137	var conn *sql.Conn
138	select {
139	case <-ctx.Done():
140		return fmt.Errorf("Pool.Exec: %w", ctx.Err())
141	case conn = <-p.writer:
142	}
143	var err error
144	defer func() {
145		p.writer <- conn
146	}()
147	_, err = conn.ExecContext(ctx, query, args...)
148	return wrapErr("pool.exec", err)
149}
150
151func (p *Pool) Tx(ctx context.Context, fn func(ctx context.Context, tx *Tx) error) error {
152	checkNoTx(ctx, "Tx")
153	var conn *sql.Conn
154	select {
155	case <-ctx.Done():
156		return fmt.Errorf("Tx: %w", ctx.Err())
157	case conn = <-p.writer:
158	}
159
160	// If the context is closed, we want BEGIN to succeed and then
161	// we roll it back later.
162	if _, err := conn.ExecContext(context.WithoutCancel(ctx), "BEGIN IMMEDIATE;"); err != nil {
163		if strings.Contains(err.Error(), "SQLITE_BUSY") {
164			p.writer <- conn
165			return fmt.Errorf("Tx begin: %w", err)
166		}
167		// unrecoverable error, this will lock everything up
168		return fmt.Errorf("Tx LEAK %w", err)
169	}
170	tx := &Tx{
171		Rx:  &Rx{conn: conn, p: p, caller: callerOfCaller(1)},
172		Now: time.Now(),
173	}
174	tx.ctx = context.WithValue(ctx, CtxKey, tx)
175
176	var err error
177	defer func() {
178		if err == nil {
179			_, err = tx.conn.ExecContext(tx.ctx, "COMMIT;")
180			if err != nil {
181				err = fmt.Errorf("Tx: commit: %w", err)
182			}
183		}
184		if err != nil {
185			err = p.rollback(tx.ctx, "Tx", err, tx.conn)
186			// always return conn,
187			// either the entire database is closed or the conn is fine.
188		}
189		tx.p.writer <- conn
190	}()
191	if ctxErr := tx.ctx.Err(); ctxErr != nil {
192		return ctxErr // fast path for canceled context
193	}
194	err = fn(tx.ctx, tx)
195
196	return err
197}
198
199func (p *Pool) Rx(ctx context.Context, fn func(ctx context.Context, rx *Rx) error) error {
200	checkNoTx(ctx, "Rx")
201	var conn *sql.Conn
202	select {
203	case <-ctx.Done():
204		return ctx.Err()
205	case conn = <-p.readers:
206	}
207
208	// If the context is closed, we want BEGIN to succeed and then
209	// we roll it back later.
210	if _, err := conn.ExecContext(context.WithoutCancel(ctx), "BEGIN;"); err != nil {
211		if strings.Contains(err.Error(), "SQLITE_BUSY") {
212			p.readers <- conn
213			return fmt.Errorf("Rx begin: %w", err)
214		}
215		// an unrecoverable error, e.g. tx-inside-tx misuse or IOERR
216		return fmt.Errorf("Rx LEAK: %w", err)
217	}
218	rx := &Rx{conn: conn, p: p, caller: callerOfCaller(1)}
219	rx.ctx = context.WithValue(ctx, CtxKey, rx)
220
221	var err error
222	defer func() {
223		err = p.rollback(rx.ctx, "Rx", err, rx.conn)
224		// always return conn,
225		// either the entire database is closed or the conn is fine.
226		rx.p.readers <- conn
227	}()
228	if ctxErr := rx.ctx.Err(); ctxErr != nil {
229		return ctxErr // fast path for canceled context
230	}
231	err = fn(rx.ctx, rx)
232	return err
233}
234
235func (p *Pool) rollback(ctx context.Context, txType string, txErr error, conn *sql.Conn) error {
236	// Even if the context is cancelled,
237	// we still need to rollback to finish up the transaction.
238	_, err := conn.ExecContext(context.WithoutCancel(ctx), "ROLLBACK;")
239	if err != nil && !strings.Contains(err.Error(), "no transaction is active") {
240		// There are a few cases where an error during a transaction
241		// will be reported as a rollback error:
242		// 	https://sqlite.org/lang_transaction.html#response_to_errors_within_a_transaction
243		// In good operation, we should never see any of these.
244		//
245		// TODO: confirm this check works on all sqlite drivers.
246		if !strings.Contains(err.Error(), "SQLITE_BUSY") {
247			conn.Close()
248			p.db.Close()
249		}
250		return fmt.Errorf("%s: %v: rollback failed: %w", txType, txErr, err)
251	}
252	return txErr
253}
254
255type Tx struct {
256	*Rx
257	Now time.Time
258}
259
260func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
261	res, err := tx.conn.ExecContext(tx.ctx, query, args...)
262	return res, wrapErr("exec", err)
263}
264
265type Rx struct {
266	ctx    context.Context
267	conn   *sql.Conn
268	p      *Pool
269	caller string // for debugging
270}
271
272func (rx *Rx) Context() context.Context {
273	return rx.ctx
274}
275
276func (rx *Rx) Query(query string, args ...interface{}) (*sql.Rows, error) {
277	rows, err := rx.conn.QueryContext(rx.ctx, query, args...)
278	return rows, wrapErr("query", err)
279}
280
281func (rx *Rx) QueryRow(query string, args ...interface{}) *Row {
282	rows, err := rx.conn.QueryContext(rx.ctx, query, args...)
283	return &Row{err: err, rows: rows}
284}
285
286// Conn returns the underlying sql.Conn for use with external libraries like sqlc
287func (rx *Rx) Conn() *sql.Conn {
288	return rx.conn
289}
290
291// Row is equivalent to *sql.Row, but we provide a more useful error.
292type Row struct {
293	err  error
294	rows *sql.Rows
295}
296
297func (r *Row) Scan(dest ...any) error {
298	if r.err != nil {
299		return wrapErr("QueryRow", r.err)
300	}
301
302	defer r.rows.Close()
303	if !r.rows.Next() {
304		if err := r.rows.Err(); err != nil {
305			return wrapErr("QueryRow.Scan", err)
306		}
307		return wrapErr("QueryRow.Scan", sql.ErrNoRows)
308	}
309	err := r.rows.Scan(dest...)
310	if err != nil {
311		return wrapErr("QueryRow.Scan", err)
312	}
313	return wrapErr("QueryRow.Scan", r.rows.Close())
314}
315
316func wrapErr(prefix string, err error) error {
317	if err == nil {
318		return nil
319	}
320	return fmt.Errorf("%s: %s: %w", callerOfCaller(2), prefix, err)
321}
322
323func callerOfCaller(depth int) string {
324	caller := "unknown"
325	pc := make([]uintptr, 3)
326	const addedSkip = 3 // runtime.Callers, callerOfCaller, our caller (e.g. wrapErr or Rx)
327	if n := runtime.Callers(addedSkip+depth-1, pc[:]); n > 0 {
328		frames := runtime.CallersFrames(pc[:n])
329		frame, _ := frames.Next()
330		if frame.Function != "" {
331			caller = frame.Function
332		}
333		// This is a special case.
334		//
335		// We expect people to wrap the Tx/Rx objects
336		// in another domain-specific Tx/Rx object. That means
337		// they almost certainly have matching Tx/Rx methods,
338		// which aren't useful for debugging. So if we see that,
339		// we remove it.
340		if strings.HasSuffix(caller, ".Tx") || strings.HasSuffix(caller, ".Rx") {
341			frame, more := frames.Next()
342			if more && frame.Function != "" {
343				caller = frame.Function
344			}
345		}
346	}
347	if i := strings.LastIndexByte(caller, '/'); i >= 0 {
348		caller = caller[i+1:]
349	}
350	return caller
351}