agent_test.go

   1package ai
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"testing"
   9
  10	"github.com/stretchr/testify/require"
  11)
  12
  13// Mock tool for testing
  14type mockTool struct {
  15	name        string
  16	description string
  17	parameters  map[string]any
  18	required    []string
  19	executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
  20}
  21
  22func (m *mockTool) Info() ToolInfo {
  23	return ToolInfo{
  24		Name:        m.name,
  25		Description: m.description,
  26		Parameters:  m.parameters,
  27		Required:    m.required,
  28	}
  29}
  30
  31func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  32	if m.executeFunc != nil {
  33		return m.executeFunc(ctx, call)
  34	}
  35	return ToolResponse{Content: "mock result", IsError: false}, nil
  36}
  37
  38// Mock language model for testing
  39type mockLanguageModel struct {
  40	generateFunc func(ctx context.Context, call Call) (*Response, error)
  41	streamFunc   func(ctx context.Context, call Call) (StreamResponse, error)
  42}
  43
  44func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
  45	if m.generateFunc != nil {
  46		return m.generateFunc(ctx, call)
  47	}
  48	return &Response{
  49		Content: []Content{
  50			TextContent{Text: "Hello, world!"},
  51		},
  52		Usage: Usage{
  53			InputTokens:  3,
  54			OutputTokens: 10,
  55			TotalTokens:  13,
  56		},
  57		FinishReason: FinishReasonStop,
  58	}, nil
  59}
  60
  61func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
  62	if m.streamFunc != nil {
  63		return m.streamFunc(ctx, call)
  64	}
  65	return nil, fmt.Errorf("mock stream not implemented")
  66}
  67
  68func (m *mockLanguageModel) Provider() string {
  69	return "mock-provider"
  70}
  71
  72func (m *mockLanguageModel) Model() string {
  73	return "mock-model"
  74}
  75
  76// Test result.content - comprehensive content types (matches TS test)
  77func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
  78	t.Parallel()
  79
  80	// Create a type-safe tool using the new API
  81	type TestInput struct {
  82		Value string `json:"value" description:"Test value"`
  83	}
  84
  85	tool1 := NewAgentTool(
  86		"tool1",
  87		"Test tool",
  88		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
  89			require.Equal(t, "value", input.Value)
  90			return ToolResponse{Content: "result1", IsError: false}, nil
  91		},
  92	)
  93
  94	model := &mockLanguageModel{
  95		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
  96			return &Response{
  97				Content: []Content{
  98					TextContent{Text: "Hello, world!"},
  99					SourceContent{
 100						ID:         "123",
 101						URL:        "https://example.com",
 102						Title:      "Example",
 103						SourceType: SourceTypeURL,
 104						ProviderMetadata: ProviderMetadata{
 105							"provider": map[string]any{"custom": "value"},
 106						},
 107					},
 108					FileContent{
 109						Data:      []byte{1, 2, 3},
 110						MediaType: "image/png",
 111					},
 112					ReasoningContent{
 113						Text: "I will open the conversation with witty banter.",
 114					},
 115					ToolCallContent{
 116						ToolCallID: "call-1",
 117						ToolName:   "tool1",
 118						Input:      `{"value":"value"}`,
 119					},
 120					TextContent{Text: "More text"},
 121				},
 122				Usage: Usage{
 123					InputTokens:  3,
 124					OutputTokens: 10,
 125					TotalTokens:  13,
 126				},
 127				FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
 128			}, nil
 129		},
 130	}
 131
 132	agent := NewAgent(model, WithTools(tool1))
 133	result, err := agent.Generate(context.Background(), AgentCall{
 134		Prompt: "prompt",
 135	})
 136
 137	require.NoError(t, err)
 138	require.NotNil(t, result)
 139	require.Len(t, result.Steps, 1) // Single step like TypeScript
 140
 141	// Check final response content includes tool result
 142	require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
 143
 144	// Verify each content type in order
 145	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 146	require.True(t, ok)
 147	require.Equal(t, "Hello, world!", textContent.Text)
 148
 149	sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
 150	require.True(t, ok)
 151	require.Equal(t, "123", sourceContent.ID)
 152
 153	fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
 154	require.True(t, ok)
 155	require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
 156
 157	reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
 158	require.True(t, ok)
 159	require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
 160
 161	toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
 162	require.True(t, ok)
 163	require.Equal(t, "call-1", toolCallContent.ToolCallID)
 164
 165	moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
 166	require.True(t, ok)
 167	require.Equal(t, "More text", moreTextContent.Text)
 168
 169	// Tool result should be appended
 170	toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
 171	require.True(t, ok)
 172	require.Equal(t, "call-1", toolResultContent.ToolCallID)
 173	require.Equal(t, "tool1", toolResultContent.ToolName)
 174}
 175
 176// Test result.text extraction
 177func TestAgent_Generate_ResultText(t *testing.T) {
 178	t.Parallel()
 179
 180	model := &mockLanguageModel{
 181		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 182			return &Response{
 183				Content: []Content{
 184					TextContent{Text: "Hello, world!"},
 185				},
 186				Usage: Usage{
 187					InputTokens:  3,
 188					OutputTokens: 10,
 189					TotalTokens:  13,
 190				},
 191				FinishReason: FinishReasonStop,
 192			}, nil
 193		},
 194	}
 195
 196	agent := NewAgent(model)
 197	result, err := agent.Generate(context.Background(), AgentCall{
 198		Prompt: "prompt",
 199	})
 200
 201	require.NoError(t, err)
 202	require.NotNil(t, result)
 203
 204	// Test text extraction from content
 205	text := result.Response.Content.Text()
 206	require.Equal(t, "Hello, world!", text)
 207}
 208
 209// Test result.toolCalls extraction (matches TS test exactly)
 210func TestAgent_Generate_ResultToolCalls(t *testing.T) {
 211	t.Parallel()
 212
 213	// Create type-safe tools using the new API
 214	type Tool1Input struct {
 215		Value string `json:"value" description:"Test value"`
 216	}
 217
 218	type Tool2Input struct {
 219		SomethingElse string `json:"somethingElse" description:"Another test value"`
 220	}
 221
 222	tool1 := NewAgentTool(
 223		"tool1",
 224		"Test tool 1",
 225		func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
 226			return ToolResponse{Content: "result1", IsError: false}, nil
 227		},
 228	)
 229
 230	tool2 := NewAgentTool(
 231		"tool2",
 232		"Test tool 2",
 233		func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
 234			return ToolResponse{Content: "result2", IsError: false}, nil
 235		},
 236	)
 237
 238	model := &mockLanguageModel{
 239		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 240			// Verify tools are passed correctly
 241			require.Len(t, call.Tools, 2)
 242			require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
 243
 244			// Verify prompt structure
 245			require.Len(t, call.Prompt, 1)
 246			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
 247
 248			return &Response{
 249				Content: []Content{
 250					ToolCallContent{
 251						ToolCallID: "call-1",
 252						ToolName:   "tool1",
 253						Input:      `{"value":"value"}`,
 254					},
 255				},
 256				Usage: Usage{
 257					InputTokens:  3,
 258					OutputTokens: 10,
 259					TotalTokens:  13,
 260				},
 261				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
 262			}, nil
 263		},
 264	}
 265
 266	agent := NewAgent(model, WithTools(tool1, tool2))
 267	result, err := agent.Generate(context.Background(), AgentCall{
 268		Prompt: "test-input",
 269	})
 270
 271	require.NoError(t, err)
 272	require.NotNil(t, result)
 273	require.Len(t, result.Steps, 1) // Single step
 274
 275	// Extract tool calls from final response (should be empty since tools don't execute)
 276	var toolCalls []ToolCallContent
 277	for _, content := range result.Response.Content {
 278		if toolCall, ok := AsContentType[ToolCallContent](content); ok {
 279			toolCalls = append(toolCalls, toolCall)
 280		}
 281	}
 282
 283	require.Len(t, toolCalls, 1)
 284	require.Equal(t, "call-1", toolCalls[0].ToolCallID)
 285	require.Equal(t, "tool1", toolCalls[0].ToolName)
 286
 287	// Parse and verify input
 288	var input map[string]any
 289	err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
 290	require.NoError(t, err)
 291	require.Equal(t, "value", input["value"])
 292}
 293
 294// Test result.toolResults extraction (matches TS test exactly)
 295func TestAgent_Generate_ResultToolResults(t *testing.T) {
 296	t.Parallel()
 297
 298	// Create type-safe tool using the new API
 299	type TestInput struct {
 300		Value string `json:"value" description:"Test value"`
 301	}
 302
 303	tool1 := NewAgentTool(
 304		"tool1",
 305		"Test tool",
 306		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
 307			require.Equal(t, "value", input.Value)
 308			return ToolResponse{Content: "result1", IsError: false}, nil
 309		},
 310	)
 311
 312	model := &mockLanguageModel{
 313		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 314			// Verify tools and tool choice
 315			require.Len(t, call.Tools, 1)
 316			require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
 317
 318			// Verify prompt
 319			require.Len(t, call.Prompt, 1)
 320			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
 321
 322			return &Response{
 323				Content: []Content{
 324					ToolCallContent{
 325						ToolCallID: "call-1",
 326						ToolName:   "tool1",
 327						Input:      `{"value":"value"}`,
 328					},
 329				},
 330				Usage: Usage{
 331					InputTokens:  3,
 332					OutputTokens: 10,
 333					TotalTokens:  13,
 334				},
 335				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
 336			}, nil
 337		},
 338	}
 339
 340	agent := NewAgent(model, WithTools(tool1))
 341	result, err := agent.Generate(context.Background(), AgentCall{
 342		Prompt: "test-input",
 343	})
 344
 345	require.NoError(t, err)
 346	require.NotNil(t, result)
 347	require.Len(t, result.Steps, 1) // Single step
 348
 349	// Extract tool results from final response
 350	var toolResults []ToolResultContent
 351	for _, content := range result.Response.Content {
 352		if toolResult, ok := AsContentType[ToolResultContent](content); ok {
 353			toolResults = append(toolResults, toolResult)
 354		}
 355	}
 356
 357	require.Len(t, toolResults, 1)
 358	require.Equal(t, "call-1", toolResults[0].ToolCallID)
 359	require.Equal(t, "tool1", toolResults[0].ToolName)
 360
 361	// Verify result content
 362	textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
 363	require.True(t, ok)
 364	require.Equal(t, "result1", textResult.Text)
 365}
 366
 367// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
 368func TestAgent_Generate_MultipleSteps(t *testing.T) {
 369	t.Parallel()
 370
 371	// Create type-safe tool using the new API
 372	type TestInput struct {
 373		Value string `json:"value" description:"Test value"`
 374	}
 375
 376	tool1 := NewAgentTool(
 377		"tool1",
 378		"Test tool",
 379		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
 380			require.Equal(t, "value", input.Value)
 381			return ToolResponse{Content: "result1", IsError: false}, nil
 382		},
 383	)
 384
 385	callCount := 0
 386	model := &mockLanguageModel{
 387		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 388			callCount++
 389			switch callCount {
 390			case 1:
 391				// First call - return tool call with FinishReasonToolCalls
 392				return &Response{
 393					Content: []Content{
 394						ToolCallContent{
 395							ToolCallID: "call-1",
 396							ToolName:   "tool1",
 397							Input:      `{"value":"value"}`,
 398						},
 399					},
 400					Usage: Usage{
 401						InputTokens:  10,
 402						OutputTokens: 5,
 403						TotalTokens:  15,
 404					},
 405					FinishReason: FinishReasonToolCalls, // This triggers multi-step
 406				}, nil
 407			case 2:
 408				// Second call - return final text
 409				return &Response{
 410					Content: []Content{
 411						TextContent{Text: "Hello, world!"},
 412					},
 413					Usage: Usage{
 414						InputTokens:  3,
 415						OutputTokens: 10,
 416						TotalTokens:  13,
 417					},
 418					FinishReason: FinishReasonStop,
 419				}, nil
 420			default:
 421				t.Fatalf("Unexpected call count: %d", callCount)
 422				return nil, nil
 423			}
 424		},
 425	}
 426
 427	agent := NewAgent(model, WithTools(tool1))
 428	result, err := agent.Generate(context.Background(), AgentCall{
 429		Prompt: "test-input",
 430	})
 431
 432	require.NoError(t, err)
 433	require.NotNil(t, result)
 434	require.Len(t, result.Steps, 2)
 435
 436	// Check total usage sums both steps
 437	require.Equal(t, int64(13), result.TotalUsage.InputTokens)  // 10 + 3
 438	require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
 439	require.Equal(t, int64(28), result.TotalUsage.TotalTokens)  // 15 + 13
 440
 441	// Final response should be from last step
 442	require.Len(t, result.Response.Content, 1)
 443	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 444	require.True(t, ok)
 445	require.Equal(t, "Hello, world!", textContent.Text)
 446
 447	// result.toolCalls should be empty (from last step)
 448	var toolCalls []ToolCallContent
 449	for _, content := range result.Response.Content {
 450		if _, ok := AsContentType[ToolCallContent](content); ok {
 451			toolCalls = append(toolCalls, content.(ToolCallContent))
 452		}
 453	}
 454	require.Len(t, toolCalls, 0)
 455
 456	// result.toolResults should be empty (from last step)
 457	var toolResults []ToolResultContent
 458	for _, content := range result.Response.Content {
 459		if _, ok := AsContentType[ToolResultContent](content); ok {
 460			toolResults = append(toolResults, content.(ToolResultContent))
 461		}
 462	}
 463	require.Len(t, toolResults, 0)
 464}
 465
 466// Test basic text generation
 467func TestAgent_Generate_BasicText(t *testing.T) {
 468	t.Parallel()
 469
 470	model := &mockLanguageModel{
 471		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 472			return &Response{
 473				Content: []Content{
 474					TextContent{Text: "Hello, world!"},
 475				},
 476				Usage: Usage{
 477					InputTokens:  3,
 478					OutputTokens: 10,
 479					TotalTokens:  13,
 480				},
 481				FinishReason: FinishReasonStop,
 482			}, nil
 483		},
 484	}
 485
 486	agent := NewAgent(model)
 487	result, err := agent.Generate(context.Background(), AgentCall{
 488		Prompt: "test prompt",
 489	})
 490
 491	require.NoError(t, err)
 492	require.NotNil(t, result)
 493	require.Len(t, result.Steps, 1)
 494
 495	// Check final response
 496	require.Len(t, result.Response.Content, 1)
 497	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 498	require.True(t, ok)
 499	require.Equal(t, "Hello, world!", textContent.Text)
 500
 501	// Check usage
 502	require.Equal(t, int64(3), result.Response.Usage.InputTokens)
 503	require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
 504	require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
 505
 506	// Check total usage
 507	require.Equal(t, int64(3), result.TotalUsage.InputTokens)
 508	require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
 509	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
 510}
 511
 512// Test empty prompt error
 513func TestAgent_Generate_EmptyPrompt(t *testing.T) {
 514	t.Parallel()
 515
 516	model := &mockLanguageModel{}
 517	agent := NewAgent(model)
 518
 519	result, err := agent.Generate(context.Background(), AgentCall{
 520		Prompt: "", // Empty prompt should cause error
 521	})
 522
 523	require.Error(t, err)
 524	require.Nil(t, result)
 525	require.Contains(t, err.Error(), "Prompt can't be empty")
 526}
 527
 528// Test with system prompt
 529func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
 530	t.Parallel()
 531
 532	model := &mockLanguageModel{
 533		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 534			// Verify system message is included
 535			require.Len(t, call.Prompt, 2) // system + user
 536			require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
 537			require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
 538
 539			systemPart, ok := call.Prompt[0].Content[0].(TextPart)
 540			require.True(t, ok)
 541			require.Equal(t, "You are a helpful assistant", systemPart.Text)
 542
 543			return &Response{
 544				Content: []Content{
 545					TextContent{Text: "Hello, world!"},
 546				},
 547				Usage: Usage{
 548					InputTokens:  3,
 549					OutputTokens: 10,
 550					TotalTokens:  13,
 551				},
 552				FinishReason: FinishReasonStop,
 553			}, nil
 554		},
 555	}
 556
 557	agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
 558	result, err := agent.Generate(context.Background(), AgentCall{
 559		Prompt: "test prompt",
 560	})
 561
 562	require.NoError(t, err)
 563	require.NotNil(t, result)
 564}
 565
 566// Test options.activeTools filtering
 567func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
 568	t.Parallel()
 569
 570	tool1 := &mockTool{
 571		name:        "tool1",
 572		description: "Test tool 1",
 573		parameters: map[string]any{
 574			"value": map[string]any{"type": "string"},
 575		},
 576		required: []string{"value"},
 577	}
 578
 579	tool2 := &mockTool{
 580		name:        "tool2",
 581		description: "Test tool 2",
 582		parameters: map[string]any{
 583			"value": map[string]any{"type": "string"},
 584		},
 585		required: []string{"value"},
 586	}
 587
 588	model := &mockLanguageModel{
 589		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 590			// Verify only tool1 is available
 591			require.Len(t, call.Tools, 1)
 592			functionTool, ok := call.Tools[0].(FunctionTool)
 593			require.True(t, ok)
 594			require.Equal(t, "tool1", functionTool.Name)
 595
 596			return &Response{
 597				Content: []Content{
 598					TextContent{Text: "Hello, world!"},
 599				},
 600				Usage: Usage{
 601					InputTokens:  3,
 602					OutputTokens: 10,
 603					TotalTokens:  13,
 604				},
 605				FinishReason: FinishReasonStop,
 606			}, nil
 607		},
 608	}
 609
 610	agent := NewAgent(model, WithTools(tool1, tool2))
 611	result, err := agent.Generate(context.Background(), AgentCall{
 612		Prompt:      "test-input",
 613		ActiveTools: []string{"tool1"}, // Only tool1 should be active
 614	})
 615
 616	require.NoError(t, err)
 617	require.NotNil(t, result)
 618}
 619
 620func TestResponseContent_Getters(t *testing.T) {
 621	t.Parallel()
 622
 623	// Create test content with all types
 624	content := ResponseContent{
 625		TextContent{Text: "Hello world"},
 626		ReasoningContent{Text: "Let me think..."},
 627		FileContent{Data: []byte("file data"), MediaType: "text/plain"},
 628		SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
 629		ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
 630		ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
 631	}
 632
 633	// Test Text()
 634	require.Equal(t, "Hello world", content.Text())
 635
 636	// Test Reasoning()
 637	reasoning := content.Reasoning()
 638	require.Len(t, reasoning, 1)
 639	require.Equal(t, "Let me think...", reasoning[0].Text)
 640
 641	// Test ReasoningText()
 642	require.Equal(t, "Let me think...", content.ReasoningText())
 643
 644	// Test Files()
 645	files := content.Files()
 646	require.Len(t, files, 1)
 647	require.Equal(t, "text/plain", files[0].MediaType)
 648	require.Equal(t, []byte("file data"), files[0].Data)
 649
 650	// Test Sources()
 651	sources := content.Sources()
 652	require.Len(t, sources, 1)
 653	require.Equal(t, SourceTypeURL, sources[0].SourceType)
 654	require.Equal(t, "https://example.com", sources[0].URL)
 655	require.Equal(t, "Example", sources[0].Title)
 656
 657	// Test ToolCalls()
 658	toolCalls := content.ToolCalls()
 659	require.Len(t, toolCalls, 1)
 660	require.Equal(t, "call1", toolCalls[0].ToolCallID)
 661	require.Equal(t, "test_tool", toolCalls[0].ToolName)
 662	require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
 663
 664	// Test ToolResults()
 665	toolResults := content.ToolResults()
 666	require.Len(t, toolResults, 1)
 667	require.Equal(t, "call1", toolResults[0].ToolCallID)
 668	require.Equal(t, "test_tool", toolResults[0].ToolName)
 669	result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
 670	require.True(t, ok)
 671	require.Equal(t, "result", result.Text)
 672}
 673
 674func TestResponseContent_Getters_Empty(t *testing.T) {
 675	t.Parallel()
 676
 677	// Test with empty content
 678	content := ResponseContent{}
 679
 680	require.Equal(t, "", content.Text())
 681	require.Equal(t, "", content.ReasoningText())
 682	require.Empty(t, content.Reasoning())
 683	require.Empty(t, content.Files())
 684	require.Empty(t, content.Sources())
 685	require.Empty(t, content.ToolCalls())
 686	require.Empty(t, content.ToolResults())
 687}
 688
 689func TestResponseContent_Getters_MultipleItems(t *testing.T) {
 690	t.Parallel()
 691
 692	// Test with multiple items of same type
 693	content := ResponseContent{
 694		ReasoningContent{Text: "First thought"},
 695		ReasoningContent{Text: "Second thought"},
 696		FileContent{Data: []byte("file1"), MediaType: "text/plain"},
 697		FileContent{Data: []byte("file2"), MediaType: "image/png"},
 698	}
 699
 700	// Test multiple reasoning
 701	reasoning := content.Reasoning()
 702	require.Len(t, reasoning, 2)
 703	require.Equal(t, "First thought", reasoning[0].Text)
 704	require.Equal(t, "Second thought", reasoning[1].Text)
 705
 706	// Test concatenated reasoning text
 707	require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
 708
 709	// Test multiple files
 710	files := content.Files()
 711	require.Len(t, files, 2)
 712	require.Equal(t, "text/plain", files[0].MediaType)
 713	require.Equal(t, "image/png", files[1].MediaType)
 714}
 715
 716func TestStopConditions(t *testing.T) {
 717	t.Parallel()
 718
 719	// Create test steps
 720	step1 := StepResult{
 721		Response: Response{
 722			Content: ResponseContent{
 723				TextContent{Text: "Hello"},
 724			},
 725			FinishReason: FinishReasonToolCalls,
 726			Usage:        Usage{TotalTokens: 10},
 727		},
 728	}
 729
 730	step2 := StepResult{
 731		Response: Response{
 732			Content: ResponseContent{
 733				TextContent{Text: "World"},
 734				ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
 735			},
 736			FinishReason: FinishReasonStop,
 737			Usage:        Usage{TotalTokens: 15},
 738		},
 739	}
 740
 741	step3 := StepResult{
 742		Response: Response{
 743			Content: ResponseContent{
 744				ReasoningContent{Text: "Let me think..."},
 745				FileContent{Data: []byte("data"), MediaType: "text/plain"},
 746			},
 747			FinishReason: FinishReasonLength,
 748			Usage:        Usage{TotalTokens: 20},
 749		},
 750	}
 751
 752	t.Run("StepCountIs", func(t *testing.T) {
 753		t.Parallel()
 754		condition := StepCountIs(2)
 755
 756		// Should not stop with 1 step
 757		require.False(t, condition([]StepResult{step1}))
 758
 759		// Should stop with 2 steps
 760		require.True(t, condition([]StepResult{step1, step2}))
 761
 762		// Should stop with more than 2 steps
 763		require.True(t, condition([]StepResult{step1, step2, step3}))
 764
 765		// Should not stop with empty steps
 766		require.False(t, condition([]StepResult{}))
 767	})
 768
 769	t.Run("HasToolCall", func(t *testing.T) {
 770		t.Parallel()
 771		condition := HasToolCall("search")
 772
 773		// Should not stop when tool not called
 774		require.False(t, condition([]StepResult{step1}))
 775
 776		// Should stop when tool is called in last step
 777		require.True(t, condition([]StepResult{step1, step2}))
 778
 779		// Should not stop when tool called in earlier step but not last
 780		require.False(t, condition([]StepResult{step1, step2, step3}))
 781
 782		// Should not stop with empty steps
 783		require.False(t, condition([]StepResult{}))
 784
 785		// Should not stop when different tool is called
 786		differentToolCondition := HasToolCall("different_tool")
 787		require.False(t, differentToolCondition([]StepResult{step1, step2}))
 788	})
 789
 790	t.Run("HasContent", func(t *testing.T) {
 791		t.Parallel()
 792		reasoningCondition := HasContent(ContentTypeReasoning)
 793		fileCondition := HasContent(ContentTypeFile)
 794
 795		// Should not stop when content type not present
 796		require.False(t, reasoningCondition([]StepResult{step1, step2}))
 797
 798		// Should stop when content type is present in last step
 799		require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
 800		require.True(t, fileCondition([]StepResult{step1, step2, step3}))
 801
 802		// Should not stop with empty steps
 803		require.False(t, reasoningCondition([]StepResult{}))
 804	})
 805
 806	t.Run("FinishReasonIs", func(t *testing.T) {
 807		t.Parallel()
 808		stopCondition := FinishReasonIs(FinishReasonStop)
 809		lengthCondition := FinishReasonIs(FinishReasonLength)
 810
 811		// Should not stop when finish reason doesn't match
 812		require.False(t, stopCondition([]StepResult{step1}))
 813
 814		// Should stop when finish reason matches in last step
 815		require.True(t, stopCondition([]StepResult{step1, step2}))
 816		require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
 817
 818		// Should not stop with empty steps
 819		require.False(t, stopCondition([]StepResult{}))
 820	})
 821
 822	t.Run("MaxTokensUsed", func(t *testing.T) {
 823		condition := MaxTokensUsed(30)
 824
 825		// Should not stop when under limit
 826		require.False(t, condition([]StepResult{step1}))        // 10 tokens
 827		require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
 828
 829		// Should stop when at or over limit
 830		require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
 831
 832		// Should not stop with empty steps
 833		require.False(t, condition([]StepResult{}))
 834	})
 835}
 836
 837func TestStopConditions_Integration(t *testing.T) {
 838	t.Parallel()
 839
 840	t.Run("StepCountIs integration", func(t *testing.T) {
 841		t.Parallel()
 842		model := &mockLanguageModel{
 843			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 844				return &Response{
 845					Content: ResponseContent{
 846						TextContent{Text: "Mock response"},
 847					},
 848					Usage: Usage{
 849						InputTokens:  3,
 850						OutputTokens: 10,
 851						TotalTokens:  13,
 852					},
 853					FinishReason: FinishReasonStop,
 854				}, nil
 855			},
 856		}
 857
 858		agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
 859
 860		result, err := agent.Generate(context.Background(), AgentCall{
 861			Prompt: "test prompt",
 862		})
 863
 864		require.NoError(t, err)
 865		require.NotNil(t, result)
 866		require.Len(t, result.Steps, 1) // Should stop after 1 step
 867	})
 868
 869	t.Run("Multiple stop conditions", func(t *testing.T) {
 870		t.Parallel()
 871		model := &mockLanguageModel{
 872			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 873				return &Response{
 874					Content: ResponseContent{
 875						TextContent{Text: "Mock response"},
 876					},
 877					Usage: Usage{
 878						InputTokens:  3,
 879						OutputTokens: 10,
 880						TotalTokens:  13,
 881					},
 882					FinishReason: FinishReasonStop,
 883				}, nil
 884			},
 885		}
 886
 887		agent := NewAgent(model, WithStopConditions(
 888			StepCountIs(5),                   // Stop after 5 steps
 889			FinishReasonIs(FinishReasonStop), // Or stop on finish reason
 890		))
 891
 892		result, err := agent.Generate(context.Background(), AgentCall{
 893			Prompt: "test prompt",
 894		})
 895
 896		require.NoError(t, err)
 897		require.NotNil(t, result)
 898		// Should stop on first condition met (finish reason stop)
 899		require.Equal(t, FinishReasonStop, result.Response.FinishReason)
 900	})
 901}
 902
 903func TestPrepareStep(t *testing.T) {
 904	t.Parallel()
 905
 906	t.Run("System prompt modification", func(t *testing.T) {
 907		t.Parallel()
 908		var capturedSystemPrompt string
 909		model := &mockLanguageModel{
 910			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 911				// Capture the system message to verify it was modified
 912				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
 913					if len(call.Prompt[0].Content) > 0 {
 914						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
 915							capturedSystemPrompt = textPart.Text
 916						}
 917					}
 918				}
 919				return &Response{
 920					Content: ResponseContent{
 921						TextContent{Text: "Response"},
 922					},
 923					Usage:        Usage{TotalTokens: 10},
 924					FinishReason: FinishReasonStop,
 925				}, nil
 926			},
 927		}
 928
 929		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 930			newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
 931			return PrepareStepResult{
 932				Model:    options.Model,
 933				Messages: options.Messages,
 934				System:   &newSystem,
 935			}, nil
 936		}
 937
 938		agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
 939
 940		result, err := agent.Generate(context.Background(), AgentCall{
 941			Prompt:      "test prompt",
 942			PrepareStep: prepareStepFunc,
 943		})
 944
 945		require.NoError(t, err)
 946		require.NotNil(t, result)
 947		require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
 948	})
 949
 950	t.Run("Tool choice modification", func(t *testing.T) {
 951		t.Parallel()
 952		var capturedToolChoice *ToolChoice
 953		model := &mockLanguageModel{
 954			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 955				capturedToolChoice = call.ToolChoice
 956				return &Response{
 957					Content: ResponseContent{
 958						TextContent{Text: "Response"},
 959					},
 960					Usage:        Usage{TotalTokens: 10},
 961					FinishReason: FinishReasonStop,
 962				}, nil
 963			},
 964		}
 965
 966		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
 967			toolChoice := ToolChoiceNone
 968			return PrepareStepResult{
 969				Model:      options.Model,
 970				Messages:   options.Messages,
 971				ToolChoice: &toolChoice,
 972			}, nil
 973		}
 974
 975		agent := NewAgent(model)
 976
 977		result, err := agent.Generate(context.Background(), AgentCall{
 978			Prompt:      "test prompt",
 979			PrepareStep: prepareStepFunc,
 980		})
 981
 982		require.NoError(t, err)
 983		require.NotNil(t, result)
 984		require.NotNil(t, capturedToolChoice)
 985		require.Equal(t, ToolChoiceNone, *capturedToolChoice)
 986	})
 987
 988	t.Run("Active tools modification", func(t *testing.T) {
 989		t.Parallel()
 990		var capturedToolNames []string
 991		model := &mockLanguageModel{
 992			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 993				// Capture tool names to verify active tools were modified
 994				for _, tool := range call.Tools {
 995					capturedToolNames = append(capturedToolNames, tool.GetName())
 996				}
 997				return &Response{
 998					Content: ResponseContent{
 999						TextContent{Text: "Response"},
1000					},
1001					Usage:        Usage{TotalTokens: 10},
1002					FinishReason: FinishReasonStop,
1003				}, nil
1004			},
1005		}
1006
1007		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1008		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1009		tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1010
1011		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1012			activeTools := []string{"tool2"} // Only tool2 should be active
1013			return PrepareStepResult{
1014				Model:       options.Model,
1015				Messages:    options.Messages,
1016				ActiveTools: activeTools,
1017			}, nil
1018		}
1019
1020		agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1021
1022		result, err := agent.Generate(context.Background(), AgentCall{
1023			Prompt:      "test prompt",
1024			PrepareStep: prepareStepFunc,
1025		})
1026
1027		require.NoError(t, err)
1028		require.NotNil(t, result)
1029		require.Len(t, capturedToolNames, 1)
1030		require.Equal(t, "tool2", capturedToolNames[0])
1031	})
1032
1033	t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1034		t.Parallel()
1035		var capturedToolCount int
1036		model := &mockLanguageModel{
1037			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1038				capturedToolCount = len(call.Tools)
1039				return &Response{
1040					Content: ResponseContent{
1041						TextContent{Text: "Response"},
1042					},
1043					Usage:        Usage{TotalTokens: 10},
1044					FinishReason: FinishReasonStop,
1045				}, nil
1046			},
1047		}
1048
1049		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1050
1051		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1052			return PrepareStepResult{
1053				Model:           options.Model,
1054				Messages:        options.Messages,
1055				DisableAllTools: true, // Disable all tools for this step
1056			}, nil
1057		}
1058
1059		agent := NewAgent(model, WithTools(tool1))
1060
1061		result, err := agent.Generate(context.Background(), AgentCall{
1062			Prompt:      "test prompt",
1063			PrepareStep: prepareStepFunc,
1064		})
1065
1066		require.NoError(t, err)
1067		require.NotNil(t, result)
1068		require.Equal(t, 0, capturedToolCount) // No tools should be passed
1069	})
1070
1071	t.Run("All fields modified together", func(t *testing.T) {
1072		t.Parallel()
1073		var capturedSystemPrompt string
1074		var capturedToolChoice *ToolChoice
1075		var capturedToolNames []string
1076
1077		model := &mockLanguageModel{
1078			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1079				// Capture system prompt
1080				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1081					if len(call.Prompt[0].Content) > 0 {
1082						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1083							capturedSystemPrompt = textPart.Text
1084						}
1085					}
1086				}
1087				// Capture tool choice
1088				capturedToolChoice = call.ToolChoice
1089				// Capture tool names
1090				for _, tool := range call.Tools {
1091					capturedToolNames = append(capturedToolNames, tool.GetName())
1092				}
1093				return &Response{
1094					Content: ResponseContent{
1095						TextContent{Text: "Response"},
1096					},
1097					Usage:        Usage{TotalTokens: 10},
1098					FinishReason: FinishReasonStop,
1099				}, nil
1100			},
1101		}
1102
1103		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1104		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1105
1106		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1107			newSystem := "Step-specific system"
1108			toolChoice := SpecificToolChoice("tool1")
1109			activeTools := []string{"tool1"}
1110			return PrepareStepResult{
1111				Model:       options.Model,
1112				Messages:    options.Messages,
1113				System:      &newSystem,
1114				ToolChoice:  &toolChoice,
1115				ActiveTools: activeTools,
1116			}, nil
1117		}
1118
1119		agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1120
1121		result, err := agent.Generate(context.Background(), AgentCall{
1122			Prompt:      "test prompt",
1123			PrepareStep: prepareStepFunc,
1124		})
1125
1126		require.NoError(t, err)
1127		require.NotNil(t, result)
1128		require.Equal(t, "Step-specific system", capturedSystemPrompt)
1129		require.NotNil(t, capturedToolChoice)
1130		require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1131		require.Len(t, capturedToolNames, 1)
1132		require.Equal(t, "tool1", capturedToolNames[0])
1133	})
1134
1135	t.Run("Nil fields use parent values", func(t *testing.T) {
1136		t.Parallel()
1137		var capturedSystemPrompt string
1138		var capturedToolChoice *ToolChoice
1139		var capturedToolNames []string
1140
1141		model := &mockLanguageModel{
1142			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1143				// Capture system prompt
1144				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1145					if len(call.Prompt[0].Content) > 0 {
1146						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1147							capturedSystemPrompt = textPart.Text
1148						}
1149					}
1150				}
1151				// Capture tool choice
1152				capturedToolChoice = call.ToolChoice
1153				// Capture tool names
1154				for _, tool := range call.Tools {
1155					capturedToolNames = append(capturedToolNames, tool.GetName())
1156				}
1157				return &Response{
1158					Content: ResponseContent{
1159						TextContent{Text: "Response"},
1160					},
1161					Usage:        Usage{TotalTokens: 10},
1162					FinishReason: FinishReasonStop,
1163				}, nil
1164			},
1165		}
1166
1167		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1168
1169		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1170			// All optional fields are nil, should use parent values
1171			return PrepareStepResult{
1172				Model:       options.Model,
1173				Messages:    options.Messages,
1174				System:      nil, // Use parent
1175				ToolChoice:  nil, // Use parent (auto)
1176				ActiveTools: nil, // Use parent (all tools)
1177			}, nil
1178		}
1179
1180		agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1181
1182		result, err := agent.Generate(context.Background(), AgentCall{
1183			Prompt:      "test prompt",
1184			PrepareStep: prepareStepFunc,
1185		})
1186
1187		require.NoError(t, err)
1188		require.NotNil(t, result)
1189		require.Equal(t, "Parent system", capturedSystemPrompt)
1190		require.NotNil(t, capturedToolChoice)
1191		require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1192		require.Len(t, capturedToolNames, 1)
1193		require.Equal(t, "tool1", capturedToolNames[0])
1194	})
1195
1196	t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1197		t.Parallel()
1198		var capturedToolNames []string
1199		model := &mockLanguageModel{
1200			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1201				// Capture tool names to verify all tools are included
1202				for _, tool := range call.Tools {
1203					capturedToolNames = append(capturedToolNames, tool.GetName())
1204				}
1205				return &Response{
1206					Content: ResponseContent{
1207						TextContent{Text: "Response"},
1208					},
1209					Usage:        Usage{TotalTokens: 10},
1210					FinishReason: FinishReasonStop,
1211				}, nil
1212			},
1213		}
1214
1215		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1216		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1217
1218		prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1219			return PrepareStepResult{
1220				Model:       options.Model,
1221				Messages:    options.Messages,
1222				ActiveTools: []string{}, // Empty slice means all tools
1223			}, nil
1224		}
1225
1226		agent := NewAgent(model, WithTools(tool1, tool2))
1227
1228		result, err := agent.Generate(context.Background(), AgentCall{
1229			Prompt:      "test prompt",
1230			PrepareStep: prepareStepFunc,
1231		})
1232
1233		require.NoError(t, err)
1234		require.NotNil(t, result)
1235		require.Len(t, capturedToolNames, 2) // All tools should be included
1236		require.Contains(t, capturedToolNames, "tool1")
1237		require.Contains(t, capturedToolNames, "tool2")
1238	})
1239}
1240
1241func TestToolCallRepair(t *testing.T) {
1242	t.Parallel()
1243
1244	t.Run("Valid tool call passes validation", func(t *testing.T) {
1245		t.Parallel()
1246		model := &mockLanguageModel{
1247			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1248				return &Response{
1249					Content: ResponseContent{
1250						TextContent{Text: "Response"},
1251						ToolCallContent{
1252							ToolCallID: "call1",
1253							ToolName:   "test_tool",
1254							Input:      `{"value": "test"}`, // Valid JSON with required field
1255						},
1256					},
1257					Usage:        Usage{TotalTokens: 10},
1258					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1259				}, nil
1260			},
1261		}
1262
1263		tool := &mockTool{
1264			name:        "test_tool",
1265			description: "Test tool",
1266			parameters: map[string]any{
1267				"value": map[string]any{"type": "string"},
1268			},
1269			required: []string{"value"},
1270			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1271				return ToolResponse{Content: "success", IsError: false}, nil
1272			},
1273		}
1274
1275		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1276
1277		result, err := agent.Generate(context.Background(), AgentCall{
1278			Prompt: "test prompt",
1279		})
1280
1281		require.NoError(t, err)
1282		require.NotNil(t, result)
1283		require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1284
1285		// Check that tool call was executed successfully
1286		toolCalls := result.Steps[0].Content.ToolCalls()
1287		require.Len(t, toolCalls, 1)
1288		require.False(t, toolCalls[0].Invalid) // Should be valid
1289	})
1290
1291	t.Run("Invalid tool call without repair function", func(t *testing.T) {
1292		t.Parallel()
1293		model := &mockLanguageModel{
1294			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1295				return &Response{
1296					Content: ResponseContent{
1297						TextContent{Text: "Response"},
1298						ToolCallContent{
1299							ToolCallID: "call1",
1300							ToolName:   "test_tool",
1301							Input:      `{"wrong_field": "test"}`, // Missing required field
1302						},
1303					},
1304					Usage:        Usage{TotalTokens: 10},
1305					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1306				}, nil
1307			},
1308		}
1309
1310		tool := &mockTool{
1311			name:        "test_tool",
1312			description: "Test tool",
1313			parameters: map[string]any{
1314				"value": map[string]any{"type": "string"},
1315			},
1316			required: []string{"value"},
1317		}
1318
1319		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1320
1321		result, err := agent.Generate(context.Background(), AgentCall{
1322			Prompt: "test prompt",
1323		})
1324
1325		require.NoError(t, err)
1326		require.NotNil(t, result)
1327		require.Len(t, result.Steps, 1) // Only one step
1328
1329		// Check that tool call was marked as invalid
1330		toolCalls := result.Steps[0].Content.ToolCalls()
1331		require.Len(t, toolCalls, 1)
1332		require.True(t, toolCalls[0].Invalid) // Should be invalid
1333		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1334	})
1335
1336	t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1337		t.Parallel()
1338		model := &mockLanguageModel{
1339			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1340				return &Response{
1341					Content: ResponseContent{
1342						TextContent{Text: "Response"},
1343						ToolCallContent{
1344							ToolCallID: "call1",
1345							ToolName:   "test_tool",
1346							Input:      `{"wrong_field": "test"}`, // Missing required field
1347						},
1348					},
1349					Usage:        Usage{TotalTokens: 10},
1350					FinishReason: FinishReasonStop, // Changed to stop
1351				}, nil
1352			},
1353		}
1354
1355		tool := &mockTool{
1356			name:        "test_tool",
1357			description: "Test tool",
1358			parameters: map[string]any{
1359				"value": map[string]any{"type": "string"},
1360			},
1361			required: []string{"value"},
1362			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1363				return ToolResponse{Content: "repaired_success", IsError: false}, nil
1364			},
1365		}
1366
1367		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1368			// Simple repair: add the missing required field
1369			repairedToolCall := options.OriginalToolCall
1370			repairedToolCall.Input = `{"value": "repaired"}`
1371			return &repairedToolCall, nil
1372		}
1373
1374		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1375
1376		result, err := agent.Generate(context.Background(), AgentCall{
1377			Prompt: "test prompt",
1378		})
1379
1380		require.NoError(t, err)
1381		require.NotNil(t, result)
1382		require.Len(t, result.Steps, 1) // Only one step
1383
1384		// Check that tool call was repaired and is now valid
1385		toolCalls := result.Steps[0].Content.ToolCalls()
1386		require.Len(t, toolCalls, 1)
1387		require.False(t, toolCalls[0].Invalid)                        // Should be valid after repair
1388		require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1389	})
1390
1391	t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1392		t.Parallel()
1393		model := &mockLanguageModel{
1394			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1395				return &Response{
1396					Content: ResponseContent{
1397						TextContent{Text: "Response"},
1398						ToolCallContent{
1399							ToolCallID: "call1",
1400							ToolName:   "test_tool",
1401							Input:      `{"wrong_field": "test"}`, // Missing required field
1402						},
1403					},
1404					Usage:        Usage{TotalTokens: 10},
1405					FinishReason: FinishReasonStop, // Changed to stop
1406				}, nil
1407			},
1408		}
1409
1410		tool := &mockTool{
1411			name:        "test_tool",
1412			description: "Test tool",
1413			parameters: map[string]any{
1414				"value": map[string]any{"type": "string"},
1415			},
1416			required: []string{"value"},
1417		}
1418
1419		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1420			// Repair function fails
1421			return nil, errors.New("repair failed")
1422		}
1423
1424		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1425
1426		result, err := agent.Generate(context.Background(), AgentCall{
1427			Prompt: "test prompt",
1428		})
1429
1430		require.NoError(t, err)
1431		require.NotNil(t, result)
1432		require.Len(t, result.Steps, 1) // Only one step
1433
1434		// Check that tool call was marked as invalid since repair failed
1435		toolCalls := result.Steps[0].Content.ToolCalls()
1436		require.Len(t, toolCalls, 1)
1437		require.True(t, toolCalls[0].Invalid) // Should be invalid
1438		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1439	})
1440
1441	t.Run("Nonexistent tool call", func(t *testing.T) {
1442		t.Parallel()
1443		model := &mockLanguageModel{
1444			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1445				return &Response{
1446					Content: ResponseContent{
1447						TextContent{Text: "Response"},
1448						ToolCallContent{
1449							ToolCallID: "call1",
1450							ToolName:   "nonexistent_tool",
1451							Input:      `{"value": "test"}`,
1452						},
1453					},
1454					Usage:        Usage{TotalTokens: 10},
1455					FinishReason: FinishReasonStop, // Changed to stop
1456				}, nil
1457			},
1458		}
1459
1460		tool := &mockTool{name: "test_tool", description: "Test tool"}
1461
1462		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1463
1464		result, err := agent.Generate(context.Background(), AgentCall{
1465			Prompt: "test prompt",
1466		})
1467
1468		require.NoError(t, err)
1469		require.NotNil(t, result)
1470		require.Len(t, result.Steps, 1) // Only one step
1471
1472		// Check that tool call was marked as invalid due to nonexistent tool
1473		toolCalls := result.Steps[0].Content.ToolCalls()
1474		require.Len(t, toolCalls, 1)
1475		require.True(t, toolCalls[0].Invalid) // Should be invalid
1476		require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1477	})
1478
1479	t.Run("Invalid JSON in tool call", func(t *testing.T) {
1480		t.Parallel()
1481		model := &mockLanguageModel{
1482			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1483				return &Response{
1484					Content: ResponseContent{
1485						TextContent{Text: "Response"},
1486						ToolCallContent{
1487							ToolCallID: "call1",
1488							ToolName:   "test_tool",
1489							Input:      `{invalid json}`, // Invalid JSON
1490						},
1491					},
1492					Usage:        Usage{TotalTokens: 10},
1493					FinishReason: FinishReasonStop, // Changed to stop
1494				}, nil
1495			},
1496		}
1497
1498		tool := &mockTool{
1499			name:        "test_tool",
1500			description: "Test tool",
1501			parameters: map[string]any{
1502				"value": map[string]any{"type": "string"},
1503			},
1504			required: []string{"value"},
1505		}
1506
1507		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1508
1509		result, err := agent.Generate(context.Background(), AgentCall{
1510			Prompt: "test prompt",
1511		})
1512
1513		require.NoError(t, err)
1514		require.NotNil(t, result)
1515		require.Len(t, result.Steps, 1) // Only one step
1516
1517		// Check that tool call was marked as invalid due to invalid JSON
1518		toolCalls := result.Steps[0].Content.ToolCalls()
1519		require.Len(t, toolCalls, 1)
1520		require.True(t, toolCalls[0].Invalid) // Should be invalid
1521		require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1522	})
1523}