diff --git a/cmd/go2ts.go b/cmd/go2ts.go index d5b7c8cacee71381b113550a0107a7bed23c2403..299325aa3038869772e853b1c030a2112937d265 100644 --- a/cmd/go2ts.go +++ b/cmd/go2ts.go @@ -111,4 +111,5 @@ type streamResponseForTS struct { Messages []apiMessageForTS `json:"messages"` Conversation generated.Conversation `json:"conversation"` ConversationState *conversationStateForTS `json:"conversation_state,omitempty"` + Heartbeat bool `json:"heartbeat,omitempty"` } diff --git a/server/handlers.go b/server/handlers.go index d43e6e780e6242a418ba438cc541d2ae5258db21..6b71f60f9fe91fd91d6094eeddd247386ff61153 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -998,6 +998,8 @@ func (s *Server) handleCancelConversation(w http.ResponseWriter, r *http.Request } // handleStreamConversation handles GET /conversation//stream +// Query parameters: +// - last_sequence_id: Resume from this sequence ID (skip messages up to and including this ID) func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request, conversationID string) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -1006,59 +1008,140 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request ctx := r.Context() + // Parse last_sequence_id for resuming streams + lastSeqID := int64(-1) + if lastSeqStr := r.URL.Query().Get("last_sequence_id"); lastSeqStr != "" { + if parsed, err := strconv.ParseInt(lastSeqStr, 10, 64); err == nil { + lastSeqID = parsed + } + } + // Set up SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") - // Get current messages and conversation data + // For fresh connections, get messages BEFORE calling getOrCreateConversationManager. + // This is important because getOrCreateConversationManager may create a system prompt + // message during hydration, and we want to return the messages as they were before. var messages []generated.Message var conversation generated.Conversation - err := s.db.Queries(ctx, func(q *generated.Queries) error { - var err error - messages, err = q.ListMessages(ctx, conversationID) + if lastSeqID < 0 { + err := s.db.Queries(ctx, func(q *generated.Queries) error { + var err error + messages, err = q.ListMessages(ctx, conversationID) + if err != nil { + return err + } + conversation, err = q.GetConversation(ctx, conversationID) + return err + }) if err != nil { + s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + // Update lastSeqID based on messages we're sending + if len(messages) > 0 { + lastSeqID = messages[len(messages)-1].SequenceID + } + } else { + // Resuming - just get conversation metadata + err := s.db.Queries(ctx, func(q *generated.Queries) error { + var err error + conversation, err = q.GetConversation(ctx, conversationID) return err + }) + if err != nil { + s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return } - conversation, err = q.GetConversation(ctx, conversationID) - return err - }) - if err != nil { - s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return } // Get or create conversation manager to access working state manager, err := s.getOrCreateConversationManager(ctx, conversationID) if err != nil { s.logger.Error("Failed to get conversation manager", "conversationID", conversationID, "error", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) return } - // Send current messages, conversation data, and conversation state - apiMessages := toAPIMessages(messages) - streamData := StreamResponse{ - Messages: apiMessages, - Conversation: conversation, - ConversationState: &ConversationState{ - ConversationID: conversationID, - Working: manager.IsAgentWorking(), - Model: manager.GetModel(), - }, - ContextWindowSize: calculateContextWindowSize(apiMessages), + // Send initial response + if len(messages) > 0 { + // Fresh connection - send all messages + apiMessages := toAPIMessages(messages) + streamData := StreamResponse{ + Messages: apiMessages, + Conversation: conversation, + ConversationState: &ConversationState{ + ConversationID: conversationID, + Working: manager.IsAgentWorking(), + Model: manager.GetModel(), + }, + ContextWindowSize: calculateContextWindowSize(apiMessages), + } + data, _ := json.Marshal(streamData) + fmt.Fprintf(w, "data: %s\n\n", data) + w.(http.Flusher).Flush() + } else { + // Either resuming or no messages yet - send current state as heartbeat + streamData := StreamResponse{ + Conversation: conversation, + ConversationState: &ConversationState{ + ConversationID: conversationID, + Working: manager.IsAgentWorking(), + Model: manager.GetModel(), + }, + Heartbeat: true, + } + data, _ := json.Marshal(streamData) + fmt.Fprintf(w, "data: %s\n\n", data) + w.(http.Flusher).Flush() } - data, _ := json.Marshal(streamData) - fmt.Fprintf(w, "data: %s\n\n", data) - w.(http.Flusher).Flush() // Subscribe to new messages after the last one we sent - last := int64(-1) - if len(messages) > 0 { - last = messages[len(messages)-1].SequenceID - } - next := manager.subpub.Subscribe(ctx, last) + next := manager.subpub.Subscribe(ctx, lastSeqID) + + // Start heartbeat goroutine - sends state every 30 seconds if no other messages + heartbeatDone := make(chan struct{}) + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-heartbeatDone: + return + case <-ticker.C: + // Get current conversation state for heartbeat + var conv generated.Conversation + err := s.db.Queries(ctx, func(q *generated.Queries) error { + var err error + conv, err = q.GetConversation(ctx, conversationID) + return err + }) + if err != nil { + continue // Skip heartbeat on error + } + + heartbeat := StreamResponse{ + Conversation: conv, + ConversationState: &ConversationState{ + ConversationID: conversationID, + Working: manager.IsAgentWorking(), + Model: manager.GetModel(), + }, + Heartbeat: true, + } + manager.subpub.Broadcast(heartbeat) + } + } + }() + defer close(heartbeatDone) + for { streamData, cont := next() if !cont { diff --git a/server/server.go b/server/server.go index cd61516ef9147911e4f56419f076c4df331644c1..1cc08dada95c66b789cf89a57b2c622cbb2e710b 100644 --- a/server/server.go +++ b/server/server.go @@ -63,6 +63,8 @@ type StreamResponse struct { ContextWindowSize uint64 `json:"context_window_size,omitempty"` // ConversationListUpdate is set when another conversation in the list changed ConversationListUpdate *ConversationListUpdate `json:"conversation_list_update,omitempty"` + // Heartbeat indicates this is a heartbeat message (no new data, just keeping connection alive) + Heartbeat bool `json:"heartbeat,omitempty"` } // LLMProvider is an interface for getting LLM services diff --git a/server/stream_heartbeat_test.go b/server/stream_heartbeat_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1e8b6563d6d239c3ffd3891bcbfab642f997c78d --- /dev/null +++ b/server/stream_heartbeat_test.go @@ -0,0 +1,156 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "shelley.exe.dev/claudetool" + "shelley.exe.dev/db" + "shelley.exe.dev/llm" + "shelley.exe.dev/loop" +) + +// TestStreamResumeWithLastSequenceID verifies that using last_sequence_id +// parameter skips sending messages and sends a heartbeat instead. +func TestStreamResumeWithLastSequenceID(t *testing.T) { + database, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + + // Create a conversation with some messages + conv, err := database.CreateConversation(ctx, nil, true, nil, nil) + if err != nil { + t.Fatalf("Failed to create conversation: %v", err) + } + + // Add a user message + userMsg := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}}, + } + msg1, err := database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeUser, + LLMData: userMsg, + }) + if err != nil { + t.Fatalf("Failed to create user message: %v", err) + } + + // Add an agent message + agentMsg := llm.Message{ + Role: llm.MessageRoleAssistant, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}}, + EndOfTurn: true, + } + msg2, err := database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeAgent, + LLMData: agentMsg, + }) + if err != nil { + t.Fatalf("Failed to create agent message: %v", err) + } + + // Create server + predictableService := loop.NewPredictableService() + llmManager := &testLLMManager{service: predictableService} + toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false} + server := NewServer(database, llmManager, toolSetConfig, nil, true, "", "predictable", "", nil) + + mux := http.NewServeMux() + server.RegisterRoutes(mux) + + // Test 1: Fresh connection (no last_sequence_id) - should get all messages + t.Run("fresh_connection", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx) + req.Header.Set("Accept", "text/event-stream") + + w := newResponseRecorderWithClose() + + done := make(chan struct{}) + go func() { + defer close(done) + mux.ServeHTTP(w, req) + }() + + time.Sleep(300 * time.Millisecond) + w.Close() + cancel() + <-done + + body := w.Body.String() + if !strings.HasPrefix(body, "data: ") { + t.Fatalf("Expected SSE data, got: %s", body) + } + + jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ") + var response StreamResponse + if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if len(response.Messages) != 2 { + t.Errorf("Expected 2 messages, got %d", len(response.Messages)) + } + if response.Heartbeat { + t.Error("Fresh connection should not be a heartbeat") + } + }) + + // Test 2: Resume with last_sequence_id - should get heartbeat with no messages + t.Run("resume_connection", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // Use the sequence ID of the last message + req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream?last_sequence_id="+string(rune('0'+msg2.SequenceID)), nil).WithContext(ctx) + req.Header.Set("Accept", "text/event-stream") + + w := newResponseRecorderWithClose() + + done := make(chan struct{}) + go func() { + defer close(done) + mux.ServeHTTP(w, req) + }() + + time.Sleep(300 * time.Millisecond) + w.Close() + cancel() + <-done + + body := w.Body.String() + if !strings.HasPrefix(body, "data: ") { + t.Fatalf("Expected SSE data, got: %s", body) + } + + jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ") + var response StreamResponse + if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + if len(response.Messages) != 0 { + t.Errorf("Expected 0 messages when resuming, got %d", len(response.Messages)) + } + if !response.Heartbeat { + t.Error("Resume connection should be a heartbeat") + } + if response.ConversationState == nil { + t.Error("Expected ConversationState in heartbeat") + } + }) + + // Suppress unused variable warnings + _ = msg1 +} diff --git a/ui/src/components/ChatInterface.tsx b/ui/src/components/ChatInterface.tsx index a4193b9eb856c01e54d381b861bdde14da238920..d162a19bd8ce410dce2cd1db36f3bbe53766a72f 100644 --- a/ui/src/components/ChatInterface.tsx +++ b/ui/src/components/ChatInterface.tsx @@ -611,6 +611,8 @@ function ChatInterface({ const overflowMenuRef = useRef(null); const reconnectTimeoutRef = useRef(null); const periodicRetryRef = useRef(null); + const heartbeatTimeoutRef = useRef(null); + const lastSequenceIdRef = useRef(-1); const userScrolledRef = useRef(false); // Load messages and set up streaming @@ -639,6 +641,11 @@ function ChatInterface({ if (periodicRetryRef.current) { clearInterval(periodicRetryRef.current); } + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); + } + // Reset sequence ID when conversation changes + lastSequenceIdRef.current = -1; }; }, [conversationId]); @@ -740,6 +747,22 @@ function ChatInterface({ } }; + // Reset heartbeat timeout - called on every message received + const resetHeartbeatTimeout = () => { + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); + } + // If we don't receive any message (including heartbeat) within 60 seconds, reconnect + heartbeatTimeoutRef.current = window.setTimeout(() => { + console.warn("No heartbeat received in 60 seconds, reconnecting..."); + if (eventSourceRef.current) { + eventSourceRef.current.close(); + eventSourceRef.current = null; + } + setupMessageStream(); + }, 60000); + }; + const setupMessageStream = () => { if (!conversationId) return; @@ -747,18 +770,39 @@ function ChatInterface({ eventSourceRef.current.close(); } - const eventSource = api.createMessageStream(conversationId); + // Clear any existing heartbeat timeout + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); + } + + // Use last_sequence_id to resume from where we left off (avoids resending all messages) + const lastSeqId = lastSequenceIdRef.current; + const eventSource = api.createMessageStream( + conversationId, + lastSeqId >= 0 ? lastSeqId : undefined, + ); eventSourceRef.current = eventSource; eventSource.onmessage = (event) => { + // Reset heartbeat timeout on every message + resetHeartbeatTimeout(); + try { const streamResponse: StreamResponse = JSON.parse(event.data); const incomingMessages = Array.isArray(streamResponse.messages) ? streamResponse.messages : []; + // Track the latest sequence ID for reconnection + if (incomingMessages.length > 0) { + const maxSeqId = Math.max(...incomingMessages.map((m) => m.sequence_id)); + if (maxSeqId > lastSequenceIdRef.current) { + lastSequenceIdRef.current = maxSeqId; + } + } + // Merge new messages without losing existing ones. - // If no new messages (e.g., only conversation/slug update), keep existing list. + // If no new messages (e.g., only conversation/slug update or heartbeat), keep existing list. if (incomingMessages.length > 0) { setMessages((prev) => { const byId = new Map(); @@ -816,6 +860,12 @@ function ChatInterface({ eventSourceRef.current = null; } + // Clear heartbeat timeout on error + if (heartbeatTimeoutRef.current) { + clearTimeout(heartbeatTimeoutRef.current); + heartbeatTimeoutRef.current = null; + } + // Backoff delays: 1s, 2s, 5s, then show disconnected but keep retrying periodically const delays = [1000, 2000, 5000]; @@ -858,6 +908,8 @@ function ChatInterface({ clearInterval(periodicRetryRef.current); periodicRetryRef.current = null; } + // Start heartbeat timeout monitoring + resetHeartbeatTimeout(); }; }; diff --git a/ui/src/generated-types.ts b/ui/src/generated-types.ts index 313984d49ae9dfd4e0d54a1346689e8f4b96b7ef..80be151a3f520d8e62098fc48fb4294c88be5785 100644 --- a/ui/src/generated-types.ts +++ b/ui/src/generated-types.ts @@ -49,6 +49,7 @@ export interface StreamResponseForTS { messages: ApiMessageForTS[] | null; conversation: Conversation; conversation_state?: ConversationStateForTS | null; + heartbeat?: boolean; } export interface ConversationWithStateForTS { diff --git a/ui/src/services/api.ts b/ui/src/services/api.ts index 9b11789991294c8604c3a150d74c409f72f563c4..ffc1dba4a985f3b6c275affa4c586d69415f96a4 100644 --- a/ui/src/services/api.ts +++ b/ui/src/services/api.ts @@ -106,8 +106,12 @@ class ApiService { } } - createMessageStream(conversationId: string): EventSource { - return new EventSource(`${this.baseUrl}/conversation/${conversationId}/stream`); + createMessageStream(conversationId: string, lastSequenceId?: number): EventSource { + let url = `${this.baseUrl}/conversation/${conversationId}/stream`; + if (lastSequenceId !== undefined && lastSequenceId >= 0) { + url += `?last_sequence_id=${lastSequenceId}`; + } + return new EventSource(url); } async cancelConversation(conversationId: string): Promise { diff --git a/ui/src/types.ts b/ui/src/types.ts index f29dcbfe9e19852ca4422861e367465e20a4a4f9..6fb365cf2aebbd9823b2288217cb7e0e124e749e 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -65,6 +65,7 @@ export interface StreamResponse extends Omit { messages: Message[]; context_window_size?: number; conversation_list_update?: ConversationListUpdate; + heartbeat?: boolean; } // Link represents a custom link that can be added to the UI