cancel_test.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"log/slog"
  7	"net/http"
  8	"net/http/httptest"
  9	"strings"
 10	"testing"
 11	"time"
 12
 13	"shelley.exe.dev/claudetool"
 14	"shelley.exe.dev/db"
 15	"shelley.exe.dev/db/generated"
 16	"shelley.exe.dev/llm"
 17	"shelley.exe.dev/loop"
 18	"shelley.exe.dev/models"
 19)
 20
 21// setupTestDB creates a test database
 22func setupTestDB(t *testing.T) (*db.DB, func()) {
 23	t.Helper()
 24	tmpDir := t.TempDir()
 25	database, err := db.New(db.Config{DSN: tmpDir + "/test.db"})
 26	if err != nil {
 27		t.Fatalf("Failed to create test database: %v", err)
 28	}
 29
 30	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 31	defer cancel()
 32
 33	if err := database.Migrate(ctx); err != nil {
 34		t.Fatalf("Failed to migrate test database: %v", err)
 35	}
 36
 37	return database, func() {
 38		database.Close()
 39	}
 40}
 41
 42// waitFor polls a condition until it returns true or the timeout is reached.
 43func waitFor(t *testing.T, timeout time.Duration, condition func() bool) {
 44	t.Helper()
 45	deadline := time.Now().Add(timeout)
 46	for time.Now().Before(deadline) {
 47		if condition() {
 48			return
 49		}
 50		time.Sleep(10 * time.Millisecond)
 51	}
 52	t.Fatal("timed out waiting for condition")
 53}
 54
 55// TestCancelWithPredictableModel tests cancellation with the predictable model
 56func TestCancelWithPredictableModel(t *testing.T) {
 57	// Create test database
 58	database, cleanup := setupTestDB(t)
 59	defer cleanup()
 60
 61	predictableService := loop.NewPredictableService()
 62	llmManager := &testLLMManager{service: predictableService}
 63	logger := slog.Default()
 64
 65	// Register the bash tool so the sleep command actually runs and can be cancelled
 66	toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false}
 67	server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
 68
 69	// Create conversation
 70	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
 71	if err != nil {
 72		t.Fatalf("failed to create conversation: %v", err)
 73	}
 74	conversationID := conversation.ConversationID
 75
 76	// Start a conversation with a message that triggers a slow bash command
 77	chatReq := ChatRequest{
 78		Message: "bash: sleep 5",
 79		Model:   "predictable",
 80	}
 81	chatBody, _ := json.Marshal(chatReq)
 82
 83	req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(chatBody)))
 84	req.Header.Set("Content-Type", "application/json")
 85	w := httptest.NewRecorder()
 86
 87	server.handleChatConversation(w, req, conversationID)
 88
 89	if w.Code != http.StatusAccepted {
 90		t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String())
 91	}
 92
 93	// Wait for agent to record an assistant message with tool use
 94	waitFor(t, 5*time.Second, func() bool {
 95		var messages []generated.Message
 96		err := database.Queries(context.Background(), func(q *generated.Queries) error {
 97			var qerr error
 98			messages, qerr = q.ListMessages(context.Background(), conversationID)
 99			return qerr
100		})
101		if err != nil || len(messages) < 2 {
102			return false
103		}
104		// Check for assistant message with tool use
105		for _, msg := range messages {
106			if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
107				continue
108			}
109			var llmMsg llm.Message
110			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
111				continue
112			}
113			for _, content := range llmMsg.Content {
114				if content.Type == llm.ContentTypeToolUse {
115					return true
116				}
117			}
118		}
119		return false
120	})
121
122	// Cancel the conversation
123	cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil)
124	cancelW := httptest.NewRecorder()
125
126	server.handleCancelConversation(cancelW, cancelReq, conversationID)
127
128	if cancelW.Code != http.StatusOK {
129		t.Fatalf("expected status 200, got %d: %s", cancelW.Code, cancelW.Body.String())
130	}
131
132	var cancelResp map[string]string
133	if err := json.Unmarshal(cancelW.Body.Bytes(), &cancelResp); err != nil {
134		t.Fatalf("failed to parse cancel response: %v", err)
135	}
136
137	if cancelResp["status"] != "cancelled" {
138		t.Errorf("expected status 'cancelled', got '%s'", cancelResp["status"])
139	}
140
141	// Wait for agent to stop working (cancellation complete)
142	waitFor(t, 5*time.Second, func() bool {
143		return !server.IsAgentWorking(conversationID)
144	})
145
146	// Verify that a cancelled tool result was recorded
147	var messages []generated.Message
148	err = database.Queries(context.Background(), func(q *generated.Queries) error {
149		var qerr error
150		messages, qerr = q.ListMessages(context.Background(), conversationID)
151		return qerr
152	})
153	if err != nil {
154		t.Fatalf("failed to get messages after cancel: %v", err)
155	}
156
157	// Should have: user message, assistant message with tool use, cancelled tool result, and end turn message
158	if len(messages) < 4 {
159		t.Fatalf("expected at least 4 messages after cancel, got %d", len(messages))
160	}
161
162	// Check that we have the cancelled tool result
163	foundCancelledResult := false
164	foundEndTurnMessage := false
165	for i := len(messages) - 1; i >= 0; i-- {
166		msg := messages[i]
167		if msg.LlmData == nil {
168			continue
169		}
170
171		var llmMsg llm.Message
172		if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
173			continue
174		}
175
176		// Check for cancelled tool result
177		for _, content := range llmMsg.Content {
178			if content.Type == llm.ContentTypeToolResult && content.ToolError {
179				for _, result := range content.ToolResult {
180					if result.Type == llm.ContentTypeText && strings.Contains(result.Text, "cancelled") {
181						foundCancelledResult = true
182						break
183					}
184				}
185			}
186		}
187
188		// Check for end turn message
189		if msg.Type == string(db.MessageTypeAgent) && llmMsg.EndOfTurn {
190			for _, content := range llmMsg.Content {
191				if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "Operation cancelled") {
192					foundEndTurnMessage = true
193					break
194				}
195			}
196		}
197	}
198
199	if !foundCancelledResult {
200		t.Error("expected to find cancelled tool result in conversation")
201	}
202
203	if !foundEndTurnMessage {
204		t.Error("expected to find end turn message after cancellation")
205	}
206
207	// Test that conversation can be resumed after cancellation
208	resumeReq := ChatRequest{
209		Message: "echo: test after cancel",
210		Model:   "predictable",
211	}
212	resumeBody, _ := json.Marshal(resumeReq)
213
214	resumeChatReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(resumeBody)))
215	resumeChatReq.Header.Set("Content-Type", "application/json")
216	resumeW := httptest.NewRecorder()
217
218	server.handleChatConversation(resumeW, resumeChatReq, conversationID)
219
220	if resumeW.Code != http.StatusAccepted {
221		t.Fatalf("expected status 202 for resume, got %d: %s", resumeW.Code, resumeW.Body.String())
222	}
223
224	// Wait for agent to finish processing the resumed conversation
225	waitFor(t, 5*time.Second, func() bool {
226		return !server.IsAgentWorking(conversationID)
227	})
228
229	// Verify conversation continued
230	err = database.Queries(context.Background(), func(q *generated.Queries) error {
231		var qerr error
232		messages, qerr = q.ListMessages(context.Background(), conversationID)
233		return qerr
234	})
235	if err != nil {
236		t.Fatalf("failed to get messages after resume: %v", err)
237	}
238
239	// Should have additional messages from the resumed conversation
240	if len(messages) < 5 {
241		t.Fatalf("expected at least 5 messages after resume, got %d", len(messages))
242	}
243
244	// Check that we got the expected response
245	foundContinueResponse := false
246	for _, msg := range messages {
247		if msg.Type != string(db.MessageTypeAgent) {
248			continue
249		}
250		if msg.LlmData == nil {
251			continue
252		}
253		var llmMsg llm.Message
254		if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
255			continue
256		}
257		for _, content := range llmMsg.Content {
258			if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "test after cancel") {
259				foundContinueResponse = true
260				break
261			}
262		}
263	}
264
265	if !foundContinueResponse {
266		t.Error("expected to find 'test after cancel' response")
267	}
268}
269
270// TestCancelWithNoActiveConversation tests cancelling when there's no active conversation
271func TestCancelWithNoActiveConversation(t *testing.T) {
272	database, cleanup := setupTestDB(t)
273	defer cleanup()
274
275	predictableService := loop.NewPredictableService()
276	llmManager := &testLLMManager{service: predictableService}
277	logger := slog.Default()
278
279	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
280
281	// Create a conversation but don't start it
282	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
283	if err != nil {
284		t.Fatalf("failed to create conversation: %v", err)
285	}
286	conversationID := conversation.ConversationID
287
288	// Try to cancel without any active loop
289	cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil)
290	cancelW := httptest.NewRecorder()
291
292	server.handleCancelConversation(cancelW, cancelReq, conversationID)
293
294	if cancelW.Code != http.StatusOK {
295		t.Fatalf("expected status 200, got %d: %s", cancelW.Code, cancelW.Body.String())
296	}
297
298	var cancelResp map[string]string
299	if err := json.Unmarshal(cancelW.Body.Bytes(), &cancelResp); err != nil {
300		t.Fatalf("failed to parse cancel response: %v", err)
301	}
302
303	if cancelResp["status"] != "no_active_conversation" {
304		t.Errorf("expected status 'no_active_conversation', got '%s'", cancelResp["status"])
305	}
306}
307
308// TestCancelDuringTextGeneration tests cancelling during text generation (no tool call)
309func TestCancelDuringTextGeneration(t *testing.T) {
310	database, cleanup := setupTestDB(t)
311	defer cleanup()
312
313	// Use delay: prefix to trigger slow response
314	predictableService := loop.NewPredictableService()
315
316	llmManager := &testLLMManager{service: predictableService}
317	logger := slog.Default()
318	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
319
320	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
321	if err != nil {
322		t.Fatalf("failed to create conversation: %v", err)
323	}
324	conversationID := conversation.ConversationID
325
326	// Start conversation with a delay to simulate slow text generation
327	chatReq := ChatRequest{
328		Message: "delay: 2",
329		Model:   "predictable",
330	}
331	chatBody, _ := json.Marshal(chatReq)
332
333	req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(chatBody)))
334	req.Header.Set("Content-Type", "application/json")
335	w := httptest.NewRecorder()
336
337	server.handleChatConversation(w, req, conversationID)
338
339	if w.Code != http.StatusAccepted {
340		t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String())
341	}
342
343	// Wait for agent to start working
344	waitFor(t, 5*time.Second, func() bool {
345		return server.IsAgentWorking(conversationID)
346	})
347
348	// Cancel during text generation
349	cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil)
350	cancelW := httptest.NewRecorder()
351
352	server.handleCancelConversation(cancelW, cancelReq, conversationID)
353
354	if cancelW.Code != http.StatusOK {
355		t.Fatalf("expected status 200, got %d: %s", cancelW.Code, cancelW.Body.String())
356	}
357
358	// Wait for agent to stop working (cancellation complete)
359	waitFor(t, 5*time.Second, func() bool {
360		return !server.IsAgentWorking(conversationID)
361	})
362
363	// Verify that no cancelled tool result was added (since there was no tool call)
364	var messages []generated.Message
365	err = database.Queries(context.Background(), func(q *generated.Queries) error {
366		var qerr error
367		messages, qerr = q.ListMessages(context.Background(), conversationID)
368		return qerr
369	})
370	if err != nil {
371		t.Fatalf("failed to get messages: %v", err)
372	}
373
374	// Should only have user message (and possibly incomplete assistant message)
375	// Should NOT have a tool result message
376	for _, msg := range messages {
377		if msg.Type == string(db.MessageTypeUser) {
378			if msg.LlmData == nil {
379				continue
380			}
381			var llmMsg llm.Message
382			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
383				continue
384			}
385			for _, content := range llmMsg.Content {
386				if content.Type == llm.ContentTypeToolResult {
387					t.Error("did not expect tool result when cancelling during text generation")
388				}
389			}
390		}
391	}
392}
393
394// testLLMManager is a simple test implementation of LLMProvider
395type testLLMManager struct {
396	service llm.Service
397}
398
399func (m *testLLMManager) GetService(modelID string) (llm.Service, error) {
400	return m.service, nil
401}
402
403func (m *testLLMManager) GetAvailableModels() []string {
404	return []string{"predictable"}
405}
406
407func (m *testLLMManager) HasModel(modelID string) bool {
408	return modelID == "predictable"
409}
410
411func (m *testLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
412	return nil
413}
414
415func (m *testLLMManager) RefreshCustomModels() error {
416	return nil
417}