1// Package db provides database operations for the Shelley AI coding agent.
2package db
3
4//go:generate go tool github.com/sqlc-dev/sqlc/cmd/sqlc generate -f ../sqlc.yaml
5
6import (
7 "context"
8 "crypto/rand"
9 "database/sql"
10 "embed"
11 "encoding/json"
12 "errors"
13 "fmt"
14 "log/slog"
15 "os"
16 "path/filepath"
17 "regexp"
18 "sort"
19 "strconv"
20 "strings"
21
22 "github.com/google/uuid"
23 "shelley.exe.dev/db/generated"
24
25 _ "modernc.org/sqlite"
26)
27
28//go:embed schema/*.sql
29var schemaFS embed.FS
30
31// generateConversationID generates a conversation ID in the format "cXXXXXX"
32// where X are random alphanumeric characters
33func generateConversationID() (string, error) {
34 text := rand.Text()
35 if len(text) < 6 {
36 return "", fmt.Errorf("rand.Text() returned insufficient characters: %d", len(text))
37 }
38 return "c" + text[:6], nil
39}
40
41// DB wraps the database connection pool and provides high-level operations
42type DB struct {
43 pool *Pool
44}
45
46// Config holds database configuration
47type Config struct {
48 DSN string // Data Source Name for SQLite database
49}
50
51// New creates a new database connection with the given configuration
52func New(cfg Config) (*DB, error) {
53 if cfg.DSN == "" {
54 return nil, fmt.Errorf("database DSN cannot be empty")
55 }
56
57 if cfg.DSN == ":memory:" {
58 return nil, fmt.Errorf(":memory: database not supported (requires multiple connections); use a temp file")
59 }
60
61 // Ensure directory exists for file-based SQLite databases
62 if cfg.DSN != ":memory:" {
63 dir := filepath.Dir(cfg.DSN)
64 if dir != "." && dir != "" {
65 if err := os.MkdirAll(dir, 0o755); err != nil {
66 return nil, fmt.Errorf("failed to create database directory: %w", err)
67 }
68 }
69 }
70
71 // Create connection pool with 3 readers
72 dsn := cfg.DSN
73 if !strings.Contains(dsn, "?") {
74 dsn += "?_foreign_keys=on"
75 } else if !strings.Contains(dsn, "_foreign_keys") {
76 dsn += "&_foreign_keys=on"
77 }
78
79 pool, err := NewPool(dsn, 3)
80 if err != nil {
81 return nil, fmt.Errorf("failed to create connection pool: %w", err)
82 }
83
84 return &DB{
85 pool: pool,
86 }, nil
87}
88
89// Close closes the database connection pool
90func (db *DB) Close() error {
91 return db.pool.Close()
92}
93
94// Migrate runs the database migrations
95func (db *DB) Migrate(ctx context.Context) error {
96 // Read all migration files
97 entries, err := schemaFS.ReadDir("schema")
98 if err != nil {
99 return fmt.Errorf("failed to read schema directory: %w", err)
100 }
101
102 // Filter and validate migration files
103 var migrations []string
104 migrationPattern := regexp.MustCompile(`^(\d{3})-.*\.sql$`)
105 for _, entry := range entries {
106 if entry.IsDir() {
107 continue
108 }
109 if !migrationPattern.MatchString(entry.Name()) {
110 continue
111 }
112 migrations = append(migrations, entry.Name())
113 }
114
115 // Sort migrations by number
116 sort.Strings(migrations)
117
118 // Check for duplicate migration numbers
119 seenNumbers := make(map[string]string) // number -> filename
120 for _, migration := range migrations {
121 matches := migrationPattern.FindStringSubmatch(migration)
122 if len(matches) < 2 {
123 continue
124 }
125 num := matches[1]
126 if existing, ok := seenNumbers[num]; ok {
127 return fmt.Errorf("duplicate migration number %s: %s and %s", num, existing, migration)
128 }
129 seenNumbers[num] = migration
130 }
131
132 // Get executed migrations
133 executedMigrations := make(map[int]bool)
134 var tableName string
135 err = db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
136 row := rx.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations'")
137 return row.Scan(&tableName)
138 })
139
140 if err == nil {
141 // Migrations table exists, load executed migrations
142 err = db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
143 rows, err := rx.Query("SELECT migration_number FROM migrations")
144 if err != nil {
145 return fmt.Errorf("failed to query executed migrations: %w", err)
146 }
147 defer rows.Close()
148
149 for rows.Next() {
150 var migrationNumber int
151 if err := rows.Scan(&migrationNumber); err != nil {
152 return fmt.Errorf("failed to scan migration number: %w", err)
153 }
154 executedMigrations[migrationNumber] = true
155 }
156 return rows.Err()
157 })
158 if err != nil {
159 return fmt.Errorf("failed to load executed migrations: %w", err)
160 }
161 } else if !errors.Is(err, sql.ErrNoRows) {
162 // Migrations table doesn't exist - executedMigrations remains empty
163 slog.Info("migrations table not found, running all migrations")
164 }
165
166 // Run any migrations that haven't been executed
167 for _, migration := range migrations {
168 // Extract migration number from filename (e.g., "001-base.sql" -> 001)
169 matches := migrationPattern.FindStringSubmatch(migration)
170 if len(matches) != 2 {
171 return fmt.Errorf("invalid migration filename format: %s", migration)
172 }
173
174 migrationNumber, err := strconv.Atoi(matches[1])
175 if err != nil {
176 return fmt.Errorf("failed to parse migration number from %s: %w", migration, err)
177 }
178
179 if !executedMigrations[migrationNumber] {
180 slog.Info("running migration", "file", migration, "number", migrationNumber)
181 if err := db.runMigration(ctx, migration, migrationNumber); err != nil {
182 return err
183 }
184 }
185 }
186
187 return nil
188}
189
190// runMigration executes a single migration file within a transaction,
191// including recording it in the migrations table.
192func (db *DB) runMigration(ctx context.Context, filename string, migrationNumber int) error {
193 content, err := schemaFS.ReadFile("schema/" + filename)
194 if err != nil {
195 return fmt.Errorf("failed to read migration file %s: %w", filename, err)
196 }
197
198 return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
199 if _, err := tx.Exec(string(content)); err != nil {
200 return fmt.Errorf("failed to execute migration %s: %w", filename, err)
201 }
202
203 if _, err := tx.Exec("INSERT INTO migrations (migration_number, migration_name) VALUES (?, ?)", migrationNumber, filename); err != nil {
204 return fmt.Errorf("failed to record migration %s in migrations table: %w", filename, err)
205 }
206
207 return nil
208 })
209}
210
211// Pool returns the underlying connection pool for advanced operations
212func (db *DB) Pool() *Pool {
213 return db.pool
214}
215
216// WithTx runs a function within a database transaction
217func (db *DB) WithTx(ctx context.Context, fn func(*generated.Queries) error) error {
218 return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
219 queries := generated.New(tx.Conn())
220 return fn(queries)
221 })
222}
223
224// WithTxRes runs a function within a database transaction and returns a value
225func WithTxRes[T any](db *DB, ctx context.Context, fn func(*generated.Queries) (T, error)) (T, error) {
226 var result T
227 err := db.WithTx(ctx, func(queries *generated.Queries) error {
228 var err error
229 result, err = fn(queries)
230 return err
231 })
232 return result, err
233}
234
235// Conversation methods (moved from ConversationService)
236
237// CreateConversation creates a new conversation with an optional slug
238func (db *DB) CreateConversation(ctx context.Context, slug *string, userInitiated bool, cwd, model *string) (*generated.Conversation, error) {
239 conversationID, err := generateConversationID()
240 if err != nil {
241 return nil, fmt.Errorf("failed to generate conversation ID: %w", err)
242 }
243 var conversation generated.Conversation
244 err = db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
245 q := generated.New(tx.Conn())
246 conversation, err = q.CreateConversation(ctx, generated.CreateConversationParams{
247 ConversationID: conversationID,
248 Slug: slug,
249 UserInitiated: userInitiated,
250 Cwd: cwd,
251 Model: model,
252 })
253 return err
254 })
255 return &conversation, err
256}
257
258// GetConversationByID retrieves a conversation by its ID
259func (db *DB) GetConversationByID(ctx context.Context, conversationID string) (*generated.Conversation, error) {
260 var conversation generated.Conversation
261 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
262 q := generated.New(rx.Conn())
263 var err error
264 conversation, err = q.GetConversation(ctx, conversationID)
265 return err
266 })
267 if err == sql.ErrNoRows {
268 return nil, fmt.Errorf("conversation not found: %s", conversationID)
269 }
270 return &conversation, err
271}
272
273// GetConversationBySlug retrieves a conversation by its slug
274func (db *DB) GetConversationBySlug(ctx context.Context, slug string) (*generated.Conversation, error) {
275 var conversation generated.Conversation
276 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
277 q := generated.New(rx.Conn())
278 var err error
279 conversation, err = q.GetConversationBySlug(ctx, &slug)
280 return err
281 })
282 if err == sql.ErrNoRows {
283 return nil, fmt.Errorf("conversation not found with slug: %s", slug)
284 }
285 return &conversation, err
286}
287
288// ListConversations retrieves conversations with pagination
289func (db *DB) ListConversations(ctx context.Context, limit, offset int64) ([]generated.Conversation, error) {
290 var conversations []generated.Conversation
291 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
292 q := generated.New(rx.Conn())
293 var err error
294 conversations, err = q.ListConversations(ctx, generated.ListConversationsParams{
295 Limit: limit,
296 Offset: offset,
297 })
298 return err
299 })
300 return conversations, err
301}
302
303// SearchConversations searches for conversations containing the given query in their slug
304func (db *DB) SearchConversations(ctx context.Context, query string, limit, offset int64) ([]generated.Conversation, error) {
305 queryPtr := &query
306 var conversations []generated.Conversation
307 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
308 q := generated.New(rx.Conn())
309 var err error
310 conversations, err = q.SearchConversations(ctx, generated.SearchConversationsParams{
311 Column1: queryPtr,
312 Limit: limit,
313 Offset: offset,
314 })
315 return err
316 })
317 return conversations, err
318}
319
320// SearchConversationsWithMessages searches for conversations containing the query in slug or message content
321func (db *DB) SearchConversationsWithMessages(ctx context.Context, query string, limit, offset int64) ([]generated.Conversation, error) {
322 queryPtr := &query
323 var conversations []generated.Conversation
324 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
325 q := generated.New(rx.Conn())
326 var err error
327 conversations, err = q.SearchConversationsWithMessages(ctx, generated.SearchConversationsWithMessagesParams{
328 Column1: queryPtr,
329 Column2: queryPtr,
330 Column3: queryPtr,
331 Limit: limit,
332 Offset: offset,
333 })
334 return err
335 })
336 return conversations, err
337}
338
339// UpdateConversationSlug updates the slug of a conversation
340func (db *DB) UpdateConversationSlug(ctx context.Context, conversationID, slug string) (*generated.Conversation, error) {
341 var conversation generated.Conversation
342 err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
343 q := generated.New(tx.Conn())
344 var err error
345 conversation, err = q.UpdateConversationSlug(ctx, generated.UpdateConversationSlugParams{
346 Slug: &slug,
347 ConversationID: conversationID,
348 })
349 return err
350 })
351 return &conversation, err
352}
353
354// UpdateConversationCwd updates the working directory for a conversation
355func (db *DB) UpdateConversationCwd(ctx context.Context, conversationID, cwd string) error {
356 return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
357 q := generated.New(tx.Conn())
358 _, err := q.UpdateConversationCwd(ctx, generated.UpdateConversationCwdParams{
359 Cwd: &cwd,
360 ConversationID: conversationID,
361 })
362 return err
363 })
364}
365
366// UpdateConversationModel sets the model for a conversation that doesn't have one yet.
367// This is used to backfill the model for conversations created before the model column existed.
368func (db *DB) UpdateConversationModel(ctx context.Context, conversationID, model string) error {
369 return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
370 q := generated.New(tx.Conn())
371 return q.UpdateConversationModel(ctx, generated.UpdateConversationModelParams{
372 Model: &model,
373 ConversationID: conversationID,
374 })
375 })
376}
377
378// Message methods (moved from MessageService)
379
380// MessageType represents the type of message
381type MessageType string
382
383const (
384 MessageTypeUser MessageType = "user"
385 MessageTypeAgent MessageType = "agent"
386 MessageTypeTool MessageType = "tool"
387 MessageTypeSystem MessageType = "system"
388 MessageTypeError MessageType = "error"
389 MessageTypeGitInfo MessageType = "gitinfo" // user-visible only, not sent to LLM
390)
391
392// CreateMessageParams contains parameters for creating a message
393type CreateMessageParams struct {
394 ConversationID string
395 Type MessageType
396 LLMData interface{} // Will be JSON marshalled
397 UserData interface{} // Will be JSON marshalled
398 UsageData interface{} // Will be JSON marshalled
399 DisplayData interface{} // Will be JSON marshalled, tool-specific display content
400 ExcludedFromContext bool // If true, message is stored but not sent to LLM
401}
402
403// CreateMessage creates a new message
404func (db *DB) CreateMessage(ctx context.Context, params CreateMessageParams) (*generated.Message, error) {
405 messageID := uuid.New().String()
406
407 // Marshal JSON fields
408 var llmDataJSON, userDataJSON, usageDataJSON, displayDataJSON *string
409
410 if params.LLMData != nil {
411 data, err := json.Marshal(params.LLMData)
412 if err != nil {
413 return nil, fmt.Errorf("failed to marshal LLM data: %w", err)
414 }
415 str := string(data)
416 llmDataJSON = &str
417 }
418
419 if params.UserData != nil {
420 data, err := json.Marshal(params.UserData)
421 if err != nil {
422 return nil, fmt.Errorf("failed to marshal user data: %w", err)
423 }
424 str := string(data)
425 userDataJSON = &str
426 }
427
428 if params.UsageData != nil {
429 data, err := json.Marshal(params.UsageData)
430 if err != nil {
431 return nil, fmt.Errorf("failed to marshal usage data: %w", err)
432 }
433 str := string(data)
434 usageDataJSON = &str
435 }
436
437 if params.DisplayData != nil {
438 data, err := json.Marshal(params.DisplayData)
439 if err != nil {
440 return nil, fmt.Errorf("failed to marshal display data: %w", err)
441 }
442 str := string(data)
443 displayDataJSON = &str
444 }
445
446 var message generated.Message
447 err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
448 q := generated.New(tx.Conn())
449
450 // Get next sequence_id for this conversation
451 sequenceID, err := q.GetNextSequenceID(ctx, params.ConversationID)
452 if err != nil {
453 return fmt.Errorf("failed to get next sequence ID: %w", err)
454 }
455
456 message, err = q.CreateMessage(ctx, generated.CreateMessageParams{
457 MessageID: messageID,
458 ConversationID: params.ConversationID,
459 SequenceID: sequenceID,
460 Type: string(params.Type),
461 LlmData: llmDataJSON,
462 UserData: userDataJSON,
463 UsageData: usageDataJSON,
464 DisplayData: displayDataJSON,
465 ExcludedFromContext: params.ExcludedFromContext,
466 })
467 return err
468 })
469 return &message, err
470}
471
472// GetMessageByID retrieves a message by its ID
473func (db *DB) GetMessageByID(ctx context.Context, messageID string) (*generated.Message, error) {
474 var message generated.Message
475 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
476 q := generated.New(rx.Conn())
477 var err error
478 message, err = q.GetMessage(ctx, messageID)
479 return err
480 })
481 if err == sql.ErrNoRows {
482 return nil, fmt.Errorf("message not found: %s", messageID)
483 }
484 return &message, err
485}
486
487// ListMessagesByConversationPaginated retrieves messages in a conversation with pagination
488func (db *DB) ListMessagesByConversationPaginated(ctx context.Context, conversationID string, limit, offset int64) ([]generated.Message, error) {
489 var messages []generated.Message
490 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
491 q := generated.New(rx.Conn())
492 var err error
493 messages, err = q.ListMessagesPaginated(ctx, generated.ListMessagesPaginatedParams{
494 ConversationID: conversationID,
495 Limit: limit,
496 Offset: offset,
497 })
498 return err
499 })
500 return messages, err
501}
502
503// ListMessages retrieves all messages in a conversation ordered by sequence
504func (db *DB) ListMessages(ctx context.Context, conversationID string) ([]generated.Message, error) {
505 var messages []generated.Message
506 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
507 q := generated.New(rx.Conn())
508 var err error
509 messages, err = q.ListMessages(ctx, conversationID)
510 return err
511 })
512 return messages, err
513}
514
515// ListMessagesForContext retrieves messages that should be sent to the LLM (excludes excluded_from_context=true)
516func (db *DB) ListMessagesForContext(ctx context.Context, conversationID string) ([]generated.Message, error) {
517 var messages []generated.Message
518 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
519 q := generated.New(rx.Conn())
520 var err error
521 messages, err = q.ListMessagesForContext(ctx, conversationID)
522 return err
523 })
524 return messages, err
525}
526
527// ListMessagesByType retrieves messages of a specific type in a conversation
528func (db *DB) ListMessagesByType(ctx context.Context, conversationID string, messageType MessageType) ([]generated.Message, error) {
529 var messages []generated.Message
530 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
531 q := generated.New(rx.Conn())
532 var err error
533 messages, err = q.ListMessagesByType(ctx, generated.ListMessagesByTypeParams{
534 ConversationID: conversationID,
535 Type: string(messageType),
536 })
537 return err
538 })
539 return messages, err
540}
541
542// GetLatestMessage retrieves the latest message in a conversation
543func (db *DB) GetLatestMessage(ctx context.Context, conversationID string) (*generated.Message, error) {
544 var message generated.Message
545 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
546 q := generated.New(rx.Conn())
547 var err error
548 message, err = q.GetLatestMessage(ctx, conversationID)
549 return err
550 })
551 if err == sql.ErrNoRows {
552 return nil, fmt.Errorf("no messages found in conversation: %s", conversationID)
553 }
554 return &message, err
555}
556
557// CountMessagesByType returns the number of messages of a specific type in a conversation
558func (db *DB) CountMessagesByType(ctx context.Context, conversationID string, messageType MessageType) (int64, error) {
559 var count int64
560 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
561 q := generated.New(rx.Conn())
562 var err error
563 count, err = q.CountMessagesByType(ctx, generated.CountMessagesByTypeParams{
564 ConversationID: conversationID,
565 Type: string(messageType),
566 })
567 return err
568 })
569 return count, err
570}
571
572// Queries provides read-only access to generated queries within a read transaction
573func (db *DB) Queries(ctx context.Context, fn func(*generated.Queries) error) error {
574 return db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
575 q := generated.New(rx.Conn())
576 return fn(q)
577 })
578}
579
580// QueriesTx provides read-write access to generated queries within a write transaction
581func (db *DB) QueriesTx(ctx context.Context, fn func(*generated.Queries) error) error {
582 return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
583 q := generated.New(tx.Conn())
584 return fn(q)
585 })
586}
587
588// ListArchivedConversations retrieves archived conversations with pagination
589func (db *DB) ListArchivedConversations(ctx context.Context, limit, offset int64) ([]generated.Conversation, error) {
590 var conversations []generated.Conversation
591 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
592 q := generated.New(rx.Conn())
593 var err error
594 conversations, err = q.ListArchivedConversations(ctx, generated.ListArchivedConversationsParams{
595 Limit: limit,
596 Offset: offset,
597 })
598 return err
599 })
600 return conversations, err
601}
602
603// SearchArchivedConversations searches for archived conversations containing the given query in their slug
604func (db *DB) SearchArchivedConversations(ctx context.Context, query string, limit, offset int64) ([]generated.Conversation, error) {
605 queryPtr := &query
606 var conversations []generated.Conversation
607 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
608 q := generated.New(rx.Conn())
609 var err error
610 conversations, err = q.SearchArchivedConversations(ctx, generated.SearchArchivedConversationsParams{
611 Column1: queryPtr,
612 Limit: limit,
613 Offset: offset,
614 })
615 return err
616 })
617 return conversations, err
618}
619
620// ArchiveConversation archives a conversation
621func (db *DB) ArchiveConversation(ctx context.Context, conversationID string) (*generated.Conversation, error) {
622 var conversation generated.Conversation
623 err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
624 q := generated.New(tx.Conn())
625 var err error
626 conversation, err = q.ArchiveConversation(ctx, conversationID)
627 return err
628 })
629 return &conversation, err
630}
631
632// UnarchiveConversation unarchives a conversation
633func (db *DB) UnarchiveConversation(ctx context.Context, conversationID string) (*generated.Conversation, error) {
634 var conversation generated.Conversation
635 err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
636 q := generated.New(tx.Conn())
637 var err error
638 conversation, err = q.UnarchiveConversation(ctx, conversationID)
639 return err
640 })
641 return &conversation, err
642}
643
644// DeleteConversation deletes a conversation and all its messages
645func (db *DB) DeleteConversation(ctx context.Context, conversationID string) error {
646 return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
647 q := generated.New(tx.Conn())
648 // Delete messages first (foreign key constraint)
649 if err := q.DeleteConversationMessages(ctx, conversationID); err != nil {
650 return fmt.Errorf("failed to delete messages: %w", err)
651 }
652 return q.DeleteConversation(ctx, conversationID)
653 })
654}
655
656// CreateSubagentConversation creates a new subagent conversation with a parent
657func (db *DB) CreateSubagentConversation(ctx context.Context, slug, parentID string, cwd *string) (*generated.Conversation, error) {
658 conversationID, err := generateConversationID()
659 if err != nil {
660 return nil, fmt.Errorf("failed to generate conversation ID: %w", err)
661 }
662 var conversation generated.Conversation
663 err = db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
664 q := generated.New(tx.Conn())
665 conversation, err = q.CreateSubagentConversation(ctx, generated.CreateSubagentConversationParams{
666 ConversationID: conversationID,
667 Slug: &slug,
668 Cwd: cwd,
669 ParentConversationID: &parentID,
670 })
671 return err
672 })
673 return &conversation, err
674}
675
676// GetSubagents retrieves all subagent conversations for a parent conversation
677func (db *DB) GetSubagents(ctx context.Context, parentID string) ([]generated.Conversation, error) {
678 var conversations []generated.Conversation
679 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
680 q := generated.New(rx.Conn())
681 var err error
682 conversations, err = q.GetSubagents(ctx, &parentID)
683 return err
684 })
685 return conversations, err
686}
687
688// GetConversationBySlugAndParent retrieves a subagent conversation by slug and parent ID
689func (db *DB) GetConversationBySlugAndParent(ctx context.Context, slug, parentID string) (*generated.Conversation, error) {
690 var conversation generated.Conversation
691 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
692 q := generated.New(rx.Conn())
693 var err error
694 conversation, err = q.GetConversationBySlugAndParent(ctx, generated.GetConversationBySlugAndParentParams{
695 Slug: &slug,
696 ParentConversationID: &parentID,
697 })
698 return err
699 })
700 if err == sql.ErrNoRows {
701 return nil, nil // Not found, return nil without error
702 }
703 return &conversation, err
704}
705
706// SubagentDBAdapter adapts *DB to the claudetool.SubagentDB interface.
707type SubagentDBAdapter struct {
708 DB *DB
709}
710
711// GetOrCreateSubagentConversation implements claudetool.SubagentDB.
712// Returns the conversation ID and the actual slug used (may differ if a suffix was added).
713func (a *SubagentDBAdapter) GetOrCreateSubagentConversation(ctx context.Context, slug, parentID, cwd string) (string, string, error) {
714 // Try to find existing with exact slug
715 existing, err := a.DB.GetConversationBySlugAndParent(ctx, slug, parentID)
716 if err != nil {
717 return "", "", err
718 }
719 if existing != nil {
720 return existing.ConversationID, *existing.Slug, nil
721 }
722
723 // Try to create new, handling unique constraint violations by appending numbers
724 baseSlug := slug
725 actualSlug := slug
726 for attempt := 0; attempt < 100; attempt++ {
727 conv, err := a.DB.CreateSubagentConversation(ctx, actualSlug, parentID, &cwd)
728 if err == nil {
729 return conv.ConversationID, actualSlug, nil
730 }
731
732 // Check if this is a unique constraint violation
733 errLower := strings.ToLower(err.Error())
734 if strings.Contains(errLower, "unique constraint") ||
735 strings.Contains(errLower, "duplicate") {
736 // Try with a numeric suffix
737 actualSlug = fmt.Sprintf("%s-%d", baseSlug, attempt+1)
738 continue
739 }
740
741 // Some other error occurred
742 return "", "", err
743 }
744
745 return "", "", fmt.Errorf("failed to create unique subagent slug after 100 attempts")
746}
747
748// InsertLLMRequest inserts a new LLM request record
749func (db *DB) InsertLLMRequest(ctx context.Context, params generated.InsertLLMRequestParams) (*generated.LlmRequest, error) {
750 var request generated.LlmRequest
751 err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
752 q := generated.New(tx.Conn())
753
754 // If we have a conversation ID and request body, try to find common prefix
755 if params.ConversationID != nil && params.RequestBody != nil {
756 // Get the last request for this conversation
757 lastReq, err := q.GetLastRequestForConversation(ctx, params.ConversationID)
758 if err == nil {
759 // Found a previous request - compute common prefix
760 prefixLen, fullPrevBody := computeSharedPrefixLength(lastReq, *params.RequestBody)
761 if prefixLen > 0 {
762 // Store only the suffix
763 suffix := (*params.RequestBody)[prefixLen:]
764 params.RequestBody = &suffix
765 params.PrefixRequestID = &lastReq.ID
766 prefixLen64 := int64(prefixLen)
767 params.PrefixLength = &prefixLen64
768 _ = fullPrevBody // silence unused warning, used for computing prefix
769 }
770 }
771 // If no previous request found or error, just store the full body
772 }
773
774 var err error
775 request, err = q.InsertLLMRequest(ctx, params)
776 return err
777 })
778 return &request, err
779}
780
781// computeSharedPrefixLength computes the length of the shared prefix between
782// the full previous request body (reconstructed by walking the chain) and the new request body.
783// It returns the prefix length and the fully reconstructed previous body.
784func computeSharedPrefixLength(prevReq generated.LlmRequest, newBody string) (int, string) {
785 // Get the stored body (which may be just a suffix if prevReq has a prefix reference)
786 prevBody := ""
787 if prevReq.RequestBody != nil {
788 prevBody = *prevReq.RequestBody
789 }
790
791 // If the previous request has a prefix reference, we need to account for that
792 // by prepending the prefix length worth of bytes from the new body.
793 // This works because in a conversation, request N+1 typically starts with
794 // all of request N plus new content at the end.
795 if prevReq.PrefixLength != nil && *prevReq.PrefixLength > 0 {
796 // The previous request's full body would be:
797 // [first prefix_length bytes that match its parent] + [stored suffix]
798 // If the new body is a continuation, its first prefix_length bytes
799 // should match those same bytes.
800 prefixLen := int(*prevReq.PrefixLength)
801 if prefixLen <= len(newBody) {
802 prevBody = newBody[:prefixLen] + prevBody
803 }
804 }
805
806 // Compute byte-by-byte shared prefix between reconstructed prevBody and newBody
807 minLen := len(prevBody)
808 if len(newBody) < minLen {
809 minLen = len(newBody)
810 }
811
812 prefixLen := 0
813 for i := 0; i < minLen; i++ {
814 if prevBody[i] != newBody[i] {
815 break
816 }
817 prefixLen++
818 }
819
820 // Only use prefix deduplication if we save meaningful space
821 // (at least 100 bytes saved)
822 if prefixLen < 100 {
823 return 0, prevBody
824 }
825
826 return prefixLen, prevBody
827}
828
829// ListRecentLLMRequests returns the most recent LLM requests
830func (db *DB) ListRecentLLMRequests(ctx context.Context, limit int64) ([]generated.ListRecentLLMRequestsRow, error) {
831 var requests []generated.ListRecentLLMRequestsRow
832 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
833 q := generated.New(rx.Conn())
834 var err error
835 requests, err = q.ListRecentLLMRequests(ctx, limit)
836 return err
837 })
838 return requests, err
839}
840
841// GetLLMRequestBody returns the raw request body for a request
842func (db *DB) GetLLMRequestBody(ctx context.Context, id int64) (*string, error) {
843 var body *string
844 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
845 q := generated.New(rx.Conn())
846 var err error
847 body, err = q.GetLLMRequestBody(ctx, id)
848 return err
849 })
850 return body, err
851}
852
853// GetLLMResponseBody returns the raw response body for a request
854func (db *DB) GetLLMResponseBody(ctx context.Context, id int64) (*string, error) {
855 var body *string
856 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
857 q := generated.New(rx.Conn())
858 var err error
859 body, err = q.GetLLMResponseBody(ctx, id)
860 return err
861 })
862 return body, err
863}
864
865// GetFullLLMRequestBody reconstructs the full request body for a request,
866// following the prefix chain if necessary.
867func (db *DB) GetFullLLMRequestBody(ctx context.Context, requestID int64) (string, error) {
868 var result string
869 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
870 q := generated.New(rx.Conn())
871 return reconstructRequestBody(ctx, q, requestID, &result)
872 })
873 return result, err
874}
875
876// reconstructRequestBody recursively reconstructs the full request body
877func reconstructRequestBody(ctx context.Context, q *generated.Queries, requestID int64, result *string) error {
878 req, err := q.GetLLMRequestByID(ctx, requestID)
879 if err != nil {
880 return err
881 }
882
883 suffix := ""
884 if req.RequestBody != nil {
885 suffix = *req.RequestBody
886 }
887
888 if req.PrefixRequestID == nil || req.PrefixLength == nil || *req.PrefixLength == 0 {
889 // No prefix reference - the stored body is the full body
890 *result = suffix
891 return nil
892 }
893
894 // Recursively get the parent's full body
895 var parentBody string
896 if err := reconstructRequestBody(ctx, q, *req.PrefixRequestID, &parentBody); err != nil {
897 return err
898 }
899
900 // The full body is the first prefix_length bytes from the parent + our suffix
901 prefixLen := int(*req.PrefixLength)
902 if prefixLen > len(parentBody) {
903 prefixLen = len(parentBody)
904 }
905 *result = parentBody[:prefixLen] + suffix
906 return nil
907}
908
909// GetModels returns all models from the database
910func (db *DB) GetModels(ctx context.Context) ([]generated.Model, error) {
911 var models []generated.Model
912 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
913 q := generated.New(rx.Conn())
914 var err error
915 models, err = q.GetModels(ctx)
916 return err
917 })
918 return models, err
919}
920
921// GetModel returns a model by ID
922func (db *DB) GetModel(ctx context.Context, modelID string) (*generated.Model, error) {
923 var model generated.Model
924 err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error {
925 q := generated.New(rx.Conn())
926 var err error
927 model, err = q.GetModel(ctx, modelID)
928 return err
929 })
930 if err != nil {
931 return nil, err
932 }
933 return &model, nil
934}
935
936// CreateModel creates a new model
937func (db *DB) CreateModel(ctx context.Context, params generated.CreateModelParams) (*generated.Model, error) {
938 var model generated.Model
939 err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
940 q := generated.New(tx.Conn())
941 var err error
942 model, err = q.CreateModel(ctx, params)
943 return err
944 })
945 if err != nil {
946 return nil, err
947 }
948 return &model, nil
949}
950
951// UpdateModel updates a model
952func (db *DB) UpdateModel(ctx context.Context, params generated.UpdateModelParams) (*generated.Model, error) {
953 var model generated.Model
954 err := db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
955 q := generated.New(tx.Conn())
956 var err error
957 model, err = q.UpdateModel(ctx, params)
958 return err
959 })
960 if err != nil {
961 return nil, err
962 }
963 return &model, nil
964}
965
966// DeleteModel deletes a model
967func (db *DB) DeleteModel(ctx context.Context, modelID string) error {
968 return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
969 q := generated.New(tx.Conn())
970 return q.DeleteModel(ctx, modelID)
971 })
972}