From 2b4f3e1d525421f69e03eab1999b86b9afea8c8c Mon Sep 17 00:00:00 2001 From: Philip Zeyliger Date: Thu, 29 Jan 2026 22:47:30 +0000 Subject: [PATCH] shelley: add heartbeat to conversation stream, support resume Prompt: In a new worktree, make it so that the conversation stream has some sort of periodic heartbeat (e.g., every minute), and the client retries if it hasn't seen anything in a minute. perhaps keep sending the state of the conversation (whether the agent is running), since that's the most likely thing to get confused... Note that we might want to make the stream endpoint not start at the beginning but start at whatever point the client has, so as to avoid double-sending stuff. Add a periodic heartbeat (every 30 seconds) to the SSE conversation stream to keep connections alive and provide current conversation state. This helps with: 1. Detecting stale connections - the client reconnects if no message (including heartbeat) is received within 60 seconds 2. Keeping proxies/load balancers from timing out idle connections 3. Syncing conversation state (working status) even when no messages flow Also add support for resuming streams via the last_sequence_id query parameter. When provided, the server skips sending historical messages (which the client already has) and just sends the current state as a heartbeat. This avoids re-sending potentially large message histories on reconnection. Server changes: - Add Heartbeat field to StreamResponse - Add last_sequence_id query parameter to stream endpoint - Start heartbeat goroutine that broadcasts state every 30 seconds - Restructure handleStreamConversation to query messages before creating conversation manager (preserves existing behavior for fresh connections) Client changes: - Track last sequence ID from received messages - Pass last_sequence_id to stream endpoint on reconnection - Add 60-second heartbeat timeout that triggers reconnection - Update types for heartbeat field Co-authored-by: Shelley --- cmd/go2ts.go | 1 + server/handlers.go | 143 +++++++++++++++++++------ server/server.go | 2 + server/stream_heartbeat_test.go | 156 ++++++++++++++++++++++++++++ ui/src/components/ChatInterface.tsx | 56 +++++++++- ui/src/generated-types.ts | 1 + ui/src/services/api.ts | 8 +- ui/src/types.ts | 1 + 8 files changed, 334 insertions(+), 34 deletions(-) create mode 100644 server/stream_heartbeat_test.go diff --git a/cmd/go2ts.go b/cmd/go2ts.go index d5b7c8cacee71381b113550a0107a7bed23c2403..299325aa3038869772e853b1c030a2112937d265 100644 --- a/cmd/go2ts.go +++ b/cmd/go2ts.go @@ -111,4 +111,5 @@ type streamResponseForTS struct { Messages []apiMessageForTS `json:"messages"` Conversation generated.Conversation `json:"conversation"` ConversationState *conversationStateForTS `json:"conversation_state,omitempty"` + Heartbeat bool `json:"heartbeat,omitempty"` } diff --git a/server/handlers.go b/server/handlers.go index d43e6e780e6242a418ba438cc541d2ae5258db21..6b71f60f9fe91fd91d6094eeddd247386ff61153 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -998,6 +998,8 @@ func (s *Server) handleCancelConversation(w http.ResponseWriter, r *http.Request } // handleStreamConversation handles GET /conversation//stream +// Query parameters: +// - last_sequence_id: Resume from this sequence ID (skip messages up to and including this ID) func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request, conversationID string) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -1006,59 +1008,140 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request ctx := r.Context() + // Parse last_sequence_id for resuming streams + lastSeqID := int64(-1) + if lastSeqStr := r.URL.Query().Get("last_sequence_id"); lastSeqStr != "" { + if parsed, err := strconv.ParseInt(lastSeqStr, 10, 64); err == nil { + lastSeqID = parsed + } + } + // Set up SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") - // Get current messages and conversation data + // For fresh connections, get messages BEFORE calling getOrCreateConversationManager. + // This is important because getOrCreateConversationManager may create a system prompt + // message during hydration, and we want to return the messages as they were before. var messages []generated.Message var conversation generated.Conversation - err := s.db.Queries(ctx, func(q *generated.Queries) error { - var err error - messages, err = q.ListMessages(ctx, conversationID) + if lastSeqID < 0 { + err := s.db.Queries(ctx, func(q *generated.Queries) error { + var err error + messages, err = q.ListMessages(ctx, conversationID) + if err != nil { + return err + } + conversation, err = q.GetConversation(ctx, conversationID) + return err + }) if err != nil { + s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + // Update lastSeqID based on messages we're sending + if len(messages) > 0 { + lastSeqID = messages[len(messages)-1].SequenceID + } + } else { + // Resuming - just get conversation metadata + err := s.db.Queries(ctx, func(q *generated.Queries) error { + var err error + conversation, err = q.GetConversation(ctx, conversationID) return err + }) + if err != nil { + s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return } - conversation, err = q.GetConversation(ctx, conversationID) - return err - }) - if err != nil { - s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return } // Get or create conversation manager to access working state manager, err := s.getOrCreateConversationManager(ctx, conversationID) if err != nil { s.logger.Error("Failed to get conversation manager", "conversationID", conversationID, "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) return } - // Send current messages, conversation data, and conversation state - apiMessages := toAPIMessages(messages) - streamData := StreamResponse{ - Messages: apiMessages, - Conversation: conversation, - ConversationState: &ConversationState{ - ConversationID: conversationID, - Working: manager.IsAgentWorking(), - Model: manager.GetModel(), - }, - ContextWindowSize: calculateContextWindowSize(apiMessages), + // Send initial response + if len(messages) > 0 { + // Fresh connection - send all messages + apiMessages := toAPIMessages(messages) + streamData := StreamResponse{ + Messages: apiMessages, + Conversation: conversation, + ConversationState: &ConversationState{ + ConversationID: conversationID, + Working: manager.IsAgentWorking(), + Model: manager.GetModel(), + }, + ContextWindowSize: calculateContextWindowSize(apiMessages), + } + data, _ := json.Marshal(streamData) + fmt.Fprintf(w, "data: %s\n\n", data) + w.(http.Flusher).Flush() + } else { + // Either resuming or no messages yet - send current state as heartbeat + streamData := StreamResponse{ + Conversation: conversation, + ConversationState: &ConversationState{ + ConversationID: conversationID, + Working: manager.IsAgentWorking(), + Model: manager.GetModel(), + }, + Heartbeat: true, + } + data, _ := json.Marshal(streamData) + fmt.Fprintf(w, "data: %s\n\n", data) + w.(http.Flusher).Flush() } - data, _ := json.Marshal(streamData) - fmt.Fprintf(w, "data: %s\n\n", data) - w.(http.Flusher).Flush() // Subscribe to new messages after the last one we sent - last := int64(-1) - if len(messages) > 0 { - last = messages[len(messages)-1].SequenceID - } - next := manager.subpub.Subscribe(ctx, last) + next := manager.subpub.Subscribe(ctx, lastSeqID) + + // Start heartbeat goroutine - sends state every 30 seconds if no other messages + heartbeatDone := make(chan struct{}) + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-heartbeatDone: + return + case <-ticker.C: + // Get current conversation state for heartbeat + var conv generated.Conversation + err := s.db.Queries(ctx, func(q *generated.Queries) error { + var err error + conv, err = q.GetConversation(ctx, conversationID) + return err + }) + if err != nil { + continue // Skip heartbeat on error + } + + heartbeat := StreamResponse{ + Conversation: conv, + ConversationState: &ConversationState{ + ConversationID: conversationID, + Working: manager.IsAgentWorking(), + Model: manager.GetModel(), + }, + Heartbeat: true, + } + manager.subpub.Broadcast(heartbeat) + } + } + }() + defer close(heartbeatDone) + for { streamData, cont := next() if !cont { diff --git a/server/server.go b/server/server.go index cd61516ef9147911e4f56419f076c4df331644c1..1cc08dada95c66b789cf89a57b2c622cbb2e710b 100644 --- a/server/server.go +++ b/server/server.go @@ -63,6 +63,8 @@ type StreamResponse struct { ContextWindowSize uint64 `json:"context_window_size,omitempty"` // ConversationListUpdate is set when another conversation in the list changed ConversationListUpdate *ConversationListUpdate `json:"conversation_list_update,omitempty"` + // Heartbeat indicates this is a heartbeat message (no new data, just keeping connection alive) + Heartbeat bool `json:"heartbeat,omitempty"` } // LLMProvider is an interface for getting LLM services diff --git a/server/stream_heartbeat_test.go b/server/stream_heartbeat_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1e8b6563d6d239c3ffd3891bcbfab642f997c78d --- /dev/null +++ b/server/stream_heartbeat_test.go @@ -0,0 +1,156 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "shelley.exe.dev/claudetool" + "shelley.exe.dev/db" + "shelley.exe.dev/llm" + "shelley.exe.dev/loop" +) + +// TestStreamResumeWithLastSequenceID verifies that using last_sequence_id +// parameter skips sending messages and sends a heartbeat instead. +func TestStreamResumeWithLastSequenceID(t *testing.T) { + database, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + + // Create a conversation with some messages + conv, err := database.CreateConversation(ctx, nil, true, nil, nil) + if err != nil { + t.Fatalf("Failed to create conversation: %v", err) + } + + // Add a user message + userMsg := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}}, + } + msg1, err := database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeUser, + LLMData: userMsg, + }) + if err != nil { + t.Fatalf("Failed to create user message: %v", err) + } + + // Add an agent message + agentMsg := llm.Message{ + Role: llm.MessageRoleAssistant, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}}, + EndOfTurn: true, + } + msg2, err := database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeAgent, + LLMData: agentMsg, + }) + if err != nil { + t.Fatalf("Failed to create agent message: %v", err) + } + + // Create server + predictableService := loop.NewPredictableService() + llmManager := &testLLMManager{service: predictableService} + toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false} + server := NewServer(database, llmManager, toolSetConfig, nil, true, "", "predictable", "", nil) + + mux := http.NewServeMux() + server.RegisterRoutes(mux) + + // Test 1: Fresh connection (no last_sequence_id) - should get all messages + t.Run("fresh_connection", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx) + req.Header.Set("Accept", "text/event-stream") + + w := newResponseRecorderWithClose() + + done := make(chan struct{}) + go func() { + defer close(done) + mux.ServeHTTP(w, req) + }() + + time.Sleep(300 * time.Millisecond) + w.Close() + cancel() + <-done + + body := w.Body.String() + if !strings.HasPrefix(body, "data: ") { + t.Fatalf("Expected SSE data, got: %s", body) + } + + jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ") + var response StreamResponse + if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if len(response.Messages) != 2 { + t.Errorf("Expected 2 messages, got %d", len(response.Messages)) + } + if response.Heartbeat { + t.Error("Fresh connection should not be a heartbeat") + } + }) + + // Test 2: Resume with last_sequence_id - should get heartbeat with no messages + t.Run("resume_connection", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Use the sequence ID of the last message + req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream?last_sequence_id="+string(rune('0'+msg2.SequenceID)), nil).WithContext(ctx) + req.Header.Set("Accept", "text/event-stream") + + w := newResponseRecorderWithClose() + + done := make(chan struct{}) + go func() { + defer close(done) + mux.ServeHTTP(w, req) + }() + + time.Sleep(300 * time.Millisecond) + w.Close() + cancel() + <-done + + body := w.Body.String() + if !strings.HasPrefix(body, "data: ") { + t.Fatalf("Expected SSE data, got: %s", body) + } + + jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ") + var response StreamResponse + if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if len(response.Messages) != 0 { + t.Errorf("Expected 0 messages when resuming, got %d", len(response.Messages)) + } + if !response.Heartbeat { + t.Error("Resume connection should be a heartbeat") + } + if response.ConversationState == nil { + t.Error("Expected ConversationState in heartbeat") + } + }) + + // Suppress unused variable warnings + _ = msg1 +} diff --git a/ui/src/components/ChatInterface.tsx b/ui/src/components/ChatInterface.tsx index a4193b9eb856c01e54d381b861bdde14da238920..d162a19bd8ce410dce2cd1db36f3bbe53766a72f 100644 --- a/ui/src/components/ChatInterface.tsx +++ b/ui/src/components/ChatInterface.tsx @@ -611,6 +611,8 @@ function ChatInterface({ const overflowMenuRef = useRef(null); const reconnectTimeoutRef = useRef(null); const periodicRetryRef = useRef(null); + const heartbeatTimeoutRef = useRef(null); + const lastSequenceIdRef = useRef(-1); const userScrolledRef = useRef(false); // Load messages and set up streaming @@ -639,6 +641,11 @@ function ChatInterface({ if (periodicRetryRef.current) { clearInterval(periodicRetryRef.current); } + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); + } + // Reset sequence ID when conversation changes + lastSequenceIdRef.current = -1; }; }, [conversationId]); @@ -740,6 +747,22 @@ function ChatInterface({ } }; + // Reset heartbeat timeout - called on every message received + const resetHeartbeatTimeout = () => { + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); + } + // If we don't receive any message (including heartbeat) within 60 seconds, reconnect + heartbeatTimeoutRef.current = window.setTimeout(() => { + console.warn("No heartbeat received in 60 seconds, reconnecting..."); + if (eventSourceRef.current) { + eventSourceRef.current.close(); + eventSourceRef.current = null; + } + setupMessageStream(); + }, 60000); + }; + const setupMessageStream = () => { if (!conversationId) return; @@ -747,18 +770,39 @@ function ChatInterface({ eventSourceRef.current.close(); } - const eventSource = api.createMessageStream(conversationId); + // Clear any existing heartbeat timeout + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); + } + + // Use last_sequence_id to resume from where we left off (avoids resending all messages) + const lastSeqId = lastSequenceIdRef.current; + const eventSource = api.createMessageStream( + conversationId, + lastSeqId >= 0 ? lastSeqId : undefined, + ); eventSourceRef.current = eventSource; eventSource.onmessage = (event) => { + // Reset heartbeat timeout on every message + resetHeartbeatTimeout(); + try { const streamResponse: StreamResponse = JSON.parse(event.data); const incomingMessages = Array.isArray(streamResponse.messages) ? streamResponse.messages : []; + // Track the latest sequence ID for reconnection + if (incomingMessages.length > 0) { + const maxSeqId = Math.max(...incomingMessages.map((m) => m.sequence_id)); + if (maxSeqId > lastSequenceIdRef.current) { + lastSequenceIdRef.current = maxSeqId; + } + } + // Merge new messages without losing existing ones. - // If no new messages (e.g., only conversation/slug update), keep existing list. + // If no new messages (e.g., only conversation/slug update or heartbeat), keep existing list. if (incomingMessages.length > 0) { setMessages((prev) => { const byId = new Map(); @@ -816,6 +860,12 @@ function ChatInterface({ eventSourceRef.current = null; } + // Clear heartbeat timeout on error + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); + heartbeatTimeoutRef.current = null; + } + // Backoff delays: 1s, 2s, 5s, then show disconnected but keep retrying periodically const delays = [1000, 2000, 5000]; @@ -858,6 +908,8 @@ function ChatInterface({ clearInterval(periodicRetryRef.current); periodicRetryRef.current = null; } + // Start heartbeat timeout monitoring + resetHeartbeatTimeout(); }; }; diff --git a/ui/src/generated-types.ts b/ui/src/generated-types.ts index 313984d49ae9dfd4e0d54a1346689e8f4b96b7ef..80be151a3f520d8e62098fc48fb4294c88be5785 100644 --- a/ui/src/generated-types.ts +++ b/ui/src/generated-types.ts @@ -49,6 +49,7 @@ export interface StreamResponseForTS { messages: ApiMessageForTS[] | null; conversation: Conversation; conversation_state?: ConversationStateForTS | null; + heartbeat?: boolean; } export interface ConversationWithStateForTS { diff --git a/ui/src/services/api.ts b/ui/src/services/api.ts index 9b11789991294c8604c3a150d74c409f72f563c4..ffc1dba4a985f3b6c275affa4c586d69415f96a4 100644 --- a/ui/src/services/api.ts +++ b/ui/src/services/api.ts @@ -106,8 +106,12 @@ class ApiService { } } - createMessageStream(conversationId: string): EventSource { - return new EventSource(`${this.baseUrl}/conversation/${conversationId}/stream`); + createMessageStream(conversationId: string, lastSequenceId?: number): EventSource { + let url = `${this.baseUrl}/conversation/${conversationId}/stream`; + if (lastSequenceId !== undefined && lastSequenceId >= 0) { + url += `?last_sequence_id=${lastSequenceId}`; + } + return new EventSource(url); } async cancelConversation(conversationId: string): Promise { diff --git a/ui/src/types.ts b/ui/src/types.ts index f29dcbfe9e19852ca4422861e367465e20a4a4f9..6fb365cf2aebbd9823b2288217cb7e0e124e749e 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -65,6 +65,7 @@ export interface StreamResponse extends Omit { messages: Message[]; context_window_size?: number; conversation_list_update?: ConversationListUpdate; + heartbeat?: boolean; } // Link represents a custom link that can be added to the UI