fix(db): use connection pool to avoid corrupted writes

Kieran Klukas created

Change summary

internal/app/app.go         |  6 +
internal/db/connect.go      | 96 ++++++++++++++++++++++++++++++++++----
internal/db/connect_test.go | 54 +++++++++++++++++++++
3 files changed, 142 insertions(+), 14 deletions(-)

Detailed changes

internal/app/app.go 🔗

@@ -113,10 +113,12 @@ func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore) (*App, er
 
 	go mcp.Initialize(ctx, app.Permissions, store)
 
-	// cleanup database upon app shutdown
+	// Release the shared database connection on shutdown. The pool
+	// closes the underlying *sql.DB when the last reference is released.
+	dataDir := cfg.Options.DataDirectory
 	app.cleanupFuncs = append(
 		app.cleanupFuncs,
-		func(context.Context) error { return conn.Close() },
+		func(context.Context) error { return db.Release(dataDir) },
 		func(ctx context.Context) error { return mcp.Close(ctx) },
 	)
 

internal/db/connect.go 🔗

@@ -38,41 +38,113 @@ func init() {
 	}
 }
 
-// Connect opens a SQLite database connection and runs migrations.
+// connEntry holds a shared database connection and its reference count.
+type connEntry struct {
+	db       *sql.DB
+	refCount int
+}
+
+var (
+	pool   = make(map[string]*connEntry)
+	poolMu sync.Mutex
+)
+
+// Connect opens a SQLite database connection for the given data
+// directory and runs migrations. If a connection to the same database
+// file already exists, the existing connection is returned with its
+// reference count incremented. Callers must pair each Connect with a
+// [Release] when they no longer need the connection.
 func Connect(ctx context.Context, dataDir string) (*sql.DB, error) {
 	if dataDir == "" {
 		return nil, fmt.Errorf("data.dir is not set")
 	}
+
 	dbPath := filepath.Join(dataDir, "crush.db")
 
-	db, err := openDB(dbPath)
+	// Resolve to an absolute path so that different relative paths to
+	// the same file share a single connection.
+	absPath, err := filepath.Abs(dbPath)
+	if err != nil {
+		absPath = dbPath
+	}
+
+	poolMu.Lock()
+	defer poolMu.Unlock()
+
+	if entry, ok := pool[absPath]; ok {
+		entry.refCount++
+		return entry.db, nil
+	}
+
+	conn, err := openDB(dbPath)
 	if err != nil {
 		return nil, err
 	}
 
-	// Serialize all access through a single connection. SQLite serializes
-	// writes at the file level anyway, and allowing multiple pool
-	// connections to interleave writes/checkpoints (especially under
-	// concurrent sub-agents) has caused WAL/header desync resulting in
-	// SQLITE_NOTADB (26) on the next open.
-	db.SetMaxOpenConns(1)
+	// Serialize all access through a single connection. SQLite
+	// serializes writes at the file level anyway, and allowing multiple
+	// pool connections to interleave writes/checkpoints (especially
+	// under concurrent sub-agents) has caused WAL/header desync
+	// resulting in SQLITE_NOTADB (26) on the next open.
+	conn.SetMaxOpenConns(1)
 
-	if err = db.PingContext(ctx); err != nil {
-		db.Close()
+	if err = conn.PingContext(ctx); err != nil {
+		conn.Close()
 		return nil, fmt.Errorf("failed to connect to database: %w", err)
 	}
 
 	if err := initGoose(); err != nil {
+		conn.Close()
 		slog.Error("Failed to initialize goose", "error", err)
 		return nil, fmt.Errorf("failed to initialize goose: %w", err)
 	}
 
-	if err := goose.Up(db, "migrations"); err != nil {
+	if err := goose.Up(conn, "migrations"); err != nil {
+		conn.Close()
 		slog.Error("Failed to apply migrations", "error", err)
 		return nil, fmt.Errorf("failed to apply migrations: %w", err)
 	}
 
-	return db, nil
+	pool[absPath] = &connEntry{db: conn, refCount: 1}
+	return conn, nil
+}
+
+// Release decrements the reference count for the database at the given
+// data directory. When the count reaches zero the underlying connection
+// is closed and removed from the pool.
+func Release(dataDir string) error {
+	dbPath := filepath.Join(dataDir, "crush.db")
+	absPath, err := filepath.Abs(dbPath)
+	if err != nil {
+		absPath = dbPath
+	}
+
+	poolMu.Lock()
+	defer poolMu.Unlock()
+
+	entry, ok := pool[absPath]
+	if !ok {
+		return nil
+	}
+
+	entry.refCount--
+	if entry.refCount > 0 {
+		return nil
+	}
+
+	delete(pool, absPath)
+	return entry.db.Close()
+}
+
+// ResetPool closes all pooled connections and clears the pool. This is
+// intended for use in tests to ensure a clean state between test cases.
+func ResetPool() {
+	poolMu.Lock()
+	defer poolMu.Unlock()
+	for path, entry := range pool {
+		entry.db.Close()
+		delete(pool, path)
+	}
 }
 
 func initGoose() error {

internal/db/connect_test.go 🔗

@@ -0,0 +1,54 @@
+package db
+
+import (
+	"context"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestConnect_SharesConnectionForSameDataDir(t *testing.T) {
+	t.Cleanup(ResetPool)
+
+	dataDir := t.TempDir()
+
+	conn1, err := Connect(context.Background(), dataDir)
+	require.NoError(t, err)
+
+	conn2, err := Connect(context.Background(), dataDir)
+	require.NoError(t, err)
+
+	require.Same(t, conn1, conn2, "should return the same *sql.DB for the same data dir")
+
+	// Releasing once should not close the connection.
+	require.NoError(t, Release(dataDir))
+	require.NoError(t, conn1.PingContext(context.Background()), "connection should still be usable after partial release")
+
+	// Releasing again should close it.
+	require.NoError(t, Release(dataDir))
+	require.Error(t, conn1.PingContext(context.Background()), "connection should be closed after final release")
+}
+
+func TestConnect_SeparateConnectionsForDifferentDataDirs(t *testing.T) {
+	t.Cleanup(ResetPool)
+
+	dir1 := t.TempDir()
+	dir2 := t.TempDir()
+
+	conn1, err := Connect(context.Background(), dir1)
+	require.NoError(t, err)
+
+	conn2, err := Connect(context.Background(), dir2)
+	require.NoError(t, err)
+
+	require.NotSame(t, conn1, conn2, "different data dirs should get different connections")
+
+	require.NoError(t, Release(dir1))
+	require.NoError(t, Release(dir2))
+}
+
+func TestRelease_NoopForUnknownDataDir(t *testing.T) {
+	t.Cleanup(ResetPool)
+
+	require.NoError(t, Release("/nonexistent/path"), "releasing unknown data dir should not error")
+}