1package server
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "testing"
11 "time"
12
13 "shelley.exe.dev/claudetool"
14 "shelley.exe.dev/db"
15 "shelley.exe.dev/db/generated"
16 "shelley.exe.dev/llm"
17 "shelley.exe.dev/loop"
18 "shelley.exe.dev/models"
19)
20
21// setupTestDB creates a test database
22func setupTestDB(t *testing.T) (*db.DB, func()) {
23 t.Helper()
24 tmpDir := t.TempDir()
25 database, err := db.New(db.Config{DSN: tmpDir + "/test.db"})
26 if err != nil {
27 t.Fatalf("Failed to create test database: %v", err)
28 }
29
30 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
31 defer cancel()
32
33 if err := database.Migrate(ctx); err != nil {
34 t.Fatalf("Failed to migrate test database: %v", err)
35 }
36
37 return database, func() {
38 database.Close()
39 }
40}
41
42// waitFor polls a condition until it returns true or the timeout is reached.
43func waitFor(t *testing.T, timeout time.Duration, condition func() bool) {
44 t.Helper()
45 deadline := time.Now().Add(timeout)
46 for time.Now().Before(deadline) {
47 if condition() {
48 return
49 }
50 time.Sleep(10 * time.Millisecond)
51 }
52 t.Fatal("timed out waiting for condition")
53}
54
55// TestCancelWithPredictableModel tests cancellation with the predictable model
56func TestCancelWithPredictableModel(t *testing.T) {
57 // Create test database
58 database, cleanup := setupTestDB(t)
59 defer cleanup()
60
61 predictableService := loop.NewPredictableService()
62 llmManager := &testLLMManager{service: predictableService}
63 logger := slog.Default()
64
65 // Register the bash tool so the sleep command actually runs and can be cancelled
66 toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false}
67 server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
68
69 // Create conversation
70 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
71 if err != nil {
72 t.Fatalf("failed to create conversation: %v", err)
73 }
74 conversationID := conversation.ConversationID
75
76 // Start a conversation with a message that triggers a slow bash command
77 chatReq := ChatRequest{
78 Message: "bash: sleep 5",
79 Model: "predictable",
80 }
81 chatBody, _ := json.Marshal(chatReq)
82
83 req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(chatBody)))
84 req.Header.Set("Content-Type", "application/json")
85 w := httptest.NewRecorder()
86
87 server.handleChatConversation(w, req, conversationID)
88
89 if w.Code != http.StatusAccepted {
90 t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String())
91 }
92
93 // Wait for agent to record an assistant message with tool use
94 waitFor(t, 5*time.Second, func() bool {
95 var messages []generated.Message
96 err := database.Queries(context.Background(), func(q *generated.Queries) error {
97 var qerr error
98 messages, qerr = q.ListMessages(context.Background(), conversationID)
99 return qerr
100 })
101 if err != nil || len(messages) < 2 {
102 return false
103 }
104 // Check for assistant message with tool use
105 for _, msg := range messages {
106 if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
107 continue
108 }
109 var llmMsg llm.Message
110 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
111 continue
112 }
113 for _, content := range llmMsg.Content {
114 if content.Type == llm.ContentTypeToolUse {
115 return true
116 }
117 }
118 }
119 return false
120 })
121
122 // Cancel the conversation
123 cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil)
124 cancelW := httptest.NewRecorder()
125
126 server.handleCancelConversation(cancelW, cancelReq, conversationID)
127
128 if cancelW.Code != http.StatusOK {
129 t.Fatalf("expected status 200, got %d: %s", cancelW.Code, cancelW.Body.String())
130 }
131
132 var cancelResp map[string]string
133 if err := json.Unmarshal(cancelW.Body.Bytes(), &cancelResp); err != nil {
134 t.Fatalf("failed to parse cancel response: %v", err)
135 }
136
137 if cancelResp["status"] != "cancelled" {
138 t.Errorf("expected status 'cancelled', got '%s'", cancelResp["status"])
139 }
140
141 // Wait for agent to stop working (cancellation complete)
142 waitFor(t, 5*time.Second, func() bool {
143 return !server.IsAgentWorking(conversationID)
144 })
145
146 // Verify that a cancelled tool result was recorded
147 var messages []generated.Message
148 err = database.Queries(context.Background(), func(q *generated.Queries) error {
149 var qerr error
150 messages, qerr = q.ListMessages(context.Background(), conversationID)
151 return qerr
152 })
153 if err != nil {
154 t.Fatalf("failed to get messages after cancel: %v", err)
155 }
156
157 // Should have: user message, assistant message with tool use, cancelled tool result, and end turn message
158 if len(messages) < 4 {
159 t.Fatalf("expected at least 4 messages after cancel, got %d", len(messages))
160 }
161
162 // Check that we have the cancelled tool result
163 foundCancelledResult := false
164 foundEndTurnMessage := false
165 for i := len(messages) - 1; i >= 0; i-- {
166 msg := messages[i]
167 if msg.LlmData == nil {
168 continue
169 }
170
171 var llmMsg llm.Message
172 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
173 continue
174 }
175
176 // Check for cancelled tool result
177 for _, content := range llmMsg.Content {
178 if content.Type == llm.ContentTypeToolResult && content.ToolError {
179 for _, result := range content.ToolResult {
180 if result.Type == llm.ContentTypeText && strings.Contains(result.Text, "cancelled") {
181 foundCancelledResult = true
182 break
183 }
184 }
185 }
186 }
187
188 // Check for end turn message
189 if msg.Type == string(db.MessageTypeAgent) && llmMsg.EndOfTurn {
190 for _, content := range llmMsg.Content {
191 if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "Operation cancelled") {
192 foundEndTurnMessage = true
193 break
194 }
195 }
196 }
197 }
198
199 if !foundCancelledResult {
200 t.Error("expected to find cancelled tool result in conversation")
201 }
202
203 if !foundEndTurnMessage {
204 t.Error("expected to find end turn message after cancellation")
205 }
206
207 // Test that conversation can be resumed after cancellation
208 resumeReq := ChatRequest{
209 Message: "echo: test after cancel",
210 Model: "predictable",
211 }
212 resumeBody, _ := json.Marshal(resumeReq)
213
214 resumeChatReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(resumeBody)))
215 resumeChatReq.Header.Set("Content-Type", "application/json")
216 resumeW := httptest.NewRecorder()
217
218 server.handleChatConversation(resumeW, resumeChatReq, conversationID)
219
220 if resumeW.Code != http.StatusAccepted {
221 t.Fatalf("expected status 202 for resume, got %d: %s", resumeW.Code, resumeW.Body.String())
222 }
223
224 // Wait for agent to finish processing the resumed conversation
225 waitFor(t, 5*time.Second, func() bool {
226 return !server.IsAgentWorking(conversationID)
227 })
228
229 // Verify conversation continued
230 err = database.Queries(context.Background(), func(q *generated.Queries) error {
231 var qerr error
232 messages, qerr = q.ListMessages(context.Background(), conversationID)
233 return qerr
234 })
235 if err != nil {
236 t.Fatalf("failed to get messages after resume: %v", err)
237 }
238
239 // Should have additional messages from the resumed conversation
240 if len(messages) < 5 {
241 t.Fatalf("expected at least 5 messages after resume, got %d", len(messages))
242 }
243
244 // Check that we got the expected response
245 foundContinueResponse := false
246 for _, msg := range messages {
247 if msg.Type != string(db.MessageTypeAgent) {
248 continue
249 }
250 if msg.LlmData == nil {
251 continue
252 }
253 var llmMsg llm.Message
254 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
255 continue
256 }
257 for _, content := range llmMsg.Content {
258 if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "test after cancel") {
259 foundContinueResponse = true
260 break
261 }
262 }
263 }
264
265 if !foundContinueResponse {
266 t.Error("expected to find 'test after cancel' response")
267 }
268}
269
270// TestCancelWithNoActiveConversation tests cancelling when there's no active conversation
271func TestCancelWithNoActiveConversation(t *testing.T) {
272 database, cleanup := setupTestDB(t)
273 defer cleanup()
274
275 predictableService := loop.NewPredictableService()
276 llmManager := &testLLMManager{service: predictableService}
277 logger := slog.Default()
278
279 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
280
281 // Create a conversation but don't start it
282 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
283 if err != nil {
284 t.Fatalf("failed to create conversation: %v", err)
285 }
286 conversationID := conversation.ConversationID
287
288 // Try to cancel without any active loop
289 cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil)
290 cancelW := httptest.NewRecorder()
291
292 server.handleCancelConversation(cancelW, cancelReq, conversationID)
293
294 if cancelW.Code != http.StatusOK {
295 t.Fatalf("expected status 200, got %d: %s", cancelW.Code, cancelW.Body.String())
296 }
297
298 var cancelResp map[string]string
299 if err := json.Unmarshal(cancelW.Body.Bytes(), &cancelResp); err != nil {
300 t.Fatalf("failed to parse cancel response: %v", err)
301 }
302
303 if cancelResp["status"] != "no_active_conversation" {
304 t.Errorf("expected status 'no_active_conversation', got '%s'", cancelResp["status"])
305 }
306}
307
308// TestCancelDuringTextGeneration tests cancelling during text generation (no tool call)
309func TestCancelDuringTextGeneration(t *testing.T) {
310 database, cleanup := setupTestDB(t)
311 defer cleanup()
312
313 // Use delay: prefix to trigger slow response
314 predictableService := loop.NewPredictableService()
315
316 llmManager := &testLLMManager{service: predictableService}
317 logger := slog.Default()
318 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
319
320 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
321 if err != nil {
322 t.Fatalf("failed to create conversation: %v", err)
323 }
324 conversationID := conversation.ConversationID
325
326 // Start conversation with a delay to simulate slow text generation
327 chatReq := ChatRequest{
328 Message: "delay: 2",
329 Model: "predictable",
330 }
331 chatBody, _ := json.Marshal(chatReq)
332
333 req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(chatBody)))
334 req.Header.Set("Content-Type", "application/json")
335 w := httptest.NewRecorder()
336
337 server.handleChatConversation(w, req, conversationID)
338
339 if w.Code != http.StatusAccepted {
340 t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String())
341 }
342
343 // Wait for agent to start working
344 waitFor(t, 5*time.Second, func() bool {
345 return server.IsAgentWorking(conversationID)
346 })
347
348 // Cancel during text generation
349 cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil)
350 cancelW := httptest.NewRecorder()
351
352 server.handleCancelConversation(cancelW, cancelReq, conversationID)
353
354 if cancelW.Code != http.StatusOK {
355 t.Fatalf("expected status 200, got %d: %s", cancelW.Code, cancelW.Body.String())
356 }
357
358 // Wait for agent to stop working (cancellation complete)
359 waitFor(t, 5*time.Second, func() bool {
360 return !server.IsAgentWorking(conversationID)
361 })
362
363 // Verify that no cancelled tool result was added (since there was no tool call)
364 var messages []generated.Message
365 err = database.Queries(context.Background(), func(q *generated.Queries) error {
366 var qerr error
367 messages, qerr = q.ListMessages(context.Background(), conversationID)
368 return qerr
369 })
370 if err != nil {
371 t.Fatalf("failed to get messages: %v", err)
372 }
373
374 // Should only have user message (and possibly incomplete assistant message)
375 // Should NOT have a tool result message
376 for _, msg := range messages {
377 if msg.Type == string(db.MessageTypeUser) {
378 if msg.LlmData == nil {
379 continue
380 }
381 var llmMsg llm.Message
382 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
383 continue
384 }
385 for _, content := range llmMsg.Content {
386 if content.Type == llm.ContentTypeToolResult {
387 t.Error("did not expect tool result when cancelling during text generation")
388 }
389 }
390 }
391 }
392}
393
394// testLLMManager is a simple test implementation of LLMProvider
395type testLLMManager struct {
396 service llm.Service
397}
398
399func (m *testLLMManager) GetService(modelID string) (llm.Service, error) {
400 return m.service, nil
401}
402
403func (m *testLLMManager) GetAvailableModels() []string {
404 return []string{"predictable"}
405}
406
407func (m *testLLMManager) HasModel(modelID string) bool {
408 return modelID == "predictable"
409}
410
411func (m *testLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
412 return nil
413}
414
415func (m *testLLMManager) RefreshCustomModels() error {
416 return nil
417}