agent_test.go

   1package fantasy
   2
   3import (
   4	"context"
   5	"encoding/json"
   6	"errors"
   7	"fmt"
   8	"testing"
   9
  10	"github.com/stretchr/testify/require"
  11)
  12
  13// Mock tool for testing
  14type mockTool struct {
  15	name            string
  16	providerOptions ProviderOptions
  17	description     string
  18	parameters      map[string]any
  19	required        []string
  20	executeFunc     func(ctx context.Context, call ToolCall) (ToolResponse, error)
  21}
  22
  23func (m *mockTool) SetProviderOptions(opts ProviderOptions) {
  24	m.providerOptions = opts
  25}
  26
  27func (m *mockTool) ProviderOptions() ProviderOptions {
  28	return m.providerOptions
  29}
  30
  31func (m *mockTool) Info() ToolInfo {
  32	return ToolInfo{
  33		Name:        m.name,
  34		Description: m.description,
  35		Parameters:  m.parameters,
  36		Required:    m.required,
  37	}
  38}
  39
  40func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  41	if m.executeFunc != nil {
  42		return m.executeFunc(ctx, call)
  43	}
  44	return ToolResponse{Content: "mock result", IsError: false}, nil
  45}
  46
  47// Mock language model for testing
  48type mockLanguageModel struct {
  49	generateFunc func(ctx context.Context, call Call) (*Response, error)
  50	streamFunc   func(ctx context.Context, call Call) (StreamResponse, error)
  51}
  52
  53func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
  54	if m.generateFunc != nil {
  55		return m.generateFunc(ctx, call)
  56	}
  57	return &Response{
  58		Content: []Content{
  59			TextContent{Text: "Hello, world!"},
  60		},
  61		Usage: Usage{
  62			InputTokens:  3,
  63			OutputTokens: 10,
  64			TotalTokens:  13,
  65		},
  66		FinishReason: FinishReasonStop,
  67	}, nil
  68}
  69
  70func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
  71	if m.streamFunc != nil {
  72		return m.streamFunc(ctx, call)
  73	}
  74	return nil, fmt.Errorf("mock stream not implemented")
  75}
  76
  77func (m *mockLanguageModel) Provider() string {
  78	return "mock-provider"
  79}
  80
  81func (m *mockLanguageModel) Model() string {
  82	return "mock-model"
  83}
  84
  85func (m *mockLanguageModel) GenerateObject(ctx context.Context, call ObjectCall) (*ObjectResponse, error) {
  86	return nil, fmt.Errorf("mock GenerateObject not implemented")
  87}
  88
  89func (m *mockLanguageModel) StreamObject(ctx context.Context, call ObjectCall) (ObjectStreamResponse, error) {
  90	return nil, fmt.Errorf("mock StreamObject not implemented")
  91}
  92
  93// Test result.content - comprehensive content types (matches TS test)
  94func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
  95	t.Parallel()
  96
  97	// Create a type-safe tool using the new API
  98	type TestInput struct {
  99		Value string `json:"value" description:"Test value"`
 100	}
 101
 102	tool1 := NewAgentTool(
 103		"tool1",
 104		"Test tool",
 105		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
 106			require.Equal(t, "value", input.Value)
 107			return ToolResponse{Content: "result1", IsError: false}, nil
 108		},
 109	)
 110
 111	model := &mockLanguageModel{
 112		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 113			return &Response{
 114				Content: []Content{
 115					TextContent{Text: "Hello, world!"},
 116					SourceContent{
 117						ID:         "123",
 118						URL:        "https://example.com",
 119						Title:      "Example",
 120						SourceType: SourceTypeURL,
 121					},
 122					FileContent{
 123						Data:      []byte{1, 2, 3},
 124						MediaType: "image/png",
 125					},
 126					ReasoningContent{
 127						Text: "I will open the conversation with witty banter.",
 128					},
 129					ToolCallContent{
 130						ToolCallID: "call-1",
 131						ToolName:   "tool1",
 132						Input:      `{"value":"value"}`,
 133					},
 134					TextContent{Text: "More text"},
 135				},
 136				Usage: Usage{
 137					InputTokens:  3,
 138					OutputTokens: 10,
 139					TotalTokens:  13,
 140				},
 141				FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
 142			}, nil
 143		},
 144	}
 145
 146	agent := NewAgent(model, WithTools(tool1))
 147	result, err := agent.Generate(context.Background(), AgentCall{
 148		Prompt: "prompt",
 149	})
 150
 151	require.NoError(t, err)
 152	require.NotNil(t, result)
 153	require.Len(t, result.Steps, 1) // Single step like TypeScript
 154
 155	// Check final response content includes tool result
 156	require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
 157
 158	// Verify each content type in order
 159	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 160	require.True(t, ok)
 161	require.Equal(t, "Hello, world!", textContent.Text)
 162
 163	sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
 164	require.True(t, ok)
 165	require.Equal(t, "123", sourceContent.ID)
 166
 167	fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
 168	require.True(t, ok)
 169	require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
 170
 171	reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
 172	require.True(t, ok)
 173	require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
 174
 175	toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
 176	require.True(t, ok)
 177	require.Equal(t, "call-1", toolCallContent.ToolCallID)
 178
 179	moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
 180	require.True(t, ok)
 181	require.Equal(t, "More text", moreTextContent.Text)
 182
 183	// Tool result should be appended
 184	toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
 185	require.True(t, ok)
 186	require.Equal(t, "call-1", toolResultContent.ToolCallID)
 187	require.Equal(t, "tool1", toolResultContent.ToolName)
 188}
 189
 190// Test result.text extraction
 191func TestAgent_Generate_ResultText(t *testing.T) {
 192	t.Parallel()
 193
 194	model := &mockLanguageModel{
 195		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 196			return &Response{
 197				Content: []Content{
 198					TextContent{Text: "Hello, world!"},
 199				},
 200				Usage: Usage{
 201					InputTokens:  3,
 202					OutputTokens: 10,
 203					TotalTokens:  13,
 204				},
 205				FinishReason: FinishReasonStop,
 206			}, nil
 207		},
 208	}
 209
 210	agent := NewAgent(model)
 211	result, err := agent.Generate(context.Background(), AgentCall{
 212		Prompt: "prompt",
 213	})
 214
 215	require.NoError(t, err)
 216	require.NotNil(t, result)
 217
 218	// Test text extraction from content
 219	text := result.Response.Content.Text()
 220	require.Equal(t, "Hello, world!", text)
 221}
 222
 223// Test result.toolCalls extraction (matches TS test exactly)
 224func TestAgent_Generate_ResultToolCalls(t *testing.T) {
 225	t.Parallel()
 226
 227	// Create type-safe tools using the new API
 228	type Tool1Input struct {
 229		Value string `json:"value" description:"Test value"`
 230	}
 231
 232	type Tool2Input struct {
 233		SomethingElse string `json:"somethingElse" description:"Another test value"`
 234	}
 235
 236	tool1 := NewAgentTool(
 237		"tool1",
 238		"Test tool 1",
 239		func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
 240			return ToolResponse{Content: "result1", IsError: false}, nil
 241		},
 242	)
 243
 244	tool2 := NewAgentTool(
 245		"tool2",
 246		"Test tool 2",
 247		func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
 248			return ToolResponse{Content: "result2", IsError: false}, nil
 249		},
 250	)
 251
 252	model := &mockLanguageModel{
 253		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 254			// Verify tools are passed correctly
 255			require.Len(t, call.Tools, 2)
 256			require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
 257
 258			// Verify prompt structure
 259			require.Len(t, call.Prompt, 1)
 260			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
 261
 262			return &Response{
 263				Content: []Content{
 264					ToolCallContent{
 265						ToolCallID: "call-1",
 266						ToolName:   "tool1",
 267						Input:      `{"value":"value"}`,
 268					},
 269				},
 270				Usage: Usage{
 271					InputTokens:  3,
 272					OutputTokens: 10,
 273					TotalTokens:  13,
 274				},
 275				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
 276			}, nil
 277		},
 278	}
 279
 280	agent := NewAgent(model, WithTools(tool1, tool2))
 281	result, err := agent.Generate(context.Background(), AgentCall{
 282		Prompt: "test-input",
 283	})
 284
 285	require.NoError(t, err)
 286	require.NotNil(t, result)
 287	require.Len(t, result.Steps, 1) // Single step
 288
 289	// Extract tool calls from final response (should be empty since tools don't execute)
 290	var toolCalls []ToolCallContent
 291	for _, content := range result.Response.Content {
 292		if toolCall, ok := AsContentType[ToolCallContent](content); ok {
 293			toolCalls = append(toolCalls, toolCall)
 294		}
 295	}
 296
 297	require.Len(t, toolCalls, 1)
 298	require.Equal(t, "call-1", toolCalls[0].ToolCallID)
 299	require.Equal(t, "tool1", toolCalls[0].ToolName)
 300
 301	// Parse and verify input
 302	var input map[string]any
 303	err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
 304	require.NoError(t, err)
 305	require.Equal(t, "value", input["value"])
 306}
 307
 308// Test result.toolResults extraction (matches TS test exactly)
 309func TestAgent_Generate_ResultToolResults(t *testing.T) {
 310	t.Parallel()
 311
 312	// Create type-safe tool using the new API
 313	type TestInput struct {
 314		Value string `json:"value" description:"Test value"`
 315	}
 316
 317	tool1 := NewAgentTool(
 318		"tool1",
 319		"Test tool",
 320		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
 321			require.Equal(t, "value", input.Value)
 322			return ToolResponse{Content: "result1", IsError: false}, nil
 323		},
 324	)
 325
 326	model := &mockLanguageModel{
 327		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 328			// Verify tools and tool choice
 329			require.Len(t, call.Tools, 1)
 330			require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
 331
 332			// Verify prompt
 333			require.Len(t, call.Prompt, 1)
 334			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
 335
 336			return &Response{
 337				Content: []Content{
 338					ToolCallContent{
 339						ToolCallID: "call-1",
 340						ToolName:   "tool1",
 341						Input:      `{"value":"value"}`,
 342					},
 343				},
 344				Usage: Usage{
 345					InputTokens:  3,
 346					OutputTokens: 10,
 347					TotalTokens:  13,
 348				},
 349				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
 350			}, nil
 351		},
 352	}
 353
 354	agent := NewAgent(model, WithTools(tool1))
 355	result, err := agent.Generate(context.Background(), AgentCall{
 356		Prompt: "test-input",
 357	})
 358
 359	require.NoError(t, err)
 360	require.NotNil(t, result)
 361	require.Len(t, result.Steps, 1) // Single step
 362
 363	// Extract tool results from final response
 364	var toolResults []ToolResultContent
 365	for _, content := range result.Response.Content {
 366		if toolResult, ok := AsContentType[ToolResultContent](content); ok {
 367			toolResults = append(toolResults, toolResult)
 368		}
 369	}
 370
 371	require.Len(t, toolResults, 1)
 372	require.Equal(t, "call-1", toolResults[0].ToolCallID)
 373	require.Equal(t, "tool1", toolResults[0].ToolName)
 374
 375	// Verify result content
 376	textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
 377	require.True(t, ok)
 378	require.Equal(t, "result1", textResult.Text)
 379}
 380
 381// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
 382func TestAgent_Generate_MultipleSteps(t *testing.T) {
 383	t.Parallel()
 384
 385	// Create type-safe tool using the new API
 386	type TestInput struct {
 387		Value string `json:"value" description:"Test value"`
 388	}
 389
 390	tool1 := NewAgentTool(
 391		"tool1",
 392		"Test tool",
 393		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
 394			require.Equal(t, "value", input.Value)
 395			return ToolResponse{Content: "result1", IsError: false}, nil
 396		},
 397	)
 398
 399	callCount := 0
 400	model := &mockLanguageModel{
 401		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 402			callCount++
 403			switch callCount {
 404			case 1:
 405				// First call - return tool call with FinishReasonToolCalls
 406				return &Response{
 407					Content: []Content{
 408						ToolCallContent{
 409							ToolCallID: "call-1",
 410							ToolName:   "tool1",
 411							Input:      `{"value":"value"}`,
 412						},
 413					},
 414					Usage: Usage{
 415						InputTokens:  10,
 416						OutputTokens: 5,
 417						TotalTokens:  15,
 418					},
 419					FinishReason: FinishReasonToolCalls, // This triggers multi-step
 420				}, nil
 421			case 2:
 422				// Second call - return final text
 423				return &Response{
 424					Content: []Content{
 425						TextContent{Text: "Hello, world!"},
 426					},
 427					Usage: Usage{
 428						InputTokens:  3,
 429						OutputTokens: 10,
 430						TotalTokens:  13,
 431					},
 432					FinishReason: FinishReasonStop,
 433				}, nil
 434			default:
 435				t.Fatalf("Unexpected call count: %d", callCount)
 436				return nil, nil
 437			}
 438		},
 439	}
 440
 441	agent := NewAgent(model, WithTools(tool1))
 442	result, err := agent.Generate(context.Background(), AgentCall{
 443		Prompt: "test-input",
 444	})
 445
 446	require.NoError(t, err)
 447	require.NotNil(t, result)
 448	require.Len(t, result.Steps, 2)
 449
 450	// Check total usage sums both steps
 451	require.Equal(t, int64(13), result.TotalUsage.InputTokens)  // 10 + 3
 452	require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
 453	require.Equal(t, int64(28), result.TotalUsage.TotalTokens)  // 15 + 13
 454
 455	// Final response should be from last step
 456	require.Len(t, result.Response.Content, 1)
 457	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 458	require.True(t, ok)
 459	require.Equal(t, "Hello, world!", textContent.Text)
 460
 461	// result.toolCalls should be empty (from last step)
 462	var toolCalls []ToolCallContent
 463	for _, content := range result.Response.Content {
 464		if _, ok := AsContentType[ToolCallContent](content); ok {
 465			toolCalls = append(toolCalls, content.(ToolCallContent))
 466		}
 467	}
 468	require.Len(t, toolCalls, 0)
 469
 470	// result.toolResults should be empty (from last step)
 471	var toolResults []ToolResultContent
 472	for _, content := range result.Response.Content {
 473		if _, ok := AsContentType[ToolResultContent](content); ok {
 474			toolResults = append(toolResults, content.(ToolResultContent))
 475		}
 476	}
 477	require.Len(t, toolResults, 0)
 478}
 479
 480// Test basic text generation
 481func TestAgent_Generate_BasicText(t *testing.T) {
 482	t.Parallel()
 483
 484	model := &mockLanguageModel{
 485		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 486			return &Response{
 487				Content: []Content{
 488					TextContent{Text: "Hello, world!"},
 489				},
 490				Usage: Usage{
 491					InputTokens:  3,
 492					OutputTokens: 10,
 493					TotalTokens:  13,
 494				},
 495				FinishReason: FinishReasonStop,
 496			}, nil
 497		},
 498	}
 499
 500	agent := NewAgent(model)
 501	result, err := agent.Generate(context.Background(), AgentCall{
 502		Prompt: "test prompt",
 503	})
 504
 505	require.NoError(t, err)
 506	require.NotNil(t, result)
 507	require.Len(t, result.Steps, 1)
 508
 509	// Check final response
 510	require.Len(t, result.Response.Content, 1)
 511	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 512	require.True(t, ok)
 513	require.Equal(t, "Hello, world!", textContent.Text)
 514
 515	// Check usage
 516	require.Equal(t, int64(3), result.Response.Usage.InputTokens)
 517	require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
 518	require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
 519
 520	// Check total usage
 521	require.Equal(t, int64(3), result.TotalUsage.InputTokens)
 522	require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
 523	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
 524}
 525
 526// Test empty prompt validation
 527func TestAgent_Generate_EmptyPrompt(t *testing.T) {
 528	t.Parallel()
 529
 530	model := &mockLanguageModel{}
 531	agent := NewAgent(model)
 532
 533	t.Run("fails without messages", func(t *testing.T) {
 534		result, err := agent.Generate(context.Background(), AgentCall{
 535			Prompt: "",
 536		})
 537		require.Error(t, err)
 538		require.Nil(t, result)
 539		require.Contains(t, err.Error(), "prompt can't be empty when there are no messages")
 540	})
 541
 542	t.Run("fails with files even if messages exist", func(t *testing.T) {
 543		result, err := agent.Generate(context.Background(), AgentCall{
 544			Prompt: "",
 545			Messages: []Message{
 546				{Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
 547			},
 548			Files: []FilePart{{Filename: "test.txt", Data: []byte("test"), MediaType: "text/plain"}},
 549		})
 550		require.Error(t, err)
 551		require.Nil(t, result)
 552		require.Contains(t, err.Error(), "prompt can't be empty when there are files")
 553	})
 554
 555	t.Run("fails when last message is assistant", func(t *testing.T) {
 556		result, err := agent.Generate(context.Background(), AgentCall{
 557			Prompt: "",
 558			Messages: []Message{
 559				{Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
 560				{Role: MessageRoleAssistant, Content: []MessagePart{TextPart{Text: "hi there"}}},
 561			},
 562		})
 563		require.Error(t, err)
 564		require.Nil(t, result)
 565		require.Contains(t, err.Error(), "prompt can't be empty when the last message is not a user or tool message")
 566	})
 567
 568	t.Run("succeeds when last message is user", func(t *testing.T) {
 569		model := &mockLanguageModel{
 570			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 571				return &Response{
 572					Content:      []Content{TextContent{Text: "response"}},
 573					FinishReason: FinishReasonStop,
 574				}, nil
 575			},
 576		}
 577		agent := NewAgent(model)
 578
 579		result, err := agent.Generate(context.Background(), AgentCall{
 580			Prompt: "",
 581			Messages: []Message{
 582				{Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
 583			},
 584		})
 585		require.NoError(t, err)
 586		require.NotNil(t, result)
 587	})
 588
 589	t.Run("succeeds when last message is tool", func(t *testing.T) {
 590		model := &mockLanguageModel{
 591			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 592				return &Response{
 593					Content:      []Content{TextContent{Text: "response"}},
 594					FinishReason: FinishReasonStop,
 595				}, nil
 596			},
 597		}
 598		agent := NewAgent(model)
 599
 600		result, err := agent.Generate(context.Background(), AgentCall{
 601			Prompt: "",
 602			Messages: []Message{
 603				{Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
 604				{Role: MessageRoleAssistant, Content: []MessagePart{ToolCallPart{ToolCallID: "call_1", ToolName: "test"}}},
 605				{Role: MessageRoleTool, Content: []MessagePart{ToolResultPart{ToolCallID: "call_1", Output: ToolResultOutputContentText{Text: "result"}}}},
 606			},
 607		})
 608		require.NoError(t, err)
 609		require.NotNil(t, result)
 610	})
 611}
 612
 613// Test with system prompt
 614func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
 615	t.Parallel()
 616
 617	model := &mockLanguageModel{
 618		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 619			// Verify system message is included
 620			require.Len(t, call.Prompt, 2) // system + user
 621			require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
 622			require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
 623
 624			systemPart, ok := call.Prompt[0].Content[0].(TextPart)
 625			require.True(t, ok)
 626			require.Equal(t, "You are a helpful assistant", systemPart.Text)
 627
 628			return &Response{
 629				Content: []Content{
 630					TextContent{Text: "Hello, world!"},
 631				},
 632				Usage: Usage{
 633					InputTokens:  3,
 634					OutputTokens: 10,
 635					TotalTokens:  13,
 636				},
 637				FinishReason: FinishReasonStop,
 638			}, nil
 639		},
 640	}
 641
 642	agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
 643	result, err := agent.Generate(context.Background(), AgentCall{
 644		Prompt: "test prompt",
 645	})
 646
 647	require.NoError(t, err)
 648	require.NotNil(t, result)
 649}
 650
 651// Test options.activeTools filtering
 652func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
 653	t.Parallel()
 654
 655	tool1 := &mockTool{
 656		name:        "tool1",
 657		description: "Test tool 1",
 658		parameters: map[string]any{
 659			"value": map[string]any{"type": "string"},
 660		},
 661		required: []string{"value"},
 662	}
 663
 664	tool2 := &mockTool{
 665		name:        "tool2",
 666		description: "Test tool 2",
 667		parameters: map[string]any{
 668			"value": map[string]any{"type": "string"},
 669		},
 670		required: []string{"value"},
 671	}
 672
 673	model := &mockLanguageModel{
 674		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 675			// Verify only tool1 is available
 676			require.Len(t, call.Tools, 1)
 677			functionTool, ok := call.Tools[0].(FunctionTool)
 678			require.True(t, ok)
 679			require.Equal(t, "tool1", functionTool.Name)
 680
 681			return &Response{
 682				Content: []Content{
 683					TextContent{Text: "Hello, world!"},
 684				},
 685				Usage: Usage{
 686					InputTokens:  3,
 687					OutputTokens: 10,
 688					TotalTokens:  13,
 689				},
 690				FinishReason: FinishReasonStop,
 691			}, nil
 692		},
 693	}
 694
 695	agent := NewAgent(model, WithTools(tool1, tool2))
 696	result, err := agent.Generate(context.Background(), AgentCall{
 697		Prompt:      "test-input",
 698		ActiveTools: []string{"tool1"}, // Only tool1 should be active
 699	})
 700
 701	require.NoError(t, err)
 702	require.NotNil(t, result)
 703}
 704
 705func TestAgent_Generate_OptionsActiveTools_WithProviderDefinedTools(t *testing.T) {
 706	t.Parallel()
 707
 708	tool1 := &mockTool{
 709		name:        "tool1",
 710		description: "Test tool 1",
 711		parameters: map[string]any{
 712			"value": map[string]any{"type": "string"},
 713		},
 714		required: []string{"value"},
 715	}
 716
 717	providerTool1 := ProviderDefinedTool{ID: "provider.web_search", Name: "web_search"}
 718	providerTool2 := ProviderDefinedTool{ID: "provider.code_execution", Name: "code_execution"}
 719
 720	model := &mockLanguageModel{
 721		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 722			require.Len(t, call.Tools, 2)
 723
 724			functionTool, ok := call.Tools[0].(FunctionTool)
 725			require.True(t, ok)
 726			require.Equal(t, "tool1", functionTool.Name)
 727
 728			providerTool, ok := call.Tools[1].(ProviderDefinedTool)
 729			require.True(t, ok)
 730			require.Equal(t, "web_search", providerTool.Name)
 731
 732			return &Response{
 733				Content: []Content{
 734					TextContent{Text: "Hello, world!"},
 735				},
 736				Usage: Usage{
 737					InputTokens:  3,
 738					OutputTokens: 10,
 739					TotalTokens:  13,
 740				},
 741				FinishReason: FinishReasonStop,
 742			}, nil
 743		},
 744	}
 745
 746	agent := NewAgent(model, WithTools(tool1), WithProviderDefinedTools(providerTool1, providerTool2))
 747	result, err := agent.Generate(context.Background(), AgentCall{
 748		Prompt:      "test-input",
 749		ActiveTools: []string{"tool1", "web_search"}, // Only tool1 and web_search should be active
 750	})
 751
 752	require.NoError(t, err)
 753	require.NotNil(t, result)
 754}
 755
 756func TestResponseContent_Getters(t *testing.T) {
 757	t.Parallel()
 758
 759	// Create test content with all types
 760	content := ResponseContent{
 761		TextContent{Text: "Hello world"},
 762		ReasoningContent{Text: "Let me think..."},
 763		FileContent{Data: []byte("file data"), MediaType: "text/plain"},
 764		SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
 765		ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
 766		ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
 767	}
 768
 769	// Test Text()
 770	require.Equal(t, "Hello world", content.Text())
 771
 772	// Test Reasoning()
 773	reasoning := content.Reasoning()
 774	require.Len(t, reasoning, 1)
 775	require.Equal(t, "Let me think...", reasoning[0].Text)
 776
 777	// Test ReasoningText()
 778	require.Equal(t, "Let me think...", content.ReasoningText())
 779
 780	// Test Files()
 781	files := content.Files()
 782	require.Len(t, files, 1)
 783	require.Equal(t, "text/plain", files[0].MediaType)
 784	require.Equal(t, []byte("file data"), files[0].Data)
 785
 786	// Test Sources()
 787	sources := content.Sources()
 788	require.Len(t, sources, 1)
 789	require.Equal(t, SourceTypeURL, sources[0].SourceType)
 790	require.Equal(t, "https://example.com", sources[0].URL)
 791	require.Equal(t, "Example", sources[0].Title)
 792
 793	// Test ToolCalls()
 794	toolCalls := content.ToolCalls()
 795	require.Len(t, toolCalls, 1)
 796	require.Equal(t, "call1", toolCalls[0].ToolCallID)
 797	require.Equal(t, "test_tool", toolCalls[0].ToolName)
 798	require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
 799
 800	// Test ToolResults()
 801	toolResults := content.ToolResults()
 802	require.Len(t, toolResults, 1)
 803	require.Equal(t, "call1", toolResults[0].ToolCallID)
 804	require.Equal(t, "test_tool", toolResults[0].ToolName)
 805	result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
 806	require.True(t, ok)
 807	require.Equal(t, "result", result.Text)
 808}
 809
 810func TestResponseContent_Getters_Empty(t *testing.T) {
 811	t.Parallel()
 812
 813	// Test with empty content
 814	content := ResponseContent{}
 815
 816	require.Equal(t, "", content.Text())
 817	require.Equal(t, "", content.ReasoningText())
 818	require.Empty(t, content.Reasoning())
 819	require.Empty(t, content.Files())
 820	require.Empty(t, content.Sources())
 821	require.Empty(t, content.ToolCalls())
 822	require.Empty(t, content.ToolResults())
 823}
 824
 825func TestResponseContent_Getters_MultipleItems(t *testing.T) {
 826	t.Parallel()
 827
 828	// Test with multiple items of same type
 829	content := ResponseContent{
 830		ReasoningContent{Text: "First thought"},
 831		ReasoningContent{Text: "Second thought"},
 832		FileContent{Data: []byte("file1"), MediaType: "text/plain"},
 833		FileContent{Data: []byte("file2"), MediaType: "image/png"},
 834	}
 835
 836	// Test multiple reasoning
 837	reasoning := content.Reasoning()
 838	require.Len(t, reasoning, 2)
 839	require.Equal(t, "First thought", reasoning[0].Text)
 840	require.Equal(t, "Second thought", reasoning[1].Text)
 841
 842	// Test concatenated reasoning text
 843	require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
 844
 845	// Test multiple files
 846	files := content.Files()
 847	require.Len(t, files, 2)
 848	require.Equal(t, "text/plain", files[0].MediaType)
 849	require.Equal(t, "image/png", files[1].MediaType)
 850}
 851
 852func TestStopConditions(t *testing.T) {
 853	t.Parallel()
 854
 855	// Create test steps
 856	step1 := StepResult{
 857		Response: Response{
 858			Content: ResponseContent{
 859				TextContent{Text: "Hello"},
 860			},
 861			FinishReason: FinishReasonToolCalls,
 862			Usage:        Usage{TotalTokens: 10},
 863		},
 864	}
 865
 866	step2 := StepResult{
 867		Response: Response{
 868			Content: ResponseContent{
 869				TextContent{Text: "World"},
 870				ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
 871			},
 872			FinishReason: FinishReasonStop,
 873			Usage:        Usage{TotalTokens: 15},
 874		},
 875	}
 876
 877	step3 := StepResult{
 878		Response: Response{
 879			Content: ResponseContent{
 880				ReasoningContent{Text: "Let me think..."},
 881				FileContent{Data: []byte("data"), MediaType: "text/plain"},
 882			},
 883			FinishReason: FinishReasonLength,
 884			Usage:        Usage{TotalTokens: 20},
 885		},
 886	}
 887
 888	t.Run("StepCountIs", func(t *testing.T) {
 889		t.Parallel()
 890		condition := StepCountIs(2)
 891
 892		// Should not stop with 1 step
 893		require.False(t, condition([]StepResult{step1}))
 894
 895		// Should stop with 2 steps
 896		require.True(t, condition([]StepResult{step1, step2}))
 897
 898		// Should stop with more than 2 steps
 899		require.True(t, condition([]StepResult{step1, step2, step3}))
 900
 901		// Should not stop with empty steps
 902		require.False(t, condition([]StepResult{}))
 903	})
 904
 905	t.Run("HasToolCall", func(t *testing.T) {
 906		t.Parallel()
 907		condition := HasToolCall("search")
 908
 909		// Should not stop when tool not called
 910		require.False(t, condition([]StepResult{step1}))
 911
 912		// Should stop when tool is called in last step
 913		require.True(t, condition([]StepResult{step1, step2}))
 914
 915		// Should not stop when tool called in earlier step but not last
 916		require.False(t, condition([]StepResult{step1, step2, step3}))
 917
 918		// Should not stop with empty steps
 919		require.False(t, condition([]StepResult{}))
 920
 921		// Should not stop when different tool is called
 922		differentToolCondition := HasToolCall("different_tool")
 923		require.False(t, differentToolCondition([]StepResult{step1, step2}))
 924	})
 925
 926	t.Run("HasContent", func(t *testing.T) {
 927		t.Parallel()
 928		reasoningCondition := HasContent(ContentTypeReasoning)
 929		fileCondition := HasContent(ContentTypeFile)
 930
 931		// Should not stop when content type not present
 932		require.False(t, reasoningCondition([]StepResult{step1, step2}))
 933
 934		// Should stop when content type is present in last step
 935		require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
 936		require.True(t, fileCondition([]StepResult{step1, step2, step3}))
 937
 938		// Should not stop with empty steps
 939		require.False(t, reasoningCondition([]StepResult{}))
 940	})
 941
 942	t.Run("FinishReasonIs", func(t *testing.T) {
 943		t.Parallel()
 944		stopCondition := FinishReasonIs(FinishReasonStop)
 945		lengthCondition := FinishReasonIs(FinishReasonLength)
 946
 947		// Should not stop when finish reason doesn't match
 948		require.False(t, stopCondition([]StepResult{step1}))
 949
 950		// Should stop when finish reason matches in last step
 951		require.True(t, stopCondition([]StepResult{step1, step2}))
 952		require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
 953
 954		// Should not stop with empty steps
 955		require.False(t, stopCondition([]StepResult{}))
 956	})
 957
 958	t.Run("MaxTokensUsed", func(t *testing.T) {
 959		condition := MaxTokensUsed(30)
 960
 961		// Should not stop when under limit
 962		require.False(t, condition([]StepResult{step1}))        // 10 tokens
 963		require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
 964
 965		// Should stop when at or over limit
 966		require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
 967
 968		// Should not stop with empty steps
 969		require.False(t, condition([]StepResult{}))
 970	})
 971}
 972
 973func TestStopConditions_Integration(t *testing.T) {
 974	t.Parallel()
 975
 976	t.Run("StepCountIs integration", func(t *testing.T) {
 977		t.Parallel()
 978		model := &mockLanguageModel{
 979			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 980				return &Response{
 981					Content: ResponseContent{
 982						TextContent{Text: "Mock response"},
 983					},
 984					Usage: Usage{
 985						InputTokens:  3,
 986						OutputTokens: 10,
 987						TotalTokens:  13,
 988					},
 989					FinishReason: FinishReasonStop,
 990				}, nil
 991			},
 992		}
 993
 994		agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
 995
 996		result, err := agent.Generate(context.Background(), AgentCall{
 997			Prompt: "test prompt",
 998		})
 999
1000		require.NoError(t, err)
1001		require.NotNil(t, result)
1002		require.Len(t, result.Steps, 1) // Should stop after 1 step
1003	})
1004
1005	t.Run("Multiple stop conditions", func(t *testing.T) {
1006		t.Parallel()
1007		model := &mockLanguageModel{
1008			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1009				return &Response{
1010					Content: ResponseContent{
1011						TextContent{Text: "Mock response"},
1012					},
1013					Usage: Usage{
1014						InputTokens:  3,
1015						OutputTokens: 10,
1016						TotalTokens:  13,
1017					},
1018					FinishReason: FinishReasonStop,
1019				}, nil
1020			},
1021		}
1022
1023		agent := NewAgent(model, WithStopConditions(
1024			StepCountIs(5),                   // Stop after 5 steps
1025			FinishReasonIs(FinishReasonStop), // Or stop on finish reason
1026		))
1027
1028		result, err := agent.Generate(context.Background(), AgentCall{
1029			Prompt: "test prompt",
1030		})
1031
1032		require.NoError(t, err)
1033		require.NotNil(t, result)
1034		// Should stop on first condition met (finish reason stop)
1035		require.Equal(t, FinishReasonStop, result.Response.FinishReason)
1036	})
1037}
1038
1039func TestPrepareStep(t *testing.T) {
1040	t.Parallel()
1041
1042	t.Run("System prompt modification", func(t *testing.T) {
1043		t.Parallel()
1044		var capturedSystemPrompt string
1045		model := &mockLanguageModel{
1046			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1047				// Capture the system message to verify it was modified
1048				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1049					if len(call.Prompt[0].Content) > 0 {
1050						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1051							capturedSystemPrompt = textPart.Text
1052						}
1053					}
1054				}
1055				return &Response{
1056					Content: ResponseContent{
1057						TextContent{Text: "Response"},
1058					},
1059					Usage:        Usage{TotalTokens: 10},
1060					FinishReason: FinishReasonStop,
1061				}, nil
1062			},
1063		}
1064
1065		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1066			newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
1067			return ctx, PrepareStepResult{
1068				Model:    options.Model,
1069				Messages: options.Messages,
1070				System:   &newSystem,
1071			}, nil
1072		}
1073
1074		agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
1075
1076		result, err := agent.Generate(context.Background(), AgentCall{
1077			Prompt:      "test prompt",
1078			PrepareStep: prepareStepFunc,
1079		})
1080
1081		require.NoError(t, err)
1082		require.NotNil(t, result)
1083		require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
1084	})
1085
1086	t.Run("Tool choice modification", func(t *testing.T) {
1087		t.Parallel()
1088		var capturedToolChoice *ToolChoice
1089		model := &mockLanguageModel{
1090			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1091				capturedToolChoice = call.ToolChoice
1092				return &Response{
1093					Content: ResponseContent{
1094						TextContent{Text: "Response"},
1095					},
1096					Usage:        Usage{TotalTokens: 10},
1097					FinishReason: FinishReasonStop,
1098				}, nil
1099			},
1100		}
1101
1102		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1103			toolChoice := ToolChoiceNone
1104			return ctx, PrepareStepResult{
1105				Model:      options.Model,
1106				Messages:   options.Messages,
1107				ToolChoice: &toolChoice,
1108			}, nil
1109		}
1110
1111		agent := NewAgent(model)
1112
1113		result, err := agent.Generate(context.Background(), AgentCall{
1114			Prompt:      "test prompt",
1115			PrepareStep: prepareStepFunc,
1116		})
1117
1118		require.NoError(t, err)
1119		require.NotNil(t, result)
1120		require.NotNil(t, capturedToolChoice)
1121		require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1122	})
1123
1124	t.Run("Active tools modification", func(t *testing.T) {
1125		t.Parallel()
1126		var capturedToolNames []string
1127		model := &mockLanguageModel{
1128			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1129				// Capture tool names to verify active tools were modified
1130				for _, tool := range call.Tools {
1131					capturedToolNames = append(capturedToolNames, tool.GetName())
1132				}
1133				return &Response{
1134					Content: ResponseContent{
1135						TextContent{Text: "Response"},
1136					},
1137					Usage:        Usage{TotalTokens: 10},
1138					FinishReason: FinishReasonStop,
1139				}, nil
1140			},
1141		}
1142
1143		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1144		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1145		tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1146
1147		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1148			activeTools := []string{"tool2"} // Only tool2 should be active
1149			return ctx, PrepareStepResult{
1150				Model:       options.Model,
1151				Messages:    options.Messages,
1152				ActiveTools: activeTools,
1153			}, nil
1154		}
1155
1156		agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1157
1158		result, err := agent.Generate(context.Background(), AgentCall{
1159			Prompt:      "test prompt",
1160			PrepareStep: prepareStepFunc,
1161		})
1162
1163		require.NoError(t, err)
1164		require.NotNil(t, result)
1165		require.Len(t, capturedToolNames, 1)
1166		require.Equal(t, "tool2", capturedToolNames[0])
1167	})
1168
1169	t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1170		t.Parallel()
1171		var capturedToolCount int
1172		model := &mockLanguageModel{
1173			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1174				capturedToolCount = len(call.Tools)
1175				return &Response{
1176					Content: ResponseContent{
1177						TextContent{Text: "Response"},
1178					},
1179					Usage:        Usage{TotalTokens: 10},
1180					FinishReason: FinishReasonStop,
1181				}, nil
1182			},
1183		}
1184
1185		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1186
1187		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1188			return ctx, PrepareStepResult{
1189				Model:           options.Model,
1190				Messages:        options.Messages,
1191				DisableAllTools: true, // Disable all tools for this step
1192			}, nil
1193		}
1194
1195		agent := NewAgent(model, WithTools(tool1))
1196
1197		result, err := agent.Generate(context.Background(), AgentCall{
1198			Prompt:      "test prompt",
1199			PrepareStep: prepareStepFunc,
1200		})
1201
1202		require.NoError(t, err)
1203		require.NotNil(t, result)
1204		require.Equal(t, 0, capturedToolCount) // No tools should be passed
1205	})
1206
1207	t.Run("All fields modified together", func(t *testing.T) {
1208		t.Parallel()
1209		var capturedSystemPrompt string
1210		var capturedToolChoice *ToolChoice
1211		var capturedToolNames []string
1212
1213		model := &mockLanguageModel{
1214			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1215				// Capture system prompt
1216				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1217					if len(call.Prompt[0].Content) > 0 {
1218						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1219							capturedSystemPrompt = textPart.Text
1220						}
1221					}
1222				}
1223				// Capture tool choice
1224				capturedToolChoice = call.ToolChoice
1225				// Capture tool names
1226				for _, tool := range call.Tools {
1227					capturedToolNames = append(capturedToolNames, tool.GetName())
1228				}
1229				return &Response{
1230					Content: ResponseContent{
1231						TextContent{Text: "Response"},
1232					},
1233					Usage:        Usage{TotalTokens: 10},
1234					FinishReason: FinishReasonStop,
1235				}, nil
1236			},
1237		}
1238
1239		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1240		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1241
1242		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1243			newSystem := "Step-specific system"
1244			toolChoice := SpecificToolChoice("tool1")
1245			activeTools := []string{"tool1"}
1246			return ctx, PrepareStepResult{
1247				Model:       options.Model,
1248				Messages:    options.Messages,
1249				System:      &newSystem,
1250				ToolChoice:  &toolChoice,
1251				ActiveTools: activeTools,
1252			}, nil
1253		}
1254
1255		agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1256
1257		result, err := agent.Generate(context.Background(), AgentCall{
1258			Prompt:      "test prompt",
1259			PrepareStep: prepareStepFunc,
1260		})
1261
1262		require.NoError(t, err)
1263		require.NotNil(t, result)
1264		require.Equal(t, "Step-specific system", capturedSystemPrompt)
1265		require.NotNil(t, capturedToolChoice)
1266		require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1267		require.Len(t, capturedToolNames, 1)
1268		require.Equal(t, "tool1", capturedToolNames[0])
1269	})
1270
1271	t.Run("Nil fields use parent values", func(t *testing.T) {
1272		t.Parallel()
1273		var capturedSystemPrompt string
1274		var capturedToolChoice *ToolChoice
1275		var capturedToolNames []string
1276
1277		model := &mockLanguageModel{
1278			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1279				// Capture system prompt
1280				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1281					if len(call.Prompt[0].Content) > 0 {
1282						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1283							capturedSystemPrompt = textPart.Text
1284						}
1285					}
1286				}
1287				// Capture tool choice
1288				capturedToolChoice = call.ToolChoice
1289				// Capture tool names
1290				for _, tool := range call.Tools {
1291					capturedToolNames = append(capturedToolNames, tool.GetName())
1292				}
1293				return &Response{
1294					Content: ResponseContent{
1295						TextContent{Text: "Response"},
1296					},
1297					Usage:        Usage{TotalTokens: 10},
1298					FinishReason: FinishReasonStop,
1299				}, nil
1300			},
1301		}
1302
1303		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1304
1305		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1306			// All optional fields are nil, should use parent values
1307			return ctx, PrepareStepResult{
1308				Model:       options.Model,
1309				Messages:    options.Messages,
1310				System:      nil, // Use parent
1311				ToolChoice:  nil, // Use parent (auto)
1312				ActiveTools: nil, // Use parent (all tools)
1313			}, nil
1314		}
1315
1316		agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1317
1318		result, err := agent.Generate(context.Background(), AgentCall{
1319			Prompt:      "test prompt",
1320			PrepareStep: prepareStepFunc,
1321		})
1322
1323		require.NoError(t, err)
1324		require.NotNil(t, result)
1325		require.Equal(t, "Parent system", capturedSystemPrompt)
1326		require.NotNil(t, capturedToolChoice)
1327		require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1328		require.Len(t, capturedToolNames, 1)
1329		require.Equal(t, "tool1", capturedToolNames[0])
1330	})
1331
1332	t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1333		t.Parallel()
1334		var capturedToolNames []string
1335		model := &mockLanguageModel{
1336			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1337				// Capture tool names to verify all tools are included
1338				for _, tool := range call.Tools {
1339					capturedToolNames = append(capturedToolNames, tool.GetName())
1340				}
1341				return &Response{
1342					Content: ResponseContent{
1343						TextContent{Text: "Response"},
1344					},
1345					Usage:        Usage{TotalTokens: 10},
1346					FinishReason: FinishReasonStop,
1347				}, nil
1348			},
1349		}
1350
1351		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1352		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1353
1354		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1355			return ctx, PrepareStepResult{
1356				Model:       options.Model,
1357				Messages:    options.Messages,
1358				ActiveTools: []string{}, // Empty slice means all tools
1359			}, nil
1360		}
1361
1362		agent := NewAgent(model, WithTools(tool1, tool2))
1363
1364		result, err := agent.Generate(context.Background(), AgentCall{
1365			Prompt:      "test prompt",
1366			PrepareStep: prepareStepFunc,
1367		})
1368
1369		require.NoError(t, err)
1370		require.NotNil(t, result)
1371		require.Len(t, capturedToolNames, 2) // All tools should be included
1372		require.Contains(t, capturedToolNames, "tool1")
1373		require.Contains(t, capturedToolNames, "tool2")
1374	})
1375}
1376
1377func TestToolCallRepair(t *testing.T) {
1378	t.Parallel()
1379
1380	t.Run("Valid tool call passes validation", func(t *testing.T) {
1381		t.Parallel()
1382		model := &mockLanguageModel{
1383			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1384				return &Response{
1385					Content: ResponseContent{
1386						TextContent{Text: "Response"},
1387						ToolCallContent{
1388							ToolCallID: "call1",
1389							ToolName:   "test_tool",
1390							Input:      `{"value": "test"}`, // Valid JSON with required field
1391						},
1392					},
1393					Usage:        Usage{TotalTokens: 10},
1394					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1395				}, nil
1396			},
1397		}
1398
1399		tool := &mockTool{
1400			name:        "test_tool",
1401			description: "Test tool",
1402			parameters: map[string]any{
1403				"value": map[string]any{"type": "string"},
1404			},
1405			required: []string{"value"},
1406			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1407				return ToolResponse{Content: "success", IsError: false}, nil
1408			},
1409		}
1410
1411		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1412
1413		result, err := agent.Generate(context.Background(), AgentCall{
1414			Prompt: "test prompt",
1415		})
1416
1417		require.NoError(t, err)
1418		require.NotNil(t, result)
1419		require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1420
1421		// Check that tool call was executed successfully
1422		toolCalls := result.Steps[0].Content.ToolCalls()
1423		require.Len(t, toolCalls, 1)
1424		require.False(t, toolCalls[0].Invalid) // Should be valid
1425	})
1426
1427	t.Run("Invalid tool call without repair function", func(t *testing.T) {
1428		t.Parallel()
1429		model := &mockLanguageModel{
1430			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1431				return &Response{
1432					Content: ResponseContent{
1433						TextContent{Text: "Response"},
1434						ToolCallContent{
1435							ToolCallID: "call1",
1436							ToolName:   "test_tool",
1437							Input:      `{"wrong_field": "test"}`, // Missing required field
1438						},
1439					},
1440					Usage:        Usage{TotalTokens: 10},
1441					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1442				}, nil
1443			},
1444		}
1445
1446		tool := &mockTool{
1447			name:        "test_tool",
1448			description: "Test tool",
1449			parameters: map[string]any{
1450				"value": map[string]any{"type": "string"},
1451			},
1452			required: []string{"value"},
1453		}
1454
1455		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1456
1457		result, err := agent.Generate(context.Background(), AgentCall{
1458			Prompt: "test prompt",
1459		})
1460
1461		require.NoError(t, err)
1462		require.NotNil(t, result)
1463		require.Len(t, result.Steps, 1) // Only one step
1464
1465		// Check that tool call was marked as invalid
1466		toolCalls := result.Steps[0].Content.ToolCalls()
1467		require.Len(t, toolCalls, 1)
1468		require.True(t, toolCalls[0].Invalid) // Should be invalid
1469		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1470	})
1471
1472	t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1473		t.Parallel()
1474		model := &mockLanguageModel{
1475			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1476				return &Response{
1477					Content: ResponseContent{
1478						TextContent{Text: "Response"},
1479						ToolCallContent{
1480							ToolCallID: "call1",
1481							ToolName:   "test_tool",
1482							Input:      `{"wrong_field": "test"}`, // Missing required field
1483						},
1484					},
1485					Usage:        Usage{TotalTokens: 10},
1486					FinishReason: FinishReasonStop, // Changed to stop
1487				}, nil
1488			},
1489		}
1490
1491		tool := &mockTool{
1492			name:        "test_tool",
1493			description: "Test tool",
1494			parameters: map[string]any{
1495				"value": map[string]any{"type": "string"},
1496			},
1497			required: []string{"value"},
1498			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1499				return ToolResponse{Content: "repaired_success", IsError: false}, nil
1500			},
1501		}
1502
1503		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1504			// Simple repair: add the missing required field
1505			repairedToolCall := options.OriginalToolCall
1506			repairedToolCall.Input = `{"value": "repaired"}`
1507			return &repairedToolCall, nil
1508		}
1509
1510		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1511
1512		result, err := agent.Generate(context.Background(), AgentCall{
1513			Prompt: "test prompt",
1514		})
1515
1516		require.NoError(t, err)
1517		require.NotNil(t, result)
1518		require.Len(t, result.Steps, 1) // Only one step
1519
1520		// Check that tool call was repaired and is now valid
1521		toolCalls := result.Steps[0].Content.ToolCalls()
1522		require.Len(t, toolCalls, 1)
1523		require.False(t, toolCalls[0].Invalid)                        // Should be valid after repair
1524		require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1525	})
1526
1527	t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1528		t.Parallel()
1529		model := &mockLanguageModel{
1530			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1531				return &Response{
1532					Content: ResponseContent{
1533						TextContent{Text: "Response"},
1534						ToolCallContent{
1535							ToolCallID: "call1",
1536							ToolName:   "test_tool",
1537							Input:      `{"wrong_field": "test"}`, // Missing required field
1538						},
1539					},
1540					Usage:        Usage{TotalTokens: 10},
1541					FinishReason: FinishReasonStop, // Changed to stop
1542				}, nil
1543			},
1544		}
1545
1546		tool := &mockTool{
1547			name:        "test_tool",
1548			description: "Test tool",
1549			parameters: map[string]any{
1550				"value": map[string]any{"type": "string"},
1551			},
1552			required: []string{"value"},
1553		}
1554
1555		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1556			// Repair function fails
1557			return nil, errors.New("repair failed")
1558		}
1559
1560		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1561
1562		result, err := agent.Generate(context.Background(), AgentCall{
1563			Prompt: "test prompt",
1564		})
1565
1566		require.NoError(t, err)
1567		require.NotNil(t, result)
1568		require.Len(t, result.Steps, 1) // Only one step
1569
1570		// Check that tool call was marked as invalid since repair failed
1571		toolCalls := result.Steps[0].Content.ToolCalls()
1572		require.Len(t, toolCalls, 1)
1573		require.True(t, toolCalls[0].Invalid) // Should be invalid
1574		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1575	})
1576
1577	t.Run("Nonexistent tool call", func(t *testing.T) {
1578		t.Parallel()
1579		model := &mockLanguageModel{
1580			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1581				return &Response{
1582					Content: ResponseContent{
1583						TextContent{Text: "Response"},
1584						ToolCallContent{
1585							ToolCallID: "call1",
1586							ToolName:   "nonexistent_tool",
1587							Input:      `{"value": "test"}`,
1588						},
1589					},
1590					Usage:        Usage{TotalTokens: 10},
1591					FinishReason: FinishReasonStop, // Changed to stop
1592				}, nil
1593			},
1594		}
1595
1596		tool := &mockTool{name: "test_tool", description: "Test tool"}
1597
1598		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1599
1600		result, err := agent.Generate(context.Background(), AgentCall{
1601			Prompt: "test prompt",
1602		})
1603
1604		require.NoError(t, err)
1605		require.NotNil(t, result)
1606		require.Len(t, result.Steps, 1) // Only one step
1607
1608		// Check that tool call was marked as invalid due to nonexistent tool
1609		toolCalls := result.Steps[0].Content.ToolCalls()
1610		require.Len(t, toolCalls, 1)
1611		require.True(t, toolCalls[0].Invalid) // Should be invalid
1612		require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1613	})
1614
1615	t.Run("Invalid JSON in tool call", func(t *testing.T) {
1616		t.Parallel()
1617		model := &mockLanguageModel{
1618			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1619				return &Response{
1620					Content: ResponseContent{
1621						TextContent{Text: "Response"},
1622						ToolCallContent{
1623							ToolCallID: "call1",
1624							ToolName:   "test_tool",
1625							Input:      `{invalid json}`, // Invalid JSON
1626						},
1627					},
1628					Usage:        Usage{TotalTokens: 10},
1629					FinishReason: FinishReasonStop, // Changed to stop
1630				}, nil
1631			},
1632		}
1633
1634		tool := &mockTool{
1635			name:        "test_tool",
1636			description: "Test tool",
1637			parameters: map[string]any{
1638				"value": map[string]any{"type": "string"},
1639			},
1640			required: []string{"value"},
1641		}
1642
1643		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1644
1645		result, err := agent.Generate(context.Background(), AgentCall{
1646			Prompt: "test prompt",
1647		})
1648
1649		require.NoError(t, err)
1650		require.NotNil(t, result)
1651		require.Len(t, result.Steps, 1) // Only one step
1652
1653		// Check that tool call was marked as invalid due to invalid JSON
1654		toolCalls := result.Steps[0].Content.ToolCalls()
1655		require.Len(t, toolCalls, 1)
1656		require.True(t, toolCalls[0].Invalid) // Should be invalid
1657		require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1658	})
1659}
1660
1661// Test media and image tool responses
1662func TestAgent_MediaToolResponses(t *testing.T) {
1663	t.Parallel()
1664
1665	imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
1666	audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
1667
1668	t.Run("Image tool response", func(t *testing.T) {
1669		t.Parallel()
1670
1671		imageTool := &mockTool{
1672			name:        "generate_image",
1673			description: "Generates an image",
1674			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1675				return NewImageResponse(imageData, "image/png"), nil
1676			},
1677		}
1678
1679		model := &mockLanguageModel{
1680			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1681				if len(call.Prompt) == 1 {
1682					// First call - request image tool
1683					return &Response{
1684						Content: []Content{
1685							ToolCallContent{
1686								ToolCallID: "img-1",
1687								ToolName:   "generate_image",
1688								Input:      `{}`,
1689							},
1690						},
1691						Usage:        Usage{TotalTokens: 10},
1692						FinishReason: FinishReasonToolCalls,
1693					}, nil
1694				}
1695				// Second call - after tool execution
1696				return &Response{
1697					Content:      []Content{TextContent{Text: "Image generated"}},
1698					Usage:        Usage{TotalTokens: 20},
1699					FinishReason: FinishReasonStop,
1700				}, nil
1701			},
1702		}
1703
1704		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1705
1706		result, err := agent.Generate(context.Background(), AgentCall{
1707			Prompt: "Generate an image",
1708		})
1709
1710		require.NoError(t, err)
1711		require.NotNil(t, result)
1712		require.Len(t, result.Steps, 2) // Tool call step + final response
1713
1714		// Check tool results in first step
1715		toolResults := result.Steps[0].Content.ToolResults()
1716		require.Len(t, toolResults, 1)
1717
1718		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1719		require.True(t, ok, "Expected media result")
1720		require.Equal(t, string(imageData), mediaResult.Data)
1721		require.Equal(t, "image/png", mediaResult.MediaType)
1722	})
1723
1724	t.Run("Media tool response (audio)", func(t *testing.T) {
1725		t.Parallel()
1726
1727		audioTool := &mockTool{
1728			name:        "generate_audio",
1729			description: "Generates audio",
1730			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1731				return NewMediaResponse(audioData, "audio/wav"), nil
1732			},
1733		}
1734
1735		model := &mockLanguageModel{
1736			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1737				if len(call.Prompt) == 1 {
1738					return &Response{
1739						Content: []Content{
1740							ToolCallContent{
1741								ToolCallID: "audio-1",
1742								ToolName:   "generate_audio",
1743								Input:      `{}`,
1744							},
1745						},
1746						Usage:        Usage{TotalTokens: 10},
1747						FinishReason: FinishReasonToolCalls,
1748					}, nil
1749				}
1750				return &Response{
1751					Content:      []Content{TextContent{Text: "Audio generated"}},
1752					Usage:        Usage{TotalTokens: 20},
1753					FinishReason: FinishReasonStop,
1754				}, nil
1755			},
1756		}
1757
1758		agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
1759
1760		result, err := agent.Generate(context.Background(), AgentCall{
1761			Prompt: "Generate audio",
1762		})
1763
1764		require.NoError(t, err)
1765		require.NotNil(t, result)
1766
1767		toolResults := result.Steps[0].Content.ToolResults()
1768		require.Len(t, toolResults, 1)
1769
1770		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1771		require.True(t, ok, "Expected media result")
1772		require.Equal(t, string(audioData), mediaResult.Data)
1773		require.Equal(t, "audio/wav", mediaResult.MediaType)
1774	})
1775
1776	t.Run("Media response with text", func(t *testing.T) {
1777		t.Parallel()
1778
1779		imageTool := &mockTool{
1780			name:        "screenshot",
1781			description: "Takes a screenshot",
1782			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1783				resp := NewImageResponse(imageData, "image/png")
1784				resp.Content = "Screenshot captured successfully"
1785				return resp, nil
1786			},
1787		}
1788
1789		model := &mockLanguageModel{
1790			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1791				if len(call.Prompt) == 1 {
1792					return &Response{
1793						Content: []Content{
1794							ToolCallContent{
1795								ToolCallID: "screen-1",
1796								ToolName:   "screenshot",
1797								Input:      `{}`,
1798							},
1799						},
1800						Usage:        Usage{TotalTokens: 10},
1801						FinishReason: FinishReasonToolCalls,
1802					}, nil
1803				}
1804				return &Response{
1805					Content:      []Content{TextContent{Text: "Done"}},
1806					Usage:        Usage{TotalTokens: 20},
1807					FinishReason: FinishReasonStop,
1808				}, nil
1809			},
1810		}
1811
1812		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1813
1814		result, err := agent.Generate(context.Background(), AgentCall{
1815			Prompt: "Take a screenshot",
1816		})
1817
1818		require.NoError(t, err)
1819		require.NotNil(t, result)
1820
1821		toolResults := result.Steps[0].Content.ToolResults()
1822		require.Len(t, toolResults, 1)
1823
1824		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1825		require.True(t, ok, "Expected media result")
1826		require.Equal(t, string(imageData), mediaResult.Data)
1827		require.Equal(t, "image/png", mediaResult.MediaType)
1828		require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
1829	})
1830
1831	t.Run("Media response preserves metadata", func(t *testing.T) {
1832		t.Parallel()
1833
1834		type ImageMetadata struct {
1835			Width  int `json:"width"`
1836			Height int `json:"height"`
1837		}
1838
1839		imageTool := &mockTool{
1840			name:        "generate_image",
1841			description: "Generates an image",
1842			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1843				resp := NewImageResponse(imageData, "image/png")
1844				return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
1845			},
1846		}
1847
1848		model := &mockLanguageModel{
1849			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1850				if len(call.Prompt) == 1 {
1851					return &Response{
1852						Content: []Content{
1853							ToolCallContent{
1854								ToolCallID: "img-1",
1855								ToolName:   "generate_image",
1856								Input:      `{}`,
1857							},
1858						},
1859						Usage:        Usage{TotalTokens: 10},
1860						FinishReason: FinishReasonToolCalls,
1861					}, nil
1862				}
1863				return &Response{
1864					Content:      []Content{TextContent{Text: "Done"}},
1865					Usage:        Usage{TotalTokens: 20},
1866					FinishReason: FinishReasonStop,
1867				}, nil
1868			},
1869		}
1870
1871		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1872
1873		result, err := agent.Generate(context.Background(), AgentCall{
1874			Prompt: "Generate image",
1875		})
1876
1877		require.NoError(t, err)
1878		require.NotNil(t, result)
1879
1880		toolResults := result.Steps[0].Content.ToolResults()
1881		require.Len(t, toolResults, 1)
1882
1883		// Check metadata was preserved
1884		require.NotEmpty(t, toolResults[0].ClientMetadata)
1885
1886		var metadata ImageMetadata
1887		err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
1888		require.NoError(t, err)
1889		require.Equal(t, 800, metadata.Width)
1890		require.Equal(t, 600, metadata.Height)
1891	})
1892}
1893
1894func TestToResponseMessages_ProviderExecutedRouting(t *testing.T) {
1895	t.Parallel()
1896
1897	// Build step content that mixes a provider-executed tool call/result
1898	// (e.g. web search) with a regular local tool call/result.
1899	content := []Content{
1900		// Provider-executed tool call.
1901		&ToolCallContent{
1902			ToolCallID:       "srvtoolu_01",
1903			ToolName:         "web_search",
1904			Input:            `{"query":"test"}`,
1905			ProviderExecuted: true,
1906		},
1907		// Provider-executed tool result.
1908		&ToolResultContent{
1909			ToolCallID:       "srvtoolu_01",
1910			ProviderExecuted: true,
1911		},
1912		// Regular (locally-executed) tool call.
1913		&ToolCallContent{
1914			ToolCallID: "toolu_02",
1915			ToolName:   "calculator",
1916			Input:      `{"expr":"1+1"}`,
1917		},
1918		// Regular tool result.
1919		&ToolResultContent{
1920			ToolCallID: "toolu_02",
1921			Result:     ToolResultOutputContentText{Text: "2"},
1922		},
1923		// Some trailing text.
1924		&TextContent{Text: "Done."},
1925	}
1926
1927	msgs := toResponseMessages(content)
1928
1929	// Expect two messages: assistant + tool.
1930	require.Len(t, msgs, 2)
1931
1932	// Assistant message should contain:
1933	//   1. provider-executed ToolCallPart
1934	//   2. provider-executed ToolResultPart
1935	//   3. regular ToolCallPart
1936	//   4. TextPart
1937	assistant := msgs[0]
1938	require.Equal(t, MessageRoleAssistant, assistant.Role)
1939	require.Len(t, assistant.Content, 4)
1940
1941	// Verify provider-executed tool call is in assistant.
1942	tc1, ok := AsMessagePart[ToolCallPart](assistant.Content[0])
1943	require.True(t, ok)
1944	require.Equal(t, "srvtoolu_01", tc1.ToolCallID)
1945	require.True(t, tc1.ProviderExecuted)
1946
1947	// Verify provider-executed tool result is in assistant.
1948	tr1, ok := AsMessagePart[ToolResultPart](assistant.Content[1])
1949	require.True(t, ok)
1950	require.Equal(t, "srvtoolu_01", tr1.ToolCallID)
1951	require.True(t, tr1.ProviderExecuted)
1952
1953	// Verify regular tool call is in assistant.
1954	tc2, ok := AsMessagePart[ToolCallPart](assistant.Content[2])
1955	require.True(t, ok)
1956	require.Equal(t, "toolu_02", tc2.ToolCallID)
1957	require.False(t, tc2.ProviderExecuted)
1958
1959	// Verify text part is in assistant.
1960	text, ok := AsMessagePart[TextPart](assistant.Content[3])
1961	require.True(t, ok)
1962	require.Equal(t, "Done.", text.Text)
1963
1964	// Tool message should contain only the regular tool result.
1965	toolMsg := msgs[1]
1966	require.Equal(t, MessageRoleTool, toolMsg.Role)
1967	require.Len(t, toolMsg.Content, 1)
1968
1969	tr2, ok := AsMessagePart[ToolResultPart](toolMsg.Content[0])
1970	require.True(t, ok)
1971	require.Equal(t, "toolu_02", tr2.ToolCallID)
1972	require.False(t, tr2.ProviderExecuted)
1973}