diff --git a/cmd/go2ts.go b/cmd/go2ts.go index 2f1c383bb02532b41178af221b2f2607761920a6..d5b7c8cacee71381b113550a0107a7bed23c2403 100644 --- a/cmd/go2ts.go +++ b/cmd/go2ts.go @@ -91,6 +91,7 @@ type apiMessageForTS struct { type conversationStateForTS struct { ConversationID string `json:"conversation_id"` Working bool `json:"working"` + Model string `json:"model,omitempty"` } type conversationWithStateForTS struct { @@ -102,6 +103,7 @@ type conversationWithStateForTS struct { Cwd *string `json:"cwd"` Archived bool `json:"archived"` ParentConversationID *string `json:"parent_conversation_id"` + Model *string `json:"model"` Working bool `json:"working"` } diff --git a/db/conversations_test.go b/db/conversations_test.go index b91aa53c6f068e195972871e84d103701940acb5..84f404dc0eb1950469d80479ebda5fa21ace0e8d 100644 --- a/db/conversations_test.go +++ b/db/conversations_test.go @@ -33,7 +33,7 @@ func TestConversationService_Create(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - conv, err := db.CreateConversation(ctx, tt.slug, true, nil) + conv, err := db.CreateConversation(ctx, tt.slug, true, nil, nil) if err != nil { t.Errorf("Create() error = %v", err) return @@ -73,7 +73,7 @@ func TestConversationService_GetByID(t *testing.T) { defer cancel() // Create a test conversation - created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -108,7 +108,7 @@ func TestConversationService_GetBySlug(t *testing.T) { defer cancel() // Create a test conversation with slug - created, err := db.CreateConversation(ctx, stringPtr("test-slug"), true, nil) + created, err := db.CreateConversation(ctx, stringPtr("test-slug"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -143,7 +143,7 @@ func TestConversationService_UpdateSlug(t *testing.T) { defer cancel() // Create a test conversation - created, err := db.CreateConversation(ctx, nil, true, nil) + created, err := db.CreateConversation(ctx, nil, true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -177,7 +177,7 @@ func TestConversationService_List(t *testing.T) { // Create multiple test conversations for i := 0; i < 5; i++ { slug := stringPtr("conversation-" + string(rune('a'+i))) - _, err := db.CreateConversation(ctx, slug, true, nil) + _, err := db.CreateConversation(ctx, slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation %d: %v", i, err) } @@ -209,7 +209,7 @@ func TestConversationService_Search(t *testing.T) { // Create test conversations with different slugs testCases := []string{"project-alpha", "project-beta", "work-task", "personal-note"} for _, slug := range testCases { - _, err := db.CreateConversation(ctx, stringPtr(slug), true, nil) + _, err := db.CreateConversation(ctx, stringPtr(slug), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation with slug %s: %v", slug, err) } @@ -243,7 +243,7 @@ func TestConversationService_Touch(t *testing.T) { defer cancel() // Create a test conversation - created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -278,7 +278,7 @@ func TestConversationService_Delete(t *testing.T) { defer cancel() // Create a test conversation - created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -324,7 +324,7 @@ func TestConversationService_Count(t *testing.T) { // Create test conversations for i := 0; i < 3; i++ { - _, err := db.CreateConversation(ctx, stringPtr("conversation-"+string(rune('a'+i))), true, nil) + _, err := db.CreateConversation(ctx, stringPtr("conversation-"+string(rune('a'+i))), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation %d: %v", i, err) } @@ -354,13 +354,13 @@ func TestConversationService_MultipleNullSlugs(t *testing.T) { defer cancel() // Create multiple conversations with null slugs - this should not fail - conv1, err := db.CreateConversation(ctx, nil, true, nil) + conv1, err := db.CreateConversation(ctx, nil, true, nil, nil) if err != nil { t.Errorf("Create() first conversation error = %v", err) return } - conv2, err := db.CreateConversation(ctx, nil, true, nil) + conv2, err := db.CreateConversation(ctx, nil, true, nil, nil) if err != nil { t.Errorf("Create() second conversation error = %v", err) return @@ -389,14 +389,14 @@ func TestConversationService_SlugUniquenessWhenNotNull(t *testing.T) { defer cancel() // Create first conversation with a slug - _, err := db.CreateConversation(ctx, stringPtr("unique-slug"), true, nil) + _, err := db.CreateConversation(ctx, stringPtr("unique-slug"), true, nil, nil) if err != nil { t.Errorf("Create() first conversation error = %v", err) return } // Try to create second conversation with the same slug - this should fail - _, err = db.CreateConversation(ctx, stringPtr("unique-slug"), true, nil) + _, err = db.CreateConversation(ctx, stringPtr("unique-slug"), true, nil, nil) if err == nil { t.Error("Expected error when creating conversation with duplicate slug") return @@ -416,7 +416,7 @@ func TestConversationService_ArchiveUnarchive(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -450,12 +450,12 @@ func TestConversationService_ListArchivedConversations(t *testing.T) { defer cancel() // Create test conversations - conv1, err := db.CreateConversation(ctx, stringPtr("test-conversation-1"), true, nil) + conv1, err := db.CreateConversation(ctx, stringPtr("test-conversation-1"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation 1: %v", err) } - conv2, err := db.CreateConversation(ctx, stringPtr("test-conversation-2"), true, nil) + conv2, err := db.CreateConversation(ctx, stringPtr("test-conversation-2"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation 2: %v", err) } @@ -498,12 +498,12 @@ func TestConversationService_SearchArchivedConversations(t *testing.T) { defer cancel() // Create test conversations - conv1, err := db.CreateConversation(ctx, stringPtr("test-conversation-search-1"), true, nil) + conv1, err := db.CreateConversation(ctx, stringPtr("test-conversation-search-1"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation 1: %v", err) } - conv2, err := db.CreateConversation(ctx, stringPtr("another-conversation"), true, nil) + conv2, err := db.CreateConversation(ctx, stringPtr("another-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation 2: %v", err) } @@ -544,7 +544,7 @@ func TestConversationService_DeleteConversation(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-to-delete"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-to-delete"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -580,7 +580,7 @@ func TestConversationService_UpdateConversationCwd(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-cwd"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-cwd"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } diff --git a/db/db.go b/db/db.go index 9f7e3fdd71d1cfe13a88407d09abdaeaba175476..05571114f221fe1f7aaca85478bece701183204f 100644 --- a/db/db.go +++ b/db/db.go @@ -233,7 +233,7 @@ func WithTxRes[T any](db *DB, ctx context.Context, fn func(*generated.Queries) ( // Conversation methods (moved from ConversationService) // CreateConversation creates a new conversation with an optional slug -func (db *DB) CreateConversation(ctx context.Context, slug *string, userInitiated bool, cwd *string) (*generated.Conversation, error) { +func (db *DB) CreateConversation(ctx context.Context, slug *string, userInitiated bool, cwd, model *string) (*generated.Conversation, error) { conversationID, err := generateConversationID() if err != nil { return nil, fmt.Errorf("failed to generate conversation ID: %w", err) @@ -246,6 +246,7 @@ func (db *DB) CreateConversation(ctx context.Context, slug *string, userInitiate Slug: slug, UserInitiated: userInitiated, Cwd: cwd, + Model: model, }) return err }) @@ -360,6 +361,18 @@ func (db *DB) UpdateConversationCwd(ctx context.Context, conversationID, cwd str }) } +// UpdateConversationModel sets the model for a conversation that doesn't have one yet. +// This is used to backfill the model for conversations created before the model column existed. +func (db *DB) UpdateConversationModel(ctx context.Context, conversationID, model string) error { + return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error { + q := generated.New(tx.Conn()) + return q.UpdateConversationModel(ctx, generated.UpdateConversationModelParams{ + Model: &model, + ConversationID: conversationID, + }) + }) +} + // Message methods (moved from MessageService) // MessageType represents the type of message diff --git a/db/db_test.go b/db/db_test.go index 23becd394d77408533a889f1209500bf6f2679c4..3d4cfdd0ff4e3a8d20353be56dbb5946643008ba 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -122,6 +122,7 @@ func TestDB_WithTx(t *testing.T) { ConversationID: "test-conv-1", Slug: stringPtr("test-slug"), UserInitiated: true, + Model: nil, }) return err }) @@ -227,7 +228,7 @@ func TestLLMRequestPrefixDeduplication(t *testing.T) { // Create a conversation first slug := "test-prefix-conv" - conv, err := db.CreateConversation(ctx, &slug, true, nil) + conv, err := db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -343,7 +344,7 @@ func TestLLMRequestNoPrefixForShortOverlap(t *testing.T) { defer cancel() slug := "test-short-conv" - conv, err := db.CreateConversation(ctx, &slug, true, nil) + conv, err := db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -437,7 +438,7 @@ func TestLLMRequestRealisticConversation(t *testing.T) { defer cancel() slug := "test-realistic-conv" - conv, err := db.CreateConversation(ctx, &slug, true, nil) + conv, err := db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -573,7 +574,7 @@ func TestLLMRequestOpenAIStyle(t *testing.T) { defer cancel() slug := "test-openai-conv" - conv, err := db.CreateConversation(ctx, &slug, true, nil) + conv, err := db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } diff --git a/db/generated/conversations.sql.go b/db/generated/conversations.sql.go index f5ab40b65afdcd2b71c8b1f5238c8c4584ea630f..d7d146d88d7edf490485f0c8a47f4bd48e47251f 100644 --- a/db/generated/conversations.sql.go +++ b/db/generated/conversations.sql.go @@ -13,7 +13,7 @@ const archiveConversation = `-- name: ArchiveConversation :one UPDATE conversations SET archived = TRUE, updated_at = CURRENT_TIMESTAMP WHERE conversation_id = ? -RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id +RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model ` func (q *Queries) ArchiveConversation(ctx context.Context, conversationID string) (Conversation, error) { @@ -28,6 +28,7 @@ func (q *Queries) ArchiveConversation(ctx context.Context, conversationID string &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } @@ -55,9 +56,9 @@ func (q *Queries) CountConversations(ctx context.Context) (int64, error) { } const createConversation = `-- name: CreateConversation :one -INSERT INTO conversations (conversation_id, slug, user_initiated, cwd) -VALUES (?, ?, ?, ?) -RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id +INSERT INTO conversations (conversation_id, slug, user_initiated, cwd, model) +VALUES (?, ?, ?, ?, ?) +RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model ` type CreateConversationParams struct { @@ -65,6 +66,7 @@ type CreateConversationParams struct { Slug *string `json:"slug"` UserInitiated bool `json:"user_initiated"` Cwd *string `json:"cwd"` + Model *string `json:"model"` } func (q *Queries) CreateConversation(ctx context.Context, arg CreateConversationParams) (Conversation, error) { @@ -73,6 +75,7 @@ func (q *Queries) CreateConversation(ctx context.Context, arg CreateConversation arg.Slug, arg.UserInitiated, arg.Cwd, + arg.Model, ) var i Conversation err := row.Scan( @@ -84,6 +87,7 @@ func (q *Queries) CreateConversation(ctx context.Context, arg CreateConversation &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } @@ -91,7 +95,7 @@ func (q *Queries) CreateConversation(ctx context.Context, arg CreateConversation const createSubagentConversation = `-- name: CreateSubagentConversation :one INSERT INTO conversations (conversation_id, slug, user_initiated, cwd, parent_conversation_id) VALUES (?, ?, FALSE, ?, ?) -RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id +RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model ` type CreateSubagentConversationParams struct { @@ -118,6 +122,7 @@ func (q *Queries) CreateSubagentConversation(ctx context.Context, arg CreateSuba &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } @@ -133,7 +138,7 @@ func (q *Queries) DeleteConversation(ctx context.Context, conversationID string) } const getConversation = `-- name: GetConversation :one -SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations +SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations WHERE conversation_id = ? ` @@ -149,12 +154,13 @@ func (q *Queries) GetConversation(ctx context.Context, conversationID string) (C &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } const getConversationBySlug = `-- name: GetConversationBySlug :one -SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations +SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations WHERE slug = ? ` @@ -170,12 +176,13 @@ func (q *Queries) GetConversationBySlug(ctx context.Context, slug *string) (Conv &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } const getConversationBySlugAndParent = `-- name: GetConversationBySlugAndParent :one -SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations +SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations WHERE slug = ? AND parent_conversation_id = ? ` @@ -196,12 +203,13 @@ func (q *Queries) GetConversationBySlugAndParent(ctx context.Context, arg GetCon &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } const getSubagents = `-- name: GetSubagents :many -SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations +SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations WHERE parent_conversation_id = ? ORDER BY created_at ASC ` @@ -224,6 +232,7 @@ func (q *Queries) GetSubagents(ctx context.Context, parentConversationID *string &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ); err != nil { return nil, err } @@ -239,7 +248,7 @@ func (q *Queries) GetSubagents(ctx context.Context, parentConversationID *string } const listArchivedConversations = `-- name: ListArchivedConversations :many -SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations +SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations WHERE archived = TRUE ORDER BY updated_at DESC LIMIT ? OFFSET ? @@ -268,6 +277,7 @@ func (q *Queries) ListArchivedConversations(ctx context.Context, arg ListArchive &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ); err != nil { return nil, err } @@ -283,7 +293,7 @@ func (q *Queries) ListArchivedConversations(ctx context.Context, arg ListArchive } const listConversations = `-- name: ListConversations :many -SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations +SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations WHERE archived = FALSE AND parent_conversation_id IS NULL ORDER BY updated_at DESC LIMIT ? OFFSET ? @@ -312,6 +322,7 @@ func (q *Queries) ListConversations(ctx context.Context, arg ListConversationsPa &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ); err != nil { return nil, err } @@ -327,7 +338,7 @@ func (q *Queries) ListConversations(ctx context.Context, arg ListConversationsPa } const searchArchivedConversations = `-- name: SearchArchivedConversations :many -SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations +SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations WHERE slug LIKE '%' || ? || '%' AND archived = TRUE ORDER BY updated_at DESC LIMIT ? OFFSET ? @@ -357,6 +368,7 @@ func (q *Queries) SearchArchivedConversations(ctx context.Context, arg SearchArc &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ); err != nil { return nil, err } @@ -372,7 +384,7 @@ func (q *Queries) SearchArchivedConversations(ctx context.Context, arg SearchArc } const searchConversations = `-- name: SearchConversations :many -SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations +SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations WHERE slug LIKE '%' || ? || '%' AND archived = FALSE AND parent_conversation_id IS NULL ORDER BY updated_at DESC LIMIT ? OFFSET ? @@ -402,6 +414,7 @@ func (q *Queries) SearchConversations(ctx context.Context, arg SearchConversatio &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ); err != nil { return nil, err } @@ -417,7 +430,7 @@ func (q *Queries) SearchConversations(ctx context.Context, arg SearchConversatio } const searchConversationsWithMessages = `-- name: SearchConversationsWithMessages :many -SELECT DISTINCT c.conversation_id, c.slug, c.user_initiated, c.created_at, c.updated_at, c.cwd, c.archived, c.parent_conversation_id FROM conversations c +SELECT DISTINCT c.conversation_id, c.slug, c.user_initiated, c.created_at, c.updated_at, c.cwd, c.archived, c.parent_conversation_id, c.model FROM conversations c LEFT JOIN messages m ON c.conversation_id = m.conversation_id AND m.type IN ('user', 'agent') WHERE c.archived = FALSE AND ( @@ -463,6 +476,7 @@ func (q *Queries) SearchConversationsWithMessages(ctx context.Context, arg Searc &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ); err != nil { return nil, err } @@ -481,7 +495,7 @@ const unarchiveConversation = `-- name: UnarchiveConversation :one UPDATE conversations SET archived = FALSE, updated_at = CURRENT_TIMESTAMP WHERE conversation_id = ? -RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id +RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model ` func (q *Queries) UnarchiveConversation(ctx context.Context, conversationID string) (Conversation, error) { @@ -496,6 +510,7 @@ func (q *Queries) UnarchiveConversation(ctx context.Context, conversationID stri &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } @@ -504,7 +519,7 @@ const updateConversationCwd = `-- name: UpdateConversationCwd :one UPDATE conversations SET cwd = ?, updated_at = CURRENT_TIMESTAMP WHERE conversation_id = ? -RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id +RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model ` type UpdateConversationCwdParams struct { @@ -524,15 +539,32 @@ func (q *Queries) UpdateConversationCwd(ctx context.Context, arg UpdateConversat &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } +const updateConversationModel = `-- name: UpdateConversationModel :exec +UPDATE conversations +SET model = ? +WHERE conversation_id = ? AND model IS NULL +` + +type UpdateConversationModelParams struct { + Model *string `json:"model"` + ConversationID string `json:"conversation_id"` +} + +func (q *Queries) UpdateConversationModel(ctx context.Context, arg UpdateConversationModelParams) error { + _, err := q.db.ExecContext(ctx, updateConversationModel, arg.Model, arg.ConversationID) + return err +} + const updateConversationSlug = `-- name: UpdateConversationSlug :one UPDATE conversations SET slug = ?, updated_at = CURRENT_TIMESTAMP WHERE conversation_id = ? -RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id +RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model ` type UpdateConversationSlugParams struct { @@ -552,6 +584,7 @@ func (q *Queries) UpdateConversationSlug(ctx context.Context, arg UpdateConversa &i.Cwd, &i.Archived, &i.ParentConversationID, + &i.Model, ) return i, err } diff --git a/db/generated/models.go b/db/generated/models.go index 4fcfda92a6481fb51e28ddd27f9cf5b6c669242f..5d66089b07b8fe84874c7ecbc9ad6c879f56cbca 100644 --- a/db/generated/models.go +++ b/db/generated/models.go @@ -17,6 +17,7 @@ type Conversation struct { Cwd *string `json:"cwd"` Archived bool `json:"archived"` ParentConversationID *string `json:"parent_conversation_id"` + Model *string `json:"model"` } type LlmRequest struct { diff --git a/db/messages_test.go b/db/messages_test.go index fb5024dc993027380edbafd6e28476c56ee18030..9b4b5634e8aa4a83f3a44ae87d094814bc6094bf 100644 --- a/db/messages_test.go +++ b/db/messages_test.go @@ -21,7 +21,7 @@ func TestMessageService_Create(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -116,7 +116,7 @@ func TestMessageService_GetByID(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -162,7 +162,7 @@ func TestMessageService_ListByConversation(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -216,7 +216,7 @@ func TestMessageService_ListByType(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -263,7 +263,7 @@ func TestMessageService_GetLatest(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -310,7 +310,7 @@ func TestMessageService_Delete(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -351,7 +351,7 @@ func TestMessageService_CountInConversation(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -408,7 +408,7 @@ func TestMessageService_CountByType(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } @@ -465,7 +465,7 @@ func TestMessageService_ListMessagesByConversationPaginated(t *testing.T) { defer cancel() // Create a test conversation - conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-paginated"), true, nil) + conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-paginated"), true, nil, nil) if err != nil { t.Fatalf("Failed to create test conversation: %v", err) } diff --git a/db/query/conversations.sql b/db/query/conversations.sql index 22b2ddfae90c2d33fb30ece19a2e3cc5869e5bdc..a54c7f5e3aeec29e93bdac5ddf7b875ded22357e 100644 --- a/db/query/conversations.sql +++ b/db/query/conversations.sql @@ -1,6 +1,6 @@ -- name: CreateConversation :one -INSERT INTO conversations (conversation_id, slug, user_initiated, cwd) -VALUES (?, ?, ?, ?) +INSERT INTO conversations (conversation_id, slug, user_initiated, cwd, model) +VALUES (?, ?, ?, ?, ?) RETURNING *; -- name: GetConversation :one @@ -102,3 +102,8 @@ ORDER BY created_at ASC; -- name: GetConversationBySlugAndParent :one SELECT * FROM conversations WHERE slug = ? AND parent_conversation_id = ?; + +-- name: UpdateConversationModel :exec +UPDATE conversations +SET model = ? +WHERE conversation_id = ? AND model IS NULL; diff --git a/db/schema/013-add-model.sql b/db/schema/013-add-model.sql new file mode 100644 index 0000000000000000000000000000000000000000..04f2c5076e0e8a27afddc7855b93ad1ca4388e72 --- /dev/null +++ b/db/schema/013-add-model.sql @@ -0,0 +1,4 @@ +-- Add model column to conversations table +-- This stores the LLM model used for the conversation + +ALTER TABLE conversations ADD COLUMN model TEXT; diff --git a/server/cancel_test.go b/server/cancel_test.go index 6917229091fe93cfb1e4a428f3ce69f710e5de5a..4c660bdf98ee062e0040cbe0c4a45105d488d49f 100644 --- a/server/cancel_test.go +++ b/server/cancel_test.go @@ -54,7 +54,7 @@ func TestCancelWithPredictableModel(t *testing.T) { server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } @@ -251,7 +251,7 @@ func TestCancelWithNoActiveConversation(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create a conversation but don't start it - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } @@ -289,7 +289,7 @@ func TestCancelDuringTextGeneration(t *testing.T) { logger := slog.Default() server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } diff --git a/server/conversation_by_slug_test.go b/server/conversation_by_slug_test.go index fda368c1a95e5971348f45b619c598d0407c1e1a..8e085c30b4cb5d19222d658956fc60347b9968d9 100644 --- a/server/conversation_by_slug_test.go +++ b/server/conversation_by_slug_test.go @@ -15,7 +15,7 @@ func TestGetConversationBySlug(t *testing.T) { // Create a conversation with a slug slug := "my-test-slug" - conv, err := h.db.CreateConversation(t.Context(), &slug, true, nil) + conv, err := h.db.CreateConversation(t.Context(), &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } diff --git a/server/conversation_flow_test.go b/server/conversation_flow_test.go index f39be30eb90895c89efb3d5bd4a438bcbf979d9d..4d17e97a6161e31652f10b3576b45acca347480a 100644 --- a/server/conversation_flow_test.go +++ b/server/conversation_flow_test.go @@ -29,7 +29,7 @@ func TestMessageQueuedDuringThinking(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } @@ -160,7 +160,7 @@ func TestContextPreservedAfterCancel(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } diff --git a/server/conversation_list_stream_test.go b/server/conversation_list_stream_test.go index 322c6cf41254e9dc616f96a35063986c5835944a..d9a2358264dc9f22f4bf8fb849546fadb59ec504 100644 --- a/server/conversation_list_stream_test.go +++ b/server/conversation_list_stream_test.go @@ -27,7 +27,7 @@ func TestConversationStreamReceivesListUpdateForNewConversation(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create a conversation to subscribe to - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } @@ -125,11 +125,11 @@ func TestConversationStreamReceivesListUpdateForRename(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create two conversations - conv1, err := database.CreateConversation(context.Background(), nil, true, nil) + conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation 1: %v", err) } - conv2, err := database.CreateConversation(context.Background(), nil, true, nil) + conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation 2: %v", err) } @@ -216,11 +216,11 @@ func TestConversationStreamReceivesListUpdateForDelete(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create two conversations - conv1, err := database.CreateConversation(context.Background(), nil, true, nil) + conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation 1: %v", err) } - conv2, err := database.CreateConversation(context.Background(), nil, true, nil) + conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation 2: %v", err) } @@ -306,11 +306,11 @@ func TestConversationStreamReceivesListUpdateForArchive(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create two conversations - conv1, err := database.CreateConversation(context.Background(), nil, true, nil) + conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation 1: %v", err) } - conv2, err := database.CreateConversation(context.Background(), nil, true, nil) + conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation 2: %v", err) } diff --git a/server/conversation_state_test.go b/server/conversation_state_test.go index 2ac5f24f004b97fc3cc463c925f51d9bd2364e4f..cbc75814be2927ab14ddba78dba718a5f86d3812 100644 --- a/server/conversation_state_test.go +++ b/server/conversation_state_test.go @@ -49,7 +49,7 @@ func TestConversationStateAfterServerRestart(t *testing.T) { ctx := context.Background() // Create a conversation with some messages (simulating previous activity) - conv, err := database.CreateConversation(ctx, nil, true, nil) + conv, err := database.CreateConversation(ctx, nil, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -145,3 +145,98 @@ func TestConversationStateAfterServerRestart(t *testing.T) { t.Errorf("Expected 2 messages, got %d", len(response.Messages)) } } + +// TestModelRestorationAfterServerRestart verifies that when a conversation is +// resumed after a server restart, the model is correctly loaded from the database +// and reported in the ConversationState. +func TestModelRestorationAfterServerRestart(t *testing.T) { + database, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + + // Create a conversation with a specific model + modelID := "claude-sonnet-4-20250514" + conv, err := database.CreateConversation(ctx, nil, true, nil, &modelID) + if err != nil { + t.Fatalf("Failed to create conversation: %v", err) + } + + // Add a user message + userMsg := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}}, + } + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeUser, + LLMData: userMsg, + }) + if err != nil { + t.Fatalf("Failed to create user message: %v", err) + } + + // Add an agent message + agentMsg := llm.Message{ + Role: llm.MessageRoleAssistant, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}}, + EndOfTurn: true, + } + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeAgent, + LLMData: agentMsg, + }) + if err != nil { + t.Fatalf("Failed to create agent message: %v", err) + } + + // Create a NEW server (simulating server restart or different browser session) + predictableService := loop.NewPredictableService() + llmManager := &testLLMManager{service: predictableService} + toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false} + server := NewServer(database, llmManager, toolSetConfig, nil, true, "", "predictable", "", nil) + + mux := http.NewServeMux() + server.RegisterRoutes(mux) + + // Make a streaming request + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx) + req.Header.Set("Accept", "text/event-stream") + + w := newResponseRecorderWithClose() + + done := make(chan struct{}) + go func() { + defer close(done) + mux.ServeHTTP(w, req) + }() + + time.Sleep(500 * time.Millisecond) + w.Close() + cancel() + <-done + + // Parse the first SSE message + body := w.Body.String() + if !strings.HasPrefix(body, "data: ") { + t.Fatalf("Expected SSE data, got: %s", body) + } + + jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ") + var response StreamResponse + if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + // Verify conversation state includes the model from the database + if response.ConversationState == nil { + t.Fatal("Expected ConversationState in response") + } + if response.ConversationState.Model != modelID { + t.Errorf("Expected Model='%s', got '%s'", modelID, response.ConversationState.Model) + } +} diff --git a/server/convo.go b/server/convo.go index b94b37526b4db470b5883e66449af74ddcf9cfe2..51d4f1e9c24530038286cba6ce4df5e0fe73085d 100644 --- a/server/convo.go +++ b/server/convo.go @@ -82,6 +82,7 @@ func (cm *ConversationManager) SetAgentWorking(working bool) { cm.agentWorking = working onStateChange := cm.onStateChange convID := cm.conversationID + modelID := cm.modelID cm.mu.Unlock() cm.logger.Debug("agent working state changed", "working", working) @@ -89,6 +90,7 @@ func (cm *ConversationManager) SetAgentWorking(working bool) { onStateChange(ConversationState{ ConversationID: convID, Working: working, + Model: modelID, }) } } @@ -100,6 +102,13 @@ func (cm *ConversationManager) IsAgentWorking() bool { return cm.agentWorking } +// GetModel returns the model ID used by this conversation. +func (cm *ConversationManager) GetModel() string { + cm.mu.Lock() + defer cm.mu.Unlock() + return cm.modelID +} + // Hydrate loads conversation state from the database, generating a system prompt if missing. func (cm *ConversationManager) Hydrate(ctx context.Context) error { cm.mu.Lock() @@ -133,6 +142,12 @@ func (cm *ConversationManager) Hydrate(ctx context.Context) error { } cm.cwd = cwd + // Load model from conversation if available + var modelID string + if conversation.Model != nil { + modelID = *conversation.Model + } + // Generate system prompt if missing: // - For user-initiated conversations: full system prompt // - For subagent conversations (has parent): minimal subagent prompt @@ -162,8 +177,12 @@ func (cm *ConversationManager) Hydrate(ctx context.Context) error { cm.hasConversationEvents = len(history) > 0 cm.lastActivity = time.Now() cm.hydrated = true + cm.modelID = modelID cm.mu.Unlock() + if modelID != "" { + cm.logger.Info("Loaded model from conversation", "model", modelID) + } cm.logSystemPromptState(system, len(messages)) return nil @@ -403,6 +422,8 @@ func (cm *ConversationManager) ensureLoop(service llm.Service, modelID string) e } return nil } + // Check if we need to persist the model (for conversations created before model column existed) + needsPersist := cm.modelID == "" && modelID != "" cm.loop = loopInstance cm.loopCancel = cancel cm.loopCtx = processCtx @@ -412,6 +433,13 @@ func (cm *ConversationManager) ensureLoop(service llm.Service, modelID string) e cm.system = nil cm.mu.Unlock() + // Persist model for legacy conversations + if needsPersist { + if err := db.UpdateConversationModel(context.Background(), conversationID, modelID); err != nil { + logger.Error("failed to persist model for legacy conversation", "error", err) + } + } + go func() { if err := loopInstance.Go(processCtx); err != nil && err != context.DeadlineExceeded && err != context.Canceled { if logger != nil { diff --git a/server/duplicate_tool_result_test.go b/server/duplicate_tool_result_test.go index cc860933347b056afdd8bdf49318faf148aded13..6b4556dc02ffc206e8730ea984018de6b8c03c77 100644 --- a/server/duplicate_tool_result_test.go +++ b/server/duplicate_tool_result_test.go @@ -41,7 +41,7 @@ func TestCancelAfterToolCompletesCreatesDuplicateToolResult(t *testing.T) { server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } diff --git a/server/handlers.go b/server/handlers.go index b6cfa46e03137f5e043a3832983d4499cddaf2a9..cea774008df65743ed884a7320455e47c8df114a 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -706,7 +706,7 @@ func (s *Server) handleNewConversation(w http.ResponseWriter, r *http.Request) { if req.Cwd != "" { cwdPtr = &req.Cwd } - conversation, err := s.db.CreateConversation(ctx, nil, true, cwdPtr) + conversation, err := s.db.CreateConversation(ctx, nil, true, cwdPtr, &modelID) if err != nil { s.logger.Error("Failed to create conversation", "error", err) http.Error(w, "Internal server error", http.StatusInternalServerError) @@ -854,6 +854,7 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request ConversationState: &ConversationState{ ConversationID: conversationID, Working: manager.IsAgentWorking(), + Model: manager.GetModel(), }, ContextWindowSize: calculateContextWindowSize(apiMessages), } diff --git a/server/handlers_test.go b/server/handlers_test.go index 8359dc077e67ae4f166c2bce45192040ba49b863..98b9a99b5a27476295acac72408ebc2a7441b902 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -46,7 +46,7 @@ func TestHandleArchivedConversations(t *testing.T) { // Create a test conversation and archive it ctx := context.Background() slug := "test-conversation" - conv, err := h.db.CreateConversation(ctx, &slug, true, nil) + conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -104,7 +104,7 @@ func TestHandleArchiveConversation(t *testing.T) { // Create a test conversation ctx := context.Background() slug := "test-conversation" - conv, err := h.db.CreateConversation(ctx, &slug, true, nil) + conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -157,7 +157,7 @@ func TestHandleUnarchiveConversation(t *testing.T) { // Create a test conversation and archive it ctx := context.Background() slug := "test-conversation" - conv, err := h.db.CreateConversation(ctx, &slug, true, nil) + conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -215,7 +215,7 @@ func TestHandleDeleteConversation(t *testing.T) { // Create a test conversation ctx := context.Background() slug := "test-conversation" - conv, err := h.db.CreateConversation(ctx, &slug, true, nil) + conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -274,7 +274,7 @@ func TestHandleRenameConversation(t *testing.T) { // Create a test conversation ctx := context.Background() slug := "test-conversation" - conv, err := h.db.CreateConversation(ctx, &slug, true, nil) + conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } diff --git a/server/message_bandwidth_test.go b/server/message_bandwidth_test.go index becf80706a21a7967702e1e1b5cb068f560ee233..2bda7868a793d45dcd8c392c989b19e2b6f54567 100644 --- a/server/message_bandwidth_test.go +++ b/server/message_bandwidth_test.go @@ -29,7 +29,7 @@ func TestMessageSentOnlyOnce(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } @@ -163,7 +163,7 @@ func TestContextWindowSizeInSSE(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } diff --git a/server/orphan_tool_result_test.go b/server/orphan_tool_result_test.go index bd837a196b48358b8c2d70978c640e8386722fc9..b2b8baa9ccca4cf912bf86a3037ec5ec3b2c3b08 100644 --- a/server/orphan_tool_result_test.go +++ b/server/orphan_tool_result_test.go @@ -50,7 +50,7 @@ func TestOrphanToolResultAfterCancellation(t *testing.T) { server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } @@ -233,7 +233,7 @@ func TestOrphanToolResultFiltering(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } diff --git a/server/server.go b/server/server.go index e96f34fb95651f1dc7ef7951c29a6a06817df167..6a7bfad1639279a08808166b188115ddfaa0dc9d 100644 --- a/server/server.go +++ b/server/server.go @@ -46,6 +46,7 @@ type APIMessage struct { type ConversationState struct { ConversationID string `json:"conversation_id"` Working bool `json:"working"` + Model string `json:"model,omitempty"` } // ConversationWithState combines a conversation with its working state. diff --git a/server/sse_immediacy_test.go b/server/sse_immediacy_test.go index 432aa4ec9d67857f5089cc7e8de383cc2b850e2c..f215a09befa23efcc11b98a1a6f608867b53fce9 100644 --- a/server/sse_immediacy_test.go +++ b/server/sse_immediacy_test.go @@ -81,7 +81,7 @@ func TestSSEUserMessageAppearsImmediately(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } @@ -222,7 +222,7 @@ func TestSSEUserMessageWithRealHTTPServer(t *testing.T) { srv := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create conversation - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } @@ -332,7 +332,7 @@ func TestSSEUserMessageWithExistingConnection(t *testing.T) { server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil) // Create conversation and get a manager (simulating an established SSE connection) - conversation, err := database.CreateConversation(context.Background(), nil, true, nil) + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("failed to create conversation: %v", err) } diff --git a/slug/slug_test.go b/slug/slug_test.go index 6ef78739742092889950714aa9b9fe3020b87d7a..2fa4108006d1bd3ab651236c659e8127d61188f0 100644 --- a/slug/slug_test.go +++ b/slug/slug_test.go @@ -138,7 +138,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) { })) // Create first conversation to establish the base slug - conv1, err := database.CreateConversation(ctx, nil, true, nil) + conv1, err := database.CreateConversation(ctx, nil, true, nil, nil) if err != nil { t.Fatalf("Failed to create first conversation: %v", err) } @@ -153,7 +153,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) { } // Create second conversation - conv2, err := database.CreateConversation(ctx, nil, true, nil) + conv2, err := database.CreateConversation(ctx, nil, true, nil, nil) if err != nil { t.Fatalf("Failed to create second conversation: %v", err) } @@ -168,7 +168,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) { } // Create third conversation - conv3, err := database.CreateConversation(ctx, nil, true, nil) + conv3, err := database.CreateConversation(ctx, nil, true, nil, nil) if err != nil { t.Fatalf("Failed to create third conversation: %v", err) } diff --git a/test/anthropic_test.go b/test/anthropic_test.go index ccf2e22f2a4ec23e589c60ec8d8b16d747aecef0..2c2722817c19052731e7c6aef9fd26cd5bf230b9 100644 --- a/test/anthropic_test.go +++ b/test/anthropic_test.go @@ -72,7 +72,7 @@ func TestWithAnthropicAPI(t *testing.T) { // Create a conversation // Using database directly instead of service slug := "claude-test" - conv, err := database.CreateConversation(context.Background(), &slug, true, nil) + conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -181,7 +181,7 @@ func TestWithAnthropicAPI(t *testing.T) { // Create a conversation // Using database directly instead of service slug := "tool-test" - conv, err := database.CreateConversation(context.Background(), &slug, true, nil) + conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -247,7 +247,7 @@ func TestWithAnthropicAPI(t *testing.T) { // Using database directly instead of service // Using database directly instead of service slug := "stream-test" - conv, err := database.CreateConversation(context.Background(), &slug, true, nil) + conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } diff --git a/test/server_test.go b/test/server_test.go index 2fa0f7101d65459bc6822add6b2c22fb0f102099..7293a4bd86ebd56b38b817886d50caba80797422 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -73,7 +73,7 @@ func TestServerEndToEnd(t *testing.T) { // Create a conversation // Using database directly instead of service slug := "test-conversation" - conv, err := database.CreateConversation(context.Background(), &slug, true, nil) + conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -107,7 +107,7 @@ func TestServerEndToEnd(t *testing.T) { // Create a conversation // Using database directly instead of service slug := "chat-test" - conv, err := database.CreateConversation(context.Background(), &slug, true, nil) + conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -170,7 +170,7 @@ func TestServerEndToEnd(t *testing.T) { // Using database directly instead of service // Using database directly instead of service slug := "stream-test" - conv, err := database.CreateConversation(context.Background(), &slug, true, nil) + conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -226,7 +226,7 @@ func TestServerEndToEnd(t *testing.T) { ctx := context.Background() // Create a conversation without a slug - conv, err := database.CreateConversation(ctx, nil, true, nil) + conv, err := database.CreateConversation(ctx, nil, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -390,7 +390,7 @@ func TestConversationCleanup(t *testing.T) { // Create a conversation // Using database directly instead of service - conv, err := database.CreateConversation(context.Background(), nil, true, nil) + conv, err := database.CreateConversation(context.Background(), nil, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -558,7 +558,7 @@ func TestSlugEndToEnd(t *testing.T) { // Create a conversation with a specific slug ctx := context.Background() testSlug := "test-conversation-slug" - conv, err := database.CreateConversation(ctx, &testSlug, true, nil) + conv, err := database.CreateConversation(ctx, &testSlug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } @@ -620,7 +620,7 @@ func TestSSEIncrementalUpdates(t *testing.T) { // Create a conversation with initial message slug := "test-sse" - conv, err := database.CreateConversation(context.Background(), &slug, true, nil) + conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil) if err != nil { t.Fatalf("Failed to create conversation: %v", err) } diff --git a/ui/src/components/ChatInterface.tsx b/ui/src/components/ChatInterface.tsx index 8c26f0a4d6e452007462ae7ee1f73a5aba35239d..9e2c9cf3676849662f7cba993462c679c870299d 100644 --- a/ui/src/components/ChatInterface.tsx +++ b/ui/src/components/ChatInterface.tsx @@ -379,6 +379,7 @@ function AnimatedWorkingStatus() { interface ConversationStateUpdate { conversation_id: string; working: boolean; + model?: string; } interface ChatInterfaceProps { @@ -703,6 +704,10 @@ function ChatInterface({ // Update local state if this is for our conversation if (streamResponse.conversation_state.conversation_id === conversationId) { setAgentWorking(streamResponse.conversation_state.working); + // Update selected model from conversation (ensures consistency across sessions) + if (streamResponse.conversation_state.model) { + setSelectedModel(streamResponse.conversation_state.model); + } } } diff --git a/ui/src/generated-types.ts b/ui/src/generated-types.ts index 1c736e05065d9804685b8092ad37bfc0c29347ef..313984d49ae9dfd4e0d54a1346689e8f4b96b7ef 100644 --- a/ui/src/generated-types.ts +++ b/ui/src/generated-types.ts @@ -12,6 +12,7 @@ export interface Conversation { cwd: string | null; archived: boolean; parent_conversation_id: string | null; + model: string | null; } export interface Usage { @@ -41,6 +42,7 @@ export interface ApiMessageForTS { export interface ConversationStateForTS { conversation_id: string; working: boolean; + model?: string; } export interface StreamResponseForTS { @@ -58,6 +60,7 @@ export interface ConversationWithStateForTS { cwd: string | null; archived: boolean; parent_conversation_id: string | null; + model: string | null; working: boolean; }