agent_test.go

   1package ai
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"testing"
   9
  10	"github.com/charmbracelet/crush/internal/llm/tools"
  11	"github.com/stretchr/testify/require"
  12)
  13
  14// Mock tool for testing
  15type mockTool struct {
  16	name        string
  17	description string
  18	parameters  map[string]any
  19	required    []string
  20	executeFunc func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error)
  21}
  22
  23func (m *mockTool) Info() tools.ToolInfo {
  24	return tools.ToolInfo{
  25		Name:        m.name,
  26		Description: m.description,
  27		Parameters:  m.parameters,
  28		Required:    m.required,
  29	}
  30}
  31
  32func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
  33	if m.executeFunc != nil {
  34		return m.executeFunc(ctx, call)
  35	}
  36	return tools.ToolResponse{Content: "mock result", IsError: false}, nil
  37}
  38
  39// Mock language model for testing
  40type mockLanguageModel struct {
  41	generateFunc func(ctx context.Context, call Call) (*Response, error)
  42}
  43
  44func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
  45	if m.generateFunc != nil {
  46		return m.generateFunc(ctx, call)
  47	}
  48	return &Response{
  49		Content: []Content{
  50			TextContent{Text: "Hello, world!"},
  51		},
  52		Usage: Usage{
  53			InputTokens:  3,
  54			OutputTokens: 10,
  55			TotalTokens:  13,
  56		},
  57		FinishReason: FinishReasonStop,
  58	}, nil
  59}
  60
  61func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
  62	panic("not implemented")
  63}
  64
  65func (m *mockLanguageModel) Provider() string {
  66	return "mock-provider"
  67}
  68
  69func (m *mockLanguageModel) Model() string {
  70	return "mock-model"
  71}
  72
  73// Test result.content - comprehensive content types (matches TS test)
  74func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
  75	t.Parallel()
  76
  77	tool1 := &mockTool{
  78		name:        "tool1",
  79		description: "Test tool",
  80		parameters: map[string]any{
  81			"value": map[string]any{"type": "string"},
  82		},
  83		required: []string{"value"},
  84		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
  85			var input map[string]any
  86			err := json.Unmarshal([]byte(call.Input), &input)
  87			require.NoError(t, err)
  88			require.Equal(t, "value", input["value"])
  89			return tools.ToolResponse{Content: "result1", IsError: false}, nil
  90		},
  91	}
  92
  93	model := &mockLanguageModel{
  94		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
  95			return &Response{
  96				Content: []Content{
  97					TextContent{Text: "Hello, world!"},
  98					SourceContent{
  99						ID:         "123",
 100						URL:        "https://example.com",
 101						Title:      "Example",
 102						SourceType: SourceTypeURL,
 103						ProviderMetadata: ProviderMetadata{
 104							"provider": map[string]any{"custom": "value"},
 105						},
 106					},
 107					FileContent{
 108						Data:      []byte{1, 2, 3},
 109						MediaType: "image/png",
 110					},
 111					ReasoningContent{
 112						Text: "I will open the conversation with witty banter.",
 113					},
 114					ToolCallContent{
 115						ToolCallID: "call-1",
 116						ToolName:   "tool1",
 117						Input:      `{"value":"value"}`,
 118					},
 119					TextContent{Text: "More text"},
 120				},
 121				Usage: Usage{
 122					InputTokens:  3,
 123					OutputTokens: 10,
 124					TotalTokens:  13,
 125				},
 126				FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
 127			}, nil
 128		},
 129	}
 130
 131	agent := NewAgent(model, WithTools(tool1))
 132	result, err := agent.Generate(context.Background(), AgentCall{
 133		Prompt: "prompt",
 134	})
 135
 136	require.NoError(t, err)
 137	require.NotNil(t, result)
 138	require.Len(t, result.Steps, 1) // Single step like TypeScript
 139
 140	// Check final response content includes tool result
 141	require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
 142
 143	// Verify each content type in order
 144	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 145	require.True(t, ok)
 146	require.Equal(t, "Hello, world!", textContent.Text)
 147
 148	sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
 149	require.True(t, ok)
 150	require.Equal(t, "123", sourceContent.ID)
 151
 152	fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
 153	require.True(t, ok)
 154	require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
 155
 156	reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
 157	require.True(t, ok)
 158	require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
 159
 160	toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
 161	require.True(t, ok)
 162	require.Equal(t, "call-1", toolCallContent.ToolCallID)
 163
 164	moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
 165	require.True(t, ok)
 166	require.Equal(t, "More text", moreTextContent.Text)
 167
 168	// Tool result should be appended
 169	toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
 170	require.True(t, ok)
 171	require.Equal(t, "call-1", toolResultContent.ToolCallID)
 172	require.Equal(t, "tool1", toolResultContent.ToolName)
 173}
 174
 175// Test result.text extraction
 176func TestAgent_Generate_ResultText(t *testing.T) {
 177	t.Parallel()
 178
 179	model := &mockLanguageModel{
 180		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 181			return &Response{
 182				Content: []Content{
 183					TextContent{Text: "Hello, world!"},
 184				},
 185				Usage: Usage{
 186					InputTokens:  3,
 187					OutputTokens: 10,
 188					TotalTokens:  13,
 189				},
 190				FinishReason: FinishReasonStop,
 191			}, nil
 192		},
 193	}
 194
 195	agent := NewAgent(model)
 196	result, err := agent.Generate(context.Background(), AgentCall{
 197		Prompt: "prompt",
 198	})
 199
 200	require.NoError(t, err)
 201	require.NotNil(t, result)
 202
 203	// Test text extraction from content
 204	text := result.Response.Content.Text()
 205	require.Equal(t, "Hello, world!", text)
 206}
 207
 208// Test result.toolCalls extraction (matches TS test exactly)
 209func TestAgent_Generate_ResultToolCalls(t *testing.T) {
 210	t.Parallel()
 211
 212	tool1 := &mockTool{
 213		name:        "tool1",
 214		description: "Test tool 1",
 215		parameters: map[string]any{
 216			"value": map[string]any{"type": "string"},
 217		},
 218		required: []string{"value"},
 219	}
 220
 221	tool2 := &mockTool{
 222		name:        "tool2",
 223		description: "Test tool 2",
 224		parameters: map[string]any{
 225			"somethingElse": map[string]any{"type": "string"},
 226		},
 227		required: []string{"somethingElse"},
 228	}
 229
 230	model := &mockLanguageModel{
 231		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 232			// Verify tools are passed correctly
 233			require.Len(t, call.Tools, 2)
 234			require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
 235
 236			// Verify prompt structure
 237			require.Len(t, call.Prompt, 1)
 238			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
 239
 240			return &Response{
 241				Content: []Content{
 242					ToolCallContent{
 243						ToolCallID: "call-1",
 244						ToolName:   "tool1",
 245						Input:      `{"value":"value"}`,
 246					},
 247				},
 248				Usage: Usage{
 249					InputTokens:  3,
 250					OutputTokens: 10,
 251					TotalTokens:  13,
 252				},
 253				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
 254			}, nil
 255		},
 256	}
 257
 258	agent := NewAgent(model, WithTools(tool1, tool2))
 259	result, err := agent.Generate(context.Background(), AgentCall{
 260		Prompt: "test-input",
 261	})
 262
 263	require.NoError(t, err)
 264	require.NotNil(t, result)
 265	require.Len(t, result.Steps, 1) // Single step
 266
 267	// Extract tool calls from final response (should be empty since tools don't execute)
 268	var toolCalls []ToolCallContent
 269	for _, content := range result.Response.Content {
 270		if toolCall, ok := AsContentType[ToolCallContent](content); ok {
 271			toolCalls = append(toolCalls, toolCall)
 272		}
 273	}
 274
 275	require.Len(t, toolCalls, 1)
 276	require.Equal(t, "call-1", toolCalls[0].ToolCallID)
 277	require.Equal(t, "tool1", toolCalls[0].ToolName)
 278
 279	// Parse and verify input
 280	var input map[string]any
 281	err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
 282	require.NoError(t, err)
 283	require.Equal(t, "value", input["value"])
 284}
 285
 286// Test result.toolResults extraction (matches TS test exactly)
 287func TestAgent_Generate_ResultToolResults(t *testing.T) {
 288	t.Parallel()
 289
 290	tool1 := &mockTool{
 291		name:        "tool1",
 292		description: "Test tool",
 293		parameters: map[string]any{
 294			"value": map[string]any{"type": "string"},
 295		},
 296		required: []string{"value"},
 297		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
 298			var input map[string]any
 299			err := json.Unmarshal([]byte(call.Input), &input)
 300			require.NoError(t, err)
 301			require.Equal(t, "value", input["value"])
 302			return tools.ToolResponse{Content: "result1", IsError: false}, nil
 303		},
 304	}
 305
 306	model := &mockLanguageModel{
 307		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 308			// Verify tools and tool choice
 309			require.Len(t, call.Tools, 1)
 310			require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
 311
 312			// Verify prompt
 313			require.Len(t, call.Prompt, 1)
 314			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
 315
 316			return &Response{
 317				Content: []Content{
 318					ToolCallContent{
 319						ToolCallID: "call-1",
 320						ToolName:   "tool1",
 321						Input:      `{"value":"value"}`,
 322					},
 323				},
 324				Usage: Usage{
 325					InputTokens:  3,
 326					OutputTokens: 10,
 327					TotalTokens:  13,
 328				},
 329				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
 330			}, nil
 331		},
 332	}
 333
 334	agent := NewAgent(model, WithTools(tool1))
 335	result, err := agent.Generate(context.Background(), AgentCall{
 336		Prompt: "test-input",
 337	})
 338
 339	require.NoError(t, err)
 340	require.NotNil(t, result)
 341	require.Len(t, result.Steps, 1) // Single step
 342
 343	// Extract tool results from final response
 344	var toolResults []ToolResultContent
 345	for _, content := range result.Response.Content {
 346		if toolResult, ok := AsContentType[ToolResultContent](content); ok {
 347			toolResults = append(toolResults, toolResult)
 348		}
 349	}
 350
 351	require.Len(t, toolResults, 1)
 352	require.Equal(t, "call-1", toolResults[0].ToolCallID)
 353	require.Equal(t, "tool1", toolResults[0].ToolName)
 354
 355	// Verify result content
 356	textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
 357	require.True(t, ok)
 358	require.Equal(t, "result1", textResult.Text)
 359}
 360
 361// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
 362func TestAgent_Generate_MultipleSteps(t *testing.T) {
 363	t.Parallel()
 364
 365	tool1 := &mockTool{
 366		name:        "tool1",
 367		description: "Test tool",
 368		parameters: map[string]any{
 369			"value": map[string]any{"type": "string"},
 370		},
 371		required: []string{"value"},
 372		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
 373			var input map[string]any
 374			err := json.Unmarshal([]byte(call.Input), &input)
 375			require.NoError(t, err)
 376			require.Equal(t, "value", input["value"])
 377			return tools.ToolResponse{Content: "result1", IsError: false}, nil
 378		},
 379	}
 380
 381	callCount := 0
 382	model := &mockLanguageModel{
 383		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 384			callCount++
 385			switch callCount {
 386			case 1:
 387				// First call - return tool call with FinishReasonToolCalls
 388				return &Response{
 389					Content: []Content{
 390						ToolCallContent{
 391							ToolCallID: "call-1",
 392							ToolName:   "tool1",
 393							Input:      `{"value":"value"}`,
 394						},
 395					},
 396					Usage: Usage{
 397						InputTokens:  10,
 398						OutputTokens: 5,
 399						TotalTokens:  15,
 400					},
 401					FinishReason: FinishReasonToolCalls, // This triggers multi-step
 402				}, nil
 403			case 2:
 404				// Second call - return final text
 405				return &Response{
 406					Content: []Content{
 407						TextContent{Text: "Hello, world!"},
 408					},
 409					Usage: Usage{
 410						InputTokens:  3,
 411						OutputTokens: 10,
 412						TotalTokens:  13,
 413					},
 414					FinishReason: FinishReasonStop,
 415				}, nil
 416			default:
 417				t.Fatalf("Unexpected call count: %d", callCount)
 418				return nil, nil
 419			}
 420		},
 421	}
 422
 423	agent := NewAgent(model, WithTools(tool1))
 424	result, err := agent.Generate(context.Background(), AgentCall{
 425		Prompt: "test-input",
 426	})
 427
 428	require.NoError(t, err)
 429	require.NotNil(t, result)
 430	require.Len(t, result.Steps, 2)
 431
 432	// Check total usage sums both steps
 433	require.Equal(t, int64(13), result.TotalUsage.InputTokens)  // 10 + 3
 434	require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
 435	require.Equal(t, int64(28), result.TotalUsage.TotalTokens)  // 15 + 13
 436
 437	// Final response should be from last step
 438	require.Len(t, result.Response.Content, 1)
 439	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 440	require.True(t, ok)
 441	require.Equal(t, "Hello, world!", textContent.Text)
 442
 443	// result.toolCalls should be empty (from last step)
 444	var toolCalls []ToolCallContent
 445	for _, content := range result.Response.Content {
 446		if _, ok := AsContentType[ToolCallContent](content); ok {
 447			toolCalls = append(toolCalls, content.(ToolCallContent))
 448		}
 449	}
 450	require.Len(t, toolCalls, 0)
 451
 452	// result.toolResults should be empty (from last step)
 453	var toolResults []ToolResultContent
 454	for _, content := range result.Response.Content {
 455		if _, ok := AsContentType[ToolResultContent](content); ok {
 456			toolResults = append(toolResults, content.(ToolResultContent))
 457		}
 458	}
 459	require.Len(t, toolResults, 0)
 460}
 461
 462// Test basic text generation
 463func TestAgent_Generate_BasicText(t *testing.T) {
 464	t.Parallel()
 465
 466	model := &mockLanguageModel{
 467		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 468			return &Response{
 469				Content: []Content{
 470					TextContent{Text: "Hello, world!"},
 471				},
 472				Usage: Usage{
 473					InputTokens:  3,
 474					OutputTokens: 10,
 475					TotalTokens:  13,
 476				},
 477				FinishReason: FinishReasonStop,
 478			}, nil
 479		},
 480	}
 481
 482	agent := NewAgent(model)
 483	result, err := agent.Generate(context.Background(), AgentCall{
 484		Prompt: "test prompt",
 485	})
 486
 487	require.NoError(t, err)
 488	require.NotNil(t, result)
 489	require.Len(t, result.Steps, 1)
 490
 491	// Check final response
 492	require.Len(t, result.Response.Content, 1)
 493	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 494	require.True(t, ok)
 495	require.Equal(t, "Hello, world!", textContent.Text)
 496
 497	// Check usage
 498	require.Equal(t, int64(3), result.Response.Usage.InputTokens)
 499	require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
 500	require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
 501
 502	// Check total usage
 503	require.Equal(t, int64(3), result.TotalUsage.InputTokens)
 504	require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
 505	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
 506}
 507
 508// Test empty prompt error
 509func TestAgent_Generate_EmptyPrompt(t *testing.T) {
 510	t.Parallel()
 511
 512	model := &mockLanguageModel{}
 513	agent := NewAgent(model)
 514
 515	result, err := agent.Generate(context.Background(), AgentCall{
 516		Prompt: "", // Empty prompt should cause error
 517	})
 518
 519	require.Error(t, err)
 520	require.Nil(t, result)
 521	require.Contains(t, err.Error(), "Prompt can't be empty")
 522}
 523
 524// Test with system prompt
 525func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
 526	t.Parallel()
 527
 528	model := &mockLanguageModel{
 529		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 530			// Verify system message is included
 531			require.Len(t, call.Prompt, 2) // system + user
 532			require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
 533			require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
 534
 535			systemPart, ok := call.Prompt[0].Content[0].(TextPart)
 536			require.True(t, ok)
 537			require.Equal(t, "You are a helpful assistant", systemPart.Text)
 538
 539			return &Response{
 540				Content: []Content{
 541					TextContent{Text: "Hello, world!"},
 542				},
 543				Usage: Usage{
 544					InputTokens:  3,
 545					OutputTokens: 10,
 546					TotalTokens:  13,
 547				},
 548				FinishReason: FinishReasonStop,
 549			}, nil
 550		},
 551	}
 552
 553	agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
 554	result, err := agent.Generate(context.Background(), AgentCall{
 555		Prompt: "test prompt",
 556	})
 557
 558	require.NoError(t, err)
 559	require.NotNil(t, result)
 560}
 561
 562// Test options.headers
 563func TestAgent_Generate_OptionsHeaders(t *testing.T) {
 564	t.Parallel()
 565
 566	model := &mockLanguageModel{
 567		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 568			// Verify headers are passed
 569			require.Equal(t, map[string]string{
 570				"custom-request-header": "request-header-value",
 571			}, call.Headers)
 572
 573			return &Response{
 574				Content: []Content{
 575					TextContent{Text: "Hello, world!"},
 576				},
 577				Usage: Usage{
 578					InputTokens:  3,
 579					OutputTokens: 10,
 580					TotalTokens:  13,
 581				},
 582				FinishReason: FinishReasonStop,
 583			}, nil
 584		},
 585	}
 586
 587	agent := NewAgent(model)
 588	result, err := agent.Generate(context.Background(), AgentCall{
 589		Prompt:  "test-input",
 590		Headers: map[string]string{"custom-request-header": "request-header-value"},
 591	})
 592
 593	require.NoError(t, err)
 594	require.NotNil(t, result)
 595	require.Equal(t, "Hello, world!", result.Response.Content.Text())
 596}
 597
 598// Test options.activeTools filtering
 599func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
 600	t.Parallel()
 601
 602	tool1 := &mockTool{
 603		name:        "tool1",
 604		description: "Test tool 1",
 605		parameters: map[string]any{
 606			"value": map[string]any{"type": "string"},
 607		},
 608		required: []string{"value"},
 609	}
 610
 611	tool2 := &mockTool{
 612		name:        "tool2",
 613		description: "Test tool 2",
 614		parameters: map[string]any{
 615			"value": map[string]any{"type": "string"},
 616		},
 617		required: []string{"value"},
 618	}
 619
 620	model := &mockLanguageModel{
 621		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 622			// Verify only tool1 is available
 623			require.Len(t, call.Tools, 1)
 624			functionTool, ok := call.Tools[0].(FunctionTool)
 625			require.True(t, ok)
 626			require.Equal(t, "tool1", functionTool.Name)
 627
 628			return &Response{
 629				Content: []Content{
 630					TextContent{Text: "Hello, world!"},
 631				},
 632				Usage: Usage{
 633					InputTokens:  3,
 634					OutputTokens: 10,
 635					TotalTokens:  13,
 636				},
 637				FinishReason: FinishReasonStop,
 638			}, nil
 639		},
 640	}
 641
 642	agent := NewAgent(model, WithTools(tool1, tool2))
 643	result, err := agent.Generate(context.Background(), AgentCall{
 644		Prompt:      "test-input",
 645		ActiveTools: []string{"tool1"}, // Only tool1 should be active
 646	})
 647
 648	require.NoError(t, err)
 649	require.NotNil(t, result)
 650}
 651
 652func TestResponseContent_Getters(t *testing.T) {
 653	t.Parallel()
 654
 655	// Create test content with all types
 656	content := ResponseContent{
 657		TextContent{Text: "Hello world"},
 658		ReasoningContent{Text: "Let me think..."},
 659		FileContent{Data: []byte("file data"), MediaType: "text/plain"},
 660		SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
 661		ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
 662		ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
 663	}
 664
 665	// Test Text()
 666	require.Equal(t, "Hello world", content.Text())
 667
 668	// Test Reasoning()
 669	reasoning := content.Reasoning()
 670	require.Len(t, reasoning, 1)
 671	require.Equal(t, "Let me think...", reasoning[0].Text)
 672
 673	// Test ReasoningText()
 674	require.Equal(t, "Let me think...", content.ReasoningText())
 675
 676	// Test Files()
 677	files := content.Files()
 678	require.Len(t, files, 1)
 679	require.Equal(t, "text/plain", files[0].MediaType)
 680	require.Equal(t, []byte("file data"), files[0].Data)
 681
 682	// Test Sources()
 683	sources := content.Sources()
 684	require.Len(t, sources, 1)
 685	require.Equal(t, SourceTypeURL, sources[0].SourceType)
 686	require.Equal(t, "https://example.com", sources[0].URL)
 687	require.Equal(t, "Example", sources[0].Title)
 688
 689	// Test ToolCalls()
 690	toolCalls := content.ToolCalls()
 691	require.Len(t, toolCalls, 1)
 692	require.Equal(t, "call1", toolCalls[0].ToolCallID)
 693	require.Equal(t, "test_tool", toolCalls[0].ToolName)
 694	require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
 695
 696	// Test ToolResults()
 697	toolResults := content.ToolResults()
 698	require.Len(t, toolResults, 1)
 699	require.Equal(t, "call1", toolResults[0].ToolCallID)
 700	require.Equal(t, "test_tool", toolResults[0].ToolName)
 701	result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
 702	require.True(t, ok)
 703	require.Equal(t, "result", result.Text)
 704}
 705
 706func TestResponseContent_Getters_Empty(t *testing.T) {
 707	t.Parallel()
 708
 709	// Test with empty content
 710	content := ResponseContent{}
 711
 712	require.Equal(t, "", content.Text())
 713	require.Equal(t, "", content.ReasoningText())
 714	require.Empty(t, content.Reasoning())
 715	require.Empty(t, content.Files())
 716	require.Empty(t, content.Sources())
 717	require.Empty(t, content.ToolCalls())
 718	require.Empty(t, content.ToolResults())
 719}
 720
 721func TestResponseContent_Getters_MultipleItems(t *testing.T) {
 722	t.Parallel()
 723
 724	// Test with multiple items of same type
 725	content := ResponseContent{
 726		ReasoningContent{Text: "First thought"},
 727		ReasoningContent{Text: "Second thought"},
 728		FileContent{Data: []byte("file1"), MediaType: "text/plain"},
 729		FileContent{Data: []byte("file2"), MediaType: "image/png"},
 730	}
 731
 732	// Test multiple reasoning
 733	reasoning := content.Reasoning()
 734	require.Len(t, reasoning, 2)
 735	require.Equal(t, "First thought", reasoning[0].Text)
 736	require.Equal(t, "Second thought", reasoning[1].Text)
 737
 738	// Test concatenated reasoning text
 739	require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
 740
 741	// Test multiple files
 742	files := content.Files()
 743	require.Len(t, files, 2)
 744	require.Equal(t, "text/plain", files[0].MediaType)
 745	require.Equal(t, "image/png", files[1].MediaType)
 746}
 747
 748func TestStopConditions(t *testing.T) {
 749	t.Parallel()
 750
 751	// Create test steps
 752	step1 := StepResult{
 753		Response: Response{
 754			Content: ResponseContent{
 755				TextContent{Text: "Hello"},
 756			},
 757			FinishReason: FinishReasonToolCalls,
 758			Usage:        Usage{TotalTokens: 10},
 759		},
 760	}
 761
 762	step2 := StepResult{
 763		Response: Response{
 764			Content: ResponseContent{
 765				TextContent{Text: "World"},
 766				ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
 767			},
 768			FinishReason: FinishReasonStop,
 769			Usage:        Usage{TotalTokens: 15},
 770		},
 771	}
 772
 773	step3 := StepResult{
 774		Response: Response{
 775			Content: ResponseContent{
 776				ReasoningContent{Text: "Let me think..."},
 777				FileContent{Data: []byte("data"), MediaType: "text/plain"},
 778			},
 779			FinishReason: FinishReasonLength,
 780			Usage:        Usage{TotalTokens: 20},
 781		},
 782	}
 783
 784	t.Run("StepCountIs", func(t *testing.T) {
 785		condition := StepCountIs(2)
 786
 787		// Should not stop with 1 step
 788		require.False(t, condition([]StepResult{step1}))
 789
 790		// Should stop with 2 steps
 791		require.True(t, condition([]StepResult{step1, step2}))
 792
 793		// Should stop with more than 2 steps
 794		require.True(t, condition([]StepResult{step1, step2, step3}))
 795
 796		// Should not stop with empty steps
 797		require.False(t, condition([]StepResult{}))
 798	})
 799
 800	t.Run("HasToolCall", func(t *testing.T) {
 801		condition := HasToolCall("search")
 802
 803		// Should not stop when tool not called
 804		require.False(t, condition([]StepResult{step1}))
 805
 806		// Should stop when tool is called in last step
 807		require.True(t, condition([]StepResult{step1, step2}))
 808
 809		// Should not stop when tool called in earlier step but not last
 810		require.False(t, condition([]StepResult{step1, step2, step3}))
 811
 812		// Should not stop with empty steps
 813		require.False(t, condition([]StepResult{}))
 814
 815		// Should not stop when different tool is called
 816		differentToolCondition := HasToolCall("different_tool")
 817		require.False(t, differentToolCondition([]StepResult{step1, step2}))
 818	})
 819
 820	t.Run("HasContent", func(t *testing.T) {
 821		reasoningCondition := HasContent(ContentTypeReasoning)
 822		fileCondition := HasContent(ContentTypeFile)
 823
 824		// Should not stop when content type not present
 825		require.False(t, reasoningCondition([]StepResult{step1, step2}))
 826
 827		// Should stop when content type is present in last step
 828		require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
 829		require.True(t, fileCondition([]StepResult{step1, step2, step3}))
 830
 831		// Should not stop with empty steps
 832		require.False(t, reasoningCondition([]StepResult{}))
 833	})
 834
 835	t.Run("FinishReasonIs", func(t *testing.T) {
 836		stopCondition := FinishReasonIs(FinishReasonStop)
 837		lengthCondition := FinishReasonIs(FinishReasonLength)
 838
 839		// Should not stop when finish reason doesn't match
 840		require.False(t, stopCondition([]StepResult{step1}))
 841
 842		// Should stop when finish reason matches in last step
 843		require.True(t, stopCondition([]StepResult{step1, step2}))
 844		require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
 845
 846		// Should not stop with empty steps
 847		require.False(t, stopCondition([]StepResult{}))
 848	})
 849
 850	t.Run("MaxTokensUsed", func(t *testing.T) {
 851		condition := MaxTokensUsed(30)
 852
 853		// Should not stop when under limit
 854		require.False(t, condition([]StepResult{step1}))        // 10 tokens
 855		require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
 856
 857		// Should stop when at or over limit
 858		require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
 859
 860		// Should not stop with empty steps
 861		require.False(t, condition([]StepResult{}))
 862	})
 863}
 864
 865func TestStopConditions_Integration(t *testing.T) {
 866	t.Parallel()
 867
 868	t.Run("StepCountIs integration", func(t *testing.T) {
 869		model := &mockLanguageModel{
 870			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 871				return &Response{
 872					Content: ResponseContent{
 873						TextContent{Text: "Mock response"},
 874					},
 875					Usage: Usage{
 876						InputTokens:  3,
 877						OutputTokens: 10,
 878						TotalTokens:  13,
 879					},
 880					FinishReason: FinishReasonStop,
 881				}, nil
 882			},
 883		}
 884
 885		agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
 886
 887		result, err := agent.Generate(context.Background(), AgentCall{
 888			Prompt: "test prompt",
 889		})
 890
 891		require.NoError(t, err)
 892		require.NotNil(t, result)
 893		require.Len(t, result.Steps, 1) // Should stop after 1 step
 894	})
 895
 896	t.Run("Multiple stop conditions", func(t *testing.T) {
 897		model := &mockLanguageModel{
 898			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 899				return &Response{
 900					Content: ResponseContent{
 901						TextContent{Text: "Mock response"},
 902					},
 903					Usage: Usage{
 904						InputTokens:  3,
 905						OutputTokens: 10,
 906						TotalTokens:  13,
 907					},
 908					FinishReason: FinishReasonStop,
 909				}, nil
 910			},
 911		}
 912
 913		agent := NewAgent(model, WithStopConditions(
 914			StepCountIs(5),                   // Stop after 5 steps
 915			FinishReasonIs(FinishReasonStop), // Or stop on finish reason
 916		))
 917
 918		result, err := agent.Generate(context.Background(), AgentCall{
 919			Prompt: "test prompt",
 920		})
 921
 922		require.NoError(t, err)
 923		require.NotNil(t, result)
 924		// Should stop on first condition met (finish reason stop)
 925		require.Equal(t, FinishReasonStop, result.Response.FinishReason)
 926	})
 927}
 928
 929func TestPrepareStep(t *testing.T) {
 930	t.Parallel()
 931
 932	t.Run("System prompt modification", func(t *testing.T) {
 933		var capturedSystemPrompt string
 934		model := &mockLanguageModel{
 935			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 936				// Capture the system message to verify it was modified
 937				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
 938					if len(call.Prompt[0].Content) > 0 {
 939						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
 940							capturedSystemPrompt = textPart.Text
 941						}
 942					}
 943				}
 944				return &Response{
 945					Content: ResponseContent{
 946						TextContent{Text: "Response"},
 947					},
 948					Usage:        Usage{TotalTokens: 10},
 949					FinishReason: FinishReasonStop,
 950				}, nil
 951			},
 952		}
 953
 954		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
 955			newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
 956			return PrepareStepResult{
 957				Model:    options.Model,
 958				Messages: options.Messages,
 959				System:   &newSystem,
 960			}
 961		}
 962
 963		agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
 964
 965		result, err := agent.Generate(context.Background(), AgentCall{
 966			Prompt:      "test prompt",
 967			PrepareStep: prepareStepFunc,
 968		})
 969
 970		require.NoError(t, err)
 971		require.NotNil(t, result)
 972		require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
 973	})
 974
 975	t.Run("Tool choice modification", func(t *testing.T) {
 976		var capturedToolChoice *ToolChoice
 977		model := &mockLanguageModel{
 978			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 979				capturedToolChoice = call.ToolChoice
 980				return &Response{
 981					Content: ResponseContent{
 982						TextContent{Text: "Response"},
 983					},
 984					Usage:        Usage{TotalTokens: 10},
 985					FinishReason: FinishReasonStop,
 986				}, nil
 987			},
 988		}
 989
 990		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
 991			toolChoice := ToolChoiceNone
 992			return PrepareStepResult{
 993				Model:      options.Model,
 994				Messages:   options.Messages,
 995				ToolChoice: &toolChoice,
 996			}
 997		}
 998
 999		agent := NewAgent(model)
