From d440d2614f7386f38f64ebf2bf02f5dcaca7ac4c Mon Sep 17 00:00:00 2001 From: Philip Zeyliger Date: Sun, 25 Jan 2026 09:30:41 -0800 Subject: [PATCH] Shelley: Add continue conversation feature for long contexts We've had a spirited conversation in Discord about how to do compaction. This is one possible approach, which I'm sure needs some fine-tuning, but is harmless in that it's opt-in. Prompt: In addition to the warning sign about high context, when the user hits 100k tokens, open the little window up, and add a button there to "continue in new conversation". When that is pressed, start a new conversation. The initial user prompt should be: "Continue the conversation with slug . Here are the user and agent messages so far (including tool inputs up to ~250 characters and tool outputs up to ~250 characters); use sqlite to look up additional details." I want the munging to happen on the server side, so create a new endpoint that creates a new conversation, but references a pre-existing one, so this works. When a conversation reaches 100k tokens: - The context usage popup auto-opens (once per conversation) - Shows a 'Continue in new conversation' button - Clicking it creates a new conversation with a summary of the previous one as the initial prompt The summary includes: - Reference to the source conversation slug - User and agent messages (text content) - Tool calls with inputs truncated to ~250 chars - Tool results truncated to ~250 chars - Instruction to use sqlite for additional details Server changes: - POST /api/conversations/continue endpoint - db.ListMessages() function to get all messages UI changes: - ContextUsageBar auto-opens at 100k tokens - Continue button in the popup - App.tsx handler to navigate to new conversation Co-authored-by: Shelley --- db/db.go | 12 ++ server/handlers.go | 211 ++++++++++++++++++++++++++++ server/server.go | 3 +- test/server_test.go | 189 +++++++++++++++++++++++++ ui/src/App.tsx | 21 +++ ui/src/components/ChatInterface.tsx | 79 ++++++++++- ui/src/services/api.ts | 20 +++ 7 files changed, 532 insertions(+), 3 deletions(-) diff --git a/db/db.go b/db/db.go index 3b4543d62ae4639d5912eca96fcf213ea57937b1..480a6f2b0d7a57695307b3d74bac539e39405a90 100644 --- a/db/db.go +++ b/db/db.go @@ -500,6 +500,18 @@ func (db *DB) ListMessagesByConversationPaginated(ctx context.Context, conversat return messages, err } +// ListMessages retrieves all messages in a conversation ordered by sequence +func (db *DB) ListMessages(ctx context.Context, conversationID string) ([]generated.Message, error) { + var messages []generated.Message + err := db.pool.Rx(ctx, func(ctx context.Context, rx *Rx) error { + q := generated.New(rx.Conn()) + var err error + messages, err = q.ListMessages(ctx, conversationID) + return err + }) + return messages, err +} + // ListMessagesForContext retrieves messages that should be sent to the LLM (excludes excluded_from_context=true) func (db *DB) ListMessagesForContext(ctx context.Context, conversationID string) ([]generated.Message, error) { var messages []generated.Message diff --git a/server/handlers.go b/server/handlers.go index cea774008df65743ed884a7320455e47c8df114a..528bda07bda562656533256d515481ccbfc9d7ae 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -20,6 +20,7 @@ import ( "time" "shelley.exe.dev/claudetool/browse" + "shelley.exe.dev/db" "shelley.exe.dev/db/generated" "shelley.exe.dev/llm" "shelley.exe.dev/models" @@ -773,6 +774,216 @@ func (s *Server) handleNewConversation(w http.ResponseWriter, r *http.Request) { }) } +// ContinueConversationRequest represents the request to continue a conversation in a new one +type ContinueConversationRequest struct { + SourceConversationID string `json:"source_conversation_id"` + Model string `json:"model,omitempty"` + Cwd string `json:"cwd,omitempty"` +} + +// handleContinueConversation handles POST /api/conversations/continue +// Creates a new conversation with a summary of the source conversation as the initial message +func (s *Server) handleContinueConversation(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + ctx := r.Context() + + // Parse request + var req ContinueConversationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + return + } + + if req.SourceConversationID == "" { + http.Error(w, "source_conversation_id is required", http.StatusBadRequest) + return + } + + // Get source conversation + sourceConv, err := s.db.GetConversationByID(ctx, req.SourceConversationID) + if err != nil { + s.logger.Error("Failed to get source conversation", "conversationID", req.SourceConversationID, "error", err) + http.Error(w, "Source conversation not found", http.StatusNotFound) + return + } + + // Get messages from source conversation + messages, err := s.db.ListMessages(ctx, req.SourceConversationID) + if err != nil { + s.logger.Error("Failed to get messages", "conversationID", req.SourceConversationID, "error", err) + http.Error(w, "Failed to get messages", http.StatusInternalServerError) + return + } + + // Build summary message + sourceSlug := "unknown" + if sourceConv.Slug != nil { + sourceSlug = *sourceConv.Slug + } + summary := buildConversationSummary(sourceSlug, messages) + + // Get LLM service for the requested model + modelID := req.Model + if modelID == "" && sourceConv.Model != nil { + modelID = *sourceConv.Model + } + if modelID == "" { + modelID = "qwen3-coder-fireworks" + } + + llmService, err := s.llmManager.GetService(modelID) + if err != nil { + s.logger.Error("Unsupported model requested", "model", modelID, "error", err) + http.Error(w, fmt.Sprintf("Unsupported model: %s", modelID), http.StatusBadRequest) + return + } + + // Create new conversation with cwd from request or source conversation + var cwdPtr *string + if req.Cwd != "" { + cwdPtr = &req.Cwd + } else if sourceConv.Cwd != nil { + cwdPtr = sourceConv.Cwd + } + conversation, err := s.db.CreateConversation(ctx, nil, true, cwdPtr, &modelID) + if err != nil { + s.logger.Error("Failed to create conversation", "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + conversationID := conversation.ConversationID + + // Notify conversation list subscribers about the new conversation + go s.publishConversationListUpdate(ConversationListUpdate{ + Type: "update", + Conversation: conversation, + }) + + // Get or create conversation manager + manager, err := s.getOrCreateConversationManager(ctx, conversationID) + if errors.Is(err, errConversationModelMismatch) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err != nil { + s.logger.Error("Failed to get conversation manager", "conversationID", conversationID, "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Create user message with the summary + userMessage := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: summary}, + }, + } + + firstMessage, err := manager.AcceptUserMessage(ctx, llmService, modelID, userMessage) + if errors.Is(err, errConversationModelMismatch) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err != nil { + s.logger.Error("Failed to accept user message", "conversationID", conversationID, "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Generate slug for the new conversation + if firstMessage { + ctxNoCancel := context.WithoutCancel(ctx) + go func() { + slugCtx, cancel := context.WithTimeout(ctxNoCancel, 15*time.Second) + defer cancel() + _, err := slug.GenerateSlug(slugCtx, s.llmManager, s.db, s.logger, conversationID, summary, modelID) + if err != nil { + s.logger.Warn("Failed to generate slug for conversation", "conversationID", conversationID, "error", err) + } else { + go s.notifySubscribers(ctxNoCancel, conversationID) + } + }() + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "accepted", + "conversation_id": conversationID, + }) +} + +// buildConversationSummary creates a summary of messages from a conversation +// for use as the initial prompt in a continuation conversation +func buildConversationSummary(slug string, messages []generated.Message) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Continue the conversation with slug %q. Here are the user and agent messages so far (including tool inputs up to ~250 characters and tool outputs up to ~250 characters); use sqlite to look up additional details.\n\n", slug)) + + for _, msg := range messages { + if msg.Type != string(db.MessageTypeUser) && msg.Type != string(db.MessageTypeAgent) { + continue + } + + if msg.LlmData == nil { + continue + } + + var llmMsg llm.Message + if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil { + continue + } + + var role string + if msg.Type == string(db.MessageTypeUser) { + role = "User" + } else { + role = "Agent" + } + + for _, content := range llmMsg.Content { + switch content.Type { + case llm.ContentTypeText: + if content.Text != "" { + sb.WriteString(fmt.Sprintf("%s: %s\n\n", role, content.Text)) + } + case llm.ContentTypeToolUse: + inputStr := string(content.ToolInput) + if len(inputStr) > 250 { + inputStr = inputStr[:250] + "..." + } + sb.WriteString(fmt.Sprintf("%s: [Tool: %s] %s\n\n", role, content.ToolName, inputStr)) + case llm.ContentTypeToolResult: + // Get the text content from tool result + var resultText string + for _, res := range content.ToolResult { + if res.Type == llm.ContentTypeText && res.Text != "" { + resultText = res.Text + break + } + } + if len(resultText) > 250 { + resultText = resultText[:250] + "..." + } + if resultText != "" { + errStr := "" + if content.ToolError { + errStr = " (error)" + } + sb.WriteString(fmt.Sprintf("%s: [Tool Result%s] %s\n\n", role, errStr, resultText)) + } + case llm.ContentTypeThinking: + // Skip thinking blocks - they're internal + } + } + } + + return sb.String() +} + // handleCancelConversation handles POST /conversation//cancel func (s *Server) handleCancelConversation(w http.ResponseWriter, r *http.Request, conversationID string) { if r.Method != http.MethodPost { diff --git a/server/server.go b/server/server.go index 15e3d4ee9036e9c56f377641c44f770608c85b4f..53cec6ed8be2a930343ec3a0011b10a524f6c50b 100644 --- a/server/server.go +++ b/server/server.go @@ -251,7 +251,8 @@ func (s *Server) RegisterRoutes(mux *http.ServeMux) { // API routes - wrap with gzip where beneficial mux.Handle("/api/conversations", gzipHandler(http.HandlerFunc(s.handleConversations))) mux.Handle("/api/conversations/archived", gzipHandler(http.HandlerFunc(s.handleArchivedConversations))) - mux.Handle("/api/conversations/new", http.HandlerFunc(s.handleNewConversation)) // Small response + mux.Handle("/api/conversations/new", http.HandlerFunc(s.handleNewConversation)) // Small response + mux.Handle("/api/conversations/continue", http.HandlerFunc(s.handleContinueConversation)) // Small response mux.Handle("/api/conversation/", http.StripPrefix("/api/conversation", s.conversationMux())) mux.Handle("/api/conversation-by-slug/", gzipHandler(http.HandlerFunc(s.handleConversationBySlug))) mux.Handle("/api/validate-cwd", http.HandlerFunc(s.handleValidateCwd)) // Small response diff --git a/test/server_test.go b/test/server_test.go index 7293a4bd86ebd56b38b817886d50caba80797422..1e27af8e3be76d97966af573145d474afc331335 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -1268,3 +1268,192 @@ func TestSubagentEndToEnd(t *testing.T) { } t.Logf("Subagent conversation has %d messages", len(subConvData.Messages)) } + +func TestContinueConversation(t *testing.T) { + // Create temporary database + tempDB := t.TempDir() + "/test.db" + database, err := db.New(db.Config{DSN: tempDB}) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer database.Close() + + // Run migrations + if err := database.Migrate(context.Background()); err != nil { + t.Fatalf("Failed to migrate database: %v", err) + } + + // Create logger + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + // Create LLM service manager + llmManager := server.NewLLMServiceManager(&server.LLMConfig{Logger: logger, DB: database}) + + // Set up tools config + toolSetConfig := claudetool.ToolSetConfig{ + WorkingDir: t.TempDir(), + EnableBrowser: false, + } + + // Create server + svr := server.NewServer(database, llmManager, toolSetConfig, logger, false, "", "", "", nil) + + // Set up HTTP server + mux := http.NewServeMux() + svr.RegisterRoutes(mux) + testServer := httptest.NewServer(mux) + defer testServer.Close() + + ctx := context.Background() + + // Create source conversation with a slug and some messages + sourceSlug := "source-conversation" + cwd := "/tmp/testdir" + model := "predictable" + sourceConv, err := database.CreateConversation(ctx, &sourceSlug, true, &cwd, &model) + if err != nil { + t.Fatalf("Failed to create source conversation: %v", err) + } + + // Add some messages to the source conversation + userMessage := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello, this is a test message"}}, + } + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: sourceConv.ConversationID, + Type: db.MessageTypeUser, + LLMData: userMessage, + }) + if err != nil { + t.Fatalf("Failed to create user message: %v", err) + } + + agentMessage := llm.Message{ + Role: llm.MessageRoleAssistant, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello! How can I help you?"}}, + } + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: sourceConv.ConversationID, + Type: db.MessageTypeAgent, + LLMData: agentMessage, + }) + if err != nil { + t.Fatalf("Failed to create agent message: %v", err) + } + + // Create a tool use message + toolMessage := llm.Message{ + Role: llm.MessageRoleAssistant, + Content: []llm.Content{{ + Type: llm.ContentTypeToolUse, + ToolName: "bash", + ToolInput: json.RawMessage(`{"command": "echo hello world this is a long command that should be truncated if it exceeds the limit"}`), + }}, + } + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: sourceConv.ConversationID, + Type: db.MessageTypeAgent, + LLMData: toolMessage, + }) + if err != nil { + t.Fatalf("Failed to create tool message: %v", err) + } + + // Test the continue conversation endpoint + reqBody := map[string]string{ + "source_conversation_id": sourceConv.ConversationID, + "model": "predictable", + } + body, _ := json.Marshal(reqBody) + + resp, err := http.Post(testServer.URL+"/api/conversations/continue", "application/json", bytes.NewBuffer(body)) + if err != nil { + t.Fatalf("Failed to continue conversation: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("Expected status 201, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + newConversationID, ok := result["conversation_id"].(string) + if !ok || newConversationID == "" { + t.Fatal("Response should contain conversation_id") + } + + // Verify new conversation was created + newConv, err := database.GetConversationByID(ctx, newConversationID) + if err != nil { + t.Fatalf("Failed to get new conversation: %v", err) + } + + // Verify the new conversation inherited the cwd + if newConv.Cwd == nil || *newConv.Cwd != cwd { + t.Errorf("Expected cwd %s, got %v", cwd, newConv.Cwd) + } + + // Verify the new conversation has a user message with the summary + messages, err := database.ListMessages(ctx, newConversationID) + if err != nil { + t.Fatalf("Failed to list messages: %v", err) + } + + if len(messages) < 1 { + t.Fatal("Expected at least 1 message in new conversation") + } + + // Find the user message with the summary (may be after system prompt) + var summaryText string + for _, msg := range messages { + if msg.Type != string(db.MessageTypeUser) { + continue + } + if msg.LlmData == nil { + continue + } + var llmMsg llm.Message + if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil { + continue + } + for _, content := range llmMsg.Content { + if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "Continue the conversation") { + summaryText = content.Text + break + } + } + if summaryText != "" { + break + } + } + + if summaryText == "" { + t.Fatal("Could not find summary message in new conversation") + } + + if !strings.Contains(summaryText, sourceSlug) { + t.Errorf("Summary should reference source conversation slug %q, got: %s", sourceSlug, summaryText) + } + + if !strings.Contains(summaryText, "Hello, this is a test message") { + t.Error("Summary should contain user message text") + } + + if !strings.Contains(summaryText, "Hello! How can I help you?") { + t.Error("Summary should contain agent message text") + } + + if !strings.Contains(summaryText, "Tool: bash") { + t.Error("Summary should contain tool name") + } + + t.Logf("Successfully continued conversation from %s to %s", sourceConv.ConversationID, newConversationID) +} diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 955847977129e4273b3c4dbaf8698cdb82b01fe2..0b699b2f310d880fd9e0151bc22c3ff3802891ec 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -317,6 +317,26 @@ function App() { } }; + const handleContinueConversation = async ( + sourceConversationId: string, + model: string, + cwd?: string, + ) => { + try { + const response = await api.continueConversation(sourceConversationId, model, cwd); + const newConversationId = response.conversation_id; + + // Fetch the new conversation details + const updatedConvs = await api.getConversations(); + setConversations(updatedConvs); + setCurrentConversationId(newConversationId); + } catch (err) { + console.error("Failed to continue conversation:", err); + setError("Failed to continue conversation"); + throw err; + } + }; + return (
{/* Conversations drawer */} @@ -347,6 +367,7 @@ function App() { onConversationListUpdate={handleConversationListUpdate} onConversationStateUpdate={handleConversationStateUpdate} onFirstMessage={handleFirstMessage} + onContinueConversation={handleContinueConversation} mostRecentCwd={mostRecentCwd} isDrawerCollapsed={drawerCollapsed} onToggleDrawerCollapse={toggleDrawerCollapsed} diff --git a/ui/src/components/ChatInterface.tsx b/ui/src/components/ChatInterface.tsx index 9e2c9cf3676849662f7cba993462c679c870299d..8d7dc6d10e8c311d852eb9715142cf6c8c335310 100644 --- a/ui/src/components/ChatInterface.tsx +++ b/ui/src/components/ChatInterface.tsx @@ -31,11 +31,20 @@ import { useVersionChecker } from "./VersionChecker"; interface ContextUsageBarProps { contextWindowSize: number; maxContextTokens: number; + conversationId?: string | null; + onContinueConversation?: () => void; } -function ContextUsageBar({ contextWindowSize, maxContextTokens }: ContextUsageBarProps) { +function ContextUsageBar({ + contextWindowSize, + maxContextTokens, + conversationId, + onContinueConversation, +}: ContextUsageBarProps) { const [showPopup, setShowPopup] = useState(false); + const [continuing, setContinuing] = useState(false); const barRef = useRef(null); + const hasAutoOpenedRef = useRef(null); const percentage = maxContextTokens > 0 ? (contextWindowSize / maxContextTokens) * 100 : 0; const clampedPercentage = Math.min(percentage, 100); @@ -57,6 +66,18 @@ function ContextUsageBar({ contextWindowSize, maxContextTokens }: ContextUsageBa setShowPopup(!showPopup); }; + // Auto-open popup when hitting 100k tokens (once per conversation) + useEffect(() => { + if ( + showLongConversationWarning && + conversationId && + hasAutoOpenedRef.current !== conversationId + ) { + hasAutoOpenedRef.current = conversationId; + setShowPopup(true); + } + }, [showLongConversationWarning, conversationId]); + // Close popup when clicking outside useEffect(() => { if (!showPopup) return; @@ -86,6 +107,17 @@ function ContextUsageBar({ contextWindowSize, maxContextTokens }: ContextUsageBa } }, [showPopup]); + const handleContinue = async () => { + if (continuing || !onContinueConversation) return; + setContinuing(true); + try { + await onContinueConversation(); + setShowPopup(false); + } finally { + setContinuing(false); + } + }; + return (
{showPopup && popupPosition && ( @@ -114,6 +146,26 @@ function ContextUsageBar({ contextWindowSize, maxContextTokens }: ContextUsageBa For best results, start a new conversation.
)} + {onContinueConversation && conversationId && ( + + )}
)}
@@ -391,6 +443,11 @@ interface ChatInterfaceProps { onConversationListUpdate?: (update: ConversationListUpdate) => void; onConversationStateUpdate?: (state: ConversationStateUpdate) => void; onFirstMessage?: (message: string, model: string, cwd?: string) => Promise; + onContinueConversation?: ( + sourceConversationId: string, + model: string, + cwd?: string, + ) => Promise; mostRecentCwd?: string | null; isDrawerCollapsed?: boolean; onToggleDrawerCollapse?: () => void; @@ -407,6 +464,7 @@ function ChatInterface({ onConversationListUpdate, onConversationStateUpdate, onFirstMessage, + onContinueConversation, mostRecentCwd, isDrawerCollapsed, onToggleDrawerCollapse, @@ -860,6 +918,16 @@ function ChatInterface({ } }; + // Handler to continue conversation in a new one + const handleContinueConversation = async () => { + if (!conversationId || !onContinueConversation) return; + await onContinueConversation( + conversationId, + selectedModel, + currentConversation?.cwd || selectedCwd || undefined, + ); + }; + const getDisplayTitle = () => { return currentConversation?.slug || "Shelley"; }; @@ -1464,6 +1532,10 @@ function ChatInterface({ maxContextTokens={ models.find((m) => m.id === selectedModel)?.max_context_tokens || 200000 } + conversationId={conversationId} + onContinueConversation={ + onContinueConversation ? handleContinueConversation : undefined + } />
) : // Idle state - show ready message, or configuration for empty conversation @@ -1527,13 +1599,16 @@ function ChatInterface({ maxContextTokens={ models.find((m) => m.id === selectedModel)?.max_context_tokens || 200000 } + conversationId={conversationId} + onContinueConversation={ + onContinueConversation ? handleContinueConversation : undefined + } /> )} - {/* Message input */} {/* Message input */} { + const response = await fetch(`${this.baseUrl}/conversations/continue`, { + method: "POST", + headers: this.postHeaders, + body: JSON.stringify({ + source_conversation_id: sourceConversationId, + model: model || "", + cwd: cwd || "", + }), + }); + if (!response.ok) { + throw new Error(`Failed to continue conversation: ${response.statusText}`); + } + return response.json(); + } + async getConversation(conversationId: string): Promise { const response = await fetch(`${this.baseUrl}/conversation/${conversationId}`); if (!response.ok) {