Detailed changes
@@ -91,6 +91,7 @@ type apiMessageForTS struct {
type conversationStateForTS struct {
ConversationID string `json:"conversation_id"`
Working bool `json:"working"`
+ Model string `json:"model,omitempty"`
}
type conversationWithStateForTS struct {
@@ -102,6 +103,7 @@ type conversationWithStateForTS struct {
Cwd *string `json:"cwd"`
Archived bool `json:"archived"`
ParentConversationID *string `json:"parent_conversation_id"`
+ Model *string `json:"model"`
Working bool `json:"working"`
}
@@ -33,7 +33,7 @@ func TestConversationService_Create(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- conv, err := db.CreateConversation(ctx, tt.slug, true, nil)
+ conv, err := db.CreateConversation(ctx, tt.slug, true, nil, nil)
if err != nil {
t.Errorf("Create() error = %v", err)
return
@@ -73,7 +73,7 @@ func TestConversationService_GetByID(t *testing.T) {
defer cancel()
// Create a test conversation
- created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -108,7 +108,7 @@ func TestConversationService_GetBySlug(t *testing.T) {
defer cancel()
// Create a test conversation with slug
- created, err := db.CreateConversation(ctx, stringPtr("test-slug"), true, nil)
+ created, err := db.CreateConversation(ctx, stringPtr("test-slug"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -143,7 +143,7 @@ func TestConversationService_UpdateSlug(t *testing.T) {
defer cancel()
// Create a test conversation
- created, err := db.CreateConversation(ctx, nil, true, nil)
+ created, err := db.CreateConversation(ctx, nil, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -177,7 +177,7 @@ func TestConversationService_List(t *testing.T) {
// Create multiple test conversations
for i := 0; i < 5; i++ {
slug := stringPtr("conversation-" + string(rune('a'+i)))
- _, err := db.CreateConversation(ctx, slug, true, nil)
+ _, err := db.CreateConversation(ctx, slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation %d: %v", i, err)
}
@@ -209,7 +209,7 @@ func TestConversationService_Search(t *testing.T) {
// Create test conversations with different slugs
testCases := []string{"project-alpha", "project-beta", "work-task", "personal-note"}
for _, slug := range testCases {
- _, err := db.CreateConversation(ctx, stringPtr(slug), true, nil)
+ _, err := db.CreateConversation(ctx, stringPtr(slug), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation with slug %s: %v", slug, err)
}
@@ -243,7 +243,7 @@ func TestConversationService_Touch(t *testing.T) {
defer cancel()
// Create a test conversation
- created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -278,7 +278,7 @@ func TestConversationService_Delete(t *testing.T) {
defer cancel()
// Create a test conversation
- created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ created, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -324,7 +324,7 @@ func TestConversationService_Count(t *testing.T) {
// Create test conversations
for i := 0; i < 3; i++ {
- _, err := db.CreateConversation(ctx, stringPtr("conversation-"+string(rune('a'+i))), true, nil)
+ _, err := db.CreateConversation(ctx, stringPtr("conversation-"+string(rune('a'+i))), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation %d: %v", i, err)
}
@@ -354,13 +354,13 @@ func TestConversationService_MultipleNullSlugs(t *testing.T) {
defer cancel()
// Create multiple conversations with null slugs - this should not fail
- conv1, err := db.CreateConversation(ctx, nil, true, nil)
+ conv1, err := db.CreateConversation(ctx, nil, true, nil, nil)
if err != nil {
t.Errorf("Create() first conversation error = %v", err)
return
}
- conv2, err := db.CreateConversation(ctx, nil, true, nil)
+ conv2, err := db.CreateConversation(ctx, nil, true, nil, nil)
if err != nil {
t.Errorf("Create() second conversation error = %v", err)
return
@@ -389,14 +389,14 @@ func TestConversationService_SlugUniquenessWhenNotNull(t *testing.T) {
defer cancel()
// Create first conversation with a slug
- _, err := db.CreateConversation(ctx, stringPtr("unique-slug"), true, nil)
+ _, err := db.CreateConversation(ctx, stringPtr("unique-slug"), true, nil, nil)
if err != nil {
t.Errorf("Create() first conversation error = %v", err)
return
}
// Try to create second conversation with the same slug - this should fail
- _, err = db.CreateConversation(ctx, stringPtr("unique-slug"), true, nil)
+ _, err = db.CreateConversation(ctx, stringPtr("unique-slug"), true, nil, nil)
if err == nil {
t.Error("Expected error when creating conversation with duplicate slug")
return
@@ -416,7 +416,7 @@ func TestConversationService_ArchiveUnarchive(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -450,12 +450,12 @@ func TestConversationService_ListArchivedConversations(t *testing.T) {
defer cancel()
// Create test conversations
- conv1, err := db.CreateConversation(ctx, stringPtr("test-conversation-1"), true, nil)
+ conv1, err := db.CreateConversation(ctx, stringPtr("test-conversation-1"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation 1: %v", err)
}
- conv2, err := db.CreateConversation(ctx, stringPtr("test-conversation-2"), true, nil)
+ conv2, err := db.CreateConversation(ctx, stringPtr("test-conversation-2"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation 2: %v", err)
}
@@ -498,12 +498,12 @@ func TestConversationService_SearchArchivedConversations(t *testing.T) {
defer cancel()
// Create test conversations
- conv1, err := db.CreateConversation(ctx, stringPtr("test-conversation-search-1"), true, nil)
+ conv1, err := db.CreateConversation(ctx, stringPtr("test-conversation-search-1"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation 1: %v", err)
}
- conv2, err := db.CreateConversation(ctx, stringPtr("another-conversation"), true, nil)
+ conv2, err := db.CreateConversation(ctx, stringPtr("another-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation 2: %v", err)
}
@@ -544,7 +544,7 @@ func TestConversationService_DeleteConversation(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-to-delete"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-to-delete"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -580,7 +580,7 @@ func TestConversationService_UpdateConversationCwd(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-cwd"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-cwd"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -233,7 +233,7 @@ func WithTxRes[T any](db *DB, ctx context.Context, fn func(*generated.Queries) (
// Conversation methods (moved from ConversationService)
// CreateConversation creates a new conversation with an optional slug
-func (db *DB) CreateConversation(ctx context.Context, slug *string, userInitiated bool, cwd *string) (*generated.Conversation, error) {
+func (db *DB) CreateConversation(ctx context.Context, slug *string, userInitiated bool, cwd, model *string) (*generated.Conversation, error) {
conversationID, err := generateConversationID()
if err != nil {
return nil, fmt.Errorf("failed to generate conversation ID: %w", err)
@@ -246,6 +246,7 @@ func (db *DB) CreateConversation(ctx context.Context, slug *string, userInitiate
Slug: slug,
UserInitiated: userInitiated,
Cwd: cwd,
+ Model: model,
})
return err
})
@@ -360,6 +361,18 @@ func (db *DB) UpdateConversationCwd(ctx context.Context, conversationID, cwd str
})
}
+// UpdateConversationModel sets the model for a conversation that doesn't have one yet.
+// This is used to backfill the model for conversations created before the model column existed.
+func (db *DB) UpdateConversationModel(ctx context.Context, conversationID, model string) error {
+ return db.pool.Tx(ctx, func(ctx context.Context, tx *Tx) error {
+ q := generated.New(tx.Conn())
+ return q.UpdateConversationModel(ctx, generated.UpdateConversationModelParams{
+ Model: &model,
+ ConversationID: conversationID,
+ })
+ })
+}
+
// Message methods (moved from MessageService)
// MessageType represents the type of message
@@ -122,6 +122,7 @@ func TestDB_WithTx(t *testing.T) {
ConversationID: "test-conv-1",
Slug: stringPtr("test-slug"),
UserInitiated: true,
+ Model: nil,
})
return err
})
@@ -227,7 +228,7 @@ func TestLLMRequestPrefixDeduplication(t *testing.T) {
// Create a conversation first
slug := "test-prefix-conv"
- conv, err := db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -343,7 +344,7 @@ func TestLLMRequestNoPrefixForShortOverlap(t *testing.T) {
defer cancel()
slug := "test-short-conv"
- conv, err := db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -437,7 +438,7 @@ func TestLLMRequestRealisticConversation(t *testing.T) {
defer cancel()
slug := "test-realistic-conv"
- conv, err := db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -573,7 +574,7 @@ func TestLLMRequestOpenAIStyle(t *testing.T) {
defer cancel()
slug := "test-openai-conv"
- conv, err := db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -13,7 +13,7 @@ const archiveConversation = `-- name: ArchiveConversation :one
UPDATE conversations
SET archived = TRUE, updated_at = CURRENT_TIMESTAMP
WHERE conversation_id = ?
-RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id
+RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model
`
func (q *Queries) ArchiveConversation(ctx context.Context, conversationID string) (Conversation, error) {
@@ -28,6 +28,7 @@ func (q *Queries) ArchiveConversation(ctx context.Context, conversationID string
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
@@ -55,9 +56,9 @@ func (q *Queries) CountConversations(ctx context.Context) (int64, error) {
}
const createConversation = `-- name: CreateConversation :one
-INSERT INTO conversations (conversation_id, slug, user_initiated, cwd)
-VALUES (?, ?, ?, ?)
-RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id
+INSERT INTO conversations (conversation_id, slug, user_initiated, cwd, model)
+VALUES (?, ?, ?, ?, ?)
+RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model
`
type CreateConversationParams struct {
@@ -65,6 +66,7 @@ type CreateConversationParams struct {
Slug *string `json:"slug"`
UserInitiated bool `json:"user_initiated"`
Cwd *string `json:"cwd"`
+ Model *string `json:"model"`
}
func (q *Queries) CreateConversation(ctx context.Context, arg CreateConversationParams) (Conversation, error) {
@@ -73,6 +75,7 @@ func (q *Queries) CreateConversation(ctx context.Context, arg CreateConversation
arg.Slug,
arg.UserInitiated,
arg.Cwd,
+ arg.Model,
)
var i Conversation
err := row.Scan(
@@ -84,6 +87,7 @@ func (q *Queries) CreateConversation(ctx context.Context, arg CreateConversation
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
@@ -91,7 +95,7 @@ func (q *Queries) CreateConversation(ctx context.Context, arg CreateConversation
const createSubagentConversation = `-- name: CreateSubagentConversation :one
INSERT INTO conversations (conversation_id, slug, user_initiated, cwd, parent_conversation_id)
VALUES (?, ?, FALSE, ?, ?)
-RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id
+RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model
`
type CreateSubagentConversationParams struct {
@@ -118,6 +122,7 @@ func (q *Queries) CreateSubagentConversation(ctx context.Context, arg CreateSuba
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
@@ -133,7 +138,7 @@ func (q *Queries) DeleteConversation(ctx context.Context, conversationID string)
}
const getConversation = `-- name: GetConversation :one
-SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations
+SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations
WHERE conversation_id = ?
`
@@ -149,12 +154,13 @@ func (q *Queries) GetConversation(ctx context.Context, conversationID string) (C
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
const getConversationBySlug = `-- name: GetConversationBySlug :one
-SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations
+SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations
WHERE slug = ?
`
@@ -170,12 +176,13 @@ func (q *Queries) GetConversationBySlug(ctx context.Context, slug *string) (Conv
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
const getConversationBySlugAndParent = `-- name: GetConversationBySlugAndParent :one
-SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations
+SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations
WHERE slug = ? AND parent_conversation_id = ?
`
@@ -196,12 +203,13 @@ func (q *Queries) GetConversationBySlugAndParent(ctx context.Context, arg GetCon
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
const getSubagents = `-- name: GetSubagents :many
-SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations
+SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations
WHERE parent_conversation_id = ?
ORDER BY created_at ASC
`
@@ -224,6 +232,7 @@ func (q *Queries) GetSubagents(ctx context.Context, parentConversationID *string
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
); err != nil {
return nil, err
}
@@ -239,7 +248,7 @@ func (q *Queries) GetSubagents(ctx context.Context, parentConversationID *string
}
const listArchivedConversations = `-- name: ListArchivedConversations :many
-SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations
+SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations
WHERE archived = TRUE
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
@@ -268,6 +277,7 @@ func (q *Queries) ListArchivedConversations(ctx context.Context, arg ListArchive
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
); err != nil {
return nil, err
}
@@ -283,7 +293,7 @@ func (q *Queries) ListArchivedConversations(ctx context.Context, arg ListArchive
}
const listConversations = `-- name: ListConversations :many
-SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations
+SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations
WHERE archived = FALSE AND parent_conversation_id IS NULL
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
@@ -312,6 +322,7 @@ func (q *Queries) ListConversations(ctx context.Context, arg ListConversationsPa
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
); err != nil {
return nil, err
}
@@ -327,7 +338,7 @@ func (q *Queries) ListConversations(ctx context.Context, arg ListConversationsPa
}
const searchArchivedConversations = `-- name: SearchArchivedConversations :many
-SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations
+SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations
WHERE slug LIKE '%' || ? || '%' AND archived = TRUE
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
@@ -357,6 +368,7 @@ func (q *Queries) SearchArchivedConversations(ctx context.Context, arg SearchArc
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
); err != nil {
return nil, err
}
@@ -372,7 +384,7 @@ func (q *Queries) SearchArchivedConversations(ctx context.Context, arg SearchArc
}
const searchConversations = `-- name: SearchConversations :many
-SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id FROM conversations
+SELECT conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model FROM conversations
WHERE slug LIKE '%' || ? || '%' AND archived = FALSE AND parent_conversation_id IS NULL
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
@@ -402,6 +414,7 @@ func (q *Queries) SearchConversations(ctx context.Context, arg SearchConversatio
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
); err != nil {
return nil, err
}
@@ -417,7 +430,7 @@ func (q *Queries) SearchConversations(ctx context.Context, arg SearchConversatio
}
const searchConversationsWithMessages = `-- name: SearchConversationsWithMessages :many
-SELECT DISTINCT c.conversation_id, c.slug, c.user_initiated, c.created_at, c.updated_at, c.cwd, c.archived, c.parent_conversation_id FROM conversations c
+SELECT DISTINCT c.conversation_id, c.slug, c.user_initiated, c.created_at, c.updated_at, c.cwd, c.archived, c.parent_conversation_id, c.model FROM conversations c
LEFT JOIN messages m ON c.conversation_id = m.conversation_id AND m.type IN ('user', 'agent')
WHERE c.archived = FALSE
AND (
@@ -463,6 +476,7 @@ func (q *Queries) SearchConversationsWithMessages(ctx context.Context, arg Searc
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
); err != nil {
return nil, err
}
@@ -481,7 +495,7 @@ const unarchiveConversation = `-- name: UnarchiveConversation :one
UPDATE conversations
SET archived = FALSE, updated_at = CURRENT_TIMESTAMP
WHERE conversation_id = ?
-RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id
+RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model
`
func (q *Queries) UnarchiveConversation(ctx context.Context, conversationID string) (Conversation, error) {
@@ -496,6 +510,7 @@ func (q *Queries) UnarchiveConversation(ctx context.Context, conversationID stri
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
@@ -504,7 +519,7 @@ const updateConversationCwd = `-- name: UpdateConversationCwd :one
UPDATE conversations
SET cwd = ?, updated_at = CURRENT_TIMESTAMP
WHERE conversation_id = ?
-RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id
+RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model
`
type UpdateConversationCwdParams struct {
@@ -524,15 +539,32 @@ func (q *Queries) UpdateConversationCwd(ctx context.Context, arg UpdateConversat
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
+const updateConversationModel = `-- name: UpdateConversationModel :exec
+UPDATE conversations
+SET model = ?
+WHERE conversation_id = ? AND model IS NULL
+`
+
+type UpdateConversationModelParams struct {
+ Model *string `json:"model"`
+ ConversationID string `json:"conversation_id"`
+}
+
+func (q *Queries) UpdateConversationModel(ctx context.Context, arg UpdateConversationModelParams) error {
+ _, err := q.db.ExecContext(ctx, updateConversationModel, arg.Model, arg.ConversationID)
+ return err
+}
+
const updateConversationSlug = `-- name: UpdateConversationSlug :one
UPDATE conversations
SET slug = ?, updated_at = CURRENT_TIMESTAMP
WHERE conversation_id = ?
-RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id
+RETURNING conversation_id, slug, user_initiated, created_at, updated_at, cwd, archived, parent_conversation_id, model
`
type UpdateConversationSlugParams struct {
@@ -552,6 +584,7 @@ func (q *Queries) UpdateConversationSlug(ctx context.Context, arg UpdateConversa
&i.Cwd,
&i.Archived,
&i.ParentConversationID,
+ &i.Model,
)
return i, err
}
@@ -17,6 +17,7 @@ type Conversation struct {
Cwd *string `json:"cwd"`
Archived bool `json:"archived"`
ParentConversationID *string `json:"parent_conversation_id"`
+ Model *string `json:"model"`
}
type LlmRequest struct {
@@ -21,7 +21,7 @@ func TestMessageService_Create(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -116,7 +116,7 @@ func TestMessageService_GetByID(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -162,7 +162,7 @@ func TestMessageService_ListByConversation(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -216,7 +216,7 @@ func TestMessageService_ListByType(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -263,7 +263,7 @@ func TestMessageService_GetLatest(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -310,7 +310,7 @@ func TestMessageService_Delete(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -351,7 +351,7 @@ func TestMessageService_CountInConversation(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -408,7 +408,7 @@ func TestMessageService_CountByType(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -465,7 +465,7 @@ func TestMessageService_ListMessagesByConversationPaginated(t *testing.T) {
defer cancel()
// Create a test conversation
- conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-paginated"), true, nil)
+ conv, err := db.CreateConversation(ctx, stringPtr("test-conversation-paginated"), true, nil, nil)
if err != nil {
t.Fatalf("Failed to create test conversation: %v", err)
}
@@ -1,6 +1,6 @@
-- name: CreateConversation :one
-INSERT INTO conversations (conversation_id, slug, user_initiated, cwd)
-VALUES (?, ?, ?, ?)
+INSERT INTO conversations (conversation_id, slug, user_initiated, cwd, model)
+VALUES (?, ?, ?, ?, ?)
RETURNING *;
-- name: GetConversation :one
@@ -102,3 +102,8 @@ ORDER BY created_at ASC;
-- name: GetConversationBySlugAndParent :one
SELECT * FROM conversations
WHERE slug = ? AND parent_conversation_id = ?;
+
+-- name: UpdateConversationModel :exec
+UPDATE conversations
+SET model = ?
+WHERE conversation_id = ? AND model IS NULL;
@@ -0,0 +1,4 @@
+-- Add model column to conversations table
+-- This stores the LLM model used for the conversation
+
+ALTER TABLE conversations ADD COLUMN model TEXT;
@@ -54,7 +54,7 @@ func TestCancelWithPredictableModel(t *testing.T) {
server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -251,7 +251,7 @@ func TestCancelWithNoActiveConversation(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create a conversation but don't start it
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -289,7 +289,7 @@ func TestCancelDuringTextGeneration(t *testing.T) {
logger := slog.Default()
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -15,7 +15,7 @@ func TestGetConversationBySlug(t *testing.T) {
// Create a conversation with a slug
slug := "my-test-slug"
- conv, err := h.db.CreateConversation(t.Context(), &slug, true, nil)
+ conv, err := h.db.CreateConversation(t.Context(), &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -29,7 +29,7 @@ func TestMessageQueuedDuringThinking(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -160,7 +160,7 @@ func TestContextPreservedAfterCancel(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -27,7 +27,7 @@ func TestConversationStreamReceivesListUpdateForNewConversation(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create a conversation to subscribe to
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -125,11 +125,11 @@ func TestConversationStreamReceivesListUpdateForRename(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create two conversations
- conv1, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation 1: %v", err)
}
- conv2, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation 2: %v", err)
}
@@ -216,11 +216,11 @@ func TestConversationStreamReceivesListUpdateForDelete(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create two conversations
- conv1, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation 1: %v", err)
}
- conv2, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation 2: %v", err)
}
@@ -306,11 +306,11 @@ func TestConversationStreamReceivesListUpdateForArchive(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create two conversations
- conv1, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation 1: %v", err)
}
- conv2, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation 2: %v", err)
}
@@ -49,7 +49,7 @@ func TestConversationStateAfterServerRestart(t *testing.T) {
ctx := context.Background()
// Create a conversation with some messages (simulating previous activity)
- conv, err := database.CreateConversation(ctx, nil, true, nil)
+ conv, err := database.CreateConversation(ctx, nil, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -145,3 +145,98 @@ func TestConversationStateAfterServerRestart(t *testing.T) {
t.Errorf("Expected 2 messages, got %d", len(response.Messages))
}
}
+
+// TestModelRestorationAfterServerRestart verifies that when a conversation is
+// resumed after a server restart, the model is correctly loaded from the database
+// and reported in the ConversationState.
+func TestModelRestorationAfterServerRestart(t *testing.T) {
+ database, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Create a conversation with a specific model
+ modelID := "claude-sonnet-4-20250514"
+ conv, err := database.CreateConversation(ctx, nil, true, nil, &modelID)
+ if err != nil {
+ t.Fatalf("Failed to create conversation: %v", err)
+ }
+
+ // Add a user message
+ userMsg := llm.Message{
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}},
+ }
+ _, err = database.CreateMessage(ctx, db.CreateMessageParams{
+ ConversationID: conv.ConversationID,
+ Type: db.MessageTypeUser,
+ LLMData: userMsg,
+ })
+ if err != nil {
+ t.Fatalf("Failed to create user message: %v", err)
+ }
+
+ // Add an agent message
+ agentMsg := llm.Message{
+ Role: llm.MessageRoleAssistant,
+ Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}},
+ EndOfTurn: true,
+ }
+ _, err = database.CreateMessage(ctx, db.CreateMessageParams{
+ ConversationID: conv.ConversationID,
+ Type: db.MessageTypeAgent,
+ LLMData: agentMsg,
+ })
+ if err != nil {
+ t.Fatalf("Failed to create agent message: %v", err)
+ }
+
+ // Create a NEW server (simulating server restart or different browser session)
+ predictableService := loop.NewPredictableService()
+ llmManager := &testLLMManager{service: predictableService}
+ toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false}
+ server := NewServer(database, llmManager, toolSetConfig, nil, true, "", "predictable", "", nil)
+
+ mux := http.NewServeMux()
+ server.RegisterRoutes(mux)
+
+ // Make a streaming request
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx)
+ req.Header.Set("Accept", "text/event-stream")
+
+ w := newResponseRecorderWithClose()
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ mux.ServeHTTP(w, req)
+ }()
+
+ time.Sleep(500 * time.Millisecond)
+ w.Close()
+ cancel()
+ <-done
+
+ // Parse the first SSE message
+ body := w.Body.String()
+ if !strings.HasPrefix(body, "data: ") {
+ t.Fatalf("Expected SSE data, got: %s", body)
+ }
+
+ jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ")
+ var response StreamResponse
+ if err := json.Unmarshal([]byte(jsonData), &response); err != nil {
+ t.Fatalf("Failed to parse response: %v", err)
+ }
+
+ // Verify conversation state includes the model from the database
+ if response.ConversationState == nil {
+ t.Fatal("Expected ConversationState in response")
+ }
+ if response.ConversationState.Model != modelID {
+ t.Errorf("Expected Model='%s', got '%s'", modelID, response.ConversationState.Model)
+ }
+}
@@ -82,6 +82,7 @@ func (cm *ConversationManager) SetAgentWorking(working bool) {
cm.agentWorking = working
onStateChange := cm.onStateChange
convID := cm.conversationID
+ modelID := cm.modelID
cm.mu.Unlock()
cm.logger.Debug("agent working state changed", "working", working)
@@ -89,6 +90,7 @@ func (cm *ConversationManager) SetAgentWorking(working bool) {
onStateChange(ConversationState{
ConversationID: convID,
Working: working,
+ Model: modelID,
})
}
}
@@ -100,6 +102,13 @@ func (cm *ConversationManager) IsAgentWorking() bool {
return cm.agentWorking
}
+// GetModel returns the model ID used by this conversation.
+func (cm *ConversationManager) GetModel() string {
+ cm.mu.Lock()
+ defer cm.mu.Unlock()
+ return cm.modelID
+}
+
// Hydrate loads conversation state from the database, generating a system prompt if missing.
func (cm *ConversationManager) Hydrate(ctx context.Context) error {
cm.mu.Lock()
@@ -133,6 +142,12 @@ func (cm *ConversationManager) Hydrate(ctx context.Context) error {
}
cm.cwd = cwd
+ // Load model from conversation if available
+ var modelID string
+ if conversation.Model != nil {
+ modelID = *conversation.Model
+ }
+
// Generate system prompt if missing:
// - For user-initiated conversations: full system prompt
// - For subagent conversations (has parent): minimal subagent prompt
@@ -162,8 +177,12 @@ func (cm *ConversationManager) Hydrate(ctx context.Context) error {
cm.hasConversationEvents = len(history) > 0
cm.lastActivity = time.Now()
cm.hydrated = true
+ cm.modelID = modelID
cm.mu.Unlock()
+ if modelID != "" {
+ cm.logger.Info("Loaded model from conversation", "model", modelID)
+ }
cm.logSystemPromptState(system, len(messages))
return nil
@@ -403,6 +422,8 @@ func (cm *ConversationManager) ensureLoop(service llm.Service, modelID string) e
}
return nil
}
+ // Check if we need to persist the model (for conversations created before model column existed)
+ needsPersist := cm.modelID == "" && modelID != ""
cm.loop = loopInstance
cm.loopCancel = cancel
cm.loopCtx = processCtx
@@ -412,6 +433,13 @@ func (cm *ConversationManager) ensureLoop(service llm.Service, modelID string) e
cm.system = nil
cm.mu.Unlock()
+ // Persist model for legacy conversations
+ if needsPersist {
+ if err := db.UpdateConversationModel(context.Background(), conversationID, modelID); err != nil {
+ logger.Error("failed to persist model for legacy conversation", "error", err)
+ }
+ }
+
go func() {
if err := loopInstance.Go(processCtx); err != nil && err != context.DeadlineExceeded && err != context.Canceled {
if logger != nil {
@@ -41,7 +41,7 @@ func TestCancelAfterToolCompletesCreatesDuplicateToolResult(t *testing.T) {
server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -706,7 +706,7 @@ func (s *Server) handleNewConversation(w http.ResponseWriter, r *http.Request) {
if req.Cwd != "" {
cwdPtr = &req.Cwd
}
- conversation, err := s.db.CreateConversation(ctx, nil, true, cwdPtr)
+ conversation, err := s.db.CreateConversation(ctx, nil, true, cwdPtr, &modelID)
if err != nil {
s.logger.Error("Failed to create conversation", "error", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
@@ -854,6 +854,7 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request
ConversationState: &ConversationState{
ConversationID: conversationID,
Working: manager.IsAgentWorking(),
+ Model: manager.GetModel(),
},
ContextWindowSize: calculateContextWindowSize(apiMessages),
}
@@ -46,7 +46,7 @@ func TestHandleArchivedConversations(t *testing.T) {
// Create a test conversation and archive it
ctx := context.Background()
slug := "test-conversation"
- conv, err := h.db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -104,7 +104,7 @@ func TestHandleArchiveConversation(t *testing.T) {
// Create a test conversation
ctx := context.Background()
slug := "test-conversation"
- conv, err := h.db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -157,7 +157,7 @@ func TestHandleUnarchiveConversation(t *testing.T) {
// Create a test conversation and archive it
ctx := context.Background()
slug := "test-conversation"
- conv, err := h.db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -215,7 +215,7 @@ func TestHandleDeleteConversation(t *testing.T) {
// Create a test conversation
ctx := context.Background()
slug := "test-conversation"
- conv, err := h.db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -274,7 +274,7 @@ func TestHandleRenameConversation(t *testing.T) {
// Create a test conversation
ctx := context.Background()
slug := "test-conversation"
- conv, err := h.db.CreateConversation(ctx, &slug, true, nil)
+ conv, err := h.db.CreateConversation(ctx, &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -29,7 +29,7 @@ func TestMessageSentOnlyOnce(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -163,7 +163,7 @@ func TestContextWindowSizeInSSE(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -50,7 +50,7 @@ func TestOrphanToolResultAfterCancellation(t *testing.T) {
server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -233,7 +233,7 @@ func TestOrphanToolResultFiltering(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -46,6 +46,7 @@ type APIMessage struct {
type ConversationState struct {
ConversationID string `json:"conversation_id"`
Working bool `json:"working"`
+ Model string `json:"model,omitempty"`
}
// ConversationWithState combines a conversation with its working state.
@@ -81,7 +81,7 @@ func TestSSEUserMessageAppearsImmediately(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -222,7 +222,7 @@ func TestSSEUserMessageWithRealHTTPServer(t *testing.T) {
srv := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create conversation
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -332,7 +332,7 @@ func TestSSEUserMessageWithExistingConnection(t *testing.T) {
server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
// Create conversation and get a manager (simulating an established SSE connection)
- conversation, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("failed to create conversation: %v", err)
}
@@ -138,7 +138,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}))
// Create first conversation to establish the base slug
- conv1, err := database.CreateConversation(ctx, nil, true, nil)
+ conv1, err := database.CreateConversation(ctx, nil, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create first conversation: %v", err)
}
@@ -153,7 +153,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}
// Create second conversation
- conv2, err := database.CreateConversation(ctx, nil, true, nil)
+ conv2, err := database.CreateConversation(ctx, nil, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create second conversation: %v", err)
}
@@ -168,7 +168,7 @@ func TestGenerateSlug_DatabaseIntegration(t *testing.T) {
}
// Create third conversation
- conv3, err := database.CreateConversation(ctx, nil, true, nil)
+ conv3, err := database.CreateConversation(ctx, nil, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create third conversation: %v", err)
}
@@ -72,7 +72,7 @@ func TestWithAnthropicAPI(t *testing.T) {
// Create a conversation
// Using database directly instead of service
slug := "claude-test"
- conv, err := database.CreateConversation(context.Background(), &slug, true, nil)
+ conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -181,7 +181,7 @@ func TestWithAnthropicAPI(t *testing.T) {
// Create a conversation
// Using database directly instead of service
slug := "tool-test"
- conv, err := database.CreateConversation(context.Background(), &slug, true, nil)
+ conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -247,7 +247,7 @@ func TestWithAnthropicAPI(t *testing.T) {
// Using database directly instead of service
// Using database directly instead of service
slug := "stream-test"
- conv, err := database.CreateConversation(context.Background(), &slug, true, nil)
+ conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -73,7 +73,7 @@ func TestServerEndToEnd(t *testing.T) {
// Create a conversation
// Using database directly instead of service
slug := "test-conversation"
- conv, err := database.CreateConversation(context.Background(), &slug, true, nil)
+ conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -107,7 +107,7 @@ func TestServerEndToEnd(t *testing.T) {
// Create a conversation
// Using database directly instead of service
slug := "chat-test"
- conv, err := database.CreateConversation(context.Background(), &slug, true, nil)
+ conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -170,7 +170,7 @@ func TestServerEndToEnd(t *testing.T) {
// Using database directly instead of service
// Using database directly instead of service
slug := "stream-test"
- conv, err := database.CreateConversation(context.Background(), &slug, true, nil)
+ conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -226,7 +226,7 @@ func TestServerEndToEnd(t *testing.T) {
ctx := context.Background()
// Create a conversation without a slug
- conv, err := database.CreateConversation(ctx, nil, true, nil)
+ conv, err := database.CreateConversation(ctx, nil, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -390,7 +390,7 @@ func TestConversationCleanup(t *testing.T) {
// Create a conversation
// Using database directly instead of service
- conv, err := database.CreateConversation(context.Background(), nil, true, nil)
+ conv, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -558,7 +558,7 @@ func TestSlugEndToEnd(t *testing.T) {
// Create a conversation with a specific slug
ctx := context.Background()
testSlug := "test-conversation-slug"
- conv, err := database.CreateConversation(ctx, &testSlug, true, nil)
+ conv, err := database.CreateConversation(ctx, &testSlug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -620,7 +620,7 @@ func TestSSEIncrementalUpdates(t *testing.T) {
// Create a conversation with initial message
slug := "test-sse"
- conv, err := database.CreateConversation(context.Background(), &slug, true, nil)
+ conv, err := database.CreateConversation(context.Background(), &slug, true, nil, nil)
if err != nil {
t.Fatalf("Failed to create conversation: %v", err)
}
@@ -379,6 +379,7 @@ function AnimatedWorkingStatus() {
interface ConversationStateUpdate {
conversation_id: string;
working: boolean;
+ model?: string;
}
interface ChatInterfaceProps {
@@ -703,6 +704,10 @@ function ChatInterface({
// Update local state if this is for our conversation
if (streamResponse.conversation_state.conversation_id === conversationId) {
setAgentWorking(streamResponse.conversation_state.working);
+ // Update selected model from conversation (ensures consistency across sessions)
+ if (streamResponse.conversation_state.model) {
+ setSelectedModel(streamResponse.conversation_state.model);
+ }
}
}
@@ -12,6 +12,7 @@ export interface Conversation {
cwd: string | null;
archived: boolean;
parent_conversation_id: string | null;
+ model: string | null;
}
export interface Usage {
@@ -41,6 +42,7 @@ export interface ApiMessageForTS {
export interface ConversationStateForTS {
conversation_id: string;
working: boolean;
+ model?: string;
}
export interface StreamResponseForTS {
@@ -58,6 +60,7 @@ export interface ConversationWithStateForTS {
cwd: string | null;
archived: boolean;
parent_conversation_id: string | null;
+ model: string | null;
working: boolean;
}