1// Package db provides database operations for the Shelley AI coding agent.
2package db
3
4import (
5 "context"
6 "database/sql"
7 "fmt"
8 "runtime"
9 "strings"
10 "time"
11)
12
13// Pool is an SQLite connection pool.
14//
15// We deliberately minimize our use of database/sql machinery because
16// the semantics do not match SQLite well.
17//
18// Instead, we choose a single connection to use for writing (because
19// SQLite is single-writer) and use the rest as readers.
20type Pool struct {
21 db *sql.DB
22 writer chan *sql.Conn
23 readers chan *sql.Conn
24}
25
26func NewPool(dataSourceName string, readerCount int) (*Pool, error) {
27 if dataSourceName == ":memory:" {
28 return nil, fmt.Errorf(":memory: is not supported (because multiple conns are needed); use a temp file")
29 }
30 // TODO: a caller could override PRAGMA query_only.
31 // Consider opening two *sql.DBs, one configured as read-only,
32 // to ensure read-only transactions are always such.
33 db, err := sql.Open("sqlite", dataSourceName)
34 if err != nil {
35 return nil, fmt.Errorf("NewPool: %w", err)
36 }
37 numConns := readerCount + 1
38 if err := InitPoolDB(db, numConns); err != nil {
39 return nil, fmt.Errorf("NewPool: %w", err)
40 }
41
42 var conns []*sql.Conn
43 for i := 0; i < numConns; i++ {
44 conn, err := db.Conn(context.Background())
45 if err != nil {
46 db.Close()
47 return nil, fmt.Errorf("NewPool: %w", err)
48 }
49 conns = append(conns, conn)
50 }
51
52 p := &Pool{
53 db: db,
54 writer: make(chan *sql.Conn, 1),
55 readers: make(chan *sql.Conn, readerCount),
56 }
57 p.writer <- conns[0]
58 for _, conn := range conns[1:] {
59 if _, err := conn.ExecContext(context.Background(), "PRAGMA query_only=1;"); err != nil {
60 db.Close()
61 return nil, fmt.Errorf("NewPool query_only: %w", err)
62 }
63 p.readers <- conn
64 }
65
66 return p, nil
67}
68
69// InitPoolDB fixes the database/sql pool to a set of fixed connections.
70func InitPoolDB(db *sql.DB, numConns int) error {
71 db.SetMaxIdleConns(numConns)
72 db.SetMaxOpenConns(numConns)
73 db.SetConnMaxLifetime(-1)
74 db.SetConnMaxIdleTime(-1)
75
76 initQueries := []string{
77 "PRAGMA journal_mode=wal;",
78 "PRAGMA busy_timeout=1000;",
79 "PRAGMA foreign_keys=ON;",
80 }
81
82 var conns []*sql.Conn
83 for i := 0; i < numConns; i++ {
84 conn, err := db.Conn(context.Background())
85 if err != nil {
86 db.Close()
87 return fmt.Errorf("InitPoolDB: %w", err)
88 }
89 for _, q := range initQueries {
90 if _, err := conn.ExecContext(context.Background(), q); err != nil {
91 db.Close()
92 return fmt.Errorf("InitPoolDB %d: %w", i, err)
93 }
94 }
95 conns = append(conns, conn)
96 }
97 for _, conn := range conns {
98 if err := conn.Close(); err != nil {
99 db.Close()
100 return fmt.Errorf("InitPoolDB: %w", err)
101 }
102 }
103 return nil
104}
105
106func (p *Pool) Close() error {
107 return p.db.Close()
108}
109
110type ctxKeyType int
111
112// CtxKey is the context value key used to store the current *Tx or *Rx.
113// In general this should not be used, plumb the tx directly.
114// This code is here is used for an exception: the slog package.
115var CtxKey any = ctxKeyType(0)
116
117func checkNoTx(ctx context.Context, typ string) {
118 x := ctx.Value(CtxKey)
119 if x == nil {
120 return
121 }
122 orig := "unexpected"
123 switch x := x.(type) {
124 case *Tx:
125 orig = "Tx (" + x.caller + ")"
126 case *Rx:
127 orig = "Rx (" + x.caller + ")"
128 }
129 panic(typ + " inside " + orig)
130}
131
132// Exec executes a single statement outside of a transaction.
133// Useful in the rare case of PRAGMAs that cannot execute inside a tx,
134// such as PRAGMA wal_checkpoint.
135func (p *Pool) Exec(ctx context.Context, query string, args ...interface{}) error {
136 checkNoTx(ctx, "Tx")
137 var conn *sql.Conn
138 select {
139 case <-ctx.Done():
140 return fmt.Errorf("Pool.Exec: %w", ctx.Err())
141 case conn = <-p.writer:
142 }
143 var err error
144 defer func() {
145 p.writer <- conn
146 }()
147 _, err = conn.ExecContext(ctx, query, args...)
148 return wrapErr("pool.exec", err)
149}
150
151func (p *Pool) Tx(ctx context.Context, fn func(ctx context.Context, tx *Tx) error) error {
152 checkNoTx(ctx, "Tx")
153 var conn *sql.Conn
154 select {
155 case <-ctx.Done():
156 return fmt.Errorf("Tx: %w", ctx.Err())
157 case conn = <-p.writer:
158 }
159
160 // If the context is closed, we want BEGIN to succeed and then
161 // we roll it back later.
162 if _, err := conn.ExecContext(context.WithoutCancel(ctx), "BEGIN IMMEDIATE;"); err != nil {
163 if strings.Contains(err.Error(), "SQLITE_BUSY") {
164 p.writer <- conn
165 return fmt.Errorf("Tx begin: %w", err)
166 }
167 // unrecoverable error, this will lock everything up
168 return fmt.Errorf("Tx LEAK %w", err)
169 }
170 tx := &Tx{
171 Rx: &Rx{conn: conn, p: p, caller: callerOfCaller(1)},
172 Now: time.Now(),
173 }
174 tx.ctx = context.WithValue(ctx, CtxKey, tx)
175
176 var err error
177 defer func() {
178 if err == nil {
179 _, err = tx.conn.ExecContext(tx.ctx, "COMMIT;")
180 if err != nil {
181 err = fmt.Errorf("Tx: commit: %w", err)
182 }
183 }
184 if err != nil {
185 err = p.rollback(tx.ctx, "Tx", err, tx.conn)
186 // always return conn,
187 // either the entire database is closed or the conn is fine.
188 }
189 tx.p.writer <- conn
190 }()
191 if ctxErr := tx.ctx.Err(); ctxErr != nil {
192 return ctxErr // fast path for canceled context
193 }
194 err = fn(tx.ctx, tx)
195
196 return err
197}
198
199func (p *Pool) Rx(ctx context.Context, fn func(ctx context.Context, rx *Rx) error) error {
200 checkNoTx(ctx, "Rx")
201 var conn *sql.Conn
202 select {
203 case <-ctx.Done():
204 return ctx.Err()
205 case conn = <-p.readers:
206 }
207
208 // If the context is closed, we want BEGIN to succeed and then
209 // we roll it back later.
210 if _, err := conn.ExecContext(context.WithoutCancel(ctx), "BEGIN;"); err != nil {
211 if strings.Contains(err.Error(), "SQLITE_BUSY") {
212 p.readers <- conn
213 return fmt.Errorf("Rx begin: %w", err)
214 }
215 // an unrecoverable error, e.g. tx-inside-tx misuse or IOERR
216 return fmt.Errorf("Rx LEAK: %w", err)
217 }
218 rx := &Rx{conn: conn, p: p, caller: callerOfCaller(1)}
219 rx.ctx = context.WithValue(ctx, CtxKey, rx)
220
221 var err error
222 defer func() {
223 err = p.rollback(rx.ctx, "Rx", err, rx.conn)
224 // always return conn,
225 // either the entire database is closed or the conn is fine.
226 rx.p.readers <- conn
227 }()
228 if ctxErr := rx.ctx.Err(); ctxErr != nil {
229 return ctxErr // fast path for canceled context
230 }
231 err = fn(rx.ctx, rx)
232 return err
233}
234
235func (p *Pool) rollback(ctx context.Context, txType string, txErr error, conn *sql.Conn) error {
236 // Even if the context is cancelled,
237 // we still need to rollback to finish up the transaction.
238 _, err := conn.ExecContext(context.WithoutCancel(ctx), "ROLLBACK;")
239 if err != nil && !strings.Contains(err.Error(), "no transaction is active") {
240 // There are a few cases where an error during a transaction
241 // will be reported as a rollback error:
242 // https://sqlite.org/lang_transaction.html#response_to_errors_within_a_transaction
243 // In good operation, we should never see any of these.
244 //
245 // TODO: confirm this check works on all sqlite drivers.
246 if !strings.Contains(err.Error(), "SQLITE_BUSY") {
247 conn.Close()
248 p.db.Close()
249 }
250 return fmt.Errorf("%s: %v: rollback failed: %w", txType, txErr, err)
251 }
252 return txErr
253}
254
255type Tx struct {
256 *Rx
257 Now time.Time
258}
259
260func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
261 res, err := tx.conn.ExecContext(tx.ctx, query, args...)
262 return res, wrapErr("exec", err)
263}
264
265type Rx struct {
266 ctx context.Context
267 conn *sql.Conn
268 p *Pool
269 caller string // for debugging
270}
271
272func (rx *Rx) Context() context.Context {
273 return rx.ctx
274}
275
276func (rx *Rx) Query(query string, args ...interface{}) (*sql.Rows, error) {
277 rows, err := rx.conn.QueryContext(rx.ctx, query, args...)
278 return rows, wrapErr("query", err)
279}
280
281func (rx *Rx) QueryRow(query string, args ...interface{}) *Row {
282 rows, err := rx.conn.QueryContext(rx.ctx, query, args...)
283 return &Row{err: err, rows: rows}
284}
285
286// Conn returns the underlying sql.Conn for use with external libraries like sqlc
287func (rx *Rx) Conn() *sql.Conn {
288 return rx.conn
289}
290
291// Row is equivalent to *sql.Row, but we provide a more useful error.
292type Row struct {
293 err error
294 rows *sql.Rows
295}
296
297func (r *Row) Scan(dest ...any) error {
298 if r.err != nil {
299 return wrapErr("QueryRow", r.err)
300 }
301
302 defer r.rows.Close()
303 if !r.rows.Next() {
304 if err := r.rows.Err(); err != nil {
305 return wrapErr("QueryRow.Scan", err)
306 }
307 return wrapErr("QueryRow.Scan", sql.ErrNoRows)
308 }
309 err := r.rows.Scan(dest...)
310 if err != nil {
311 return wrapErr("QueryRow.Scan", err)
312 }
313 return wrapErr("QueryRow.Scan", r.rows.Close())
314}
315
316func wrapErr(prefix string, err error) error {
317 if err == nil {
318 return nil
319 }
320 return fmt.Errorf("%s: %s: %w", callerOfCaller(2), prefix, err)
321}
322
323func callerOfCaller(depth int) string {
324 caller := "unknown"
325 pc := make([]uintptr, 3)
326 const addedSkip = 3 // runtime.Callers, callerOfCaller, our caller (e.g. wrapErr or Rx)
327 if n := runtime.Callers(addedSkip+depth-1, pc[:]); n > 0 {
328 frames := runtime.CallersFrames(pc[:n])
329 frame, _ := frames.Next()
330 if frame.Function != "" {
331 caller = frame.Function
332 }
333 // This is a special case.
334 //
335 // We expect people to wrap the Tx/Rx objects
336 // in another domain-specific Tx/Rx object. That means
337 // they almost certainly have matching Tx/Rx methods,
338 // which aren't useful for debugging. So if we see that,
339 // we remove it.
340 if strings.HasSuffix(caller, ".Tx") || strings.HasSuffix(caller, ".Rx") {
341 frame, more := frames.Next()
342 if more && frame.Function != "" {
343 caller = frame.Function
344 }
345 }
346 }
347 if i := strings.LastIndexByte(caller, '/'); i >= 0 {
348 caller = caller[i+1:]
349 }
350 return caller
351}