Detailed changes
@@ -111,4 +111,5 @@ type streamResponseForTS struct {
Messages []apiMessageForTS `json:"messages"`
Conversation generated.Conversation `json:"conversation"`
ConversationState *conversationStateForTS `json:"conversation_state,omitempty"`
+ Heartbeat bool `json:"heartbeat,omitempty"`
}
@@ -998,6 +998,8 @@ func (s *Server) handleCancelConversation(w http.ResponseWriter, r *http.Request
}
// handleStreamConversation handles GET /conversation/<id>/stream
+// Query parameters:
+// - last_sequence_id: Resume from this sequence ID (skip messages up to and including this ID)
func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request, conversationID string) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@@ -1006,59 +1008,140 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request
ctx := r.Context()
+ // Parse last_sequence_id for resuming streams
+ lastSeqID := int64(-1)
+ if lastSeqStr := r.URL.Query().Get("last_sequence_id"); lastSeqStr != "" {
+ if parsed, err := strconv.ParseInt(lastSeqStr, 10, 64); err == nil {
+ lastSeqID = parsed
+ }
+ }
+
// Set up SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
- // Get current messages and conversation data
+ // For fresh connections, get messages BEFORE calling getOrCreateConversationManager.
+ // This is important because getOrCreateConversationManager may create a system prompt
+ // message during hydration, and we want to return the messages as they were before.
var messages []generated.Message
var conversation generated.Conversation
- err := s.db.Queries(ctx, func(q *generated.Queries) error {
- var err error
- messages, err = q.ListMessages(ctx, conversationID)
+ if lastSeqID < 0 {
+ err := s.db.Queries(ctx, func(q *generated.Queries) error {
+ var err error
+ messages, err = q.ListMessages(ctx, conversationID)
+ if err != nil {
+ return err
+ }
+ conversation, err = q.GetConversation(ctx, conversationID)
+ return err
+ })
if err != nil {
+ s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
+ }
+ // Update lastSeqID based on messages we're sending
+ if len(messages) > 0 {
+ lastSeqID = messages[len(messages)-1].SequenceID
+ }
+ } else {
+ // Resuming - just get conversation metadata
+ err := s.db.Queries(ctx, func(q *generated.Queries) error {
+ var err error
+ conversation, err = q.GetConversation(ctx, conversationID)
return err
+ })
+ if err != nil {
+ s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ return
}
- conversation, err = q.GetConversation(ctx, conversationID)
- return err
- })
- if err != nil {
- s.logger.Error("Failed to get conversation data", "conversationID", conversationID, "error", err)
- http.Error(w, "Internal server error", http.StatusInternalServerError)
- return
}
// Get or create conversation manager to access working state
manager, err := s.getOrCreateConversationManager(ctx, conversationID)
if err != nil {
s.logger.Error("Failed to get conversation manager", "conversationID", conversationID, "error", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
- // Send current messages, conversation data, and conversation state
- apiMessages := toAPIMessages(messages)
- streamData := StreamResponse{
- Messages: apiMessages,
- Conversation: conversation,
- ConversationState: &ConversationState{
- ConversationID: conversationID,
- Working: manager.IsAgentWorking(),
- Model: manager.GetModel(),
- },
- ContextWindowSize: calculateContextWindowSize(apiMessages),
+ // Send initial response
+ if len(messages) > 0 {
+ // Fresh connection - send all messages
+ apiMessages := toAPIMessages(messages)
+ streamData := StreamResponse{
+ Messages: apiMessages,
+ Conversation: conversation,
+ ConversationState: &ConversationState{
+ ConversationID: conversationID,
+ Working: manager.IsAgentWorking(),
+ Model: manager.GetModel(),
+ },
+ ContextWindowSize: calculateContextWindowSize(apiMessages),
+ }
+ data, _ := json.Marshal(streamData)
+ fmt.Fprintf(w, "data: %s\n\n", data)
+ w.(http.Flusher).Flush()
+ } else {
+ // Either resuming or no messages yet - send current state as heartbeat
+ streamData := StreamResponse{
+ Conversation: conversation,
+ ConversationState: &ConversationState{
+ ConversationID: conversationID,
+ Working: manager.IsAgentWorking(),
+ Model: manager.GetModel(),
+ },
+ Heartbeat: true,
+ }
+ data, _ := json.Marshal(streamData)
+ fmt.Fprintf(w, "data: %s\n\n", data)
+ w.(http.Flusher).Flush()
}
- data, _ := json.Marshal(streamData)
- fmt.Fprintf(w, "data: %s\n\n", data)
- w.(http.Flusher).Flush()
// Subscribe to new messages after the last one we sent
- last := int64(-1)
- if len(messages) > 0 {
- last = messages[len(messages)-1].SequenceID
- }
- next := manager.subpub.Subscribe(ctx, last)
+ next := manager.subpub.Subscribe(ctx, lastSeqID)
+
+ // Start heartbeat goroutine - sends state every 30 seconds if no other messages
+ heartbeatDone := make(chan struct{})
+ go func() {
+ ticker := time.NewTicker(30 * time.Second)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-heartbeatDone:
+ return
+ case <-ticker.C:
+ // Get current conversation state for heartbeat
+ var conv generated.Conversation
+ err := s.db.Queries(ctx, func(q *generated.Queries) error {
+ var err error
+ conv, err = q.GetConversation(ctx, conversationID)
+ return err
+ })
+ if err != nil {
+ continue // Skip heartbeat on error
+ }
+
+ heartbeat := StreamResponse{
+ Conversation: conv,
+ ConversationState: &ConversationState{
+ ConversationID: conversationID,
+ Working: manager.IsAgentWorking(),
+ Model: manager.GetModel(),
+ },
+ Heartbeat: true,
+ }
+ manager.subpub.Broadcast(heartbeat)
+ }
+ }
+ }()
+ defer close(heartbeatDone)
+
for {
streamData, cont := next()
if !cont {
@@ -63,6 +63,8 @@ type StreamResponse struct {
ContextWindowSize uint64 `json:"context_window_size,omitempty"`
// ConversationListUpdate is set when another conversation in the list changed
ConversationListUpdate *ConversationListUpdate `json:"conversation_list_update,omitempty"`
+ // Heartbeat indicates this is a heartbeat message (no new data, just keeping connection alive)
+ Heartbeat bool `json:"heartbeat,omitempty"`
}
// LLMProvider is an interface for getting LLM services
@@ -0,0 +1,156 @@
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "shelley.exe.dev/claudetool"
+ "shelley.exe.dev/db"
+ "shelley.exe.dev/llm"
+ "shelley.exe.dev/loop"
+)
+
+// TestStreamResumeWithLastSequenceID verifies that using last_sequence_id
+// parameter skips sending messages and sends a heartbeat instead.
+func TestStreamResumeWithLastSequenceID(t *testing.T) {
+ database, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Create a conversation with some messages
+ conv, err := database.CreateConversation(ctx, nil, true, nil, nil)
+ if err != nil {
+ t.Fatalf("Failed to create conversation: %v", err)
+ }
+
+ // Add a user message
+ userMsg := llm.Message{
+ Role: llm.MessageRoleUser,
+ Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}},
+ }
+ msg1, err := database.CreateMessage(ctx, db.CreateMessageParams{
+ ConversationID: conv.ConversationID,
+ Type: db.MessageTypeUser,
+ LLMData: userMsg,
+ })
+ if err != nil {
+ t.Fatalf("Failed to create user message: %v", err)
+ }
+
+ // Add an agent message
+ agentMsg := llm.Message{
+ Role: llm.MessageRoleAssistant,
+ Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}},
+ EndOfTurn: true,
+ }
+ msg2, err := database.CreateMessage(ctx, db.CreateMessageParams{
+ ConversationID: conv.ConversationID,
+ Type: db.MessageTypeAgent,
+ LLMData: agentMsg,
+ })
+ if err != nil {
+ t.Fatalf("Failed to create agent message: %v", err)
+ }
+
+ // Create server
+ predictableService := loop.NewPredictableService()
+ llmManager := &testLLMManager{service: predictableService}
+ toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false}
+ server := NewServer(database, llmManager, toolSetConfig, nil, true, "", "predictable", "", nil)
+
+ mux := http.NewServeMux()
+ server.RegisterRoutes(mux)
+
+ // Test 1: Fresh connection (no last_sequence_id) - should get all messages
+ t.Run("fresh_connection", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx)
+ req.Header.Set("Accept", "text/event-stream")
+
+ w := newResponseRecorderWithClose()
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ mux.ServeHTTP(w, req)
+ }()
+
+ time.Sleep(300 * time.Millisecond)
+ w.Close()
+ cancel()
+ <-done
+
+ body := w.Body.String()
+ if !strings.HasPrefix(body, "data: ") {
+ t.Fatalf("Expected SSE data, got: %s", body)
+ }
+
+ jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ")
+ var response StreamResponse
+ if err := json.Unmarshal([]byte(jsonData), &response); err != nil {
+ t.Fatalf("Failed to parse response: %v", err)
+ }
+
+ if len(response.Messages) != 2 {
+ t.Errorf("Expected 2 messages, got %d", len(response.Messages))
+ }
+ if response.Heartbeat {
+ t.Error("Fresh connection should not be a heartbeat")
+ }
+ })
+
+ // Test 2: Resume with last_sequence_id - should get heartbeat with no messages
+ t.Run("resume_connection", func(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ // Use the sequence ID of the last message
+ req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream?last_sequence_id="+string(rune('0'+msg2.SequenceID)), nil).WithContext(ctx)
+ req.Header.Set("Accept", "text/event-stream")
+
+ w := newResponseRecorderWithClose()
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ mux.ServeHTTP(w, req)
+ }()
+
+ time.Sleep(300 * time.Millisecond)
+ w.Close()
+ cancel()
+ <-done
+
+ body := w.Body.String()
+ if !strings.HasPrefix(body, "data: ") {
+ t.Fatalf("Expected SSE data, got: %s", body)
+ }
+
+ jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ")
+ var response StreamResponse
+ if err := json.Unmarshal([]byte(jsonData), &response); err != nil {
+ t.Fatalf("Failed to parse response: %v", err)
+ }
+
+ if len(response.Messages) != 0 {
+ t.Errorf("Expected 0 messages when resuming, got %d", len(response.Messages))
+ }
+ if !response.Heartbeat {
+ t.Error("Resume connection should be a heartbeat")
+ }
+ if response.ConversationState == nil {
+ t.Error("Expected ConversationState in heartbeat")
+ }
+ })
+
+ // Suppress unused variable warnings
+ _ = msg1
+}
@@ -611,6 +611,8 @@ function ChatInterface({
const overflowMenuRef = useRef<HTMLDivElement>(null);
const reconnectTimeoutRef = useRef<number | null>(null);
const periodicRetryRef = useRef<number | null>(null);
+ const heartbeatTimeoutRef = useRef<number | null>(null);
+ const lastSequenceIdRef = useRef<number>(-1);
const userScrolledRef = useRef(false);
// Load messages and set up streaming
@@ -639,6 +641,11 @@ function ChatInterface({
if (periodicRetryRef.current) {
clearInterval(periodicRetryRef.current);
}
+ if (heartbeatTimeoutRef.current) {
+ clearTimeout(heartbeatTimeoutRef.current);
+ }
+ // Reset sequence ID when conversation changes
+ lastSequenceIdRef.current = -1;
};
}, [conversationId]);
@@ -740,6 +747,22 @@ function ChatInterface({
}
};
+ // Reset heartbeat timeout - called on every message received
+ const resetHeartbeatTimeout = () => {
+ if (heartbeatTimeoutRef.current) {
+ clearTimeout(heartbeatTimeoutRef.current);
+ }
+ // If we don't receive any message (including heartbeat) within 60 seconds, reconnect
+ heartbeatTimeoutRef.current = window.setTimeout(() => {
+ console.warn("No heartbeat received in 60 seconds, reconnecting...");
+ if (eventSourceRef.current) {
+ eventSourceRef.current.close();
+ eventSourceRef.current = null;
+ }
+ setupMessageStream();
+ }, 60000);
+ };
+
const setupMessageStream = () => {
if (!conversationId) return;
@@ -747,18 +770,39 @@ function ChatInterface({
eventSourceRef.current.close();
}
- const eventSource = api.createMessageStream(conversationId);
+ // Clear any existing heartbeat timeout
+ if (heartbeatTimeoutRef.current) {
+ clearTimeout(heartbeatTimeoutRef.current);
+ }
+
+ // Use last_sequence_id to resume from where we left off (avoids resending all messages)
+ const lastSeqId = lastSequenceIdRef.current;
+ const eventSource = api.createMessageStream(
+ conversationId,
+ lastSeqId >= 0 ? lastSeqId : undefined,
+ );
eventSourceRef.current = eventSource;
eventSource.onmessage = (event) => {
+ // Reset heartbeat timeout on every message
+ resetHeartbeatTimeout();
+
try {
const streamResponse: StreamResponse = JSON.parse(event.data);
const incomingMessages = Array.isArray(streamResponse.messages)
? streamResponse.messages
: [];
+ // Track the latest sequence ID for reconnection
+ if (incomingMessages.length > 0) {
+ const maxSeqId = Math.max(...incomingMessages.map((m) => m.sequence_id));
+ if (maxSeqId > lastSequenceIdRef.current) {
+ lastSequenceIdRef.current = maxSeqId;
+ }
+ }
+
// Merge new messages without losing existing ones.
- // If no new messages (e.g., only conversation/slug update), keep existing list.
+ // If no new messages (e.g., only conversation/slug update or heartbeat), keep existing list.
if (incomingMessages.length > 0) {
setMessages((prev) => {
const byId = new Map<string, Message>();
@@ -816,6 +860,12 @@ function ChatInterface({
eventSourceRef.current = null;
}
+ // Clear heartbeat timeout on error
+ if (heartbeatTimeoutRef.current) {
+ clearTimeout(heartbeatTimeoutRef.current);
+ heartbeatTimeoutRef.current = null;
+ }
+
// Backoff delays: 1s, 2s, 5s, then show disconnected but keep retrying periodically
const delays = [1000, 2000, 5000];
@@ -858,6 +908,8 @@ function ChatInterface({
clearInterval(periodicRetryRef.current);
periodicRetryRef.current = null;
}
+ // Start heartbeat timeout monitoring
+ resetHeartbeatTimeout();
};
};
@@ -49,6 +49,7 @@ export interface StreamResponseForTS {
messages: ApiMessageForTS[] | null;
conversation: Conversation;
conversation_state?: ConversationStateForTS | null;
+ heartbeat?: boolean;
}
export interface ConversationWithStateForTS {
@@ -106,8 +106,12 @@ class ApiService {
}
}
- createMessageStream(conversationId: string): EventSource {
- return new EventSource(`${this.baseUrl}/conversation/${conversationId}/stream`);
+ createMessageStream(conversationId: string, lastSequenceId?: number): EventSource {
+ let url = `${this.baseUrl}/conversation/${conversationId}/stream`;
+ if (lastSequenceId !== undefined && lastSequenceId >= 0) {
+ url += `?last_sequence_id=${lastSequenceId}`;
+ }
+ return new EventSource(url);
}
async cancelConversation(conversationId: string): Promise<void> {
@@ -65,6 +65,7 @@ export interface StreamResponse extends Omit<StreamResponseForTS, "messages"> {
messages: Message[];
context_window_size?: number;
conversation_list_update?: ConversationListUpdate;
+ heartbeat?: boolean;
}
// Link represents a custom link that can be added to the UI