1000
1001		result, err := agent.Generate(context.Background(), AgentCall{
1002			Prompt:      "test prompt",
1003			PrepareStep: prepareStepFunc,
1004		})
1005
1006		require.NoError(t, err)
1007		require.NotNil(t, result)
1008		require.NotNil(t, capturedToolChoice)
1009		require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1010	})
1011
1012	t.Run("Active tools modification", func(t *testing.T) {
1013		var capturedToolNames []string
1014		model := &mockLanguageModel{
1015			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1016				// Capture tool names to verify active tools were modified
1017				for _, tool := range call.Tools {
1018					capturedToolNames = append(capturedToolNames, tool.GetName())
1019				}
1020				return &Response{
1021					Content: ResponseContent{
1022						TextContent{Text: "Response"},
1023					},
1024					Usage:        Usage{TotalTokens: 10},
1025					FinishReason: FinishReasonStop,
1026				}, nil
1027			},
1028		}
1029
1030		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1031		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1032		tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1033
1034		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1035			activeTools := []string{"tool2"} // Only tool2 should be active
1036			return PrepareStepResult{
1037				Model:       options.Model,
1038				Messages:    options.Messages,
1039				ActiveTools: activeTools,
1040			}
1041		}
1042
1043		agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1044
1045		result, err := agent.Generate(context.Background(), AgentCall{
1046			Prompt:      "test prompt",
1047			PrepareStep: prepareStepFunc,
1048		})
1049
1050		require.NoError(t, err)
1051		require.NotNil(t, result)
1052		require.Len(t, capturedToolNames, 1)
1053		require.Equal(t, "tool2", capturedToolNames[0])
1054	})
1055
1056	t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1057		var capturedToolCount int
1058		model := &mockLanguageModel{
1059			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1060				capturedToolCount = len(call.Tools)
1061				return &Response{
1062					Content: ResponseContent{
1063						TextContent{Text: "Response"},
1064					},
1065					Usage:        Usage{TotalTokens: 10},
1066					FinishReason: FinishReasonStop,
1067				}, nil
1068			},
1069		}
1070
1071		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1072
1073		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1074			return PrepareStepResult{
1075				Model:           options.Model,
1076				Messages:        options.Messages,
1077				DisableAllTools: true, // Disable all tools for this step
1078			}
1079		}
1080
1081		agent := NewAgent(model, WithTools(tool1))
1082
1083		result, err := agent.Generate(context.Background(), AgentCall{
1084			Prompt:      "test prompt",
1085			PrepareStep: prepareStepFunc,
1086		})
1087
1088		require.NoError(t, err)
1089		require.NotNil(t, result)
1090		require.Equal(t, 0, capturedToolCount) // No tools should be passed
1091	})
1092
1093	t.Run("All fields modified together", func(t *testing.T) {
1094		var capturedSystemPrompt string
1095		var capturedToolChoice *ToolChoice
1096		var capturedToolNames []string
1097
1098		model := &mockLanguageModel{
1099			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1100				// Capture system prompt
1101				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1102					if len(call.Prompt[0].Content) > 0 {
1103						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1104							capturedSystemPrompt = textPart.Text
1105						}
1106					}
1107				}
1108				// Capture tool choice
1109				capturedToolChoice = call.ToolChoice
1110				// Capture tool names
1111				for _, tool := range call.Tools {
1112					capturedToolNames = append(capturedToolNames, tool.GetName())
1113				}
1114				return &Response{
1115					Content: ResponseContent{
1116						TextContent{Text: "Response"},
1117					},
1118					Usage:        Usage{TotalTokens: 10},
1119					FinishReason: FinishReasonStop,
1120				}, nil
1121			},
1122		}
1123
1124		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1125		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1126
1127		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1128			newSystem := "Step-specific system"
1129			toolChoice := SpecificToolChoice("tool1")
1130			activeTools := []string{"tool1"}
1131			return PrepareStepResult{
1132				Model:       options.Model,
1133				Messages:    options.Messages,
1134				System:      &newSystem,
1135				ToolChoice:  &toolChoice,
1136				ActiveTools: activeTools,
1137			}
1138		}
1139
1140		agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1141
1142		result, err := agent.Generate(context.Background(), AgentCall{
1143			Prompt:      "test prompt",
1144			PrepareStep: prepareStepFunc,
1145		})
1146
1147		require.NoError(t, err)
1148		require.NotNil(t, result)
1149		require.Equal(t, "Step-specific system", capturedSystemPrompt)
1150		require.NotNil(t, capturedToolChoice)
1151		require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1152		require.Len(t, capturedToolNames, 1)
1153		require.Equal(t, "tool1", capturedToolNames[0])
1154	})
1155
1156	t.Run("Nil fields use parent values", func(t *testing.T) {
1157		var capturedSystemPrompt string
1158		var capturedToolChoice *ToolChoice
1159		var capturedToolNames []string
1160
1161		model := &mockLanguageModel{
1162			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1163				// Capture system prompt
1164				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1165					if len(call.Prompt[0].Content) > 0 {
1166						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1167							capturedSystemPrompt = textPart.Text
1168						}
1169					}
1170				}
1171				// Capture tool choice
1172				capturedToolChoice = call.ToolChoice
1173				// Capture tool names
1174				for _, tool := range call.Tools {
1175					capturedToolNames = append(capturedToolNames, tool.GetName())
1176				}
1177				return &Response{
1178					Content: ResponseContent{
1179						TextContent{Text: "Response"},
1180					},
1181					Usage:        Usage{TotalTokens: 10},
1182					FinishReason: FinishReasonStop,
1183				}, nil
1184			},
1185		}
1186
1187		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1188
1189		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1190			// All optional fields are nil, should use parent values
1191			return PrepareStepResult{
1192				Model:       options.Model,
1193				Messages:    options.Messages,
1194				System:      nil, // Use parent
1195				ToolChoice:  nil, // Use parent (auto)
1196				ActiveTools: nil, // Use parent (all tools)
1197			}
1198		}
1199
1200		agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1201
1202		result, err := agent.Generate(context.Background(), AgentCall{
1203			Prompt:      "test prompt",
1204			PrepareStep: prepareStepFunc,
1205		})
1206
1207		require.NoError(t, err)
1208		require.NotNil(t, result)
1209		require.Equal(t, "Parent system", capturedSystemPrompt)
1210		require.NotNil(t, capturedToolChoice)
1211		require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1212		require.Len(t, capturedToolNames, 1)
1213		require.Equal(t, "tool1", capturedToolNames[0])
1214	})
1215
1216	t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1217		var capturedToolNames []string
1218		model := &mockLanguageModel{
1219			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1220				// Capture tool names to verify all tools are included
1221				for _, tool := range call.Tools {
1222					capturedToolNames = append(capturedToolNames, tool.GetName())
1223				}
1224				return &Response{
1225					Content: ResponseContent{
1226						TextContent{Text: "Response"},
1227					},
1228					Usage:        Usage{TotalTokens: 10},
1229					FinishReason: FinishReasonStop,
1230				}, nil
1231			},
1232		}
1233
1234		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1235		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1236
1237		prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1238			return PrepareStepResult{
1239				Model:       options.Model,
1240				Messages:    options.Messages,
1241				ActiveTools: []string{}, // Empty slice means all tools
1242			}
1243		}
1244
1245		agent := NewAgent(model, WithTools(tool1, tool2))
1246
1247		result, err := agent.Generate(context.Background(), AgentCall{
1248			Prompt:      "test prompt",
1249			PrepareStep: prepareStepFunc,
1250		})
1251
1252		require.NoError(t, err)
1253		require.NotNil(t, result)
1254		require.Len(t, capturedToolNames, 2) // All tools should be included
1255		require.Contains(t, capturedToolNames, "tool1")
1256		require.Contains(t, capturedToolNames, "tool2")
1257	})
1258}
1259
1260func TestToolCallRepair(t *testing.T) {
1261	t.Parallel()
1262
1263	t.Run("Valid tool call passes validation", func(t *testing.T) {
1264		model := &mockLanguageModel{
1265			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1266				return &Response{
1267					Content: ResponseContent{
1268						TextContent{Text: "Response"},
1269						ToolCallContent{
1270							ToolCallID: "call1",
1271							ToolName:   "test_tool",
1272							Input:      `{"value": "test"}`, // Valid JSON with required field
1273						},
1274					},
1275					Usage:        Usage{TotalTokens: 10},
1276					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1277				}, nil
1278			},
1279		}
1280
1281		tool := &mockTool{
1282			name:        "test_tool",
1283			description: "Test tool",
1284			parameters: map[string]any{
1285				"value": map[string]any{"type": "string"},
1286			},
1287			required: []string{"value"},
1288			executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
1289				return tools.ToolResponse{Content: "success", IsError: false}, nil
1290			},
1291		}
1292
1293		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1294
1295		result, err := agent.Generate(context.Background(), AgentCall{
1296			Prompt: "test prompt",
1297		})
1298
1299		require.NoError(t, err)
1300		require.NotNil(t, result)
1301		require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1302
1303		// Check that tool call was executed successfully
1304		toolCalls := result.Steps[0].Response.Content.ToolCalls()
1305		require.Len(t, toolCalls, 1)
1306		require.False(t, toolCalls[0].Invalid) // Should be valid
1307	})
1308
1309	t.Run("Invalid tool call without repair function", func(t *testing.T) {
1310		model := &mockLanguageModel{
1311			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1312				return &Response{
1313					Content: ResponseContent{
1314						TextContent{Text: "Response"},
1315						ToolCallContent{
1316							ToolCallID: "call1",
1317							ToolName:   "test_tool",
1318							Input:      `{"wrong_field": "test"}`, // Missing required field
1319						},
1320					},
1321					Usage:        Usage{TotalTokens: 10},
1322					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1323				}, nil
1324			},
1325		}
1326
1327		tool := &mockTool{
1328			name:        "test_tool",
1329			description: "Test tool",
1330			parameters: map[string]any{
1331				"value": map[string]any{"type": "string"},
1332			},
1333			required: []string{"value"},
1334		}
1335
1336		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1337
1338		result, err := agent.Generate(context.Background(), AgentCall{
1339			Prompt: "test prompt",
1340		})
1341
1342		require.NoError(t, err)
1343		require.NotNil(t, result)
1344		require.Len(t, result.Steps, 1) // Only one step
1345
1346		// Check that tool call was marked as invalid
1347		toolCalls := result.Steps[0].Response.Content.ToolCalls()
1348		require.Len(t, toolCalls, 1)
1349		require.True(t, toolCalls[0].Invalid) // Should be invalid
1350		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1351	})
1352
1353	t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1354		model := &mockLanguageModel{
1355			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1356				return &Response{
1357					Content: ResponseContent{
1358						TextContent{Text: "Response"},
1359						ToolCallContent{
1360							ToolCallID: "call1",
1361							ToolName:   "test_tool",
1362							Input:      `{"wrong_field": "test"}`, // Missing required field
1363						},
1364					},
1365					Usage:        Usage{TotalTokens: 10},
1366					FinishReason: FinishReasonStop, // Changed to stop
1367				}, nil
1368			},
1369		}
1370
1371		tool := &mockTool{
1372			name:        "test_tool",
1373			description: "Test tool",
1374			parameters: map[string]any{
1375				"value": map[string]any{"type": "string"},
1376			},
1377			required: []string{"value"},
1378			executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
1379				return tools.ToolResponse{Content: "repaired_success", IsError: false}, nil
1380			},
1381		}
1382
1383		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1384			// Simple repair: add the missing required field
1385			repairedToolCall := options.OriginalToolCall
1386			repairedToolCall.Input = `{"value": "repaired"}`
1387			return &repairedToolCall, nil
1388		}
1389
1390		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1391
1392		result, err := agent.Generate(context.Background(), AgentCall{
1393			Prompt: "test prompt",
1394		})
1395
1396		require.NoError(t, err)
1397		require.NotNil(t, result)
1398		require.Len(t, result.Steps, 1) // Only one step
1399
1400		// Check that tool call was repaired and is now valid
1401		toolCalls := result.Steps[0].Response.Content.ToolCalls()
1402		require.Len(t, toolCalls, 1)
1403		require.False(t, toolCalls[0].Invalid)                        // Should be valid after repair
1404		require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1405	})
1406
1407	t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1408		model := &mockLanguageModel{
1409			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1410				return &Response{
1411					Content: ResponseContent{
1412						TextContent{Text: "Response"},
1413						ToolCallContent{
1414							ToolCallID: "call1",
1415							ToolName:   "test_tool",
1416							Input:      `{"wrong_field": "test"}`, // Missing required field
1417						},
1418					},
1419					Usage:        Usage{TotalTokens: 10},
1420					FinishReason: FinishReasonStop, // Changed to stop
1421				}, nil
1422			},
1423		}
1424
1425		tool := &mockTool{
1426			name:        "test_tool",
1427			description: "Test tool",
1428			parameters: map[string]any{
1429				"value": map[string]any{"type": "string"},
1430			},
1431			required: []string{"value"},
1432		}
1433
1434		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1435			// Repair function fails
1436			return nil, errors.New("repair failed")
1437		}
1438
1439		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1440
1441		result, err := agent.Generate(context.Background(), AgentCall{
1442			Prompt: "test prompt",
1443		})
1444
1445		require.NoError(t, err)
1446		require.NotNil(t, result)
1447		require.Len(t, result.Steps, 1) // Only one step
1448
1449		// Check that tool call was marked as invalid since repair failed
1450		toolCalls := result.Steps[0].Response.Content.ToolCalls()
1451		require.Len(t, toolCalls, 1)
1452		require.True(t, toolCalls[0].Invalid) // Should be invalid
1453		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1454	})
1455
1456	t.Run("Nonexistent tool call", func(t *testing.T) {
1457		model := &mockLanguageModel{
1458			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1459				return &Response{
1460					Content: ResponseContent{
1461						TextContent{Text: "Response"},
1462						ToolCallContent{
1463							ToolCallID: "call1",
1464							ToolName:   "nonexistent_tool",
1465							Input:      `{"value": "test"}`,
1466						},
1467					},
1468					Usage:        Usage{TotalTokens: 10},
1469					FinishReason: FinishReasonStop, // Changed to stop
1470				}, nil
1471			},
1472		}
1473
1474		tool := &mockTool{name: "test_tool", description: "Test tool"}
1475
1476		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1477
1478		result, err := agent.Generate(context.Background(), AgentCall{
1479			Prompt: "test prompt",
1480		})
1481
1482		require.NoError(t, err)
1483		require.NotNil(t, result)
1484		require.Len(t, result.Steps, 1) // Only one step
1485
1486		// Check that tool call was marked as invalid due to nonexistent tool
1487		toolCalls := result.Steps[0].Response.Content.ToolCalls()
1488		require.Len(t, toolCalls, 1)
1489		require.True(t, toolCalls[0].Invalid) // Should be invalid
1490		require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1491	})
1492
1493	t.Run("Invalid JSON in tool call", func(t *testing.T) {
1494		model := &mockLanguageModel{
1495			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1496				return &Response{
1497					Content: ResponseContent{
1498						TextContent{Text: "Response"},
1499						ToolCallContent{
1500							ToolCallID: "call1",
1501							ToolName:   "test_tool",
1502							Input:      `{invalid json}`, // Invalid JSON
1503						},
1504					},
1505					Usage:        Usage{TotalTokens: 10},
1506					FinishReason: FinishReasonStop, // Changed to stop
1507				}, nil
1508			},
1509		}
1510
1511		tool := &mockTool{
1512			name:        "test_tool",
1513			description: "Test tool",
1514			parameters: map[string]any{
1515				"value": map[string]any{"type": "string"},
1516			},
1517			required: []string{"value"},
1518		}
1519
1520		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1521
1522		result, err := agent.Generate(context.Background(), AgentCall{
1523			Prompt: "test prompt",
1524		})
1525
1526		require.NoError(t, err)
1527		require.NotNil(t, result)
1528		require.Len(t, result.Steps, 1) // Only one step
1529
1530		// Check that tool call was marked as invalid due to invalid JSON
1531		toolCalls := result.Steps[0].Response.Content.ToolCalls()
1532		require.Len(t, toolCalls, 1)
1533		require.True(t, toolCalls[0].Invalid) // Should be invalid
1534		require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1535	})
1536}