cancel_claude_test.go

   1package server
   2
   3import (
   4	"bytes"
   5	"context"
   6	"encoding/json"
   7	"io"
   8	"log/slog"
   9	"net/http"
  10	"net/http/httptest"
  11	"os"
  12	"strings"
  13	"sync"
  14	"testing"
  15	"time"
  16
  17	"shelley.exe.dev/claudetool"
  18	"shelley.exe.dev/db"
  19	"shelley.exe.dev/db/generated"
  20	"shelley.exe.dev/llm"
  21	"shelley.exe.dev/llm/ant"
  22	"shelley.exe.dev/models"
  23)
  24
  25// ClaudeTestHarness extends TestHarness with Claude-specific functionality
  26type ClaudeTestHarness struct {
  27	t                *testing.T
  28	db               *db.DB
  29	server           *Server
  30	cleanup          func()
  31	convID           string
  32	timeout          time.Duration
  33	llmService       *ant.Service
  34	requestTokens    []uint64 // Track total tokens for each request
  35	lastMessageCount int      // Track message count after last operation
  36	mu               sync.Mutex
  37}
  38
  39// NewClaudeTestHarness creates a test harness that uses the real Claude API
  40func NewClaudeTestHarness(t *testing.T) *ClaudeTestHarness {
  41	t.Helper()
  42
  43	apiKey := os.Getenv("ANTHROPIC_API_KEY")
  44	if apiKey == "" {
  45		t.Skip("ANTHROPIC_API_KEY not set, skipping Claude test")
  46	}
  47
  48	database, cleanup := setupTestDB(t)
  49
  50	// Create Claude service with HTTP recorder to track token usage
  51	h := &ClaudeTestHarness{
  52		t:             t,
  53		db:            database,
  54		cleanup:       cleanup,
  55		timeout:       60 * time.Second, // Longer timeout for real API calls
  56		requestTokens: make([]uint64, 0),
  57	}
  58
  59	// Create HTTP client with custom transport for token tracking
  60	httpc := &http.Client{
  61		Transport: &tokenTrackingTransport{
  62			base:        http.DefaultTransport,
  63			recordToken: h.recordHTTPResponse,
  64		},
  65	}
  66
  67	service := &ant.Service{
  68		APIKey: apiKey,
  69		Model:  ant.Claude45Haiku, // Use cheaper model for testing
  70		HTTPC:  httpc,
  71	}
  72	h.llmService = service
  73
  74	llmManager := &claudeLLMManager{service: service}
  75	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
  76
  77	// Set up tools - bash for testing tool cancellation
  78	toolSetConfig := claudetool.ToolSetConfig{
  79		WorkingDir:    t.TempDir(),
  80		EnableBrowser: false,
  81	}
  82
  83	server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "claude", "", nil)
  84	h.server = server
  85
  86	return h
  87}
  88
  89// tokenTrackingTransport wraps an HTTP transport to track token usage from responses
  90type tokenTrackingTransport struct {
  91	base        http.RoundTripper
  92	recordToken func(responseBody []byte, statusCode int)
  93}
  94
  95func (t *tokenTrackingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  96	resp, err := t.base.RoundTrip(req)
  97	if err != nil {
  98		return resp, err
  99	}
 100
 101	// Read and restore the response body
 102	body, _ := io.ReadAll(resp.Body)
 103	resp.Body.Close()
 104	resp.Body = io.NopCloser(bytes.NewReader(body))
 105
 106	t.recordToken(body, resp.StatusCode)
 107	return resp, nil
 108}
 109
 110// recordHTTPResponse is a callback to record HTTP responses for token tracking
 111func (h *ClaudeTestHarness) recordHTTPResponse(responseBody []byte, statusCode int) {
 112	h.t.Logf("HTTP callback: status=%d, responseLen=%d", statusCode, len(responseBody))
 113
 114	if statusCode != http.StatusOK || responseBody == nil {
 115		return
 116	}
 117
 118	// Parse response to get token usage (including cache tokens)
 119	var resp struct {
 120		Usage struct {
 121			InputTokens              uint64 `json:"input_tokens"`
 122			OutputTokens             uint64 `json:"output_tokens"`
 123			CacheCreationInputTokens uint64 `json:"cache_creation_input_tokens"`
 124			CacheReadInputTokens     uint64 `json:"cache_read_input_tokens"`
 125		} `json:"usage"`
 126	}
 127	if jsonErr := json.Unmarshal(responseBody, &resp); jsonErr == nil {
 128		// Total tokens = input + cache_creation + cache_read (this represents total context)
 129		totalTokens := resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens
 130		h.mu.Lock()
 131		h.requestTokens = append(h.requestTokens, totalTokens)
 132		h.mu.Unlock()
 133		h.t.Logf("Recorded request: input=%d, cache_creation=%d, cache_read=%d, total=%d",
 134			resp.Usage.InputTokens, resp.Usage.CacheCreationInputTokens, resp.Usage.CacheReadInputTokens, totalTokens)
 135	} else {
 136		h.t.Logf("Failed to parse response: %v", jsonErr)
 137	}
 138}
 139
 140// GetRequestTokens returns a copy of recorded request token counts
 141func (h *ClaudeTestHarness) GetRequestTokens() []uint64 {
 142	h.mu.Lock()
 143	defer h.mu.Unlock()
 144	tokens := make([]uint64, len(h.requestTokens))
 145	copy(tokens, h.requestTokens)
 146	return tokens
 147}
 148
 149// VerifyTokensNonDecreasing checks that tokens don't decrease below a baseline
 150// This verifies that context is being preserved across requests
 151func (h *ClaudeTestHarness) VerifyTokensNonDecreasing() {
 152	h.t.Helper()
 153	tokens := h.GetRequestTokens()
 154	if len(tokens) == 0 {
 155		h.t.Log("No tokens recorded, skipping token verification")
 156		return
 157	}
 158
 159	h.t.Logf("Token progression: %v", tokens)
 160
 161	// Find the baseline (first substantial token count, skipping small slug generation requests)
 162	// Slug generation requests have ~100-200 tokens, conversation requests have 4000+
 163	var baseline uint64
 164	for _, t := range tokens {
 165		if t > 1000 { // Skip small requests like slug generation
 166			baseline = t
 167			break
 168		}
 169	}
 170
 171	if baseline == 0 {
 172		h.t.Log("No substantial baseline found, skipping token verification")
 173		return
 174	}
 175
 176	// Verify no substantial request drops significantly below baseline (allow 10% variance for caching)
 177	minAllowed := baseline * 9 / 10
 178	for i, t := range tokens {
 179		if t > 1000 && t < minAllowed { // Only check substantial requests
 180			h.t.Errorf("Token count at index %d dropped significantly: %d < %d (baseline=%d)", i, t, minAllowed, baseline)
 181		}
 182	}
 183}
 184
 185// Close cleans up the test harness resources
 186func (h *ClaudeTestHarness) Close() {
 187	h.cleanup()
 188}
 189
 190// NewConversation starts a new conversation with Claude
 191func (h *ClaudeTestHarness) NewConversation(msg, cwd string) *ClaudeTestHarness {
 192	h.t.Helper()
 193
 194	chatReq := ChatRequest{
 195		Message: msg,
 196		Model:   "claude",
 197		Cwd:     cwd,
 198	}
 199	chatBody, _ := json.Marshal(chatReq)
 200
 201	req := httptest.NewRequest("POST", "/api/conversations/new", strings.NewReader(string(chatBody)))
 202	req.Header.Set("Content-Type", "application/json")
 203	w := httptest.NewRecorder()
 204
 205	h.server.handleNewConversation(w, req)
 206	if w.Code != http.StatusCreated {
 207		h.t.Fatalf("NewConversation: expected status 201, got %d: %s", w.Code, w.Body.String())
 208	}
 209
 210	var resp struct {
 211		ConversationID string `json:"conversation_id"`
 212	}
 213	if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
 214		h.t.Fatalf("NewConversation: failed to parse response: %v", err)
 215	}
 216	h.convID = resp.ConversationID
 217
 218	// Reset lastMessageCount - new conversation starts fresh
 219	h.mu.Lock()
 220	h.lastMessageCount = 0
 221	h.mu.Unlock()
 222
 223	return h
 224}
 225
 226// Chat sends a message to the current conversation
 227func (h *ClaudeTestHarness) Chat(msg string) *ClaudeTestHarness {
 228	h.t.Helper()
 229
 230	if h.convID == "" {
 231		h.t.Fatal("Chat: no conversation started, call NewConversation first")
 232	}
 233
 234	// Record message count before sending
 235	h.mu.Lock()
 236	h.lastMessageCount = len(h.GetMessagesUnsafe())
 237	h.mu.Unlock()
 238
 239	chatReq := ChatRequest{
 240		Message: msg,
 241		Model:   "claude",
 242	}
 243	chatBody, _ := json.Marshal(chatReq)
 244
 245	req := httptest.NewRequest("POST", "/api/conversation/"+h.convID+"/chat", strings.NewReader(string(chatBody)))
 246	req.Header.Set("Content-Type", "application/json")
 247	w := httptest.NewRecorder()
 248
 249	h.server.handleChatConversation(w, req, h.convID)
 250	if w.Code != http.StatusAccepted {
 251		h.t.Fatalf("Chat: expected status 202, got %d: %s", w.Code, w.Body.String())
 252	}
 253	return h
 254}
 255
 256// GetMessagesUnsafe gets messages without locking (internal use only)
 257func (h *ClaudeTestHarness) GetMessagesUnsafe() []generated.Message {
 258	var messages []generated.Message
 259	h.db.Queries(context.Background(), func(q *generated.Queries) error {
 260		var qerr error
 261		messages, qerr = q.ListMessages(context.Background(), h.convID)
 262		return qerr
 263	})
 264	return messages
 265}
 266
 267// Cancel cancels the current conversation
 268func (h *ClaudeTestHarness) Cancel() *ClaudeTestHarness {
 269	h.t.Helper()
 270
 271	if h.convID == "" {
 272		h.t.Fatal("Cancel: no conversation started")
 273	}
 274
 275	req := httptest.NewRequest("POST", "/api/conversation/"+h.convID+"/cancel", nil)
 276	w := httptest.NewRecorder()
 277
 278	h.server.handleCancelConversation(w, req, h.convID)
 279	if w.Code != http.StatusOK {
 280		h.t.Fatalf("Cancel: expected status 200, got %d: %s", w.Code, w.Body.String())
 281	}
 282	return h
 283}
 284
 285// WaitForAgentWorking waits until the agent is working (tool call started)
 286func (h *ClaudeTestHarness) WaitForAgentWorking() *ClaudeTestHarness {
 287	h.t.Helper()
 288
 289	deadline := time.Now().Add(h.timeout)
 290	for time.Now().Before(deadline) {
 291		if h.isAgentWorking() {
 292			return h
 293		}
 294		time.Sleep(100 * time.Millisecond)
 295	}
 296
 297	h.t.Fatal("WaitForAgentWorking: timed out waiting for agent to start working")
 298	return h
 299}
 300
 301// isAgentWorking checks if the agent is currently working
 302func (h *ClaudeTestHarness) isAgentWorking() bool {
 303	var messages []generated.Message
 304	err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
 305		var qerr error
 306		messages, qerr = q.ListMessages(context.Background(), h.convID)
 307		return qerr
 308	})
 309	if err != nil {
 310		return false
 311	}
 312
 313	// Look for an assistant message with tool use that doesn't have a corresponding result
 314	for i := len(messages) - 1; i >= 0; i-- {
 315		msg := messages[i]
 316		if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
 317			continue
 318		}
 319
 320		var llmMsg llm.Message
 321		if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
 322			continue
 323		}
 324
 325		// Check if this assistant message has tool use
 326		for _, content := range llmMsg.Content {
 327			if content.Type == llm.ContentTypeToolUse {
 328				// Check if there's a corresponding tool result
 329				hasResult := false
 330				for j := i + 1; j < len(messages); j++ {
 331					nextMsg := messages[j]
 332					if nextMsg.Type == string(db.MessageTypeUser) && nextMsg.LlmData != nil {
 333						var userMsg llm.Message
 334						if err := json.Unmarshal([]byte(*nextMsg.LlmData), &userMsg); err != nil {
 335							continue
 336						}
 337						for _, c := range userMsg.Content {
 338							if c.Type == llm.ContentTypeToolResult && c.ToolUseID == content.ID {
 339								hasResult = true
 340								break
 341							}
 342						}
 343					}
 344					if hasResult {
 345						break
 346					}
 347				}
 348				if !hasResult {
 349					return true // Tool is in progress
 350				}
 351			}
 352		}
 353	}
 354
 355	return false
 356}
 357
 358// WaitResponse waits for the assistant's text response (end of turn)
 359// It waits for a NEW response after the last Chat/NewConversation call
 360func (h *ClaudeTestHarness) WaitResponse() string {
 361	h.t.Helper()
 362
 363	if h.convID == "" {
 364		h.t.Fatal("WaitResponse: no conversation started")
 365	}
 366
 367	h.mu.Lock()
 368	minMessageCount := h.lastMessageCount
 369	h.mu.Unlock()
 370
 371	deadline := time.Now().Add(h.timeout)
 372	for time.Now().Before(deadline) {
 373		var messages []generated.Message
 374		err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
 375			var qerr error
 376			messages, qerr = q.ListMessages(context.Background(), h.convID)
 377			return qerr
 378		})
 379		if err != nil {
 380			h.t.Fatalf("WaitResponse: failed to get messages: %v", err)
 381		}
 382
 383		// Look for an assistant message with end_of_turn that came AFTER minMessageCount
 384		// Start from the end to find the most recent one
 385		for i := len(messages) - 1; i >= 0 && i >= minMessageCount; i-- {
 386			msg := messages[i]
 387			if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
 388				continue
 389			}
 390
 391			var llmMsg llm.Message
 392			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
 393				continue
 394			}
 395
 396			if llmMsg.EndOfTurn {
 397				for _, content := range llmMsg.Content {
 398					if content.Type == llm.ContentTypeText {
 399						// Update lastMessageCount for the next wait
 400						h.mu.Lock()
 401						h.lastMessageCount = len(messages)
 402						h.mu.Unlock()
 403						return content.Text
 404					}
 405				}
 406			}
 407		}
 408
 409		time.Sleep(100 * time.Millisecond)
 410	}
 411
 412	h.t.Fatalf("WaitResponse: timed out waiting for response (lastMessageCount=%d)", minMessageCount)
 413	return ""
 414}
 415
 416// WaitToolResult waits for a tool result and returns its text content
 417func (h *ClaudeTestHarness) WaitToolResult() string {
 418	h.t.Helper()
 419
 420	if h.convID == "" {
 421		h.t.Fatal("WaitToolResult: no conversation started")
 422	}
 423
 424	deadline := time.Now().Add(h.timeout)
 425	for time.Now().Before(deadline) {
 426		var messages []generated.Message
 427		err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
 428			var qerr error
 429			messages, qerr = q.ListMessages(context.Background(), h.convID)
 430			return qerr
 431		})
 432		if err != nil {
 433			h.t.Fatalf("WaitToolResult: failed to get messages: %v", err)
 434		}
 435
 436		for _, msg := range messages {
 437			if msg.Type != string(db.MessageTypeUser) || msg.LlmData == nil {
 438				continue
 439			}
 440
 441			var llmMsg llm.Message
 442			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
 443				continue
 444			}
 445
 446			for _, content := range llmMsg.Content {
 447				if content.Type == llm.ContentTypeToolResult {
 448					for _, result := range content.ToolResult {
 449						if result.Type == llm.ContentTypeText && result.Text != "" {
 450							return result.Text
 451						}
 452					}
 453				}
 454			}
 455		}
 456
 457		time.Sleep(100 * time.Millisecond)
 458	}
 459
 460	h.t.Fatalf("WaitToolResult: timed out waiting for tool result")
 461	return ""
 462}
 463
 464// ConversationID returns the current conversation ID
 465func (h *ClaudeTestHarness) ConversationID() string {
 466	return h.convID
 467}
 468
 469// GetMessages returns all messages in the conversation
 470func (h *ClaudeTestHarness) GetMessages() []generated.Message {
 471	var messages []generated.Message
 472	err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
 473		var qerr error
 474		messages, qerr = q.ListMessages(context.Background(), h.convID)
 475		return qerr
 476	})
 477	if err != nil {
 478		h.t.Fatalf("GetMessages: failed to get messages: %v", err)
 479	}
 480	return messages
 481}
 482
 483// HasCancelledToolResult checks if there's a cancelled tool result in the conversation
 484func (h *ClaudeTestHarness) HasCancelledToolResult() bool {
 485	messages := h.GetMessages()
 486	for _, msg := range messages {
 487		if msg.Type != string(db.MessageTypeUser) || msg.LlmData == nil {
 488			continue
 489		}
 490
 491		var llmMsg llm.Message
 492		if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
 493			continue
 494		}
 495
 496		for _, content := range llmMsg.Content {
 497			if content.Type == llm.ContentTypeToolResult && content.ToolError {
 498				for _, result := range content.ToolResult {
 499					if result.Type == llm.ContentTypeText && strings.Contains(result.Text, "cancelled") {
 500						return true
 501					}
 502				}
 503			}
 504		}
 505	}
 506	return false
 507}
 508
 509// HasCancellationMessage checks if there's a cancellation message in the conversation
 510func (h *ClaudeTestHarness) HasCancellationMessage() bool {
 511	messages := h.GetMessages()
 512	for _, msg := range messages {
 513		if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
 514			continue
 515		}
 516
 517		var llmMsg llm.Message
 518		if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
 519			continue
 520		}
 521
 522		for _, content := range llmMsg.Content {
 523			if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "Operation cancelled") {
 524				return true
 525			}
 526		}
 527	}
 528	return false
 529}
 530
 531// claudeLLMManager is an LLMProvider that returns the Claude service
 532type claudeLLMManager struct {
 533	service llm.Service
 534}
 535
 536func (m *claudeLLMManager) GetService(modelID string) (llm.Service, error) {
 537	return m.service, nil
 538}
 539
 540func (m *claudeLLMManager) GetAvailableModels() []string {
 541	return []string{"claude", "claude-haiku-4.5"}
 542}
 543
 544func (m *claudeLLMManager) HasModel(modelID string) bool {
 545	return modelID == "claude" || modelID == "claude-haiku-4.5"
 546}
 547
 548func (m *claudeLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
 549	if modelID == "claude-haiku-4.5" {
 550		return &models.ModelInfo{DisplayName: "Claude Haiku", Tags: "slug"}
 551	}
 552	return nil
 553}
 554
 555func (m *claudeLLMManager) RefreshCustomModels() error {
 556	return nil
 557}
 558
 559// TestClaudeCancelDuringToolCall tests cancellation during tool execution with Claude
 560func TestClaudeCancelDuringToolCall(t *testing.T) {
 561	h := NewClaudeTestHarness(t)
 562	defer h.Close()
 563
 564	// Start a conversation that triggers a slow bash command
 565	h.NewConversation("Please run the bash command: sleep 10", "")
 566
 567	// Wait for the tool to start executing
 568	h.WaitForAgentWorking()
 569	t.Log("Agent is working on tool call")
 570
 571	// Cancel the conversation
 572	h.Cancel()
 573	t.Log("Cancelled conversation")
 574
 575	// Wait a bit for cancellation to complete
 576	time.Sleep(500 * time.Millisecond)
 577
 578	// Verify cancellation was recorded properly
 579	if !h.HasCancelledToolResult() {
 580		t.Error("expected cancelled tool result to be recorded")
 581	}
 582
 583	if !h.HasCancellationMessage() {
 584		t.Error("expected cancellation message to be recorded")
 585	}
 586
 587	messages := h.GetMessages()
 588	t.Logf("Total messages after cancellation: %d", len(messages))
 589
 590	// Verify tokens are maintained
 591	h.VerifyTokensNonDecreasing()
 592}
 593
 594// TestClaudeCancelDuringLLMCall tests cancellation during LLM API call with Claude
 595func TestClaudeCancelDuringLLMCall(t *testing.T) {
 596	h := NewClaudeTestHarness(t)
 597	defer h.Close()
 598
 599	// Start a conversation with a message that will take some time to process
 600	h.NewConversation("Please write a very detailed essay about the history of computing, covering at least 10 major milestones.", "")
 601
 602	// Wait briefly for the request to be sent to Claude
 603	time.Sleep(500 * time.Millisecond)
 604
 605	// Cancel during the LLM call
 606	h.Cancel()
 607	t.Log("Cancelled during LLM call")
 608
 609	// Wait for cancellation
 610	time.Sleep(500 * time.Millisecond)
 611
 612	// Verify cancellation message exists
 613	if !h.HasCancellationMessage() {
 614		t.Error("expected cancellation message to be recorded")
 615	}
 616
 617	messages := h.GetMessages()
 618	t.Logf("Total messages after cancellation: %d", len(messages))
 619
 620	// Verify tokens are maintained
 621	h.VerifyTokensNonDecreasing()
 622}
 623
 624// TestClaudeCancelDuringLLMCallThenResume tests cancellation during LLM API call and then resuming
 625func TestClaudeCancelDuringLLMCallThenResume(t *testing.T) {
 626	h := NewClaudeTestHarness(t)
 627	defer h.Close()
 628
 629	// Start a conversation with context we can verify later
 630	h.NewConversation("Remember this code: BLUE42. Write a long essay about colors.", "")
 631
 632	// Wait briefly for the request to be sent to Claude
 633	time.Sleep(300 * time.Millisecond)
 634
 635	// Cancel during the LLM call (before response arrives)
 636	h.Cancel()
 637	t.Log("Cancelled during LLM call")
 638	time.Sleep(500 * time.Millisecond)
 639
 640	if !h.HasCancellationMessage() {
 641		t.Error("expected cancellation message to be recorded")
 642	}
 643
 644	tokensAfterCancel := h.GetRequestTokens()
 645	t.Logf("Tokens after cancel: %v", tokensAfterCancel)
 646
 647	// Now resume and verify context is preserved
 648	h.Chat("What was the code I asked you to remember? Just tell me the code.")
 649	response := h.WaitResponse()
 650	t.Logf("Response after resume: %s", response)
 651
 652	// Verify context was preserved - Claude should remember BLUE42
 653	if !strings.Contains(strings.ToUpper(response), "BLUE42") {
 654		t.Errorf("expected response to contain BLUE42, got: %s", response)
 655	}
 656
 657	// Verify tokens are maintained
 658	h.VerifyTokensNonDecreasing()
 659}
 660
 661// TestClaudeCancelDuringLLMCallMultipleTimes tests multiple cancellations during LLM calls
 662func TestClaudeCancelDuringLLMCallMultipleTimes(t *testing.T) {
 663	h := NewClaudeTestHarness(t)
 664	defer h.Close()
 665
 666	// First: cancel during LLM call
 667	h.NewConversation("Write a very long detailed story about space exploration.", "")
 668	time.Sleep(300 * time.Millisecond)
 669	h.Cancel()
 670	t.Log("First cancel during LLM")
 671	time.Sleep(500 * time.Millisecond)
 672
 673	// Second: cancel during LLM call again
 674	h.Chat("Write a very long detailed story about ocean exploration.")
 675	time.Sleep(300 * time.Millisecond)
 676	h.Cancel()
 677	t.Log("Second cancel during LLM")
 678	time.Sleep(500 * time.Millisecond)
 679
 680	// Third: cancel during LLM call again
 681	h.Chat("Write a very long detailed story about mountain climbing.")
 682	time.Sleep(300 * time.Millisecond)
 683	h.Cancel()
 684	t.Log("Third cancel during LLM")
 685	time.Sleep(500 * time.Millisecond)
 686
 687	// Now resume normally - the conversation should still work
 688	h.Chat("Just say 'conversation recovered' and nothing else.")
 689	response := h.WaitResponse()
 690	t.Logf("Response after multiple cancels: %s", response)
 691
 692	// Verify the conversation is functional - response should not indicate an error
 693	lowerResp := strings.ToLower(response)
 694	if strings.Contains(lowerResp, "error") || strings.Contains(lowerResp, "invalid") {
 695		t.Errorf("response may indicate an error: %s", response)
 696	}
 697
 698	// Verify tokens are maintained
 699	h.VerifyTokensNonDecreasing()
 700}
 701
 702// TestClaudeCancelDuringLLMCallAndVerifyMessageStructure verifies message structure after LLM cancellation
 703func TestClaudeCancelDuringLLMCallAndVerifyMessageStructure(t *testing.T) {
 704	h := NewClaudeTestHarness(t)
 705	defer h.Close()
 706
 707	h.NewConversation("Write a very long detailed story about a wizard.", "")
 708	time.Sleep(300 * time.Millisecond)
 709	h.Cancel()
 710	time.Sleep(500 * time.Millisecond)
 711
 712	// Check message structure
 713	messages := h.GetMessages()
 714	t.Logf("Messages after LLM cancel: %d", len(messages))
 715
 716	// Should have: system message, user message, cancellation message
 717	// The user message should be recorded even if Claude didn't respond
 718	userMessageFound := false
 719	cancelMessageFound := false
 720
 721	for _, msg := range messages {
 722		t.Logf("Message type: %s", msg.Type)
 723		if msg.Type == string(db.MessageTypeUser) {
 724			userMessageFound = true
 725		}
 726		if msg.Type == string(db.MessageTypeAgent) && msg.LlmData != nil {
 727			var llmMsg llm.Message
 728			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err == nil {
 729				for _, content := range llmMsg.Content {
 730					if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "cancelled") {
 731						cancelMessageFound = true
 732					}
 733				}
 734			}
 735		}
 736	}
 737
 738	if !userMessageFound {
 739		t.Error("expected user message to be recorded")
 740	}
 741	if !cancelMessageFound {
 742		t.Error("expected cancellation message to be recorded")
 743	}
 744
 745	// Now send a follow-up and verify no API errors about message format
 746	h.Chat("Just say hello.")
 747	response := h.WaitResponse()
 748	t.Logf("Follow-up response: %s", response)
 749
 750	// Response should not indicate an error
 751	lowerResp := strings.ToLower(response)
 752	if strings.Contains(lowerResp, "error") || strings.Contains(lowerResp, "invalid") {
 753		t.Errorf("response may indicate API error: %s", response)
 754	}
 755
 756	h.VerifyTokensNonDecreasing()
 757}
 758
 759// TestClaudeResumeAfterCancellation tests that a conversation can be resumed after cancellation
 760func TestClaudeResumeAfterCancellation(t *testing.T) {
 761	h := NewClaudeTestHarness(t)
 762	defer h.Close()
 763
 764	// Start a conversation
 765	h.NewConversation("Please run: sleep 5", "")
 766
 767	// Wait for tool to start
 768	h.WaitForAgentWorking()
 769	t.Log("Agent started tool call")
 770
 771	// Cancel
 772	h.Cancel()
 773	t.Log("Cancelled")
 774	time.Sleep(500 * time.Millisecond)
 775
 776	// Verify cancellation
 777	if !h.HasCancellationMessage() {
 778		t.Error("expected cancellation message")
 779	}
 780
 781	messagesAfterCancel := len(h.GetMessages())
 782	t.Logf("Messages after cancel: %d", messagesAfterCancel)
 783
 784	// Resume the conversation
 785	h.Chat("Hello, let's continue. Please just say 'resumed' and nothing else.")
 786
 787	// Wait for response
 788	response := h.WaitResponse()
 789	t.Logf("Response after resume: %s", response)
 790
 791	// Verify we got more messages
 792	messagesAfterResume := len(h.GetMessages())
 793	t.Logf("Messages after resume: %d", messagesAfterResume)
 794
 795	if messagesAfterResume <= messagesAfterCancel {
 796		t.Error("expected more messages after resume")
 797	}
 798
 799	// Verify tokens are maintained
 800	h.VerifyTokensNonDecreasing()
 801}
 802
 803// TestClaudeTokensMonotonicallyIncreasing tests that token count increases when resuming
 804// With prompt caching, total tokens = input + cache_creation + cache_read
 805func TestClaudeTokensMonotonicallyIncreasing(t *testing.T) {
 806	h := NewClaudeTestHarness(t)
 807	defer h.Close()
 808
 809	// First conversation turn
 810	h.NewConversation("Hello, please respond with 'first response' and nothing else.", "")
 811	h.WaitResponse()
 812	time.Sleep(500 * time.Millisecond) // Wait for any pending operations
 813
 814	tokens1 := h.GetRequestTokens()
 815	if len(tokens1) == 0 {
 816		t.Skip("No token data recorded (API may not be returning it)")
 817	}
 818	lastToken1 := tokens1[len(tokens1)-1]
 819	t.Logf("First turn total tokens: %d", lastToken1)
 820
 821	// Second conversation turn
 822	h.Chat("Now please respond with 'second response' and nothing else.")
 823	h.WaitResponse()
 824	time.Sleep(500 * time.Millisecond)
 825
 826	tokens2 := h.GetRequestTokens()
 827	if len(tokens2) <= len(tokens1) {
 828		t.Fatal("expected more requests in second turn")
 829	}
 830	lastToken2 := tokens2[len(tokens2)-1]
 831	t.Logf("Second turn total tokens: %d", lastToken2)
 832
 833	// With prompt caching, tokens should increase or stay similar
 834	// The key is that we're still sending context (total should be meaningful)
 835	if lastToken2 < lastToken1 {
 836		t.Errorf("tokens decreased significantly: first=%d, second=%d", lastToken1, lastToken2)
 837	}
 838
 839	// Third turn
 840	h.Chat("Third turn - respond with 'third response' only.")
 841	h.WaitResponse()
 842	time.Sleep(500 * time.Millisecond)
 843
 844	tokens3 := h.GetRequestTokens()
 845	if len(tokens3) <= len(tokens2) {
 846		t.Fatal("expected more requests in third turn")
 847	}
 848	lastToken3 := tokens3[len(tokens3)-1]
 849	t.Logf("Third turn total tokens: %d", lastToken3)
 850
 851	// Each subsequent turn should have at least as many tokens as the first turn
 852	// (because we're including more conversation history)
 853	if lastToken3 < lastToken1 {
 854		t.Errorf("third turn has fewer tokens than first: first=%d, third=%d", lastToken1, lastToken3)
 855	}
 856
 857	t.Logf("Token progression: %d -> %d -> %d", lastToken1, lastToken2, lastToken3)
 858}
 859
 860// TestClaudeResumeAfterCancellationPreservesContext tests context preservation after cancellation
 861func TestClaudeResumeAfterCancellationPreservesContext(t *testing.T) {
 862	h := NewClaudeTestHarness(t)
 863	defer h.Close()
 864
 865	// Start with specific context
 866	h.NewConversation("Remember this secret word: ELEPHANT. I will ask you about it later. For now, just acknowledge with 'understood'.", "")
 867	response1 := h.WaitResponse()
 868	t.Logf("First response: %s", response1)
 869
 870	tokens1 := h.GetRequestTokens()
 871	if len(tokens1) == 0 {
 872		t.Skip("No token data recorded")
 873	}
 874	t.Logf("Tokens after first exchange: %v", tokens1)
 875
 876	// Start a slow command to trigger cancellation
 877	h.Chat("Run this command: sleep 10")
 878	h.WaitForAgentWorking()
 879
 880	// Cancel
 881	h.Cancel()
 882	time.Sleep(500 * time.Millisecond)
 883
 884	tokensAfterCancel := h.GetRequestTokens()
 885	t.Logf("Tokens after cancel: %v", tokensAfterCancel)
 886
 887	// Resume and ask about the secret word
 888	h.Chat("What was the secret word I told you to remember?")
 889	response2 := h.WaitResponse()
 890	t.Logf("Response after resume: %s", response2)
 891
 892	tokensAfterResume := h.GetRequestTokens()
 893	t.Logf("Tokens after resume: %v", tokensAfterResume)
 894
 895	// Check that the response mentions ELEPHANT
 896	if !strings.Contains(strings.ToUpper(response2), "ELEPHANT") {
 897		t.Errorf("expected response to mention ELEPHANT, got: %s", response2)
 898	}
 899
 900	// Verify tokens are maintained
 901	h.VerifyTokensNonDecreasing()
 902}
 903
 904// TestClaudeMultipleCancellations tests multiple cancellations in a row
 905func TestClaudeMultipleCancellations(t *testing.T) {
 906	h := NewClaudeTestHarness(t)
 907	defer h.Close()
 908
 909	// First cancellation
 910	h.NewConversation("Run: sleep 10", "")
 911	h.WaitForAgentWorking()
 912	h.Cancel()
 913	time.Sleep(300 * time.Millisecond)
 914
 915	if !h.HasCancellationMessage() {
 916		t.Error("expected first cancellation message")
 917	}
 918
 919	// Second cancellation
 920	h.Chat("Run: sleep 10")
 921	time.Sleep(2 * time.Second) // Wait for Claude to respond and start tool
 922	h.Cancel()
 923	time.Sleep(300 * time.Millisecond)
 924
 925	// Third: complete normally
 926	h.Chat("Just say 'done' and nothing else.")
 927	response := h.WaitResponse()
 928	t.Logf("Final response: %s", response)
 929
 930	// Verify tokens are maintained
 931	h.VerifyTokensNonDecreasing()
 932}
 933
 934// TestClaudeCancelImmediately tests cancelling immediately after sending a message
 935func TestClaudeCancelImmediately(t *testing.T) {
 936	h := NewClaudeTestHarness(t)
 937	defer h.Close()
 938
 939	h.NewConversation("Write a very long essay about everything.", "")
 940
 941	// Cancel immediately
 942	time.Sleep(50 * time.Millisecond)
 943	h.Cancel()
 944
 945	time.Sleep(500 * time.Millisecond)
 946
 947	// Should still be able to resume
 948	h.Chat("Just say 'hello'")
 949	response := h.WaitResponse()
 950	t.Logf("Response after immediate cancel: %s", response)
 951
 952	if response == "" {
 953		t.Error("expected a response after resuming from immediate cancel")
 954	}
 955
 956	// Verify tokens are maintained
 957	h.VerifyTokensNonDecreasing()
 958}
 959
 960// TestClaudeCancelWithPendingToolResult tests that missing tool results are handled properly
 961func TestClaudeCancelWithPendingToolResult(t *testing.T) {
 962	h := NewClaudeTestHarness(t)
 963	defer h.Close()
 964
 965	// This tests the insertMissingToolResults logic
 966	h.NewConversation("Run: sleep 20", "")
 967	h.WaitForAgentWorking()
 968
 969	// Cancel during tool execution
 970	h.Cancel()
 971	time.Sleep(500 * time.Millisecond)
 972
 973	// Resume - this should handle the missing tool result
 974	h.Chat("Please just say 'recovered' if you can hear me.")
 975	response := h.WaitResponse()
 976	t.Logf("Recovery response: %s", response)
 977
 978	// The conversation should have recovered
 979	// Claude should not complain about bad messages
 980	if strings.Contains(strings.ToLower(response), "error") {
 981		t.Errorf("response indicates an error, which may mean message handling failed: %s", response)
 982	}
 983
 984	// Verify tokens are maintained
 985	h.VerifyTokensNonDecreasing()
 986}
 987
 988// TestClaudeCancelDuringLLMCallRapidFire tests rapid cancellations during LLM calls
 989func TestClaudeCancelDuringLLMCallRapidFire(t *testing.T) {
 990	h := NewClaudeTestHarness(t)
 991	defer h.Close()
 992
 993	// Send message and cancel as fast as possible, multiple times
 994	for i := 0; i < 3; i++ {
 995		if i == 0 {
 996			h.NewConversation("Write a long story.", "")
 997		} else {
 998			h.Chat("Write another long story.")
 999		}
1000		time.Sleep(100 * time.Millisecond)
1001		h.Cancel()
1002		time.Sleep(200 * time.Millisecond)
1003		t.Logf("Rapid cancel %d complete", i+1)
1004	}
1005
1006	// Now do a normal conversation
1007	h.Chat("Just say 'stable' and nothing else.")
1008	response := h.WaitResponse()
1009	t.Logf("Final response after rapid cancels: %s", response)
1010
1011	// Verify tokens are maintained
1012	h.VerifyTokensNonDecreasing()
1013}
1014
1015// TestClaudeCancelDuringLLMCallWithToolUseResponse tests cancel when Claude is about to use a tool
1016func TestClaudeCancelDuringLLMCallWithToolUseResponse(t *testing.T) {
1017	h := NewClaudeTestHarness(t)
1018	defer h.Close()
1019
1020	// Ask Claude to use a tool - the response will contain tool_use
1021	// Cancel before the tool actually executes
1022	h.NewConversation("Run: echo hello world", "")
1023
1024	// Wait just enough for the LLM request to be sent but not for tool execution
1025	time.Sleep(500 * time.Millisecond)
1026
1027	// Cancel - this might catch the LLM responding with tool_use but before tool execution
1028	h.Cancel()
1029	time.Sleep(500 * time.Millisecond)
1030
1031	t.Logf("Cancelled during potential tool_use response")
1032
1033	// Resume and verify conversation works
1034	h.Chat("Just say 'ok' if you can hear me.")
1035	response := h.WaitResponse()
1036	t.Logf("Response: %s", response)
1037
1038	// Verify tokens are maintained
1039	h.VerifyTokensNonDecreasing()
1040}