1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "io"
8 "log/slog"
9 "net/http"
10 "net/http/httptest"
11 "os"
12 "strings"
13 "sync"
14 "testing"
15 "time"
16
17 "shelley.exe.dev/claudetool"
18 "shelley.exe.dev/db"
19 "shelley.exe.dev/db/generated"
20 "shelley.exe.dev/llm"
21 "shelley.exe.dev/llm/ant"
22 "shelley.exe.dev/models"
23)
24
25// ClaudeTestHarness extends TestHarness with Claude-specific functionality
26type ClaudeTestHarness struct {
27 t *testing.T
28 db *db.DB
29 server *Server
30 cleanup func()
31 convID string
32 timeout time.Duration
33 llmService *ant.Service
34 requestTokens []uint64 // Track total tokens for each request
35 lastMessageCount int // Track message count after last operation
36 mu sync.Mutex
37}
38
39// NewClaudeTestHarness creates a test harness that uses the real Claude API
40func NewClaudeTestHarness(t *testing.T) *ClaudeTestHarness {
41 t.Helper()
42
43 apiKey := os.Getenv("ANTHROPIC_API_KEY")
44 if apiKey == "" {
45 t.Skip("ANTHROPIC_API_KEY not set, skipping Claude test")
46 }
47
48 database, cleanup := setupTestDB(t)
49
50 // Create Claude service with HTTP recorder to track token usage
51 h := &ClaudeTestHarness{
52 t: t,
53 db: database,
54 cleanup: cleanup,
55 timeout: 60 * time.Second, // Longer timeout for real API calls
56 requestTokens: make([]uint64, 0),
57 }
58
59 // Create HTTP client with custom transport for token tracking
60 httpc := &http.Client{
61 Transport: &tokenTrackingTransport{
62 base: http.DefaultTransport,
63 recordToken: h.recordHTTPResponse,
64 },
65 }
66
67 service := &ant.Service{
68 APIKey: apiKey,
69 Model: ant.Claude45Haiku, // Use cheaper model for testing
70 HTTPC: httpc,
71 }
72 h.llmService = service
73
74 llmManager := &claudeLLMManager{service: service}
75 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
76
77 // Set up tools - bash for testing tool cancellation
78 toolSetConfig := claudetool.ToolSetConfig{
79 WorkingDir: t.TempDir(),
80 EnableBrowser: false,
81 }
82
83 server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "claude", "", nil)
84 h.server = server
85
86 return h
87}
88
89// tokenTrackingTransport wraps an HTTP transport to track token usage from responses
90type tokenTrackingTransport struct {
91 base http.RoundTripper
92 recordToken func(responseBody []byte, statusCode int)
93}
94
95func (t *tokenTrackingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
96 resp, err := t.base.RoundTrip(req)
97 if err != nil {
98 return resp, err
99 }
100
101 // Read and restore the response body
102 body, _ := io.ReadAll(resp.Body)
103 resp.Body.Close()
104 resp.Body = io.NopCloser(bytes.NewReader(body))
105
106 t.recordToken(body, resp.StatusCode)
107 return resp, nil
108}
109
110// recordHTTPResponse is a callback to record HTTP responses for token tracking
111func (h *ClaudeTestHarness) recordHTTPResponse(responseBody []byte, statusCode int) {
112 h.t.Logf("HTTP callback: status=%d, responseLen=%d", statusCode, len(responseBody))
113
114 if statusCode != http.StatusOK || responseBody == nil {
115 return
116 }
117
118 // Parse response to get token usage (including cache tokens)
119 var resp struct {
120 Usage struct {
121 InputTokens uint64 `json:"input_tokens"`
122 OutputTokens uint64 `json:"output_tokens"`
123 CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
124 CacheReadInputTokens uint64 `json:"cache_read_input_tokens"`
125 } `json:"usage"`
126 }
127 if jsonErr := json.Unmarshal(responseBody, &resp); jsonErr == nil {
128 // Total tokens = input + cache_creation + cache_read (this represents total context)
129 totalTokens := resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens
130 h.mu.Lock()
131 h.requestTokens = append(h.requestTokens, totalTokens)
132 h.mu.Unlock()
133 h.t.Logf("Recorded request: input=%d, cache_creation=%d, cache_read=%d, total=%d",
134 resp.Usage.InputTokens, resp.Usage.CacheCreationInputTokens, resp.Usage.CacheReadInputTokens, totalTokens)
135 } else {
136 h.t.Logf("Failed to parse response: %v", jsonErr)
137 }
138}
139
140// GetRequestTokens returns a copy of recorded request token counts
141func (h *ClaudeTestHarness) GetRequestTokens() []uint64 {
142 h.mu.Lock()
143 defer h.mu.Unlock()
144 tokens := make([]uint64, len(h.requestTokens))
145 copy(tokens, h.requestTokens)
146 return tokens
147}
148
149// VerifyTokensNonDecreasing checks that tokens don't decrease below a baseline
150// This verifies that context is being preserved across requests
151func (h *ClaudeTestHarness) VerifyTokensNonDecreasing() {
152 h.t.Helper()
153 tokens := h.GetRequestTokens()
154 if len(tokens) == 0 {
155 h.t.Log("No tokens recorded, skipping token verification")
156 return
157 }
158
159 h.t.Logf("Token progression: %v", tokens)
160
161 // Find the baseline (first substantial token count, skipping small slug generation requests)
162 // Slug generation requests have ~100-200 tokens, conversation requests have 4000+
163 var baseline uint64
164 for _, t := range tokens {
165 if t > 1000 { // Skip small requests like slug generation
166 baseline = t
167 break
168 }
169 }
170
171 if baseline == 0 {
172 h.t.Log("No substantial baseline found, skipping token verification")
173 return
174 }
175
176 // Verify no substantial request drops significantly below baseline (allow 10% variance for caching)
177 minAllowed := baseline * 9 / 10
178 for i, t := range tokens {
179 if t > 1000 && t < minAllowed { // Only check substantial requests
180 h.t.Errorf("Token count at index %d dropped significantly: %d < %d (baseline=%d)", i, t, minAllowed, baseline)
181 }
182 }
183}
184
185// Close cleans up the test harness resources
186func (h *ClaudeTestHarness) Close() {
187 h.cleanup()
188}
189
190// NewConversation starts a new conversation with Claude
191func (h *ClaudeTestHarness) NewConversation(msg, cwd string) *ClaudeTestHarness {
192 h.t.Helper()
193
194 chatReq := ChatRequest{
195 Message: msg,
196 Model: "claude",
197 Cwd: cwd,
198 }
199 chatBody, _ := json.Marshal(chatReq)
200
201 req := httptest.NewRequest("POST", "/api/conversations/new", strings.NewReader(string(chatBody)))
202 req.Header.Set("Content-Type", "application/json")
203 w := httptest.NewRecorder()
204
205 h.server.handleNewConversation(w, req)
206 if w.Code != http.StatusCreated {
207 h.t.Fatalf("NewConversation: expected status 201, got %d: %s", w.Code, w.Body.String())
208 }
209
210 var resp struct {
211 ConversationID string `json:"conversation_id"`
212 }
213 if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
214 h.t.Fatalf("NewConversation: failed to parse response: %v", err)
215 }
216 h.convID = resp.ConversationID
217
218 // Reset lastMessageCount - new conversation starts fresh
219 h.mu.Lock()
220 h.lastMessageCount = 0
221 h.mu.Unlock()
222
223 return h
224}
225
226// Chat sends a message to the current conversation
227func (h *ClaudeTestHarness) Chat(msg string) *ClaudeTestHarness {
228 h.t.Helper()
229
230 if h.convID == "" {
231 h.t.Fatal("Chat: no conversation started, call NewConversation first")
232 }
233
234 // Record message count before sending
235 h.mu.Lock()
236 h.lastMessageCount = len(h.GetMessagesUnsafe())
237 h.mu.Unlock()
238
239 chatReq := ChatRequest{
240 Message: msg,
241 Model: "claude",
242 }
243 chatBody, _ := json.Marshal(chatReq)
244
245 req := httptest.NewRequest("POST", "/api/conversation/"+h.convID+"/chat", strings.NewReader(string(chatBody)))
246 req.Header.Set("Content-Type", "application/json")
247 w := httptest.NewRecorder()
248
249 h.server.handleChatConversation(w, req, h.convID)
250 if w.Code != http.StatusAccepted {
251 h.t.Fatalf("Chat: expected status 202, got %d: %s", w.Code, w.Body.String())
252 }
253 return h
254}
255
256// GetMessagesUnsafe gets messages without locking (internal use only)
257func (h *ClaudeTestHarness) GetMessagesUnsafe() []generated.Message {
258 var messages []generated.Message
259 h.db.Queries(context.Background(), func(q *generated.Queries) error {
260 var qerr error
261 messages, qerr = q.ListMessages(context.Background(), h.convID)
262 return qerr
263 })
264 return messages
265}
266
267// Cancel cancels the current conversation
268func (h *ClaudeTestHarness) Cancel() *ClaudeTestHarness {
269 h.t.Helper()
270
271 if h.convID == "" {
272 h.t.Fatal("Cancel: no conversation started")
273 }
274
275 req := httptest.NewRequest("POST", "/api/conversation/"+h.convID+"/cancel", nil)
276 w := httptest.NewRecorder()
277
278 h.server.handleCancelConversation(w, req, h.convID)
279 if w.Code != http.StatusOK {
280 h.t.Fatalf("Cancel: expected status 200, got %d: %s", w.Code, w.Body.String())
281 }
282 return h
283}
284
285// WaitForAgentWorking waits until the agent is working (tool call started)
286func (h *ClaudeTestHarness) WaitForAgentWorking() *ClaudeTestHarness {
287 h.t.Helper()
288
289 deadline := time.Now().Add(h.timeout)
290 for time.Now().Before(deadline) {
291 if h.isAgentWorking() {
292 return h
293 }
294 time.Sleep(100 * time.Millisecond)
295 }
296
297 h.t.Fatal("WaitForAgentWorking: timed out waiting for agent to start working")
298 return h
299}
300
301// isAgentWorking checks if the agent is currently working
302func (h *ClaudeTestHarness) isAgentWorking() bool {
303 var messages []generated.Message
304 err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
305 var qerr error
306 messages, qerr = q.ListMessages(context.Background(), h.convID)
307 return qerr
308 })
309 if err != nil {
310 return false
311 }
312
313 // Look for an assistant message with tool use that doesn't have a corresponding result
314 for i := len(messages) - 1; i >= 0; i-- {
315 msg := messages[i]
316 if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
317 continue
318 }
319
320 var llmMsg llm.Message
321 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
322 continue
323 }
324
325 // Check if this assistant message has tool use
326 for _, content := range llmMsg.Content {
327 if content.Type == llm.ContentTypeToolUse {
328 // Check if there's a corresponding tool result
329 hasResult := false
330 for j := i + 1; j < len(messages); j++ {
331 nextMsg := messages[j]
332 if nextMsg.Type == string(db.MessageTypeUser) && nextMsg.LlmData != nil {
333 var userMsg llm.Message
334 if err := json.Unmarshal([]byte(*nextMsg.LlmData), &userMsg); err != nil {
335 continue
336 }
337 for _, c := range userMsg.Content {
338 if c.Type == llm.ContentTypeToolResult && c.ToolUseID == content.ID {
339 hasResult = true
340 break
341 }
342 }
343 }
344 if hasResult {
345 break
346 }
347 }
348 if !hasResult {
349 return true // Tool is in progress
350 }
351 }
352 }
353 }
354
355 return false
356}
357
358// WaitResponse waits for the assistant's text response (end of turn)
359// It waits for a NEW response after the last Chat/NewConversation call
360func (h *ClaudeTestHarness) WaitResponse() string {
361 h.t.Helper()
362
363 if h.convID == "" {
364 h.t.Fatal("WaitResponse: no conversation started")
365 }
366
367 h.mu.Lock()
368 minMessageCount := h.lastMessageCount
369 h.mu.Unlock()
370
371 deadline := time.Now().Add(h.timeout)
372 for time.Now().Before(deadline) {
373 var messages []generated.Message
374 err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
375 var qerr error
376 messages, qerr = q.ListMessages(context.Background(), h.convID)
377 return qerr
378 })
379 if err != nil {
380 h.t.Fatalf("WaitResponse: failed to get messages: %v", err)
381 }
382
383 // Look for an assistant message with end_of_turn that came AFTER minMessageCount
384 // Start from the end to find the most recent one
385 for i := len(messages) - 1; i >= 0 && i >= minMessageCount; i-- {
386 msg := messages[i]
387 if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
388 continue
389 }
390
391 var llmMsg llm.Message
392 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
393 continue
394 }
395
396 if llmMsg.EndOfTurn {
397 for _, content := range llmMsg.Content {
398 if content.Type == llm.ContentTypeText {
399 // Update lastMessageCount for the next wait
400 h.mu.Lock()
401 h.lastMessageCount = len(messages)
402 h.mu.Unlock()
403 return content.Text
404 }
405 }
406 }
407 }
408
409 time.Sleep(100 * time.Millisecond)
410 }
411
412 h.t.Fatalf("WaitResponse: timed out waiting for response (lastMessageCount=%d)", minMessageCount)
413 return ""
414}
415
416// WaitToolResult waits for a tool result and returns its text content
417func (h *ClaudeTestHarness) WaitToolResult() string {
418 h.t.Helper()
419
420 if h.convID == "" {
421 h.t.Fatal("WaitToolResult: no conversation started")
422 }
423
424 deadline := time.Now().Add(h.timeout)
425 for time.Now().Before(deadline) {
426 var messages []generated.Message
427 err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
428 var qerr error
429 messages, qerr = q.ListMessages(context.Background(), h.convID)
430 return qerr
431 })
432 if err != nil {
433 h.t.Fatalf("WaitToolResult: failed to get messages: %v", err)
434 }
435
436 for _, msg := range messages {
437 if msg.Type != string(db.MessageTypeUser) || msg.LlmData == nil {
438 continue
439 }
440
441 var llmMsg llm.Message
442 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
443 continue
444 }
445
446 for _, content := range llmMsg.Content {
447 if content.Type == llm.ContentTypeToolResult {
448 for _, result := range content.ToolResult {
449 if result.Type == llm.ContentTypeText && result.Text != "" {
450 return result.Text
451 }
452 }
453 }
454 }
455 }
456
457 time.Sleep(100 * time.Millisecond)
458 }
459
460 h.t.Fatalf("WaitToolResult: timed out waiting for tool result")
461 return ""
462}
463
464// ConversationID returns the current conversation ID
465func (h *ClaudeTestHarness) ConversationID() string {
466 return h.convID
467}
468
469// GetMessages returns all messages in the conversation
470func (h *ClaudeTestHarness) GetMessages() []generated.Message {
471 var messages []generated.Message
472 err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
473 var qerr error
474 messages, qerr = q.ListMessages(context.Background(), h.convID)
475 return qerr
476 })
477 if err != nil {
478 h.t.Fatalf("GetMessages: failed to get messages: %v", err)
479 }
480 return messages
481}
482
483// HasCancelledToolResult checks if there's a cancelled tool result in the conversation
484func (h *ClaudeTestHarness) HasCancelledToolResult() bool {
485 messages := h.GetMessages()
486 for _, msg := range messages {
487 if msg.Type != string(db.MessageTypeUser) || msg.LlmData == nil {
488 continue
489 }
490
491 var llmMsg llm.Message
492 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
493 continue
494 }
495
496 for _, content := range llmMsg.Content {
497 if content.Type == llm.ContentTypeToolResult && content.ToolError {
498 for _, result := range content.ToolResult {
499 if result.Type == llm.ContentTypeText && strings.Contains(result.Text, "cancelled") {
500 return true
501 }
502 }
503 }
504 }
505 }
506 return false
507}
508
509// HasCancellationMessage checks if there's a cancellation message in the conversation
510func (h *ClaudeTestHarness) HasCancellationMessage() bool {
511 messages := h.GetMessages()
512 for _, msg := range messages {
513 if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
514 continue
515 }
516
517 var llmMsg llm.Message
518 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
519 continue
520 }
521
522 for _, content := range llmMsg.Content {
523 if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "Operation cancelled") {
524 return true
525 }
526 }
527 }
528 return false
529}
530
531// claudeLLMManager is an LLMProvider that returns the Claude service
532type claudeLLMManager struct {
533 service llm.Service
534}
535
536func (m *claudeLLMManager) GetService(modelID string) (llm.Service, error) {
537 return m.service, nil
538}
539
540func (m *claudeLLMManager) GetAvailableModels() []string {
541 return []string{"claude", "claude-haiku-4.5"}
542}
543
544func (m *claudeLLMManager) HasModel(modelID string) bool {
545 return modelID == "claude" || modelID == "claude-haiku-4.5"
546}
547
548func (m *claudeLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
549 if modelID == "claude-haiku-4.5" {
550 return &models.ModelInfo{DisplayName: "Claude Haiku", Tags: "slug"}
551 }
552 return nil
553}
554
555func (m *claudeLLMManager) RefreshCustomModels() error {
556 return nil
557}
558
559// TestClaudeCancelDuringToolCall tests cancellation during tool execution with Claude
560func TestClaudeCancelDuringToolCall(t *testing.T) {
561 h := NewClaudeTestHarness(t)
562 defer h.Close()
563
564 // Start a conversation that triggers a slow bash command
565 h.NewConversation("Please run the bash command: sleep 10", "")
566
567 // Wait for the tool to start executing
568 h.WaitForAgentWorking()
569 t.Log("Agent is working on tool call")
570
571 // Cancel the conversation
572 h.Cancel()
573 t.Log("Cancelled conversation")
574
575 // Wait a bit for cancellation to complete
576 time.Sleep(500 * time.Millisecond)
577
578 // Verify cancellation was recorded properly
579 if !h.HasCancelledToolResult() {
580 t.Error("expected cancelled tool result to be recorded")
581 }
582
583 if !h.HasCancellationMessage() {
584 t.Error("expected cancellation message to be recorded")
585 }
586
587 messages := h.GetMessages()
588 t.Logf("Total messages after cancellation: %d", len(messages))
589
590 // Verify tokens are maintained
591 h.VerifyTokensNonDecreasing()
592}
593
594// TestClaudeCancelDuringLLMCall tests cancellation during LLM API call with Claude
595func TestClaudeCancelDuringLLMCall(t *testing.T) {
596 h := NewClaudeTestHarness(t)
597 defer h.Close()
598
599 // Start a conversation with a message that will take some time to process
600 h.NewConversation("Please write a very detailed essay about the history of computing, covering at least 10 major milestones.", "")
601
602 // Wait briefly for the request to be sent to Claude
603 time.Sleep(500 * time.Millisecond)
604
605 // Cancel during the LLM call
606 h.Cancel()
607 t.Log("Cancelled during LLM call")
608
609 // Wait for cancellation
610 time.Sleep(500 * time.Millisecond)
611
612 // Verify cancellation message exists
613 if !h.HasCancellationMessage() {
614 t.Error("expected cancellation message to be recorded")
615 }
616
617 messages := h.GetMessages()
618 t.Logf("Total messages after cancellation: %d", len(messages))
619
620 // Verify tokens are maintained
621 h.VerifyTokensNonDecreasing()
622}
623
624// TestClaudeCancelDuringLLMCallThenResume tests cancellation during LLM API call and then resuming
625func TestClaudeCancelDuringLLMCallThenResume(t *testing.T) {
626 h := NewClaudeTestHarness(t)
627 defer h.Close()
628
629 // Start a conversation with context we can verify later
630 h.NewConversation("Remember this code: BLUE42. Write a long essay about colors.", "")
631
632 // Wait briefly for the request to be sent to Claude
633 time.Sleep(300 * time.Millisecond)
634
635 // Cancel during the LLM call (before response arrives)
636 h.Cancel()
637 t.Log("Cancelled during LLM call")
638 time.Sleep(500 * time.Millisecond)
639
640 if !h.HasCancellationMessage() {
641 t.Error("expected cancellation message to be recorded")
642 }
643
644 tokensAfterCancel := h.GetRequestTokens()
645 t.Logf("Tokens after cancel: %v", tokensAfterCancel)
646
647 // Now resume and verify context is preserved
648 h.Chat("What was the code I asked you to remember? Just tell me the code.")
649 response := h.WaitResponse()
650 t.Logf("Response after resume: %s", response)
651
652 // Verify context was preserved - Claude should remember BLUE42
653 if !strings.Contains(strings.ToUpper(response), "BLUE42") {
654 t.Errorf("expected response to contain BLUE42, got: %s", response)
655 }
656
657 // Verify tokens are maintained
658 h.VerifyTokensNonDecreasing()
659}
660
661// TestClaudeCancelDuringLLMCallMultipleTimes tests multiple cancellations during LLM calls
662func TestClaudeCancelDuringLLMCallMultipleTimes(t *testing.T) {
663 h := NewClaudeTestHarness(t)
664 defer h.Close()
665
666 // First: cancel during LLM call
667 h.NewConversation("Write a very long detailed story about space exploration.", "")
668 time.Sleep(300 * time.Millisecond)
669 h.Cancel()
670 t.Log("First cancel during LLM")
671 time.Sleep(500 * time.Millisecond)
672
673 // Second: cancel during LLM call again
674 h.Chat("Write a very long detailed story about ocean exploration.")
675 time.Sleep(300 * time.Millisecond)
676 h.Cancel()
677 t.Log("Second cancel during LLM")
678 time.Sleep(500 * time.Millisecond)
679
680 // Third: cancel during LLM call again
681 h.Chat("Write a very long detailed story about mountain climbing.")
682 time.Sleep(300 * time.Millisecond)
683 h.Cancel()
684 t.Log("Third cancel during LLM")
685 time.Sleep(500 * time.Millisecond)
686
687 // Now resume normally - the conversation should still work
688 h.Chat("Just say 'conversation recovered' and nothing else.")
689 response := h.WaitResponse()
690 t.Logf("Response after multiple cancels: %s", response)
691
692 // Verify the conversation is functional - response should not indicate an error
693 lowerResp := strings.ToLower(response)
694 if strings.Contains(lowerResp, "error") || strings.Contains(lowerResp, "invalid") {
695 t.Errorf("response may indicate an error: %s", response)
696 }
697
698 // Verify tokens are maintained
699 h.VerifyTokensNonDecreasing()
700}
701
702// TestClaudeCancelDuringLLMCallAndVerifyMessageStructure verifies message structure after LLM cancellation
703func TestClaudeCancelDuringLLMCallAndVerifyMessageStructure(t *testing.T) {
704 h := NewClaudeTestHarness(t)
705 defer h.Close()
706
707 h.NewConversation("Write a very long detailed story about a wizard.", "")
708 time.Sleep(300 * time.Millisecond)
709 h.Cancel()
710 time.Sleep(500 * time.Millisecond)
711
712 // Check message structure
713 messages := h.GetMessages()
714 t.Logf("Messages after LLM cancel: %d", len(messages))
715
716 // Should have: system message, user message, cancellation message
717 // The user message should be recorded even if Claude didn't respond
718 userMessageFound := false
719 cancelMessageFound := false
720
721 for _, msg := range messages {
722 t.Logf("Message type: %s", msg.Type)
723 if msg.Type == string(db.MessageTypeUser) {
724 userMessageFound = true
725 }
726 if msg.Type == string(db.MessageTypeAgent) && msg.LlmData != nil {
727 var llmMsg llm.Message
728 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err == nil {
729 for _, content := range llmMsg.Content {
730 if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "cancelled") {
731 cancelMessageFound = true
732 }
733 }
734 }
735 }
736 }
737
738 if !userMessageFound {
739 t.Error("expected user message to be recorded")
740 }
741 if !cancelMessageFound {
742 t.Error("expected cancellation message to be recorded")
743 }
744
745 // Now send a follow-up and verify no API errors about message format
746 h.Chat("Just say hello.")
747 response := h.WaitResponse()
748 t.Logf("Follow-up response: %s", response)
749
750 // Response should not indicate an error
751 lowerResp := strings.ToLower(response)
752 if strings.Contains(lowerResp, "error") || strings.Contains(lowerResp, "invalid") {
753 t.Errorf("response may indicate API error: %s", response)
754 }
755
756 h.VerifyTokensNonDecreasing()
757}
758
759// TestClaudeResumeAfterCancellation tests that a conversation can be resumed after cancellation
760func TestClaudeResumeAfterCancellation(t *testing.T) {
761 h := NewClaudeTestHarness(t)
762 defer h.Close()
763
764 // Start a conversation
765 h.NewConversation("Please run: sleep 5", "")
766
767 // Wait for tool to start
768 h.WaitForAgentWorking()
769 t.Log("Agent started tool call")
770
771 // Cancel
772 h.Cancel()
773 t.Log("Cancelled")
774 time.Sleep(500 * time.Millisecond)
775
776 // Verify cancellation
777 if !h.HasCancellationMessage() {
778 t.Error("expected cancellation message")
779 }
780
781 messagesAfterCancel := len(h.GetMessages())
782 t.Logf("Messages after cancel: %d", messagesAfterCancel)
783
784 // Resume the conversation
785 h.Chat("Hello, let's continue. Please just say 'resumed' and nothing else.")
786
787 // Wait for response
788 response := h.WaitResponse()
789 t.Logf("Response after resume: %s", response)
790
791 // Verify we got more messages
792 messagesAfterResume := len(h.GetMessages())
793 t.Logf("Messages after resume: %d", messagesAfterResume)
794
795 if messagesAfterResume <= messagesAfterCancel {
796 t.Error("expected more messages after resume")
797 }
798
799 // Verify tokens are maintained
800 h.VerifyTokensNonDecreasing()
801}
802
803// TestClaudeTokensMonotonicallyIncreasing tests that token count increases when resuming
804// With prompt caching, total tokens = input + cache_creation + cache_read
805func TestClaudeTokensMonotonicallyIncreasing(t *testing.T) {
806 h := NewClaudeTestHarness(t)
807 defer h.Close()
808
809 // First conversation turn
810 h.NewConversation("Hello, please respond with 'first response' and nothing else.", "")
811 h.WaitResponse()
812 time.Sleep(500 * time.Millisecond) // Wait for any pending operations
813
814 tokens1 := h.GetRequestTokens()
815 if len(tokens1) == 0 {
816 t.Skip("No token data recorded (API may not be returning it)")
817 }
818 lastToken1 := tokens1[len(tokens1)-1]
819 t.Logf("First turn total tokens: %d", lastToken1)
820
821 // Second conversation turn
822 h.Chat("Now please respond with 'second response' and nothing else.")
823 h.WaitResponse()
824 time.Sleep(500 * time.Millisecond)
825
826 tokens2 := h.GetRequestTokens()
827 if len(tokens2) <= len(tokens1) {
828 t.Fatal("expected more requests in second turn")
829 }
830 lastToken2 := tokens2[len(tokens2)-1]
831 t.Logf("Second turn total tokens: %d", lastToken2)
832
833 // With prompt caching, tokens should increase or stay similar
834 // The key is that we're still sending context (total should be meaningful)
835 if lastToken2 < lastToken1 {
836 t.Errorf("tokens decreased significantly: first=%d, second=%d", lastToken1, lastToken2)
837 }
838
839 // Third turn
840 h.Chat("Third turn - respond with 'third response' only.")
841 h.WaitResponse()
842 time.Sleep(500 * time.Millisecond)
843
844 tokens3 := h.GetRequestTokens()
845 if len(tokens3) <= len(tokens2) {
846 t.Fatal("expected more requests in third turn")
847 }
848 lastToken3 := tokens3[len(tokens3)-1]
849 t.Logf("Third turn total tokens: %d", lastToken3)
850
851 // Each subsequent turn should have at least as many tokens as the first turn
852 // (because we're including more conversation history)
853 if lastToken3 < lastToken1 {
854 t.Errorf("third turn has fewer tokens than first: first=%d, third=%d", lastToken1, lastToken3)
855 }
856
857 t.Logf("Token progression: %d -> %d -> %d", lastToken1, lastToken2, lastToken3)
858}
859
860// TestClaudeResumeAfterCancellationPreservesContext tests context preservation after cancellation
861func TestClaudeResumeAfterCancellationPreservesContext(t *testing.T) {
862 h := NewClaudeTestHarness(t)
863 defer h.Close()
864
865 // Start with specific context
866 h.NewConversation("Remember this secret word: ELEPHANT. I will ask you about it later. For now, just acknowledge with 'understood'.", "")
867 response1 := h.WaitResponse()
868 t.Logf("First response: %s", response1)
869
870 tokens1 := h.GetRequestTokens()
871 if len(tokens1) == 0 {
872 t.Skip("No token data recorded")
873 }
874 t.Logf("Tokens after first exchange: %v", tokens1)
875
876 // Start a slow command to trigger cancellation
877 h.Chat("Run this command: sleep 10")
878 h.WaitForAgentWorking()
879
880 // Cancel
881 h.Cancel()
882 time.Sleep(500 * time.Millisecond)
883
884 tokensAfterCancel := h.GetRequestTokens()
885 t.Logf("Tokens after cancel: %v", tokensAfterCancel)
886
887 // Resume and ask about the secret word
888 h.Chat("What was the secret word I told you to remember?")
889 response2 := h.WaitResponse()
890 t.Logf("Response after resume: %s", response2)
891
892 tokensAfterResume := h.GetRequestTokens()
893 t.Logf("Tokens after resume: %v", tokensAfterResume)
894
895 // Check that the response mentions ELEPHANT
896 if !strings.Contains(strings.ToUpper(response2), "ELEPHANT") {
897 t.Errorf("expected response to mention ELEPHANT, got: %s", response2)
898 }
899
900 // Verify tokens are maintained
901 h.VerifyTokensNonDecreasing()
902}
903
904// TestClaudeMultipleCancellations tests multiple cancellations in a row
905func TestClaudeMultipleCancellations(t *testing.T) {
906 h := NewClaudeTestHarness(t)
907 defer h.Close()
908
909 // First cancellation
910 h.NewConversation("Run: sleep 10", "")
911 h.WaitForAgentWorking()
912 h.Cancel()
913 time.Sleep(300 * time.Millisecond)
914
915 if !h.HasCancellationMessage() {
916 t.Error("expected first cancellation message")
917 }
918
919 // Second cancellation
920 h.Chat("Run: sleep 10")
921 time.Sleep(2 * time.Second) // Wait for Claude to respond and start tool
922 h.Cancel()
923 time.Sleep(300 * time.Millisecond)
924
925 // Third: complete normally
926 h.Chat("Just say 'done' and nothing else.")
927 response := h.WaitResponse()
928 t.Logf("Final response: %s", response)
929
930 // Verify tokens are maintained
931 h.VerifyTokensNonDecreasing()
932}
933
934// TestClaudeCancelImmediately tests cancelling immediately after sending a message
935func TestClaudeCancelImmediately(t *testing.T) {
936 h := NewClaudeTestHarness(t)
937 defer h.Close()
938
939 h.NewConversation("Write a very long essay about everything.", "")
940
941 // Cancel immediately
942 time.Sleep(50 * time.Millisecond)
943 h.Cancel()
944
945 time.Sleep(500 * time.Millisecond)
946
947 // Should still be able to resume
948 h.Chat("Just say 'hello'")
949 response := h.WaitResponse()
950 t.Logf("Response after immediate cancel: %s", response)
951
952 if response == "" {
953 t.Error("expected a response after resuming from immediate cancel")
954 }
955
956 // Verify tokens are maintained
957 h.VerifyTokensNonDecreasing()
958}
959
960// TestClaudeCancelWithPendingToolResult tests that missing tool results are handled properly
961func TestClaudeCancelWithPendingToolResult(t *testing.T) {
962 h := NewClaudeTestHarness(t)
963 defer h.Close()
964
965 // This tests the insertMissingToolResults logic
966 h.NewConversation("Run: sleep 20", "")
967 h.WaitForAgentWorking()
968
969 // Cancel during tool execution
970 h.Cancel()
971 time.Sleep(500 * time.Millisecond)
972
973 // Resume - this should handle the missing tool result
974 h.Chat("Please just say 'recovered' if you can hear me.")
975 response := h.WaitResponse()
976 t.Logf("Recovery response: %s", response)
977
978 // The conversation should have recovered
979 // Claude should not complain about bad messages
980 if strings.Contains(strings.ToLower(response), "error") {
981 t.Errorf("response indicates an error, which may mean message handling failed: %s", response)
982 }
983
984 // Verify tokens are maintained
985 h.VerifyTokensNonDecreasing()
986}
987
988// TestClaudeCancelDuringLLMCallRapidFire tests rapid cancellations during LLM calls
989func TestClaudeCancelDuringLLMCallRapidFire(t *testing.T) {
990 h := NewClaudeTestHarness(t)
991 defer h.Close()
992
993 // Send message and cancel as fast as possible, multiple times
994 for i := 0; i < 3; i++ {
995 if i == 0 {
996 h.NewConversation("Write a long story.", "")
997 } else {
998 h.Chat("Write another long story.")
999 }
1000 time.Sleep(100 * time.Millisecond)
1001 h.Cancel()
1002 time.Sleep(200 * time.Millisecond)
1003 t.Logf("Rapid cancel %d complete", i+1)
1004 }
1005
1006 // Now do a normal conversation
1007 h.Chat("Just say 'stable' and nothing else.")
1008 response := h.WaitResponse()
1009 t.Logf("Final response after rapid cancels: %s", response)
1010
1011 // Verify tokens are maintained
1012 h.VerifyTokensNonDecreasing()
1013}
1014
1015// TestClaudeCancelDuringLLMCallWithToolUseResponse tests cancel when Claude is about to use a tool
1016func TestClaudeCancelDuringLLMCallWithToolUseResponse(t *testing.T) {
1017 h := NewClaudeTestHarness(t)
1018 defer h.Close()
1019
1020 // Ask Claude to use a tool - the response will contain tool_use
1021 // Cancel before the tool actually executes
1022 h.NewConversation("Run: echo hello world", "")
1023
1024 // Wait just enough for the LLM request to be sent but not for tool execution
1025 time.Sleep(500 * time.Millisecond)
1026
1027 // Cancel - this might catch the LLM responding with tool_use but before tool execution
1028 h.Cancel()
1029 time.Sleep(500 * time.Millisecond)
1030
1031 t.Logf("Cancelled during potential tool_use response")
1032
1033 // Resume and verify conversation works
1034 h.Chat("Just say 'ok' if you can hear me.")
1035 response := h.WaitResponse()
1036 t.Logf("Response: %s", response)
1037
1038 // Verify tokens are maintained
1039 h.VerifyTokensNonDecreasing()
1040}