db.go

  1// Package db provides database operations for the Shelley AI coding agent.
  2package db
  3
  4//go:generate go tool github.com/sqlc-dev/sqlc/cmd/sqlc generate -f ../sqlc.yaml
  5
  6import (
  7	"context"
  8	"crypto/rand"
  9	"database/sql"
 10	"embed"
 11	"encoding/json"
 12	"errors"
 13	"fmt"
 14	"log/slog"
 15	"os"
 16	"path/filepath"
 17	"regexp"
 18	"sort"
 19	"strconv"
 20	"strings"
 21
 22	"github.com/google/uuid"
 23	"shelley.exe.dev/db/generated"
 24
 25	_ "modernc.org/sqlite"
 26)
 27
 28//go:embed schema/*.sql
 29var schemaFS embed.FS
 30
 31// generateConversationID generates a conversation ID in the format "cXXXXXX"
 32// where X are random alphanumeric characters
 33func generateConversationID() (string, error) {
 34	text := rand.Text()
 35	if len(text) < 6 {
 36		return "", fmt.Errorf("rand.Text() returned insufficient characters: %d", len(text))
 37	}
 38	return "c" + text[:6], nil
 39}
 40
 41// DB wraps the database connection pool and provides high-level operations
 42type DB struct {
 43	pool *Pool
 44}
 45
 46// Config holds database configuration
 47type Config struct {
 48	DSN string // Data Source Name for SQLite database
 49}
 50
 51// New creates a new database connection with the given configuration
 52func New(cfg Config) (*DB, error) {
 53	if cfg.DSN == "" {
 54		return nil, fmt.Errorf("database DSN cannot be empty")
 55	}
 56
 57	if cfg.DSN == ":memory:" {
 58		return nil, fmt.Errorf(":memory: database not supported (requires multiple connections); use a temp file")
 59	}
 60
 61	// Ensure directory exists for file-based SQLite databases
 62	if cfg.DSN != ":memory:" {
 63		dir := filepath.Dir(cfg.DSN)
 64		if dir != "." && dir != "" {
 65			if err := os.MkdirAll(dir, 0o755); err != nil {
 66				return nil, fmt.Errorf("failed to create database directory: %w", err)
 67			}
 68		}
 69	}
 70
 71	// Create connection pool with 3 readers
 72	dsn := cfg.DSN
 73	if !strings.Contains(dsn, "?") {
 74		dsn += "?_foreign_keys=on"
 75	} else if !strings.Contains(dsn, "_foreign_keys") {
 76		dsn += "&_foreign_keys=on"
 77	}
 78
 79	pool, err := NewPool(dsn, 3)
 80	if err != nil {
 81		return nil, fmt.Errorf("failed to create connection pool: %w", err)
 82	}
 83
 84	return &DB{
 85		pool: pool,
 86	}, nil
 87}
 88
 89// Close closes the database connection pool
 90func (db *DB) Close() error {
 91	return db.pool.Close()
 92}
 93
 94// Migrate runs the database migrations
 95func (db *DB) Migrate(ctx context.Context) error {
 96	// Read all migration files
 97	entries, err := schemaFS.ReadDir("schema")
 98	if err != nil {
 99		return fmt.Errorf("failed to read schema directory: %w", err)
100	}
101
102	// Filter and validate migration files
103	var migrations []string
104	migrationPattern := regexp.MustCompile(`^(\d{3})-.*\.sql$`)
105	for _, entry := range entries {
106		if entry.IsDir() {
107			continue
108		}
109		if !migrationPattern.MatchString(entry.Name()) {
110			continue
111		}
112		migrations = append(migrations, entry.Name())
113	}
114
115	// Sort migrations by number
116	sort.Strings(migrations)
117
118	// Check for duplicate migration numbers
119	seenNumbers := make(map[string]string) // number -> filename
120	for _, migration := range migrations {
121		matches := migrationPattern.FindStringSubmatch(migration)
122		if len(matches) < 2 {
123			continue
124		}
125		num := matches[1]
126		if existing, ok := seenNumbers[num]; ok {
127			return fmt.Errorf("duplicate migration number %s: %s and %s", num, existing, migration)
128		}
129		seenNumbers[num] = migration
130	}
131
132	// Get executed migrations
133	executedMigrations := make(map[int]bool)
134	var tableName string
135	err = db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
136		row := rx.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations'")
137		return row.Scan(&tableName)
138	})
139
140	if err == nil {
141		// Migrations table exists, load executed migrations
142		err = db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
143			rows, err := rx.Query("SELECT migration_number FROM migrations")
144			if err != nil {
145				return fmt.Errorf("failed to query executed migrations: %w", err)
146			}
147			defer rows.Close()
148
149			for rows.Next() {
150				var migrationNumber int
151				if err := rows.Scan(&migrationNumber); err != nil {
152					return fmt.Errorf("failed to scan migration number: %w", err)
153				}
154				executedMigrations[migrationNumber] = true
155			}
156			return rows.Err()
157		})
158		if err != nil {
159			return fmt.Errorf("failed to load executed migrations: %w", err)
160		}
161	} else if !errors.Is(err, sql.ErrNoRows) {
162		// Migrations table doesn't exist - executedMigrations remains empty
163		slog.Info("migrations table not found, running all migrations")
164	}
165
166	// Run any migrations that haven't been executed
167	for _, migration := range migrations {
168		// Extract migration number from filename (e.g., "001-base.sql" -> 001)
169		matches := migrationPattern.FindStringSubmatch(migration)
170		if len(matches) != 2 {
171			return fmt.Errorf("invalid migration filename format: %s", migration)
172		}
173
174		migrationNumber, err := strconv.Atoi(matches[1])
175		if err != nil {
176			return fmt.Errorf("failed to parse migration number from %s: %w", migration, err)
177		}
178
179		if !executedMigrations[migrationNumber] {
180			slog.Info("running migration", "file", migration, "number", migrationNumber)
181			if err := db.runMigration(ctx, migration, migrationNumber); err != nil {
182				return err
183			}
184		}
185	}
186
187	return nil
188}
189
190// runMigration executes a single migration file within a transaction,
191// including recording it in the migrations table.
192func (db *DB) runMigration(ctx context.Context, filename string, migrationNumber int) error {
193	content, err := schemaFS.ReadFile("schema/" + filename)
194	if err != nil {
195		return fmt.Errorf("failed to read migration file %s: %w", filename, err)
196	}
197
198	return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
199		if _, err := tx.Exec(string(content)); err != nil {
200			return fmt.Errorf("failed to execute migration %s: %w", filename, err)
201		}
202
203		if _, err := tx.Exec("INSERT INTO migrations (migration_number, migration_name) VALUES (?, ?)", migrationNumber, filename); err != nil {
204			return fmt.Errorf("failed to record migration %s in migrations table: %w", filename, err)
205		}
206
207		return nil
208	})
209}
210
211// Pool returns the underlying connection pool for advanced operations
212func (db *DB) Pool() *Pool {
213	return db.pool
214}
215
216// WithTx runs a function within a database transaction
217func (db *DB) WithTx(ctx context.Context, fn func(*generated.Queries) error) error {
218	return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
219		queries := generated.New(tx.Conn())
220		return fn(queries)
221	})
222}
223
224// WithTxRes runs a function within a database transaction and returns a value
225func WithTxRes[T any](db *DB, ctx context.Context, fn func(*generated.Queries) (T, error)) (T, error) {
226	var result T
227	err := db.WithTx(ctx, func(queries *generated.Queries) error {
228		var err error
229		result, err = fn(queries)
230		return err
231	})
232	return result, err
233}
234
235// Conversation methods (moved from ConversationService)
236
237// CreateConversation creates a new conversation with an optional slug
238func (db *DB) CreateConversation(ctx context.Context, slug *string, userInitiated bool, cwd, model *string) (*generated.Conversation, error) {
239	conversationID, err := generateConversationID()
240	if err != nil {
241		return nil, fmt.Errorf("failed to generate conversation ID: %w", err)
242	}
243	var conversation generated.Conversation
244	err = db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
245		q := generated.New(tx.Conn())
246		conversation, err = q.CreateConversation(ctx, generated.CreateConversationParams{
247			ConversationID: conversationID,
248			Slug:           slug,
249			UserInitiated:  userInitiated,
250			Cwd:            cwd,
251			Model:          model,
252		})
253		return err
254	})
255	return &conversation, err
256}
257
258// GetConversationByID retrieves a conversation by its ID
259func (db *DB) GetConversationByID(ctx context.Context, conversationID string) (*generated.Conversation, error) {
260	var conversation generated.Conversation
261	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
262		q := generated.New(rx.Conn())
263		var err error
264		conversation, err = q.GetConversation(ctx, conversationID)
265		return err
266	})
267	if err == sql.ErrNoRows {
268		return nil, fmt.Errorf("conversation not found: %s", conversationID)
269	}
270	return &conversation, err
271}
272
273// GetConversationBySlug retrieves a conversation by its slug
274func (db *DB) GetConversationBySlug(ctx context.Context, slug string) (*generated.Conversation, error) {
275	var conversation generated.Conversation
276	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
277		q := generated.New(rx.Conn())
278		var err error
279		conversation, err = q.GetConversationBySlug(ctx, &slug)
280		return err
281	})
282	if err == sql.ErrNoRows {
283		return nil, fmt.Errorf("conversation not found with slug: %s", slug)
284	}
285	return &conversation, err
286}
287
288// ListConversations retrieves conversations with pagination
289func (db *DB) ListConversations(ctx context.Context, limit, offset int64) ([]generated.Conversation, error) {
290	var conversations []generated.Conversation
291	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
292		q := generated.New(rx.Conn())
293		var err error
294		conversations, err = q.ListConversations(ctx, generated.ListConversationsParams{
295			Limit:  limit,
296			Offset: offset,
297		})
298		return err
299	})
300	return conversations, err
301}
302
303// SearchConversations searches for conversations containing the given query in their slug
304func (db *DB) SearchConversations(ctx context.Context, query string, limit, offset int64) ([]generated.Conversation, error) {
305	queryPtr := &query
306	var conversations []generated.Conversation
307	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
308		q := generated.New(rx.Conn())
309		var err error
310		conversations, err = q.SearchConversations(ctx, generated.SearchConversationsParams{
311			Column1: queryPtr,
312			Limit:   limit,
313			Offset:  offset,
314		})
315		return err
316	})
317	return conversations, err
318}
319
320// SearchConversationsWithMessages searches for conversations containing the query in slug or message content
321func (db *DB) SearchConversationsWithMessages(ctx context.Context, query string, limit, offset int64) ([]generated.Conversation, error) {
322	queryPtr := &query
323	var conversations []generated.Conversation
324	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
325		q := generated.New(rx.Conn())
326		var err error
327		conversations, err = q.SearchConversationsWithMessages(ctx, generated.SearchConversationsWithMessagesParams{
328			Column1: queryPtr,
329			Column2: queryPtr,
330			Column3: queryPtr,
331			Limit:   limit,
332			Offset:  offset,
333		})
334		return err
335	})
336	return conversations, err
337}
338
339// UpdateConversationSlug updates the slug of a conversation
340func (db *DB) UpdateConversationSlug(ctx context.Context, conversationID, slug string) (*generated.Conversation, error) {
341	var conversation generated.Conversation
342	err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
343		q := generated.New(tx.Conn())
344		var err error
345		conversation, err = q.UpdateConversationSlug(ctx, generated.UpdateConversationSlugParams{
346			Slug:           &slug,
347			ConversationID: conversationID,
348		})
349		return err
350	})
351	return &conversation, err
352}
353
354// UpdateConversationCwd updates the working directory for a conversation
355func (db *DB) UpdateConversationCwd(ctx context.Context, conversationID, cwd string) error {
356	return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
357		q := generated.New(tx.Conn())
358		_, err := q.UpdateConversationCwd(ctx, generated.UpdateConversationCwdParams{
359			Cwd:            &cwd,
360			ConversationID: conversationID,
361		})
362		return err
363	})
364}
365
366// UpdateConversationModel sets the model for a conversation that doesn't have one yet.
367// This is used to backfill the model for conversations created before the model column existed.
368func (db *DB) UpdateConversationModel(ctx context.Context, conversationID, model string) error {
369	return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
370		q := generated.New(tx.Conn())
371		return q.UpdateConversationModel(ctx, generated.UpdateConversationModelParams{
372			Model:          &model,
373			ConversationID: conversationID,
374		})
375	})
376}
377
378// Message methods (moved from MessageService)
379
380// MessageType represents the type of message
381type MessageType string
382
383const (
384	MessageTypeUser    MessageType = "user"
385	MessageTypeAgent   MessageType = "agent"
386	MessageTypeTool    MessageType = "tool"
387	MessageTypeSystem  MessageType = "system"
388	MessageTypeError   MessageType = "error"
389	MessageTypeGitInfo MessageType = "gitinfo" // user-visible only, not sent to LLM
390)
391
392// CreateMessageParams contains parameters for creating a message
393type CreateMessageParams struct {
394	ConversationID      string
395	Type                MessageType
396	LLMData             interface{} // Will be JSON marshalled
397	UserData            interface{} // Will be JSON marshalled
398	UsageData           interface{} // Will be JSON marshalled
399	DisplayData         interface{} // Will be JSON marshalled, tool-specific display content
400	ExcludedFromContext bool        // If true, message is stored but not sent to LLM
401}
402
403// CreateMessage creates a new message
404func (db *DB) CreateMessage(ctx context.Context, params CreateMessageParams) (*generated.Message, error) {
405	messageID := uuid.New().String()
406
407	// Marshal JSON fields
408	var llmDataJSON, userDataJSON, usageDataJSON, displayDataJSON *string
409
410	if params.LLMData != nil {
411		data, err := json.Marshal(params.LLMData)
412		if err != nil {
413			return nil, fmt.Errorf("failed to marshal LLM data: %w", err)
414		}
415		str := string(data)
416		llmDataJSON = &str
417	}
418
419	if params.UserData != nil {
420		data, err := json.Marshal(params.UserData)
421		if err != nil {
422			return nil, fmt.Errorf("failed to marshal user data: %w", err)
423		}
424		str := string(data)
425		userDataJSON = &str
426	}
427
428	if params.UsageData != nil {
429		data, err := json.Marshal(params.UsageData)
430		if err != nil {
431			return nil, fmt.Errorf("failed to marshal usage data: %w", err)
432		}
433		str := string(data)
434		usageDataJSON = &str
435	}
436
437	if params.DisplayData != nil {
438		data, err := json.Marshal(params.DisplayData)
439		if err != nil {
440			return nil, fmt.Errorf("failed to marshal display data: %w", err)
441		}
442		str := string(data)
443		displayDataJSON = &str
444	}
445
446	var message generated.Message
447	err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
448		q := generated.New(tx.Conn())
449
450		// Get next sequence_id for this conversation
451		sequenceID, err := q.GetNextSequenceID(ctx, params.ConversationID)
452		if err != nil {
453			return fmt.Errorf("failed to get next sequence ID: %w", err)
454		}
455
456		message, err = q.CreateMessage(ctx, generated.CreateMessageParams{
457			MessageID:           messageID,
458			ConversationID:      params.ConversationID,
459			SequenceID:          sequenceID,
460			Type:                string(params.Type),
461			LlmData:             llmDataJSON,
462			UserData:            userDataJSON,
463			UsageData:           usageDataJSON,
464			DisplayData:         displayDataJSON,
465			ExcludedFromContext: params.ExcludedFromContext,
466		})
467		return err
468	})
469	return &message, err
470}
471
472// GetMessageByID retrieves a message by its ID
473func (db *DB) GetMessageByID(ctx context.Context, messageID string) (*generated.Message, error) {
474	var message generated.Message
475	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
476		q := generated.New(rx.Conn())
477		var err error
478		message, err = q.GetMessage(ctx, messageID)
479		return err
480	})
481	if err == sql.ErrNoRows {
482		return nil, fmt.Errorf("message not found: %s", messageID)
483	}
484	return &message, err
485}
486
487// ListMessagesByConversationPaginated retrieves messages in a conversation with pagination
488func (db *DB) ListMessagesByConversationPaginated(ctx context.Context, conversationID string, limit, offset int64) ([]generated.Message, error) {
489	var messages []generated.Message
490	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
491		q := generated.New(rx.Conn())
492		var err error
493		messages, err = q.ListMessagesPaginated(ctx, generated.ListMessagesPaginatedParams{
494			ConversationID: conversationID,
495			Limit:          limit,
496			Offset:         offset,
497		})
498		return err
499	})
500	return messages, err
501}
502
503// ListMessages retrieves all messages in a conversation ordered by sequence
504func (db *DB) ListMessages(ctx context.Context, conversationID string) ([]generated.Message, error) {
505	var messages []generated.Message
506	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
507		q := generated.New(rx.Conn())
508		var err error
509		messages, err = q.ListMessages(ctx, conversationID)
510		return err
511	})
512	return messages, err
513}
514
515// ListMessagesForContext retrieves messages that should be sent to the LLM (excludes excluded_from_context=true)
516func (db *DB) ListMessagesForContext(ctx context.Context, conversationID string) ([]generated.Message, error) {
517	var messages []generated.Message
518	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
519		q := generated.New(rx.Conn())
520		var err error
521		messages, err = q.ListMessagesForContext(ctx, conversationID)
522		return err
523	})
524	return messages, err
525}
526
527// ListMessagesByType retrieves messages of a specific type in a conversation
528func (db *DB) ListMessagesByType(ctx context.Context, conversationID string, messageType MessageType) ([]generated.Message, error) {
529	var messages []generated.Message
530	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
531		q := generated.New(rx.Conn())
532		var err error
533		messages, err = q.ListMessagesByType(ctx, generated.ListMessagesByTypeParams{
534			ConversationID: conversationID,
535			Type:           string(messageType),
536		})
537		return err
538	})
539	return messages, err
540}
541
542// GetLatestMessage retrieves the latest message in a conversation
543func (db *DB) GetLatestMessage(ctx context.Context, conversationID string) (*generated.Message, error) {
544	var message generated.Message
545	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
546		q := generated.New(rx.Conn())
547		var err error
548		message, err = q.GetLatestMessage(ctx, conversationID)
549		return err
550	})
551	if err == sql.ErrNoRows {
552		return nil, fmt.Errorf("no messages found in conversation: %s", conversationID)
553	}
554	return &message, err
555}
556
557// CountMessagesByType returns the number of messages of a specific type in a conversation
558func (db *DB) CountMessagesByType(ctx context.Context, conversationID string, messageType MessageType) (int64, error) {
559	var count int64
560	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
561		q := generated.New(rx.Conn())
562		var err error
563		count, err = q.CountMessagesByType(ctx, generated.CountMessagesByTypeParams{
564			ConversationID: conversationID,
565			Type:           string(messageType),
566		})
567		return err
568	})
569	return count, err
570}
571
572// Queries provides read-only access to generated queries within a read transaction
573func (db *DB) Queries(ctx context.Context, fn func(*generated.Queries) error) error {
574	return db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
575		q := generated.New(rx.Conn())
576		return fn(q)
577	})
578}
579
580// QueriesTx provides read-write access to generated queries within a write transaction
581func (db *DB) QueriesTx(ctx context.Context, fn func(*generated.Queries) error) error {
582	return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
583		q := generated.New(tx.Conn())
584		return fn(q)
585	})
586}
587
588// ListArchivedConversations retrieves archived conversations with pagination
589func (db *DB) ListArchivedConversations(ctx context.Context, limit, offset int64) ([]generated.Conversation, error) {
590	var conversations []generated.Conversation
591	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
592		q := generated.New(rx.Conn())
593		var err error
594		conversations, err = q.ListArchivedConversations(ctx, generated.ListArchivedConversationsParams{
595			Limit:  limit,
596			Offset: offset,
597		})
598		return err
599	})
600	return conversations, err
601}
602
603// SearchArchivedConversations searches for archived conversations containing the given query in their slug
604func (db *DB) SearchArchivedConversations(ctx context.Context, query string, limit, offset int64) ([]generated.Conversation, error) {
605	queryPtr := &query
606	var conversations []generated.Conversation
607	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
608		q := generated.New(rx.Conn())
609		var err error
610		conversations, err = q.SearchArchivedConversations(ctx, generated.SearchArchivedConversationsParams{
611			Column1: queryPtr,
612			Limit:   limit,
613			Offset:  offset,
614		})
615		return err
616	})
617	return conversations, err
618}
619
620// ArchiveConversation archives a conversation
621func (db *DB) ArchiveConversation(ctx context.Context, conversationID string) (*generated.Conversation, error) {
622	var conversation generated.Conversation
623	err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
624		q := generated.New(tx.Conn())
625		var err error
626		conversation, err = q.ArchiveConversation(ctx, conversationID)
627		return err
628	})
629	return &conversation, err
630}
631
632// UnarchiveConversation unarchives a conversation
633func (db *DB) UnarchiveConversation(ctx context.Context, conversationID string) (*generated.Conversation, error) {
634	var conversation generated.Conversation
635	err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
636		q := generated.New(tx.Conn())
637		var err error
638		conversation, err = q.UnarchiveConversation(ctx, conversationID)
639		return err
640	})
641	return &conversation, err
642}
643
644// DeleteConversation deletes a conversation and all its messages
645func (db *DB) DeleteConversation(ctx context.Context, conversationID string) error {
646	return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
647		q := generated.New(tx.Conn())
648		// Delete messages first (foreign key constraint)
649		if err := q.DeleteConversationMessages(ctx, conversationID); err != nil {
650			return fmt.Errorf("failed to delete messages: %w", err)
651		}
652		return q.DeleteConversation(ctx, conversationID)
653	})
654}
655
656// CreateSubagentConversation creates a new subagent conversation with a parent
657func (db *DB) CreateSubagentConversation(ctx context.Context, slug, parentID string, cwd *string) (*generated.Conversation, error) {
658	conversationID, err := generateConversationID()
659	if err != nil {
660		return nil, fmt.Errorf("failed to generate conversation ID: %w", err)
661	}
662	var conversation generated.Conversation
663	err = db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
664		q := generated.New(tx.Conn())
665		conversation, err = q.CreateSubagentConversation(ctx, generated.CreateSubagentConversationParams{
666			ConversationID:       conversationID,
667			Slug:                 &slug,
668			Cwd:                  cwd,
669			ParentConversationID: &parentID,
670		})
671		return err
672	})
673	return &conversation, err
674}
675
676// GetSubagents retrieves all subagent conversations for a parent conversation
677func (db *DB) GetSubagents(ctx context.Context, parentID string) ([]generated.Conversation, error) {
678	var conversations []generated.Conversation
679	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
680		q := generated.New(rx.Conn())
681		var err error
682		conversations, err = q.GetSubagents(ctx, &parentID)
683		return err
684	})
685	return conversations, err
686}
687
688// GetConversationBySlugAndParent retrieves a subagent conversation by slug and parent ID
689func (db *DB) GetConversationBySlugAndParent(ctx context.Context, slug, parentID string) (*generated.Conversation, error) {
690	var conversation generated.Conversation
691	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
692		q := generated.New(rx.Conn())
693		var err error
694		conversation, err = q.GetConversationBySlugAndParent(ctx, generated.GetConversationBySlugAndParentParams{
695			Slug:                 &slug,
696			ParentConversationID: &parentID,
697		})
698		return err
699	})
700	if err == sql.ErrNoRows {
701		return nil, nil // Not found, return nil without error
702	}
703	return &conversation, err
704}
705
706// SubagentDBAdapter adapts *DB to the claudetool.SubagentDB interface.
707type SubagentDBAdapter struct {
708	DB *DB
709}
710
711// GetOrCreateSubagentConversation implements claudetool.SubagentDB.
712// Returns the conversation ID and the actual slug used (may differ if a suffix was added).
713func (a *SubagentDBAdapter) GetOrCreateSubagentConversation(ctx context.Context, slug, parentID, cwd string) (string, string, error) {
714	// Try to find existing with exact slug
715	existing, err := a.DB.GetConversationBySlugAndParent(ctx, slug, parentID)
716	if err != nil {
717		return "", "", err
718	}
719	if existing != nil {
720		return existing.ConversationID, *existing.Slug, nil
721	}
722
723	// Try to create new, handling unique constraint violations by appending numbers
724	baseSlug := slug
725	actualSlug := slug
726	for attempt := 0; attempt < 100; attempt++ {
727		conv, err := a.DB.CreateSubagentConversation(ctx, actualSlug, parentID, &cwd)
728		if err == nil {
729			return conv.ConversationID, actualSlug, nil
730		}
731
732		// Check if this is a unique constraint violation
733		errLower := strings.ToLower(err.Error())
734		if strings.Contains(errLower, "unique constraint") ||
735			strings.Contains(errLower, "duplicate") {
736			// Try with a numeric suffix
737			actualSlug = fmt.Sprintf("%s-%d", baseSlug, attempt+1)
738			continue
739		}
740
741		// Some other error occurred
742		return "", "", err
743	}
744
745	return "", "", fmt.Errorf("failed to create unique subagent slug after 100 attempts")
746}
747
748// InsertLLMRequest inserts a new LLM request record
749func (db *DB) InsertLLMRequest(ctx context.Context, params generated.InsertLLMRequestParams) (*generated.LlmRequest, error) {
750	var request generated.LlmRequest
751	err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
752		q := generated.New(tx.Conn())
753
754		// If we have a conversation ID and request body, try to find common prefix
755		if params.ConversationID != nil && params.RequestBody != nil {
756			// Get the last request for this conversation
757			lastReq, err := q.GetLastRequestForConversation(ctx, params.ConversationID)
758			if err == nil {
759				// Found a previous request - compute common prefix
760				prefixLen, fullPrevBody := computeSharedPrefixLength(lastReq, *params.RequestBody)
761				if prefixLen > 0 {
762					// Store only the suffix
763					suffix := (*params.RequestBody)[prefixLen:]
764					params.RequestBody = &suffix
765					params.PrefixRequestID = &lastReq.ID
766					prefixLen64 := int64(prefixLen)
767					params.PrefixLength = &prefixLen64
768					_ = fullPrevBody // silence unused warning, used for computing prefix
769				}
770			}
771			// If no previous request found or error, just store the full body
772		}
773
774		var err error
775		request, err = q.InsertLLMRequest(ctx, params)
776		return err
777	})
778	return &request, err
779}
780
781// computeSharedPrefixLength computes the length of the shared prefix between
782// the full previous request body (reconstructed by walking the chain) and the new request body.
783// It returns the prefix length and the fully reconstructed previous body.
784func computeSharedPrefixLength(prevReq generated.LlmRequest, newBody string) (int, string) {
785	// Get the stored body (which may be just a suffix if prevReq has a prefix reference)
786	prevBody := ""
787	if prevReq.RequestBody != nil {
788		prevBody = *prevReq.RequestBody
789	}
790
791	// If the previous request has a prefix reference, we need to account for that
792	// by prepending the prefix length worth of bytes from the new body.
793	// This works because in a conversation, request N+1 typically starts with
794	// all of request N plus new content at the end.
795	if prevReq.PrefixLength != nil && *prevReq.PrefixLength > 0 {
796		// The previous request's full body would be:
797		// [first prefix_length bytes that match its parent] + [stored suffix]
798		// If the new body is a continuation, its first prefix_length bytes
799		// should match those same bytes.
800		prefixLen := int(*prevReq.PrefixLength)
801		if prefixLen <= len(newBody) {
802			prevBody = newBody[:prefixLen] + prevBody
803		}
804	}
805
806	// Compute byte-by-byte shared prefix between reconstructed prevBody and newBody
807	minLen := len(prevBody)
808	if len(newBody) < minLen {
809		minLen = len(newBody)
810	}
811
812	prefixLen := 0
813	for i := 0; i < minLen; i++ {
814		if prevBody[i] != newBody[i] {
815			break
816		}
817		prefixLen++
818	}
819
820	// Only use prefix deduplication if we save meaningful space
821	// (at least 100 bytes saved)
822	if prefixLen < 100 {
823		return 0, prevBody
824	}
825
826	return prefixLen, prevBody
827}
828
829// ListRecentLLMRequests returns the most recent LLM requests
830func (db *DB) ListRecentLLMRequests(ctx context.Context, limit int64) ([]generated.ListRecentLLMRequestsRow, error) {
831	var requests []generated.ListRecentLLMRequestsRow
832	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
833		q := generated.New(rx.Conn())
834		var err error
835		requests, err = q.ListRecentLLMRequests(ctx, limit)
836		return err
837	})
838	return requests, err
839}
840
841// GetLLMRequestBody returns the raw request body for a request
842func (db *DB) GetLLMRequestBody(ctx context.Context, id int64) (*string, error) {
843	var body *string
844	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
845		q := generated.New(rx.Conn())
846		var err error
847		body, err = q.GetLLMRequestBody(ctx, id)
848		return err
849	})
850	return body, err
851}
852
853// GetLLMResponseBody returns the raw response body for a request
854func (db *DB) GetLLMResponseBody(ctx context.Context, id int64) (*string, error) {
855	var body *string
856	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
857		q := generated.New(rx.Conn())
858		var err error
859		body, err = q.GetLLMResponseBody(ctx, id)
860		return err
861	})
862	return body, err
863}
864
865// GetFullLLMRequestBody reconstructs the full request body for a request,
866// following the prefix chain if necessary.
867func (db *DB) GetFullLLMRequestBody(ctx context.Context, requestID int64) (string, error) {
868	var result string
869	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
870		q := generated.New(rx.Conn())
871		return reconstructRequestBody(ctx, q, requestID, &result)
872	})
873	return result, err
874}
875
876// reconstructRequestBody recursively reconstructs the full request body
877func reconstructRequestBody(ctx context.Context, q *generated.Queries, requestID int64, result *string) error {
878	req, err := q.GetLLMRequestByID(ctx, requestID)
879	if err != nil {
880		return err
881	}
882
883	suffix := ""
884	if req.RequestBody != nil {
885		suffix = *req.RequestBody
886	}
887
888	if req.PrefixRequestID == nil || req.PrefixLength == nil || *req.PrefixLength == 0 {
889		// No prefix reference - the stored body is the full body
890		*result = suffix
891		return nil
892	}
893
894	// Recursively get the parent's full body
895	var parentBody string
896	if err := reconstructRequestBody(ctx, q, *req.PrefixRequestID, &parentBody); err != nil {
897		return err
898	}
899
900	// The full body is the first prefix_length bytes from the parent + our suffix
901	prefixLen := int(*req.PrefixLength)
902	if prefixLen > len(parentBody) {
903		prefixLen = len(parentBody)
904	}
905	*result = parentBody[:prefixLen] + suffix
906	return nil
907}
908
909// GetModels returns all models from the database
910func (db *DB) GetModels(ctx context.Context) ([]generated.Model, error) {
911	var models []generated.Model
912	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
913		q := generated.New(rx.Conn())
914		var err error
915		models, err = q.GetModels(ctx)
916		return err
917	})
918	return models, err
919}
920
921// GetModel returns a model by ID
922func (db *DB) GetModel(ctx context.Context, modelID string) (*generated.Model, error) {
923	var model generated.Model
924	err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
925		q := generated.New(rx.Conn())
926		var err error
927		model, err = q.GetModel(ctx, modelID)
928		return err
929	})
930	if err != nil {
931		return nil, err
932	}
933	return &model, nil
934}
935
936// CreateModel creates a new model
937func (db *DB) CreateModel(ctx context.Context, params generated.CreateModelParams) (*generated.Model, error) {
938	var model generated.Model
939	err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
940		q := generated.New(tx.Conn())
941		var err error
942		model, err = q.CreateModel(ctx, params)
943		return err
944	})
945	if err != nil {
946		return nil, err
947	}
948	return &model, nil
949}
950
951// UpdateModel updates a model
952func (db *DB) UpdateModel(ctx context.Context, params generated.UpdateModelParams) (*generated.Model, error) {
953	var model generated.Model
954	err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
955		q := generated.New(tx.Conn())
956		var err error
957		model, err = q.UpdateModel(ctx, params)
958		return err
959	})
960	if err != nil {
961		return nil, err
962	}
963	return &model, nil
964}
965
966// DeleteModel deletes a model
967func (db *DB) DeleteModel(ctx context.Context, modelID string) error {
968	return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
969		q := generated.New(tx.Conn())
970		return q.DeleteModel(ctx, modelID)
971	})
972}