convo_test.go

   1package conversation
   2
   3import (
   4	"cmp"
   5	"context"
   6	"encoding/json"
   7	"net/http"
   8	"os"
   9	"slices"
  10	"strings"
  11	"testing"
  12	"time"
  13
  14	"shelley.exe.dev/llm"
  15	"shelley.exe.dev/llm/ant"
  16	"shelley.exe.dev/loop"
  17	"sketch.dev/httprr"
  18)
  19
  20func TestBasicConvo(t *testing.T) {
  21	ctx := context.Background()
  22	rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
  23	if err != nil {
  24		t.Fatal(err)
  25	}
  26	rr.ScrubReq(func(req *http.Request) error {
  27		req.Header.Del("x-api-key")
  28		req.Header.Del("User-Agent")
  29		req.Header.Del("Shelley-Conversation-Id")
  30		return nil
  31	})
  32
  33	apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
  34	srv := &ant.Service{
  35		APIKey: apiKey,
  36		Model:  ant.Claude4Sonnet, // Use specific model to match cached responses
  37		HTTPC:  rr.Client(),
  38	}
  39	convo := New(ctx, srv, nil)
  40
  41	const name = "Cornelius"
  42	res, err := convo.SendUserTextMessage("Hi, my name is " + name)
  43	if err != nil {
  44		t.Fatal(err)
  45	}
  46	for _, part := range res.Content {
  47		t.Logf("%s", part.Text)
  48	}
  49	res, err = convo.SendUserTextMessage("What is my name?")
  50	if err != nil {
  51		t.Fatal(err)
  52	}
  53	got := ""
  54	for _, part := range res.Content {
  55		got += part.Text
  56	}
  57	if !strings.Contains(got, name) {
  58		t.Errorf("model does not know the given name %s: %q", name, got)
  59	}
  60}
  61
  62// TestCancelToolUse tests the CancelToolUse function of the Convo struct
  63func TestCancelToolUse(t *testing.T) {
  64	tests := []struct {
  65		name         string
  66		setupToolUse bool
  67		toolUseID    string
  68		cancelErr    error
  69		expectError  bool
  70		expectCancel bool
  71	}{
  72		{
  73			name:         "Cancel existing tool use",
  74			setupToolUse: true,
  75			toolUseID:    "tool123",
  76			cancelErr:    nil,
  77			expectError:  false,
  78			expectCancel: true,
  79		},
  80		{
  81			name:         "Cancel existing tool use with error",
  82			setupToolUse: true,
  83			toolUseID:    "tool456",
  84			cancelErr:    context.Canceled,
  85			expectError:  false,
  86			expectCancel: true,
  87		},
  88		{
  89			name:         "Cancel non-existent tool use",
  90			setupToolUse: false,
  91			toolUseID:    "tool789",
  92			cancelErr:    nil,
  93			expectError:  true,
  94			expectCancel: false,
  95		},
  96	}
  97
  98	srv := &ant.Service{}
  99	for _, tt := range tests {
 100		t.Run(tt.name, func(t *testing.T) {
 101			convo := New(context.Background(), srv, nil)
 102
 103			var cancelCalled bool
 104			var cancelledWithErr error
 105
 106			if tt.setupToolUse {
 107				// Setup a mock cancel function to track calls
 108				mockCancel := func(err error) {
 109					cancelCalled = true
 110					cancelledWithErr = err
 111				}
 112
 113				convo.toolUseCancelMu.Lock()
 114				convo.toolUseCancel[tt.toolUseID] = mockCancel
 115				convo.toolUseCancelMu.Unlock()
 116			}
 117
 118			err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
 119
 120			// Check if we got the expected error state
 121			if (err != nil) != tt.expectError {
 122				t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
 123			}
 124
 125			// Check if the cancel function was called as expected
 126			if cancelCalled != tt.expectCancel {
 127				t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
 128			}
 129
 130			// If we expected the cancel to be called, verify it was called with the right error
 131			if tt.expectCancel && cancelledWithErr != tt.cancelErr {
 132				t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
 133			}
 134
 135			// Verify the toolUseID was removed from the map if it was initially added
 136			if tt.setupToolUse {
 137				convo.toolUseCancelMu.Lock()
 138				_, exists := convo.toolUseCancel[tt.toolUseID]
 139				convo.toolUseCancelMu.Unlock()
 140
 141				if exists {
 142					t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
 143				}
 144			}
 145		})
 146	}
 147}
 148
 149// TestInsertMissingToolResults tests the insertMissingToolResults function
 150// to ensure it doesn't create duplicate tool results when multiple tool uses are missing results.
 151func TestInsertMissingToolResults(t *testing.T) {
 152	tests := []struct {
 153		name            string
 154		messages        []llm.Message
 155		currentMsg      llm.Message
 156		expectedCount   int
 157		expectedToolIDs []string
 158	}{
 159		{
 160			name: "Single missing tool result",
 161			messages: []llm.Message{
 162				{
 163					Role: llm.MessageRoleAssistant,
 164					Content: []llm.Content{
 165						{
 166							Type: llm.ContentTypeToolUse,
 167							ID:   "tool1",
 168						},
 169					},
 170				},
 171			},
 172			currentMsg: llm.Message{
 173				Role:    llm.MessageRoleUser,
 174				Content: []llm.Content{},
 175			},
 176			expectedCount:   1,
 177			expectedToolIDs: []string{"tool1"},
 178		},
 179		{
 180			name: "Multiple missing tool results",
 181			messages: []llm.Message{
 182				{
 183					Role: llm.MessageRoleAssistant,
 184					Content: []llm.Content{
 185						{
 186							Type: llm.ContentTypeToolUse,
 187							ID:   "tool1",
 188						},
 189						{
 190							Type: llm.ContentTypeToolUse,
 191							ID:   "tool2",
 192						},
 193						{
 194							Type: llm.ContentTypeToolUse,
 195							ID:   "tool3",
 196						},
 197					},
 198				},
 199			},
 200			currentMsg: llm.Message{
 201				Role:    llm.MessageRoleUser,
 202				Content: []llm.Content{},
 203			},
 204			expectedCount:   3,
 205			expectedToolIDs: []string{"tool1", "tool2", "tool3"},
 206		},
 207		{
 208			name: "No missing tool results when results already present",
 209			messages: []llm.Message{
 210				{
 211					Role: llm.MessageRoleAssistant,
 212					Content: []llm.Content{
 213						{
 214							Type: llm.ContentTypeToolUse,
 215							ID:   "tool1",
 216						},
 217					},
 218				},
 219			},
 220			currentMsg: llm.Message{
 221				Role: llm.MessageRoleUser,
 222				Content: []llm.Content{
 223					{
 224						Type:      llm.ContentTypeToolResult,
 225						ToolUseID: "tool1",
 226					},
 227				},
 228			},
 229			expectedCount:   1, // Only the existing one
 230			expectedToolIDs: []string{"tool1"},
 231		},
 232		{
 233			name: "No tool uses in previous message",
 234			messages: []llm.Message{
 235				{
 236					Role: llm.MessageRoleAssistant,
 237					Content: []llm.Content{
 238						{
 239							Type: llm.ContentTypeText,
 240							Text: "Just some text",
 241						},
 242					},
 243				},
 244			},
 245			currentMsg: llm.Message{
 246				Role:    llm.MessageRoleUser,
 247				Content: []llm.Content{},
 248			},
 249			expectedCount:   0,
 250			expectedToolIDs: []string{},
 251		},
 252	}
 253
 254	for _, tt := range tests {
 255		t.Run(tt.name, func(t *testing.T) {
 256			srv := &ant.Service{}
 257			convo := New(context.Background(), srv, nil)
 258
 259			// Create request with messages
 260			req := &llm.Request{
 261				Messages: append(tt.messages, tt.currentMsg),
 262			}
 263
 264			// Call insertMissingToolResults
 265			msg := tt.currentMsg
 266			convo.insertMissingToolResults(req, &msg)
 267
 268			// Count tool results in the message
 269			toolResultCount := 0
 270			toolIDs := []string{}
 271			for _, content := range msg.Content {
 272				if content.Type == llm.ContentTypeToolResult {
 273					toolResultCount++
 274					toolIDs = append(toolIDs, content.ToolUseID)
 275				}
 276			}
 277
 278			// Verify count
 279			if toolResultCount != tt.expectedCount {
 280				t.Errorf("Expected %d tool results, got %d", tt.expectedCount, toolResultCount)
 281			}
 282
 283			// Verify no duplicates by checking unique tool IDs
 284			seenIDs := make(map[string]int)
 285			for _, id := range toolIDs {
 286				seenIDs[id]++
 287			}
 288
 289			// Check for duplicates
 290			for id, count := range seenIDs {
 291				if count > 1 {
 292					t.Errorf("Duplicate tool result for ID %s: found %d times", id, count)
 293				}
 294			}
 295
 296			// Verify all expected tool IDs are present
 297			for _, expectedID := range tt.expectedToolIDs {
 298				if !slices.Contains(toolIDs, expectedID) {
 299					t.Errorf("Expected tool ID %s not found in results", expectedID)
 300				}
 301			}
 302		})
 303	}
 304}
 305
 306// TestSubConvo tests the SubConvo function
 307func TestSubConvo(t *testing.T) {
 308	ctx := context.Background()
 309	srv := &ant.Service{}
 310	parentConvo := New(ctx, srv, nil)
 311
 312	// Test that SubConvo creates a new conversation with the correct parent relationship
 313	subConvo := parentConvo.SubConvo()
 314
 315	if subConvo == nil {
 316		t.Fatal("SubConvo returned nil")
 317	}
 318
 319	if subConvo.Parent != parentConvo {
 320		t.Error("SubConvo did not set the correct parent")
 321	}
 322
 323	if subConvo.Service != parentConvo.Service {
 324		t.Error("SubConvo did not inherit the service")
 325	}
 326
 327	if subConvo.PromptCaching != parentConvo.PromptCaching {
 328		t.Error("SubConvo did not inherit PromptCaching setting")
 329	}
 330
 331	// Check that the sub-convo has a different ID
 332	if subConvo.ID == parentConvo.ID {
 333		t.Error("SubConvo should have a different ID from parent")
 334	}
 335
 336	// Check that the sub-convo shares tool uses with parent
 337	if &subConvo.usage.ToolUses == &parentConvo.usage.ToolUses {
 338		t.Error("SubConvo should share tool uses map with parent")
 339	}
 340
 341	// Check that the sub-convo has its own usage instance
 342	if subConvo.usage == parentConvo.usage {
 343		t.Error("SubConvo should have its own usage instance (but sharing ToolUses)")
 344	}
 345}
 346
 347// TestSubConvoWithHistory tests the SubConvoWithHistory function
 348
 349// TestDepth tests the Depth function
 350
 351// TestFindTool tests the findTool function
 352func TestFindTool(t *testing.T) {
 353	ctx := context.Background()
 354	srv := &ant.Service{}
 355	convo := New(ctx, srv, nil)
 356
 357	// Add some tools to the conversation
 358	tool1 := &llm.Tool{Name: "tool1"}
 359	tool2 := &llm.Tool{Name: "tool2"}
 360	convo.Tools = append(convo.Tools, tool1, tool2)
 361
 362	// Test finding an existing tool
 363	foundTool, err := convo.findTool("tool1")
 364	if err != nil {
 365		t.Errorf("findTool returned error for existing tool: %v", err)
 366	}
 367	if foundTool != tool1 {
 368		t.Error("findTool did not return the correct tool")
 369	}
 370
 371	// Test finding another existing tool
 372	foundTool, err = convo.findTool("tool2")
 373	if err != nil {
 374		t.Errorf("findTool returned error for existing tool: %v", err)
 375	}
 376	if foundTool != tool2 {
 377		t.Error("findTool did not return the correct tool")
 378	}
 379
 380	// Test finding a non-existent tool
 381	_, err = convo.findTool("nonexistent")
 382	if err == nil {
 383		t.Error("findTool should return error for non-existent tool")
 384	}
 385	expectedErr := `tool "nonexistent" not found`
 386	if err.Error() != expectedErr {
 387		t.Errorf("Expected error %q, got %q", expectedErr, err.Error())
 388	}
 389}
 390
 391// TestToolCallInfoFromContext tests the ToolCallInfoFromContext function
 392func TestToolCallInfoFromContext(t *testing.T) {
 393	// Test with no tool call info in context
 394	ctx := context.Background()
 395	info := ToolCallInfoFromContext(ctx)
 396	if info.ToolUseID != "" {
 397		t.Error("ToolCallInfoFromContext should return empty info when no tool call info is in context")
 398	}
 399
 400	// Test with tool call info in context
 401	toolInfo := ToolCallInfo{
 402		ToolUseID: "testID",
 403	}
 404	ctxWithInfo := context.WithValue(ctx, toolCallInfoKey, toolInfo)
 405	info = ToolCallInfoFromContext(ctxWithInfo)
 406	if info.ToolUseID != "testID" {
 407		t.Errorf("Expected ToolUseID 'testID', got %q", info.ToolUseID)
 408	}
 409}
 410
 411// TestCumulativeUsageMethods tests CumulativeUsage methods
 412func TestCumulativeUsageMethods(t *testing.T) {
 413	// Test Clone method
 414	original := &CumulativeUsage{
 415		StartTime:                time.Now(),
 416		Responses:                5,
 417		InputTokens:              100,
 418		OutputTokens:             200,
 419		CacheReadInputTokens:     50,
 420		CacheCreationInputTokens: 30,
 421		TotalCostUSD:             1.23,
 422		ToolUses: map[string]int{
 423			"tool1": 3,
 424			"tool2": 2,
 425		},
 426	}
 427
 428	clone := original.Clone()
 429
 430	// Check that values are copied correctly
 431	if clone.StartTime != original.StartTime {
 432		t.Error("Clone did not copy StartTime correctly")
 433	}
 434	if clone.Responses != original.Responses {
 435		t.Error("Clone did not copy Responses correctly")
 436	}
 437	if clone.InputTokens != original.InputTokens {
 438		t.Error("Clone did not copy InputTokens correctly")
 439	}
 440	if clone.OutputTokens != original.OutputTokens {
 441		t.Error("Clone did not copy OutputTokens correctly")
 442	}
 443	if clone.CacheReadInputTokens != original.CacheReadInputTokens {
 444		t.Error("Clone did not copy CacheReadInputTokens correctly")
 445	}
 446	if clone.CacheCreationInputTokens != original.CacheCreationInputTokens {
 447		t.Error("Clone did not copy CacheCreationInputTokens correctly")
 448	}
 449	if clone.TotalCostUSD != original.TotalCostUSD {
 450		t.Error("Clone did not copy TotalCostUSD correctly")
 451	}
 452	if len(clone.ToolUses) != len(original.ToolUses) {
 453		t.Error("Clone did not copy ToolUses correctly")
 454	}
 455	for k, v := range original.ToolUses {
 456		if clone.ToolUses[k] != v {
 457			t.Errorf("Clone did not copy ToolUses correctly for key %s", k)
 458		}
 459	}
 460
 461	// Check that maps are separate instances
 462	clone.ToolUses["tool3"] = 1
 463	if _, exists := original.ToolUses["tool3"]; exists {
 464		t.Error("Clone should have separate ToolUses map")
 465	}
 466}
 467
 468// TestUsageMethods tests various usage calculation methods
 469func TestUsageMethods(t *testing.T) {
 470	ctx := context.Background()
 471	srv := loop.NewPredictableService()
 472	convo := New(ctx, srv, nil)
 473
 474	// Test CumulativeUsage on empty conversation
 475	usage := convo.CumulativeUsage()
 476	if usage.Responses != 0 {
 477		t.Error("CumulativeUsage should be empty for new conversation")
 478	}
 479
 480	// Test WallTime method
 481	wallTime := usage.WallTime()
 482	if wallTime <= 0 {
 483		t.Error("WallTime should be positive")
 484	}
 485
 486	// Test DollarsPerHour method
 487	dollarsPerHour := usage.DollarsPerHour()
 488	if dollarsPerHour != 0 {
 489		t.Error("DollarsPerHour should be 0 for empty usage")
 490	}
 491
 492	// Test TotalInputTokens method
 493	totalInputTokens := usage.TotalInputTokens()
 494	if totalInputTokens != 0 {
 495		t.Error("TotalInputTokens should be 0 for empty usage")
 496	}
 497
 498	// Test Attr method
 499	attr := usage.Attr()
 500	if attr.Key != "usage" {
 501		t.Error("Attr should have key 'usage'")
 502	}
 503}
 504
 505// TestLastUsage tests the LastUsage function
 506func TestLastUsage(t *testing.T) {
 507	ctx := context.Background()
 508	srv := loop.NewPredictableService()
 509	convo := New(ctx, srv, nil)
 510
 511	// Test LastUsage on empty conversation
 512	lastUsage := convo.LastUsage()
 513	if lastUsage.InputTokens != 0 {
 514		t.Error("LastUsage should be empty for new conversation")
 515	}
 516
 517	// Send a message to generate some usage
 518	_, err := convo.SendUserTextMessage("echo: hello")
 519	if err != nil {
 520		t.Fatalf("SendUserTextMessage failed: %v", err)
 521	}
 522
 523	// Test LastUsage after sending a message
 524	lastUsage = convo.LastUsage()
 525	if lastUsage.InputTokens == 0 {
 526		t.Error("LastUsage should have input tokens after sending a message")
 527	}
 528}
 529
 530// TestOverBudget tests the OverBudget function
 531func TestOverBudget(t *testing.T) {
 532	ctx := context.Background()
 533	srv := loop.NewPredictableService()
 534	convo := New(ctx, srv, nil)
 535
 536	// Test OverBudget with no budget set
 537	err := convo.OverBudget()
 538	if err != nil {
 539		t.Errorf("OverBudget should return nil when no budget is set, got %v", err)
 540	}
 541
 542	// Set a budget
 543	convo.Budget.MaxDollars = 10.0
 544
 545	// Test OverBudget with budget not exceeded
 546	err = convo.OverBudget()
 547	if err != nil {
 548		t.Errorf("OverBudget should return nil when budget is not exceeded, got %v", err)
 549	}
 550
 551	// Test with sub-conversation
 552	subConvo := convo.SubConvo()
 553	err = subConvo.OverBudget()
 554	if err != nil {
 555		t.Errorf("OverBudget should return nil for sub-conversation when budget is not exceeded, got %v", err)
 556	}
 557}
 558
 559// TestResetBudget tests the ResetBudget function
 560func TestResetBudget(t *testing.T) {
 561	ctx := context.Background()
 562	srv := loop.NewPredictableService()
 563	convo := New(ctx, srv, nil)
 564
 565	// Set initial budget
 566	initialBudget := Budget{MaxDollars: 5.0}
 567	convo.ResetBudget(initialBudget)
 568
 569	// Check that budget was set
 570	if convo.Budget.MaxDollars != 5.0 {
 571		t.Errorf("Expected budget MaxDollars to be 5.0, got %f", convo.Budget.MaxDollars)
 572	}
 573
 574	// Send a message to accumulate some usage
 575	_, err := convo.SendUserTextMessage("echo: hello")
 576	if err != nil {
 577		t.Fatalf("SendUserTextMessage failed: %v", err)
 578	}
 579
 580	// Get current usage
 581	usage := convo.CumulativeUsage()
 582	usedAmount := usage.TotalCostUSD
 583
 584	// Reset budget again
 585	newBudget := Budget{MaxDollars: 10.0}
 586	convo.ResetBudget(newBudget)
 587
 588	// Check that budget was adjusted by usage
 589	expectedBudget := 10.0 + usedAmount
 590	if convo.Budget.MaxDollars != expectedBudget {
 591		t.Errorf("Expected adjusted budget MaxDollars to be %f, got %f", expectedBudget, convo.Budget.MaxDollars)
 592	}
 593}
 594
 595// TestOverBudgetFunction tests the overBudget function
 596func TestOverBudgetFunction(t *testing.T) {
 597	ctx := context.Background()
 598	srv := loop.NewPredictableService()
 599	convo := New(ctx, srv, nil)
 600
 601	// Test overBudget with no budget set
 602	err := convo.overBudget()
 603	if err != nil {
 604		t.Errorf("overBudget should return nil when no budget is set, got %v", err)
 605	}
 606
 607	// Set a budget
 608	convo.Budget.MaxDollars = 5.0
 609
 610	// Test overBudget with budget not exceeded
 611	err = convo.overBudget()
 612	if err != nil {
 613		t.Errorf("overBudget should return nil when budget is not exceeded, got %v", err)
 614	}
 615}
 616
 617// TestGetID tests the GetID function
 618
 619// TestListenerMethods tests the listener methods
 620func TestListenerMethods(t *testing.T) {
 621	listener := &NoopListener{}
 622	ctx := context.Background()
 623	convo := &Convo{}
 624
 625	// Test that noop listener methods don't panic
 626	listener.OnToolCall(ctx, convo, "id", "toolName", json.RawMessage(`{"key":"value"}`), llm.Content{})
 627	listener.OnToolResult(ctx, convo, "id", "toolName", json.RawMessage(`{"key":"value"}`), llm.Content{}, nil, nil)
 628	listener.OnResponse(ctx, convo, "id", &llm.Response{})
 629	listener.OnRequest(ctx, convo, "id", &llm.Message{})
 630
 631	t.Log("NoopListener methods executed without panic")
 632}
 633
 634// TestIncrementToolUse tests the incrementToolUse function
 635func TestIncrementToolUse(t *testing.T) {
 636	ctx := context.Background()
 637	srv := loop.NewPredictableService()
 638	convo := New(ctx, srv, nil)
 639
 640	// Check initial state
 641	usage := convo.CumulativeUsage()
 642	if usage.ToolUses["testTool"] != 0 {
 643		t.Errorf("Expected 0 uses of testTool, got %d", usage.ToolUses["testTool"])
 644	}
 645
 646	// Increment tool use
 647	convo.incrementToolUse("testTool")
 648
 649	// Check that tool use was incremented
 650	usage = convo.CumulativeUsage()
 651	if usage.ToolUses["testTool"] != 1 {
 652		t.Errorf("Expected 1 use of testTool, got %d", usage.ToolUses["testTool"])
 653	}
 654
 655	// Increment again
 656	convo.incrementToolUse("testTool")
 657
 658	// Check that tool use was incremented again
 659	usage = convo.CumulativeUsage()
 660	if usage.ToolUses["testTool"] != 2 {
 661		t.Errorf("Expected 2 uses of testTool, got %d", usage.ToolUses["testTool"])
 662	}
 663
 664	// Test with different tool
 665	convo.incrementToolUse("anotherTool")
 666	usage = convo.CumulativeUsage()
 667	if usage.ToolUses["anotherTool"] != 1 {
 668		t.Errorf("Expected 1 use of anotherTool, got %d", usage.ToolUses["anotherTool"])
 669	}
 670}
 671
 672// TestDebugJSON tests the DebugJSON function
 673// TestToolResultCancelContents tests the ToolResultCancelContents function
 674func TestToolResultCancelContents(t *testing.T) {
 675	ctx := context.Background()
 676	srv := &ant.Service{}
 677	convo := New(ctx, srv, nil)
 678
 679	// Test with response that doesn't have tool use stop reason
 680	resp := &llm.Response{
 681		StopReason: llm.StopReasonEndTurn,
 682	}
 683	contents, err := convo.ToolResultCancelContents(resp)
 684	if err != nil {
 685		t.Errorf("ToolResultCancelContents should not error with non-tool-use response: %v", err)
 686	}
 687	if contents != nil {
 688		t.Error("ToolResultCancelContents should return nil with non-tool-use response")
 689	}
 690
 691	// Test with response that has tool use stop reason but no tool use content
 692	resp = &llm.Response{
 693		StopReason: llm.StopReasonToolUse,
 694		Content: []llm.Content{
 695			{Type: llm.ContentTypeText, Text: "Hello"},
 696		},
 697	}
 698	contents, err = convo.ToolResultCancelContents(resp)
 699	if err != nil {
 700		t.Errorf("ToolResultCancelContents should not error with tool use response but no tool content: %v", err)
 701	}
 702	// Check if contents is nil (this is expected when no tool uses are found)
 703	if contents != nil && len(contents) != 0 {
 704		t.Errorf("ToolResultCancelContents should return nil or empty slice with tool use response but no tool content, got length %d", len(contents))
 705	}
 706
 707	// Test with response that has tool use stop reason and actual tool use content
 708	resp = &llm.Response{
 709		StopReason: llm.StopReasonToolUse,
 710		Content: []llm.Content{
 711			{Type: llm.ContentTypeToolUse, ID: "tool1", ToolName: "testTool"},
 712		},
 713	}
 714	contents, err = convo.ToolResultCancelContents(resp)
 715	if err != nil {
 716		t.Errorf("ToolResultCancelContents should not error with tool use response and tool content: %v", err)
 717	}
 718	if contents == nil {
 719		t.Error("ToolResultCancelContents should return non-nil slice with tool use response and tool content")
 720	} else if len(contents) != 1 {
 721		t.Errorf("ToolResultCancelContents should return slice with one element with tool use response and tool content, got length %d", len(contents))
 722	} else {
 723		// Check that the returned content has the correct properties
 724		if contents[0].Type != llm.ContentTypeToolResult {
 725			t.Errorf("ToolResultCancelContents should return tool result content, got type %v", contents[0].Type)
 726		}
 727		if contents[0].ToolUseID != "tool1" {
 728			t.Errorf("ToolResultCancelContents should return content with correct ToolUseID, got %v", contents[0].ToolUseID)
 729		}
 730		if !contents[0].ToolError {
 731			t.Error("ToolResultCancelContents should return content with ToolError set to true")
 732		}
 733	}
 734}
 735
 736// TestNewToolUseContext tests the newToolUseContext function
 737func TestNewToolUseContext(t *testing.T) {
 738	ctx := context.Background()
 739	srv := &ant.Service{}
 740	convo := New(ctx, srv, nil)
 741
 742	// Test creating a new tool use context
 743	toolUseID := "test-tool-use-id"
 744	toolCtx, cancel := convo.newToolUseContext(ctx, toolUseID)
 745
 746	if toolCtx == nil {
 747		t.Error("newToolUseContext should return a valid context")
 748	}
 749
 750	if cancel == nil {
 751		t.Error("newToolUseContext should return a valid cancel function")
 752	}
 753
 754	// Check that the tool use was registered
 755	convo.toolUseCancelMu.Lock()
 756	_, exists := convo.toolUseCancel[toolUseID]
 757	convo.toolUseCancelMu.Unlock()
 758
 759	if !exists {
 760		t.Error("newToolUseContext should register the tool use cancel function")
 761	}
 762
 763	// Test that cancel function works
 764	cancel()
 765
 766	// Check that the tool use was unregistered
 767	convo.toolUseCancelMu.Lock()
 768	_, exists = convo.toolUseCancel[toolUseID]
 769	convo.toolUseCancelMu.Unlock()
 770
 771	if exists {
 772		t.Error("Cancel function should unregister the tool use")
 773	}
 774}
 775
 776// TestToolResultContents tests the ToolResultContents function
 777func TestToolResultContents(t *testing.T) {
 778	ctx := context.Background()
 779	srv := &ant.Service{}
 780	convo := New(ctx, srv, nil)
 781
 782	// Skip nil response test as the function doesn't handle nil properly
 783	// This would cause a nil pointer dereference in the actual function
 784
 785	// Test with response that doesn't have tool use stop reason
 786	resp := &llm.Response{
 787		StopReason: llm.StopReasonEndTurn,
 788	}
 789	contents, endsTurn, err := convo.ToolResultContents(ctx, resp)
 790	if err != nil {
 791		t.Errorf("ToolResultContents should not error with non-tool-use response: %v", err)
 792	}
 793	if contents != nil {
 794		t.Error("ToolResultContents should return nil with non-tool-use response")
 795	}
 796	if endsTurn {
 797		t.Error("ToolResultContents should return false for endsTurn with non-tool-use response")
 798	}
 799}
 800
 801// testListener is a custom listener implementation for testing
 802type testListener struct {
 803	events []string
 804}
 805
 806func (tl *testListener) OnToolCall(ctx context.Context, convo *Convo, id, toolName string, toolInput json.RawMessage, content llm.Content) {
 807	tl.events = append(tl.events, "OnToolCall")
 808}
 809
 810func (tl *testListener) OnToolResult(ctx context.Context, convo *Convo, id, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
 811	tl.events = append(tl.events, "OnToolResult")
 812}
 813
 814func (tl *testListener) OnResponse(ctx context.Context, convo *Convo, id string, resp *llm.Response) {
 815	tl.events = append(tl.events, "OnResponse")
 816}
 817
 818func (tl *testListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {
 819	tl.events = append(tl.events, "OnRequest")
 820}
 821
 822// TestListenerInterface tests that the Listener interface methods are called
 823func TestListenerInterface(t *testing.T) {
 824	listener := &testListener{}
 825	ctx := context.Background()
 826	convo := &Convo{}
 827
 828	// Test that all listener methods can be called without panicking
 829	listener.OnToolCall(ctx, convo, "id", "toolName", json.RawMessage(`{"key":"value"}`), llm.Content{})
 830	listener.OnToolResult(ctx, convo, "id", "toolName", json.RawMessage(`{"key":"value"}`), llm.Content{}, nil, nil)
 831	listener.OnResponse(ctx, convo, "id", &llm.Response{})
 832	listener.OnRequest(ctx, convo, "id", &llm.Message{})
 833
 834	// Check that events were recorded
 835	if len(listener.events) != 4 {
 836		t.Errorf("Expected 4 events, got %d", len(listener.events))
 837	}
 838
 839	expectedEvents := []string{"OnToolCall", "OnToolResult", "OnResponse", "OnRequest"}
 840	for i, expected := range expectedEvents {
 841		if listener.events[i] != expected {
 842			t.Errorf("Expected event %s, got %s", expected, listener.events[i])
 843		}
 844	}
 845}
 846
 847// TestToolResultContentsWithToolUse tests ToolResultContents with actual tool use
 848func TestToolResultContentsWithToolUse(t *testing.T) {
 849	ctx := context.Background()
 850	srv := loop.NewPredictableService()
 851	convo := New(ctx, srv, nil)
 852
 853	// Add a simple echo tool
 854	convo.Tools = append(convo.Tools, &llm.Tool{
 855		Name:        "echo",
 856		Description: "Echo tool for testing",
 857		InputSchema: json.RawMessage(`{"type": "object", "properties": {"message": {"type": "string"}}}`),
 858		Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
 859			return llm.ToolOut{
 860				LLMContent: []llm.Content{{Type: llm.ContentTypeText, Text: "echo response"}},
 861			}
 862		},
 863	})
 864
 865	// Create a response with tool use stop reason
 866	resp := &llm.Response{
 867		StopReason: llm.StopReasonToolUse,
 868		Content: []llm.Content{
 869			{
 870				Type:      llm.ContentTypeToolUse,
 871				ID:        "test-tool-call",
 872				ToolName:  "echo",
 873				ToolInput: json.RawMessage(`{"message": "test"}`),
 874			},
 875		},
 876	}
 877
 878	// Test ToolResultContents with tool use
 879	contents, endsTurn, err := convo.ToolResultContents(ctx, resp)
 880	if err != nil {
 881		t.Fatalf("ToolResultContents failed: %v", err)
 882	}
 883
 884	// Should return tool results
 885	if len(contents) == 0 {
 886		t.Error("ToolResultContents should return tool results")
 887	}
 888
 889	// Check the content type
 890	if contents[0].Type != llm.ContentTypeToolResult {
 891		t.Errorf("Expected ContentTypeToolResult, got %s", contents[0].Type)
 892	}
 893
 894	// For our echo tool, endsTurn should be false
 895	if endsTurn {
 896		t.Error("Expected endsTurn to be false for echo tool")
 897	}
 898}
 899
 900// TestOverBudgetWithExceeded tests OverBudget when budget is exceeded
 901func TestOverBudgetWithExceeded(t *testing.T) {
 902	ctx := context.Background()
 903	srv := loop.NewPredictableService()
 904	convo := New(ctx, srv, nil)
 905
 906	// Set a tiny budget
 907	convo.Budget.MaxDollars = 0.0000001
 908
 909	// Send a message to accumulate usage
 910	_, err := convo.SendUserTextMessage("test message")
 911	if err != nil {
 912		t.Fatalf("SendUserTextMessage failed: %v", err)
 913	}
 914
 915	// Test that OverBudget returns an error
 916	err = convo.OverBudget()
 917	if err == nil {
 918		t.Error("OverBudget should return an error when budget is exceeded")
 919	}
 920}
 921
 922// TestResetBudgetWithUsage tests ResetBudget with existing usage
 923func TestResetBudgetWithUsage(t *testing.T) {
 924	ctx := context.Background()
 925	srv := loop.NewPredictableService()
 926	convo := New(ctx, srv, nil)
 927
 928	// Send a message to accumulate usage
 929	_, err := convo.SendUserTextMessage("test message")
 930	if err != nil {
 931		t.Fatalf("SendUserTextMessage failed: %v", err)
 932	}
 933
 934	// Get current usage
 935	initialUsage := convo.CumulativeUsage()
 936	initialCost := initialUsage.TotalCostUSD
 937
 938	// Reset budget
 939	newBudget := Budget{MaxDollars: 10.0}
 940	convo.ResetBudget(newBudget)
 941
 942	// Check that budget was adjusted
 943	expectedBudget := 10.0 + initialCost
 944	if convo.Budget.MaxDollars != expectedBudget {
 945		t.Errorf("Expected budget to be %f, got %f", expectedBudget, convo.Budget.MaxDollars)
 946	}
 947}
 948
 949// TestSubConvoWithHistory tests SubConvoWithHistory method
 950
 951// TestDepth tests Depth method
 952
 953// TestGetID tests GetID method
 954
 955// TestDebugJSON tests DebugJSON method
 956
 957// recordingListener is a listener that records all calls for testing
 958type recordingListener struct {
 959	calls []string
 960}
 961
 962func (rl *recordingListener) OnToolCall(ctx context.Context, convo *Convo, id, toolName string, toolInput json.RawMessage, content llm.Content) {
 963	rl.calls = append(rl.calls, "OnToolCall")
 964}
 965
 966func (rl *recordingListener) OnToolResult(ctx context.Context, convo *Convo, id, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
 967	rl.calls = append(rl.calls, "OnToolResult")
 968}
 969
 970func (rl *recordingListener) OnResponse(ctx context.Context, convo *Convo, id string, resp *llm.Response) {
 971	rl.calls = append(rl.calls, "OnResponse")
 972}
 973
 974func (rl *recordingListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {
 975	rl.calls = append(rl.calls, "OnRequest")
 976}
 977
 978// TestConvoListenerIntegration tests that Convo actually calls listener methods during operation
 979func TestConvoListenerIntegration(t *testing.T) {
 980	ctx := context.Background()
 981	srv := loop.NewPredictableService()
 982	convo := New(ctx, srv, nil)
 983
 984	// Set up recording listener
 985	listener := &recordingListener{}
 986	convo.Listener = listener
 987
 988	// Send a message to trigger listener calls
 989	_, err := convo.SendUserTextMessage("Hello")
 990	if err != nil {
 991		t.Fatalf("SendUserTextMessage failed: %v", err)
 992	}
 993
 994	// Check that we recorded some calls
 995	if len(listener.calls) == 0 {
 996		t.Error("Expected listener methods to be called during conversation, but no calls were recorded")
 997	}
 998
 999	// Verify that request and response events were recorded
1000	requestFound := false
1001	responseFound := false
1002	for _, call := range listener.calls {
1003		if call == "OnRequest" {
1004			requestFound = true
1005		}
1006		if call == "OnResponse" {
1007			responseFound = true
1008		}
1009	}
1010
1011	if !requestFound {
1012		t.Error("Expected OnRequest to be called during conversation")
1013	}
1014	if !responseFound {
1015		t.Error("Expected OnResponse to be called during conversation")
1016	}
1017}
1018
1019// TestSubConvoWithHistory tests SubConvoWithHistory method
1020func TestSubConvoWithHistoryAdditional(t *testing.T) {
1021	ctx := context.Background()
1022	srv := loop.NewPredictableService()
1023	convo := New(ctx, srv, nil)
1024
1025	// Send a message to create some history
1026	_, err := convo.SendUserTextMessage("Hello")
1027	if err != nil {
1028		t.Fatalf("SendUserTextMessage failed: %v", err)
1029	}
1030
1031	// Create sub-conversation with history
1032	subConvo := convo.SubConvoWithHistory()
1033	if subConvo == nil {
1034		t.Fatal("SubConvoWithHistory should return a valid conversation")
1035	}
1036
1037	// Check that sub-conversation has parent
1038	if subConvo.Parent != convo {
1039		t.Error("Sub-conversation should have parent set")
1040	}
1041
1042	// Check that sub-conversation has messages (history)
1043	if len(subConvo.messages) == 0 {
1044		t.Error("Sub-conversation should have messages from parent")
1045	}
1046
1047	// Check that the first message is from the parent conversation
1048	if len(subConvo.messages) < 1 {
1049		t.Error("Sub-conversation should have at least one message")
1050	}
1051}
1052
1053// TestDepthAdditional tests Depth method
1054func TestDepthAdditional(t *testing.T) {
1055	ctx := context.Background()
1056	srv := loop.NewPredictableService()
1057	convo := New(ctx, srv, nil)
1058
1059	// Root conversation should have depth 0
1060	if convo.Depth() != 0 {
1061		t.Errorf("Expected depth 0, got %d", convo.Depth())
1062	}
1063
1064	// Sub-conversation should have depth 1
1065	subConvo := convo.SubConvo()
1066	if subConvo.Depth() != 1 {
1067		t.Errorf("Expected depth 1, got %d", subConvo.Depth())
1068	}
1069
1070	// Sub-sub-conversation should have depth 2
1071	subSubConvo := subConvo.SubConvo()
1072	if subSubConvo.Depth() != 2 {
1073		t.Errorf("Expected depth 2, got %d", subSubConvo.Depth())
1074	}
1075}
1076
1077// TestGetIDAdditional tests GetID method
1078func TestGetIDAdditional(t *testing.T) {
1079	ctx := context.Background()
1080	srv := loop.NewPredictableService()
1081	convo := New(ctx, srv, nil)
1082
1083	id := convo.GetID()
1084	if id == "" {
1085		t.Error("GetID should return a non-empty ID")
1086	}
1087	if id != convo.ID {
1088		t.Error("GetID should return the conversation ID")
1089	}
1090}
1091
1092// TestDebugJSONAdditional tests DebugJSON method
1093func TestDebugJSONAdditional(t *testing.T) {
1094	ctx := context.Background()
1095	srv := loop.NewPredictableService()
1096	convo := New(ctx, srv, nil)
1097
1098	// Test with empty conversation
1099	jsonData, err := convo.DebugJSON()
1100	if err != nil {
1101		t.Errorf("DebugJSON failed: %v", err)
1102	}
1103	if len(jsonData) == 0 {
1104		t.Error("DebugJSON should return non-empty data")
1105	}
1106
1107	// Test with conversation that has messages
1108	_, err = convo.SendUserTextMessage("Hello")
1109	if err != nil {
1110		t.Fatalf("SendUserTextMessage failed: %v", err)
1111	}
1112
1113	jsonData, err = convo.DebugJSON()
1114	if err != nil {
1115		t.Errorf("DebugJSON failed: %v", err)
1116	}
1117	if len(jsonData) == 0 {
1118		t.Error("DebugJSON should return non-empty data")
1119	}
1120
1121	// Verify it's valid JSON by trying to unmarshal it
1122	var parsed interface{}
1123	err = json.Unmarshal(jsonData, &parsed)
1124	if err != nil {
1125		t.Errorf("DebugJSON should return valid JSON: %v", err)
1126	}
1127}