stream_heartbeat_test.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"net/http"
  7	"net/http/httptest"
  8	"strings"
  9	"testing"
 10	"time"
 11
 12	"shelley.exe.dev/claudetool"
 13	"shelley.exe.dev/db"
 14	"shelley.exe.dev/llm"
 15	"shelley.exe.dev/loop"
 16)
 17
 18// TestStreamResumeWithLastSequenceID verifies that using last_sequence_id
 19// parameter skips sending messages and sends a heartbeat instead.
 20func TestStreamResumeWithLastSequenceID(t *testing.T) {
 21	database, cleanup := setupTestDB(t)
 22	defer cleanup()
 23
 24	ctx := context.Background()
 25
 26	// Create a conversation with some messages
 27	conv, err := database.CreateConversation(ctx, nil, true, nil, nil)
 28	if err != nil {
 29		t.Fatalf("Failed to create conversation: %v", err)
 30	}
 31
 32	// Add a user message
 33	userMsg := llm.Message{
 34		Role:    llm.MessageRoleUser,
 35		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}},
 36	}
 37	msg1, err := database.CreateMessage(ctx, db.CreateMessageParams{
 38		ConversationID: conv.ConversationID,
 39		Type:           db.MessageTypeUser,
 40		LLMData:        userMsg,
 41	})
 42	if err != nil {
 43		t.Fatalf("Failed to create user message: %v", err)
 44	}
 45
 46	// Add an agent message
 47	agentMsg := llm.Message{
 48		Role:      llm.MessageRoleAssistant,
 49		Content:   []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}},
 50		EndOfTurn: true,
 51	}
 52	msg2, err := database.CreateMessage(ctx, db.CreateMessageParams{
 53		ConversationID: conv.ConversationID,
 54		Type:           db.MessageTypeAgent,
 55		LLMData:        agentMsg,
 56	})
 57	if err != nil {
 58		t.Fatalf("Failed to create agent message: %v", err)
 59	}
 60
 61	// Create server
 62	predictableService := loop.NewPredictableService()
 63	llmManager := &testLLMManager{service: predictableService}
 64	toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false}
 65	server := NewServer(database, llmManager, toolSetConfig, nil, true, "", "predictable", "", nil)
 66
 67	mux := http.NewServeMux()
 68	server.RegisterRoutes(mux)
 69
 70	// Test 1: Fresh connection (no last_sequence_id) - should get all messages
 71	t.Run("fresh_connection", func(t *testing.T) {
 72		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
 73		defer cancel()
 74
 75		req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx)
 76		req.Header.Set("Accept", "text/event-stream")
 77
 78		w := newResponseRecorderWithClose()
 79
 80		done := make(chan struct{})
 81		go func() {
 82			defer close(done)
 83			mux.ServeHTTP(w, req)
 84		}()
 85
 86		time.Sleep(300 * time.Millisecond)
 87		w.Close()
 88		cancel()
 89		<-done
 90
 91		body := w.Body.String()
 92		if !strings.HasPrefix(body, "data: ") {
 93			t.Fatalf("Expected SSE data, got: %s", body)
 94		}
 95
 96		jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ")
 97		var response StreamResponse
 98		if err := json.Unmarshal([]byte(jsonData), &response); err != nil {
 99			t.Fatalf("Failed to parse response: %v", err)
100		}
101
102		if len(response.Messages) != 2 {
103			t.Errorf("Expected 2 messages, got %d", len(response.Messages))
104		}
105		if response.Heartbeat {
106			t.Error("Fresh connection should not be a heartbeat")
107		}
108	})
109
110	// Test 2: Resume with last_sequence_id - should get heartbeat with no messages
111	t.Run("resume_connection", func(t *testing.T) {
112		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
113		defer cancel()
114
115		// Use the sequence ID of the last message
116		req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream?last_sequence_id="+string(rune('0'+msg2.SequenceID)), nil).WithContext(ctx)
117		req.Header.Set("Accept", "text/event-stream")
118
119		w := newResponseRecorderWithClose()
120
121		done := make(chan struct{})
122		go func() {
123			defer close(done)
124			mux.ServeHTTP(w, req)
125		}()
126
127		time.Sleep(300 * time.Millisecond)
128		w.Close()
129		cancel()
130		<-done
131
132		body := w.Body.String()
133		if !strings.HasPrefix(body, "data: ") {
134			t.Fatalf("Expected SSE data, got: %s", body)
135		}
136
137		jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ")
138		var response StreamResponse
139		if err := json.Unmarshal([]byte(jsonData), &response); err != nil {
140			t.Fatalf("Failed to parse response: %v", err)
141		}
142
143		if len(response.Messages) != 0 {
144			t.Errorf("Expected 0 messages when resuming, got %d", len(response.Messages))
145		}
146		if !response.Heartbeat {
147			t.Error("Resume connection should be a heartbeat")
148		}
149		if response.ConversationState == nil {
150			t.Error("Expected ConversationState in heartbeat")
151		}
152	})
153
154	// Suppress unused variable warnings
155	_ = msg1
156}