diff --git a/internal/app/app.go b/internal/app/app.go index a167ca8638c8497a6d6f4260782ba334c6dbe0c3..b28b91851402892e8bb3aba0c0a007a5f1e9485d 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -113,10 +113,12 @@ func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore) (*App, er go mcp.Initialize(ctx, app.Permissions, store) - // cleanup database upon app shutdown + // Release the shared database connection on shutdown. The pool + // closes the underlying *sql.DB when the last reference is released. + dataDir := cfg.Options.DataDirectory app.cleanupFuncs = append( app.cleanupFuncs, - func(context.Context) error { return conn.Close() }, + func(context.Context) error { return db.Release(dataDir) }, func(ctx context.Context) error { return mcp.Close(ctx) }, ) diff --git a/internal/db/connect.go b/internal/db/connect.go index ef800c716efc44b137c163a188f599366f1c66e3..b247bb3f3807088b00fcc14fa5085843393a4ae6 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -38,41 +38,113 @@ func init() { } } -// Connect opens a SQLite database connection and runs migrations. +// connEntry holds a shared database connection and its reference count. +type connEntry struct { + db *sql.DB + refCount int +} + +var ( + pool = make(map[string]*connEntry) + poolMu sync.Mutex +) + +// Connect opens a SQLite database connection for the given data +// directory and runs migrations. If a connection to the same database +// file already exists, the existing connection is returned with its +// reference count incremented. Callers must pair each Connect with a +// [Release] when they no longer need the connection. func Connect(ctx context.Context, dataDir string) (*sql.DB, error) { if dataDir == "" { return nil, fmt.Errorf("data.dir is not set") } + dbPath := filepath.Join(dataDir, "crush.db") - db, err := openDB(dbPath) + // Resolve to an absolute path so that different relative paths to + // the same file share a single connection. + absPath, err := filepath.Abs(dbPath) + if err != nil { + absPath = dbPath + } + + poolMu.Lock() + defer poolMu.Unlock() + + if entry, ok := pool[absPath]; ok { + entry.refCount++ + return entry.db, nil + } + + conn, err := openDB(dbPath) if err != nil { return nil, err } - // Serialize all access through a single connection. SQLite serializes - // writes at the file level anyway, and allowing multiple pool - // connections to interleave writes/checkpoints (especially under - // concurrent sub-agents) has caused WAL/header desync resulting in - // SQLITE_NOTADB (26) on the next open. - db.SetMaxOpenConns(1) + // Serialize all access through a single connection. SQLite + // serializes writes at the file level anyway, and allowing multiple + // pool connections to interleave writes/checkpoints (especially + // under concurrent sub-agents) has caused WAL/header desync + // resulting in SQLITE_NOTADB (26) on the next open. + conn.SetMaxOpenConns(1) - if err = db.PingContext(ctx); err != nil { - db.Close() + if err = conn.PingContext(ctx); err != nil { + conn.Close() return nil, fmt.Errorf("failed to connect to database: %w", err) } if err := initGoose(); err != nil { + conn.Close() slog.Error("Failed to initialize goose", "error", err) return nil, fmt.Errorf("failed to initialize goose: %w", err) } - if err := goose.Up(db, "migrations"); err != nil { + if err := goose.Up(conn, "migrations"); err != nil { + conn.Close() slog.Error("Failed to apply migrations", "error", err) return nil, fmt.Errorf("failed to apply migrations: %w", err) } - return db, nil + pool[absPath] = &connEntry{db: conn, refCount: 1} + return conn, nil +} + +// Release decrements the reference count for the database at the given +// data directory. When the count reaches zero the underlying connection +// is closed and removed from the pool. +func Release(dataDir string) error { + dbPath := filepath.Join(dataDir, "crush.db") + absPath, err := filepath.Abs(dbPath) + if err != nil { + absPath = dbPath + } + + poolMu.Lock() + defer poolMu.Unlock() + + entry, ok := pool[absPath] + if !ok { + return nil + } + + entry.refCount-- + if entry.refCount > 0 { + return nil + } + + delete(pool, absPath) + return entry.db.Close() +} + +// ResetPool closes all pooled connections and clears the pool. This is +// intended for use in tests to ensure a clean state between test cases. +func ResetPool() { + poolMu.Lock() + defer poolMu.Unlock() + for path, entry := range pool { + entry.db.Close() + delete(pool, path) + } } func initGoose() error { diff --git a/internal/db/connect_test.go b/internal/db/connect_test.go new file mode 100644 index 0000000000000000000000000000000000000000..93c2af00216cb9076214b861eed230a45d7d9bd0 --- /dev/null +++ b/internal/db/connect_test.go @@ -0,0 +1,54 @@ +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConnect_SharesConnectionForSameDataDir(t *testing.T) { + t.Cleanup(ResetPool) + + dataDir := t.TempDir() + + conn1, err := Connect(context.Background(), dataDir) + require.NoError(t, err) + + conn2, err := Connect(context.Background(), dataDir) + require.NoError(t, err) + + require.Same(t, conn1, conn2, "should return the same *sql.DB for the same data dir") + + // Releasing once should not close the connection. + require.NoError(t, Release(dataDir)) + require.NoError(t, conn1.PingContext(context.Background()), "connection should still be usable after partial release") + + // Releasing again should close it. + require.NoError(t, Release(dataDir)) + require.Error(t, conn1.PingContext(context.Background()), "connection should be closed after final release") +} + +func TestConnect_SeparateConnectionsForDifferentDataDirs(t *testing.T) { + t.Cleanup(ResetPool) + + dir1 := t.TempDir() + dir2 := t.TempDir() + + conn1, err := Connect(context.Background(), dir1) + require.NoError(t, err) + + conn2, err := Connect(context.Background(), dir2) + require.NoError(t, err) + + require.NotSame(t, conn1, conn2, "different data dirs should get different connections") + + require.NoError(t, Release(dir1)) + require.NoError(t, Release(dir2)) +} + +func TestRelease_NoopForUnknownDataDir(t *testing.T) { + t.Cleanup(ResetPool) + + require.NoError(t, Release("/nonexistent/path"), "releasing unknown data dir should not error") +}