agent_test.go

   1package fantasy
   2
   3import (
   4	"context"
   5	"encoding/base64"
   6	"encoding/json"
   7	"errors"
   8	"fmt"
   9	"testing"
  10
  11	"github.com/stretchr/testify/require"
  12)
  13
  14// Mock tool for testing
  15type mockTool struct {
  16	name            string
  17	providerOptions ProviderOptions
  18	description     string
  19	parameters      map[string]any
  20	required        []string
  21	executeFunc     func(ctx context.Context, call ToolCall) (ToolResponse, error)
  22}
  23
  24func (m *mockTool) SetProviderOptions(opts ProviderOptions) {
  25	m.providerOptions = opts
  26}
  27
  28func (m *mockTool) ProviderOptions() ProviderOptions {
  29	return m.providerOptions
  30}
  31
  32func (m *mockTool) Info() ToolInfo {
  33	return ToolInfo{
  34		Name:        m.name,
  35		Description: m.description,
  36		Parameters:  m.parameters,
  37		Required:    m.required,
  38	}
  39}
  40
  41func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  42	if m.executeFunc != nil {
  43		return m.executeFunc(ctx, call)
  44	}
  45	return ToolResponse{Content: "mock result", IsError: false}, nil
  46}
  47
  48// Mock language model for testing
  49type mockLanguageModel struct {
  50	generateFunc func(ctx context.Context, call Call) (*Response, error)
  51	streamFunc   func(ctx context.Context, call Call) (StreamResponse, error)
  52}
  53
  54func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
  55	if m.generateFunc != nil {
  56		return m.generateFunc(ctx, call)
  57	}
  58	return &Response{
  59		Content: []Content{
  60			TextContent{Text: "Hello, world!"},
  61		},
  62		Usage: Usage{
  63			InputTokens:  3,
  64			OutputTokens: 10,
  65			TotalTokens:  13,
  66		},
  67		FinishReason: FinishReasonStop,
  68	}, nil
  69}
  70
  71func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
  72	if m.streamFunc != nil {
  73		return m.streamFunc(ctx, call)
  74	}
  75	return nil, fmt.Errorf("mock stream not implemented")
  76}
  77
  78func (m *mockLanguageModel) Provider() string {
  79	return "mock-provider"
  80}
  81
  82func (m *mockLanguageModel) Model() string {
  83	return "mock-model"
  84}
  85
  86func (m *mockLanguageModel) GenerateObject(ctx context.Context, call ObjectCall) (*ObjectResponse, error) {
  87	return nil, fmt.Errorf("mock GenerateObject not implemented")
  88}
  89
  90func (m *mockLanguageModel) StreamObject(ctx context.Context, call ObjectCall) (ObjectStreamResponse, error) {
  91	return nil, fmt.Errorf("mock StreamObject not implemented")
  92}
  93
  94// Test result.content - comprehensive content types (matches TS test)
  95func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
  96	t.Parallel()
  97
  98	// Create a type-safe tool using the new API
  99	type TestInput struct {
 100		Value string `json:"value" description:"Test value"`
 101	}
 102
 103	tool1 := NewAgentTool(
 104		"tool1",
 105		"Test tool",
 106		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
 107			require.Equal(t, "value", input.Value)
 108			return ToolResponse{Content: "result1", IsError: false}, nil
 109		},
 110	)
 111
 112	model := &mockLanguageModel{
 113		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 114			return &Response{
 115				Content: []Content{
 116					TextContent{Text: "Hello, world!"},
 117					SourceContent{
 118						ID:         "123",
 119						URL:        "https://example.com",
 120						Title:      "Example",
 121						SourceType: SourceTypeURL,
 122					},
 123					FileContent{
 124						Data:      []byte{1, 2, 3},
 125						MediaType: "image/png",
 126					},
 127					ReasoningContent{
 128						Text: "I will open the conversation with witty banter.",
 129					},
 130					ToolCallContent{
 131						ToolCallID: "call-1",
 132						ToolName:   "tool1",
 133						Input:      `{"value":"value"}`,
 134					},
 135					TextContent{Text: "More text"},
 136				},
 137				Usage: Usage{
 138					InputTokens:  3,
 139					OutputTokens: 10,
 140					TotalTokens:  13,
 141				},
 142				FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
 143			}, nil
 144		},
 145	}
 146
 147	agent := NewAgent(model, WithTools(tool1))
 148	result, err := agent.Generate(context.Background(), AgentCall{
 149		Prompt: "prompt",
 150	})
 151
 152	require.NoError(t, err)
 153	require.NotNil(t, result)
 154	require.Len(t, result.Steps, 1) // Single step like TypeScript
 155
 156	// Check final response content includes tool result
 157	require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
 158
 159	// Verify each content type in order
 160	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 161	require.True(t, ok)
 162	require.Equal(t, "Hello, world!", textContent.Text)
 163
 164	sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
 165	require.True(t, ok)
 166	require.Equal(t, "123", sourceContent.ID)
 167
 168	fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
 169	require.True(t, ok)
 170	require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
 171
 172	reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
 173	require.True(t, ok)
 174	require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
 175
 176	toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
 177	require.True(t, ok)
 178	require.Equal(t, "call-1", toolCallContent.ToolCallID)
 179
 180	moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
 181	require.True(t, ok)
 182	require.Equal(t, "More text", moreTextContent.Text)
 183
 184	// Tool result should be appended
 185	toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
 186	require.True(t, ok)
 187	require.Equal(t, "call-1", toolResultContent.ToolCallID)
 188	require.Equal(t, "tool1", toolResultContent.ToolName)
 189}
 190
 191// Test result.text extraction
 192func TestAgent_Generate_ResultText(t *testing.T) {
 193	t.Parallel()
 194
 195	model := &mockLanguageModel{
 196		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 197			return &Response{
 198				Content: []Content{
 199					TextContent{Text: "Hello, world!"},
 200				},
 201				Usage: Usage{
 202					InputTokens:  3,
 203					OutputTokens: 10,
 204					TotalTokens:  13,
 205				},
 206				FinishReason: FinishReasonStop,
 207			}, nil
 208		},
 209	}
 210
 211	agent := NewAgent(model)
 212	result, err := agent.Generate(context.Background(), AgentCall{
 213		Prompt: "prompt",
 214	})
 215
 216	require.NoError(t, err)
 217	require.NotNil(t, result)
 218
 219	// Test text extraction from content
 220	text := result.Response.Content.Text()
 221	require.Equal(t, "Hello, world!", text)
 222}
 223
 224// Test result.toolCalls extraction (matches TS test exactly)
 225func TestAgent_Generate_ResultToolCalls(t *testing.T) {
 226	t.Parallel()
 227
 228	// Create type-safe tools using the new API
 229	type Tool1Input struct {
 230		Value string `json:"value" description:"Test value"`
 231	}
 232
 233	type Tool2Input struct {
 234		SomethingElse string `json:"somethingElse" description:"Another test value"`
 235	}
 236
 237	tool1 := NewAgentTool(
 238		"tool1",
 239		"Test tool 1",
 240		func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
 241			return ToolResponse{Content: "result1", IsError: false}, nil
 242		},
 243	)
 244
 245	tool2 := NewAgentTool(
 246		"tool2",
 247		"Test tool 2",
 248		func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
 249			return ToolResponse{Content: "result2", IsError: false}, nil
 250		},
 251	)
 252
 253	model := &mockLanguageModel{
 254		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 255			// Verify tools are passed correctly
 256			require.Len(t, call.Tools, 2)
 257			require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
 258
 259			// Verify prompt structure
 260			require.Len(t, call.Prompt, 1)
 261			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
 262
 263			return &Response{
 264				Content: []Content{
 265					ToolCallContent{
 266						ToolCallID: "call-1",
 267						ToolName:   "tool1",
 268						Input:      `{"value":"value"}`,
 269					},
 270				},
 271				Usage: Usage{
 272					InputTokens:  3,
 273					OutputTokens: 10,
 274					TotalTokens:  13,
 275				},
 276				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
 277			}, nil
 278		},
 279	}
 280
 281	agent := NewAgent(model, WithTools(tool1, tool2))
 282	result, err := agent.Generate(context.Background(), AgentCall{
 283		Prompt: "test-input",
 284	})
 285
 286	require.NoError(t, err)
 287	require.NotNil(t, result)
 288	require.Len(t, result.Steps, 1) // Single step
 289
 290	// Extract tool calls from final response (should be empty since tools don't execute)
 291	var toolCalls []ToolCallContent
 292	for _, content := range result.Response.Content {
 293		if toolCall, ok := AsContentType[ToolCallContent](content); ok {
 294			toolCalls = append(toolCalls, toolCall)
 295		}
 296	}
 297
 298	require.Len(t, toolCalls, 1)
 299	require.Equal(t, "call-1", toolCalls[0].ToolCallID)
 300	require.Equal(t, "tool1", toolCalls[0].ToolName)
 301
 302	// Parse and verify input
 303	var input map[string]any
 304	err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
 305	require.NoError(t, err)
 306	require.Equal(t, "value", input["value"])
 307}
 308
 309// Test result.toolResults extraction (matches TS test exactly)
 310func TestAgent_Generate_ResultToolResults(t *testing.T) {
 311	t.Parallel()
 312
 313	// Create type-safe tool using the new API
 314	type TestInput struct {
 315		Value string `json:"value" description:"Test value"`
 316	}
 317
 318	tool1 := NewAgentTool(
 319		"tool1",
 320		"Test tool",
 321		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
 322			require.Equal(t, "value", input.Value)
 323			return ToolResponse{Content: "result1", IsError: false}, nil
 324		},
 325	)
 326
 327	model := &mockLanguageModel{
 328		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 329			// Verify tools and tool choice
 330			require.Len(t, call.Tools, 1)
 331			require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
 332
 333			// Verify prompt
 334			require.Len(t, call.Prompt, 1)
 335			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
 336
 337			return &Response{
 338				Content: []Content{
 339					ToolCallContent{
 340						ToolCallID: "call-1",
 341						ToolName:   "tool1",
 342						Input:      `{"value":"value"}`,
 343					},
 344				},
 345				Usage: Usage{
 346					InputTokens:  3,
 347					OutputTokens: 10,
 348					TotalTokens:  13,
 349				},
 350				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
 351			}, nil
 352		},
 353	}
 354
 355	agent := NewAgent(model, WithTools(tool1))
 356	result, err := agent.Generate(context.Background(), AgentCall{
 357		Prompt: "test-input",
 358	})
 359
 360	require.NoError(t, err)
 361	require.NotNil(t, result)
 362	require.Len(t, result.Steps, 1) // Single step
 363
 364	// Extract tool results from final response
 365	var toolResults []ToolResultContent
 366	for _, content := range result.Response.Content {
 367		if toolResult, ok := AsContentType[ToolResultContent](content); ok {
 368			toolResults = append(toolResults, toolResult)
 369		}
 370	}
 371
 372	require.Len(t, toolResults, 1)
 373	require.Equal(t, "call-1", toolResults[0].ToolCallID)
 374	require.Equal(t, "tool1", toolResults[0].ToolName)
 375
 376	// Verify result content
 377	textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
 378	require.True(t, ok)
 379	require.Equal(t, "result1", textResult.Text)
 380}
 381
 382// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
 383func TestAgent_Generate_MultipleSteps(t *testing.T) {
 384	t.Parallel()
 385
 386	// Create type-safe tool using the new API
 387	type TestInput struct {
 388		Value string `json:"value" description:"Test value"`
 389	}
 390
 391	tool1 := NewAgentTool(
 392		"tool1",
 393		"Test tool",
 394		func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
 395			require.Equal(t, "value", input.Value)
 396			return ToolResponse{Content: "result1", IsError: false}, nil
 397		},
 398	)
 399
 400	callCount := 0
 401	model := &mockLanguageModel{
 402		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 403			callCount++
 404			switch callCount {
 405			case 1:
 406				// First call - return tool call with FinishReasonToolCalls
 407				return &Response{
 408					Content: []Content{
 409						ToolCallContent{
 410							ToolCallID: "call-1",
 411							ToolName:   "tool1",
 412							Input:      `{"value":"value"}`,
 413						},
 414					},
 415					Usage: Usage{
 416						InputTokens:  10,
 417						OutputTokens: 5,
 418						TotalTokens:  15,
 419					},
 420					FinishReason: FinishReasonToolCalls, // This triggers multi-step
 421				}, nil
 422			case 2:
 423				// Second call - return final text
 424				return &Response{
 425					Content: []Content{
 426						TextContent{Text: "Hello, world!"},
 427					},
 428					Usage: Usage{
 429						InputTokens:  3,
 430						OutputTokens: 10,
 431						TotalTokens:  13,
 432					},
 433					FinishReason: FinishReasonStop,
 434				}, nil
 435			default:
 436				t.Fatalf("Unexpected call count: %d", callCount)
 437				return nil, nil
 438			}
 439		},
 440	}
 441
 442	agent := NewAgent(model, WithTools(tool1))
 443	result, err := agent.Generate(context.Background(), AgentCall{
 444		Prompt: "test-input",
 445	})
 446
 447	require.NoError(t, err)
 448	require.NotNil(t, result)
 449	require.Len(t, result.Steps, 2)
 450
 451	// Check total usage sums both steps
 452	require.Equal(t, int64(13), result.TotalUsage.InputTokens)  // 10 + 3
 453	require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
 454	require.Equal(t, int64(28), result.TotalUsage.TotalTokens)  // 15 + 13
 455
 456	// Final response should be from last step
 457	require.Len(t, result.Response.Content, 1)
 458	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 459	require.True(t, ok)
 460	require.Equal(t, "Hello, world!", textContent.Text)
 461
 462	// result.toolCalls should be empty (from last step)
 463	var toolCalls []ToolCallContent
 464	for _, content := range result.Response.Content {
 465		if _, ok := AsContentType[ToolCallContent](content); ok {
 466			toolCalls = append(toolCalls, content.(ToolCallContent))
 467		}
 468	}
 469	require.Len(t, toolCalls, 0)
 470
 471	// result.toolResults should be empty (from last step)
 472	var toolResults []ToolResultContent
 473	for _, content := range result.Response.Content {
 474		if _, ok := AsContentType[ToolResultContent](content); ok {
 475			toolResults = append(toolResults, content.(ToolResultContent))
 476		}
 477	}
 478	require.Len(t, toolResults, 0)
 479}
 480
 481// Test basic text generation
 482func TestAgent_Generate_BasicText(t *testing.T) {
 483	t.Parallel()
 484
 485	model := &mockLanguageModel{
 486		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 487			return &Response{
 488				Content: []Content{
 489					TextContent{Text: "Hello, world!"},
 490				},
 491				Usage: Usage{
 492					InputTokens:  3,
 493					OutputTokens: 10,
 494					TotalTokens:  13,
 495				},
 496				FinishReason: FinishReasonStop,
 497			}, nil
 498		},
 499	}
 500
 501	agent := NewAgent(model)
 502	result, err := agent.Generate(context.Background(), AgentCall{
 503		Prompt: "test prompt",
 504	})
 505
 506	require.NoError(t, err)
 507	require.NotNil(t, result)
 508	require.Len(t, result.Steps, 1)
 509
 510	// Check final response
 511	require.Len(t, result.Response.Content, 1)
 512	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
 513	require.True(t, ok)
 514	require.Equal(t, "Hello, world!", textContent.Text)
 515
 516	// Check usage
 517	require.Equal(t, int64(3), result.Response.Usage.InputTokens)
 518	require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
 519	require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
 520
 521	// Check total usage
 522	require.Equal(t, int64(3), result.TotalUsage.InputTokens)
 523	require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
 524	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
 525}
 526
 527// Test empty prompt validation
 528func TestAgent_Generate_EmptyPrompt(t *testing.T) {
 529	t.Parallel()
 530
 531	model := &mockLanguageModel{}
 532	agent := NewAgent(model)
 533
 534	t.Run("fails without messages", func(t *testing.T) {
 535		result, err := agent.Generate(context.Background(), AgentCall{
 536			Prompt: "",
 537		})
 538		require.Error(t, err)
 539		require.Nil(t, result)
 540		require.Contains(t, err.Error(), "prompt can't be empty when there are no messages")
 541	})
 542
 543	t.Run("fails with files even if messages exist", func(t *testing.T) {
 544		result, err := agent.Generate(context.Background(), AgentCall{
 545			Prompt: "",
 546			Messages: []Message{
 547				{Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
 548			},
 549			Files: []FilePart{{Filename: "test.txt", Data: []byte("test"), MediaType: "text/plain"}},
 550		})
 551		require.Error(t, err)
 552		require.Nil(t, result)
 553		require.Contains(t, err.Error(), "prompt can't be empty when there are files")
 554	})
 555
 556	t.Run("fails when last message is assistant", func(t *testing.T) {
 557		result, err := agent.Generate(context.Background(), AgentCall{
 558			Prompt: "",
 559			Messages: []Message{
 560				{Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
 561				{Role: MessageRoleAssistant, Content: []MessagePart{TextPart{Text: "hi there"}}},
 562			},
 563		})
 564		require.Error(t, err)
 565		require.Nil(t, result)
 566		require.Contains(t, err.Error(), "prompt can't be empty when the last message is not a user or tool message")
 567	})
 568
 569	t.Run("succeeds when last message is user", func(t *testing.T) {
 570		model := &mockLanguageModel{
 571			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 572				return &Response{
 573					Content:      []Content{TextContent{Text: "response"}},
 574					FinishReason: FinishReasonStop,
 575				}, nil
 576			},
 577		}
 578		agent := NewAgent(model)
 579
 580		result, err := agent.Generate(context.Background(), AgentCall{
 581			Prompt: "",
 582			Messages: []Message{
 583				{Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
 584			},
 585		})
 586		require.NoError(t, err)
 587		require.NotNil(t, result)
 588	})
 589
 590	t.Run("succeeds when last message is tool", func(t *testing.T) {
 591		model := &mockLanguageModel{
 592			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 593				return &Response{
 594					Content:      []Content{TextContent{Text: "response"}},
 595					FinishReason: FinishReasonStop,
 596				}, nil
 597			},
 598		}
 599		agent := NewAgent(model)
 600
 601		result, err := agent.Generate(context.Background(), AgentCall{
 602			Prompt: "",
 603			Messages: []Message{
 604				{Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
 605				{Role: MessageRoleAssistant, Content: []MessagePart{ToolCallPart{ToolCallID: "call_1", ToolName: "test"}}},
 606				{Role: MessageRoleTool, Content: []MessagePart{ToolResultPart{ToolCallID: "call_1", Output: ToolResultOutputContentText{Text: "result"}}}},
 607			},
 608		})
 609		require.NoError(t, err)
 610		require.NotNil(t, result)
 611	})
 612}
 613
 614// Test with system prompt
 615func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
 616	t.Parallel()
 617
 618	model := &mockLanguageModel{
 619		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 620			// Verify system message is included
 621			require.Len(t, call.Prompt, 2) // system + user
 622			require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
 623			require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
 624
 625			systemPart, ok := call.Prompt[0].Content[0].(TextPart)
 626			require.True(t, ok)
 627			require.Equal(t, "You are a helpful assistant", systemPart.Text)
 628
 629			return &Response{
 630				Content: []Content{
 631					TextContent{Text: "Hello, world!"},
 632				},
 633				Usage: Usage{
 634					InputTokens:  3,
 635					OutputTokens: 10,
 636					TotalTokens:  13,
 637				},
 638				FinishReason: FinishReasonStop,
 639			}, nil
 640		},
 641	}
 642
 643	agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
 644	result, err := agent.Generate(context.Background(), AgentCall{
 645		Prompt: "test prompt",
 646	})
 647
 648	require.NoError(t, err)
 649	require.NotNil(t, result)
 650}
 651
 652// Test options.activeTools filtering
 653func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
 654	t.Parallel()
 655
 656	tool1 := &mockTool{
 657		name:        "tool1",
 658		description: "Test tool 1",
 659		parameters: map[string]any{
 660			"value": map[string]any{"type": "string"},
 661		},
 662		required: []string{"value"},
 663	}
 664
 665	tool2 := &mockTool{
 666		name:        "tool2",
 667		description: "Test tool 2",
 668		parameters: map[string]any{
 669			"value": map[string]any{"type": "string"},
 670		},
 671		required: []string{"value"},
 672	}
 673
 674	model := &mockLanguageModel{
 675		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 676			// Verify only tool1 is available
 677			require.Len(t, call.Tools, 1)
 678			functionTool, ok := call.Tools[0].(FunctionTool)
 679			require.True(t, ok)
 680			require.Equal(t, "tool1", functionTool.Name)
 681
 682			return &Response{
 683				Content: []Content{
 684					TextContent{Text: "Hello, world!"},
 685				},
 686				Usage: Usage{
 687					InputTokens:  3,
 688					OutputTokens: 10,
 689					TotalTokens:  13,
 690				},
 691				FinishReason: FinishReasonStop,
 692			}, nil
 693		},
 694	}
 695
 696	agent := NewAgent(model, WithTools(tool1, tool2))
 697	result, err := agent.Generate(context.Background(), AgentCall{
 698		Prompt:      "test-input",
 699		ActiveTools: []string{"tool1"}, // Only tool1 should be active
 700	})
 701
 702	require.NoError(t, err)
 703	require.NotNil(t, result)
 704}
 705
 706func TestAgent_Generate_OptionsActiveTools_WithProviderDefinedTools(t *testing.T) {
 707	t.Parallel()
 708
 709	tool1 := &mockTool{
 710		name:        "tool1",
 711		description: "Test tool 1",
 712		parameters: map[string]any{
 713			"value": map[string]any{"type": "string"},
 714		},
 715		required: []string{"value"},
 716	}
 717
 718	providerTool1 := ProviderDefinedTool{ID: "provider.web_search", Name: "web_search"}
 719	providerTool2 := ProviderDefinedTool{ID: "provider.code_execution", Name: "code_execution"}
 720
 721	model := &mockLanguageModel{
 722		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 723			require.Len(t, call.Tools, 2)
 724
 725			functionTool, ok := call.Tools[0].(FunctionTool)
 726			require.True(t, ok)
 727			require.Equal(t, "tool1", functionTool.Name)
 728
 729			providerTool, ok := call.Tools[1].(ProviderDefinedTool)
 730			require.True(t, ok)
 731			require.Equal(t, "web_search", providerTool.Name)
 732
 733			return &Response{
 734				Content: []Content{
 735					TextContent{Text: "Hello, world!"},
 736				},
 737				Usage: Usage{
 738					InputTokens:  3,
 739					OutputTokens: 10,
 740					TotalTokens:  13,
 741				},
 742				FinishReason: FinishReasonStop,
 743			}, nil
 744		},
 745	}
 746
 747	agent := NewAgent(model, WithTools(tool1), WithProviderDefinedTools(providerTool1, providerTool2))
 748	result, err := agent.Generate(context.Background(), AgentCall{
 749		Prompt:      "test-input",
 750		ActiveTools: []string{"tool1", "web_search"}, // Only tool1 and web_search should be active
 751	})
 752
 753	require.NoError(t, err)
 754	require.NotNil(t, result)
 755}
 756
 757func TestResponseContent_Getters(t *testing.T) {
 758	t.Parallel()
 759
 760	// Create test content with all types
 761	content := ResponseContent{
 762		TextContent{Text: "Hello world"},
 763		ReasoningContent{Text: "Let me think..."},
 764		FileContent{Data: []byte("file data"), MediaType: "text/plain"},
 765		SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
 766		ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
 767		ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
 768	}
 769
 770	// Test Text()
 771	require.Equal(t, "Hello world", content.Text())
 772
 773	// Test Reasoning()
 774	reasoning := content.Reasoning()
 775	require.Len(t, reasoning, 1)
 776	require.Equal(t, "Let me think...", reasoning[0].Text)
 777
 778	// Test ReasoningText()
 779	require.Equal(t, "Let me think...", content.ReasoningText())
 780
 781	// Test Files()
 782	files := content.Files()
 783	require.Len(t, files, 1)
 784	require.Equal(t, "text/plain", files[0].MediaType)
 785	require.Equal(t, []byte("file data"), files[0].Data)
 786
 787	// Test Sources()
 788	sources := content.Sources()
 789	require.Len(t, sources, 1)
 790	require.Equal(t, SourceTypeURL, sources[0].SourceType)
 791	require.Equal(t, "https://example.com", sources[0].URL)
 792	require.Equal(t, "Example", sources[0].Title)
 793
 794	// Test ToolCalls()
 795	toolCalls := content.ToolCalls()
 796	require.Len(t, toolCalls, 1)
 797	require.Equal(t, "call1", toolCalls[0].ToolCallID)
 798	require.Equal(t, "test_tool", toolCalls[0].ToolName)
 799	require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
 800
 801	// Test ToolResults()
 802	toolResults := content.ToolResults()
 803	require.Len(t, toolResults, 1)
 804	require.Equal(t, "call1", toolResults[0].ToolCallID)
 805	require.Equal(t, "test_tool", toolResults[0].ToolName)
 806	result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
 807	require.True(t, ok)
 808	require.Equal(t, "result", result.Text)
 809}
 810
 811func TestResponseContent_Getters_Empty(t *testing.T) {
 812	t.Parallel()
 813
 814	// Test with empty content
 815	content := ResponseContent{}
 816
 817	require.Equal(t, "", content.Text())
 818	require.Equal(t, "", content.ReasoningText())
 819	require.Empty(t, content.Reasoning())
 820	require.Empty(t, content.Files())
 821	require.Empty(t, content.Sources())
 822	require.Empty(t, content.ToolCalls())
 823	require.Empty(t, content.ToolResults())
 824}
 825
 826func TestResponseContent_Getters_MultipleItems(t *testing.T) {
 827	t.Parallel()
 828
 829	// Test with multiple items of same type
 830	content := ResponseContent{
 831		ReasoningContent{Text: "First thought"},
 832		ReasoningContent{Text: "Second thought"},
 833		FileContent{Data: []byte("file1"), MediaType: "text/plain"},
 834		FileContent{Data: []byte("file2"), MediaType: "image/png"},
 835	}
 836
 837	// Test multiple reasoning
 838	reasoning := content.Reasoning()
 839	require.Len(t, reasoning, 2)
 840	require.Equal(t, "First thought", reasoning[0].Text)
 841	require.Equal(t, "Second thought", reasoning[1].Text)
 842
 843	// Test concatenated reasoning text
 844	require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
 845
 846	// Test multiple files
 847	files := content.Files()
 848	require.Len(t, files, 2)
 849	require.Equal(t, "text/plain", files[0].MediaType)
 850	require.Equal(t, "image/png", files[1].MediaType)
 851}
 852
 853func TestStopConditions(t *testing.T) {
 854	t.Parallel()
 855
 856	// Create test steps
 857	step1 := StepResult{
 858		Response: Response{
 859			Content: ResponseContent{
 860				TextContent{Text: "Hello"},
 861			},
 862			FinishReason: FinishReasonToolCalls,
 863			Usage:        Usage{TotalTokens: 10},
 864		},
 865	}
 866
 867	step2 := StepResult{
 868		Response: Response{
 869			Content: ResponseContent{
 870				TextContent{Text: "World"},
 871				ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
 872			},
 873			FinishReason: FinishReasonStop,
 874			Usage:        Usage{TotalTokens: 15},
 875		},
 876	}
 877
 878	step3 := StepResult{
 879		Response: Response{
 880			Content: ResponseContent{
 881				ReasoningContent{Text: "Let me think..."},
 882				FileContent{Data: []byte("data"), MediaType: "text/plain"},
 883			},
 884			FinishReason: FinishReasonLength,
 885			Usage:        Usage{TotalTokens: 20},
 886		},
 887	}
 888
 889	t.Run("StepCountIs", func(t *testing.T) {
 890		t.Parallel()
 891		condition := StepCountIs(2)
 892
 893		// Should not stop with 1 step
 894		require.False(t, condition([]StepResult{step1}))
 895
 896		// Should stop with 2 steps
 897		require.True(t, condition([]StepResult{step1, step2}))
 898
 899		// Should stop with more than 2 steps
 900		require.True(t, condition([]StepResult{step1, step2, step3}))
 901
 902		// Should not stop with empty steps
 903		require.False(t, condition([]StepResult{}))
 904	})
 905
 906	t.Run("HasToolCall", func(t *testing.T) {
 907		t.Parallel()
 908		condition := HasToolCall("search")
 909
 910		// Should not stop when tool not called
 911		require.False(t, condition([]StepResult{step1}))
 912
 913		// Should stop when tool is called in last step
 914		require.True(t, condition([]StepResult{step1, step2}))
 915
 916		// Should not stop when tool called in earlier step but not last
 917		require.False(t, condition([]StepResult{step1, step2, step3}))
 918
 919		// Should not stop with empty steps
 920		require.False(t, condition([]StepResult{}))
 921
 922		// Should not stop when different tool is called
 923		differentToolCondition := HasToolCall("different_tool")
 924		require.False(t, differentToolCondition([]StepResult{step1, step2}))
 925	})
 926
 927	t.Run("HasContent", func(t *testing.T) {
 928		t.Parallel()
 929		reasoningCondition := HasContent(ContentTypeReasoning)
 930		fileCondition := HasContent(ContentTypeFile)
 931
 932		// Should not stop when content type not present
 933		require.False(t, reasoningCondition([]StepResult{step1, step2}))
 934
 935		// Should stop when content type is present in last step
 936		require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
 937		require.True(t, fileCondition([]StepResult{step1, step2, step3}))
 938
 939		// Should not stop with empty steps
 940		require.False(t, reasoningCondition([]StepResult{}))
 941	})
 942
 943	t.Run("FinishReasonIs", func(t *testing.T) {
 944		t.Parallel()
 945		stopCondition := FinishReasonIs(FinishReasonStop)
 946		lengthCondition := FinishReasonIs(FinishReasonLength)
 947
 948		// Should not stop when finish reason doesn't match
 949		require.False(t, stopCondition([]StepResult{step1}))
 950
 951		// Should stop when finish reason matches in last step
 952		require.True(t, stopCondition([]StepResult{step1, step2}))
 953		require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
 954
 955		// Should not stop with empty steps
 956		require.False(t, stopCondition([]StepResult{}))
 957	})
 958
 959	t.Run("MaxTokensUsed", func(t *testing.T) {
 960		condition := MaxTokensUsed(30)
 961
 962		// Should not stop when under limit
 963		require.False(t, condition([]StepResult{step1}))        // 10 tokens
 964		require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
 965
 966		// Should stop when at or over limit
 967		require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
 968
 969		// Should not stop with empty steps
 970		require.False(t, condition([]StepResult{}))
 971	})
 972}
 973
 974func TestStopConditions_Integration(t *testing.T) {
 975	t.Parallel()
 976
 977	t.Run("StepCountIs integration", func(t *testing.T) {
 978		t.Parallel()
 979		model := &mockLanguageModel{
 980			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 981				return &Response{
 982					Content: ResponseContent{
 983						TextContent{Text: "Mock response"},
 984					},
 985					Usage: Usage{
 986						InputTokens:  3,
 987						OutputTokens: 10,
 988						TotalTokens:  13,
 989					},
 990					FinishReason: FinishReasonStop,
 991				}, nil
 992			},
 993		}
 994
 995		agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
 996
 997		result, err := agent.Generate(context.Background(), AgentCall{
 998			Prompt: "test prompt",
 999		})
1000
1001		require.NoError(t, err)
1002		require.NotNil(t, result)
1003		require.Len(t, result.Steps, 1) // Should stop after 1 step
1004	})
1005
1006	t.Run("Multiple stop conditions", func(t *testing.T) {
1007		t.Parallel()
1008		model := &mockLanguageModel{
1009			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1010				return &Response{
1011					Content: ResponseContent{
1012						TextContent{Text: "Mock response"},
1013					},
1014					Usage: Usage{
1015						InputTokens:  3,
1016						OutputTokens: 10,
1017						TotalTokens:  13,
1018					},
1019					FinishReason: FinishReasonStop,
1020				}, nil
1021			},
1022		}
1023
1024		agent := NewAgent(model, WithStopConditions(
1025			StepCountIs(5),                   // Stop after 5 steps
1026			FinishReasonIs(FinishReasonStop), // Or stop on finish reason
1027		))
1028
1029		result, err := agent.Generate(context.Background(), AgentCall{
1030			Prompt: "test prompt",
1031		})
1032
1033		require.NoError(t, err)
1034		require.NotNil(t, result)
1035		// Should stop on first condition met (finish reason stop)
1036		require.Equal(t, FinishReasonStop, result.Response.FinishReason)
1037	})
1038}
1039
1040func TestPrepareStep(t *testing.T) {
1041	t.Parallel()
1042
1043	t.Run("System prompt modification", func(t *testing.T) {
1044		t.Parallel()
1045		var capturedSystemPrompt string
1046		model := &mockLanguageModel{
1047			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1048				// Capture the system message to verify it was modified
1049				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1050					if len(call.Prompt[0].Content) > 0 {
1051						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1052							capturedSystemPrompt = textPart.Text
1053						}
1054					}
1055				}
1056				return &Response{
1057					Content: ResponseContent{
1058						TextContent{Text: "Response"},
1059					},
1060					Usage:        Usage{TotalTokens: 10},
1061					FinishReason: FinishReasonStop,
1062				}, nil
1063			},
1064		}
1065
1066		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1067			newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
1068			return ctx, PrepareStepResult{
1069				Model:    options.Model,
1070				Messages: options.Messages,
1071				System:   &newSystem,
1072			}, nil
1073		}
1074
1075		agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
1076
1077		result, err := agent.Generate(context.Background(), AgentCall{
1078			Prompt:      "test prompt",
1079			PrepareStep: prepareStepFunc,
1080		})
1081
1082		require.NoError(t, err)
1083		require.NotNil(t, result)
1084		require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
1085	})
1086
1087	t.Run("Tool choice modification", func(t *testing.T) {
1088		t.Parallel()
1089		var capturedToolChoice *ToolChoice
1090		model := &mockLanguageModel{
1091			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1092				capturedToolChoice = call.ToolChoice
1093				return &Response{
1094					Content: ResponseContent{
1095						TextContent{Text: "Response"},
1096					},
1097					Usage:        Usage{TotalTokens: 10},
1098					FinishReason: FinishReasonStop,
1099				}, nil
1100			},
1101		}
1102
1103		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1104			toolChoice := ToolChoiceNone
1105			return ctx, PrepareStepResult{
1106				Model:      options.Model,
1107				Messages:   options.Messages,
1108				ToolChoice: &toolChoice,
1109			}, nil
1110		}
1111
1112		agent := NewAgent(model)
1113
1114		result, err := agent.Generate(context.Background(), AgentCall{
1115			Prompt:      "test prompt",
1116			PrepareStep: prepareStepFunc,
1117		})
1118
1119		require.NoError(t, err)
1120		require.NotNil(t, result)
1121		require.NotNil(t, capturedToolChoice)
1122		require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1123	})
1124
1125	t.Run("Active tools modification", func(t *testing.T) {
1126		t.Parallel()
1127		var capturedToolNames []string
1128		model := &mockLanguageModel{
1129			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1130				// Capture tool names to verify active tools were modified
1131				for _, tool := range call.Tools {
1132					capturedToolNames = append(capturedToolNames, tool.GetName())
1133				}
1134				return &Response{
1135					Content: ResponseContent{
1136						TextContent{Text: "Response"},
1137					},
1138					Usage:        Usage{TotalTokens: 10},
1139					FinishReason: FinishReasonStop,
1140				}, nil
1141			},
1142		}
1143
1144		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1145		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1146		tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1147
1148		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1149			activeTools := []string{"tool2"} // Only tool2 should be active
1150			return ctx, PrepareStepResult{
1151				Model:       options.Model,
1152				Messages:    options.Messages,
1153				ActiveTools: activeTools,
1154			}, nil
1155		}
1156
1157		agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1158
1159		result, err := agent.Generate(context.Background(), AgentCall{
1160			Prompt:      "test prompt",
1161			PrepareStep: prepareStepFunc,
1162		})
1163
1164		require.NoError(t, err)
1165		require.NotNil(t, result)
1166		require.Len(t, capturedToolNames, 1)
1167		require.Equal(t, "tool2", capturedToolNames[0])
1168	})
1169
1170	t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1171		t.Parallel()
1172		var capturedToolCount int
1173		model := &mockLanguageModel{
1174			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1175				capturedToolCount = len(call.Tools)
1176				return &Response{
1177					Content: ResponseContent{
1178						TextContent{Text: "Response"},
1179					},
1180					Usage:        Usage{TotalTokens: 10},
1181					FinishReason: FinishReasonStop,
1182				}, nil
1183			},
1184		}
1185
1186		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1187
1188		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1189			return ctx, PrepareStepResult{
1190				Model:           options.Model,
1191				Messages:        options.Messages,
1192				DisableAllTools: true, // Disable all tools for this step
1193			}, nil
1194		}
1195
1196		agent := NewAgent(model, WithTools(tool1))
1197
1198		result, err := agent.Generate(context.Background(), AgentCall{
1199			Prompt:      "test prompt",
1200			PrepareStep: prepareStepFunc,
1201		})
1202
1203		require.NoError(t, err)
1204		require.NotNil(t, result)
1205		require.Equal(t, 0, capturedToolCount) // No tools should be passed
1206	})
1207
1208	t.Run("All fields modified together", func(t *testing.T) {
1209		t.Parallel()
1210		var capturedSystemPrompt string
1211		var capturedToolChoice *ToolChoice
1212		var capturedToolNames []string
1213
1214		model := &mockLanguageModel{
1215			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1216				// Capture system prompt
1217				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1218					if len(call.Prompt[0].Content) > 0 {
1219						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1220							capturedSystemPrompt = textPart.Text
1221						}
1222					}
1223				}
1224				// Capture tool choice
1225				capturedToolChoice = call.ToolChoice
1226				// Capture tool names
1227				for _, tool := range call.Tools {
1228					capturedToolNames = append(capturedToolNames, tool.GetName())
1229				}
1230				return &Response{
1231					Content: ResponseContent{
1232						TextContent{Text: "Response"},
1233					},
1234					Usage:        Usage{TotalTokens: 10},
1235					FinishReason: FinishReasonStop,
1236				}, nil
1237			},
1238		}
1239
1240		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1241		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1242
1243		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1244			newSystem := "Step-specific system"
1245			toolChoice := SpecificToolChoice("tool1")
1246			activeTools := []string{"tool1"}
1247			return ctx, PrepareStepResult{
1248				Model:       options.Model,
1249				Messages:    options.Messages,
1250				System:      &newSystem,
1251				ToolChoice:  &toolChoice,
1252				ActiveTools: activeTools,
1253			}, nil
1254		}
1255
1256		agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1257
1258		result, err := agent.Generate(context.Background(), AgentCall{
1259			Prompt:      "test prompt",
1260			PrepareStep: prepareStepFunc,
1261		})
1262
1263		require.NoError(t, err)
1264		require.NotNil(t, result)
1265		require.Equal(t, "Step-specific system", capturedSystemPrompt)
1266		require.NotNil(t, capturedToolChoice)
1267		require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1268		require.Len(t, capturedToolNames, 1)
1269		require.Equal(t, "tool1", capturedToolNames[0])
1270	})
1271
1272	t.Run("Nil fields use parent values", func(t *testing.T) {
1273		t.Parallel()
1274		var capturedSystemPrompt string
1275		var capturedToolChoice *ToolChoice
1276		var capturedToolNames []string
1277
1278		model := &mockLanguageModel{
1279			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1280				// Capture system prompt
1281				if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1282					if len(call.Prompt[0].Content) > 0 {
1283						if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1284							capturedSystemPrompt = textPart.Text
1285						}
1286					}
1287				}
1288				// Capture tool choice
1289				capturedToolChoice = call.ToolChoice
1290				// Capture tool names
1291				for _, tool := range call.Tools {
1292					capturedToolNames = append(capturedToolNames, tool.GetName())
1293				}
1294				return &Response{
1295					Content: ResponseContent{
1296						TextContent{Text: "Response"},
1297					},
1298					Usage:        Usage{TotalTokens: 10},
1299					FinishReason: FinishReasonStop,
1300				}, nil
1301			},
1302		}
1303
1304		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1305
1306		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1307			// All optional fields are nil, should use parent values
1308			return ctx, PrepareStepResult{
1309				Model:       options.Model,
1310				Messages:    options.Messages,
1311				System:      nil, // Use parent
1312				ToolChoice:  nil, // Use parent (auto)
1313				ActiveTools: nil, // Use parent (all tools)
1314			}, nil
1315		}
1316
1317		agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1318
1319		result, err := agent.Generate(context.Background(), AgentCall{
1320			Prompt:      "test prompt",
1321			PrepareStep: prepareStepFunc,
1322		})
1323
1324		require.NoError(t, err)
1325		require.NotNil(t, result)
1326		require.Equal(t, "Parent system", capturedSystemPrompt)
1327		require.NotNil(t, capturedToolChoice)
1328		require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1329		require.Len(t, capturedToolNames, 1)
1330		require.Equal(t, "tool1", capturedToolNames[0])
1331	})
1332
1333	t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1334		t.Parallel()
1335		var capturedToolNames []string
1336		model := &mockLanguageModel{
1337			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1338				// Capture tool names to verify all tools are included
1339				for _, tool := range call.Tools {
1340					capturedToolNames = append(capturedToolNames, tool.GetName())
1341				}
1342				return &Response{
1343					Content: ResponseContent{
1344						TextContent{Text: "Response"},
1345					},
1346					Usage:        Usage{TotalTokens: 10},
1347					FinishReason: FinishReasonStop,
1348				}, nil
1349			},
1350		}
1351
1352		tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1353		tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1354
1355		prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1356			return ctx, PrepareStepResult{
1357				Model:       options.Model,
1358				Messages:    options.Messages,
1359				ActiveTools: []string{}, // Empty slice means all tools
1360			}, nil
1361		}
1362
1363		agent := NewAgent(model, WithTools(tool1, tool2))
1364
1365		result, err := agent.Generate(context.Background(), AgentCall{
1366			Prompt:      "test prompt",
1367			PrepareStep: prepareStepFunc,
1368		})
1369
1370		require.NoError(t, err)
1371		require.NotNil(t, result)
1372		require.Len(t, capturedToolNames, 2) // All tools should be included
1373		require.Contains(t, capturedToolNames, "tool1")
1374		require.Contains(t, capturedToolNames, "tool2")
1375	})
1376}
1377
1378func TestToolCallRepair(t *testing.T) {
1379	t.Parallel()
1380
1381	t.Run("Valid tool call passes validation", func(t *testing.T) {
1382		t.Parallel()
1383		model := &mockLanguageModel{
1384			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1385				return &Response{
1386					Content: ResponseContent{
1387						TextContent{Text: "Response"},
1388						ToolCallContent{
1389							ToolCallID: "call1",
1390							ToolName:   "test_tool",
1391							Input:      `{"value": "test"}`, // Valid JSON with required field
1392						},
1393					},
1394					Usage:        Usage{TotalTokens: 10},
1395					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1396				}, nil
1397			},
1398		}
1399
1400		tool := &mockTool{
1401			name:        "test_tool",
1402			description: "Test tool",
1403			parameters: map[string]any{
1404				"value": map[string]any{"type": "string"},
1405			},
1406			required: []string{"value"},
1407			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1408				return ToolResponse{Content: "success", IsError: false}, nil
1409			},
1410		}
1411
1412		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1413
1414		result, err := agent.Generate(context.Background(), AgentCall{
1415			Prompt: "test prompt",
1416		})
1417
1418		require.NoError(t, err)
1419		require.NotNil(t, result)
1420		require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1421
1422		// Check that tool call was executed successfully
1423		toolCalls := result.Steps[0].Content.ToolCalls()
1424		require.Len(t, toolCalls, 1)
1425		require.False(t, toolCalls[0].Invalid) // Should be valid
1426	})
1427
1428	t.Run("Invalid tool call without repair function", func(t *testing.T) {
1429		t.Parallel()
1430		model := &mockLanguageModel{
1431			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1432				return &Response{
1433					Content: ResponseContent{
1434						TextContent{Text: "Response"},
1435						ToolCallContent{
1436							ToolCallID: "call1",
1437							ToolName:   "test_tool",
1438							Input:      `{"wrong_field": "test"}`, // Missing required field
1439						},
1440					},
1441					Usage:        Usage{TotalTokens: 10},
1442					FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1443				}, nil
1444			},
1445		}
1446
1447		tool := &mockTool{
1448			name:        "test_tool",
1449			description: "Test tool",
1450			parameters: map[string]any{
1451				"value": map[string]any{"type": "string"},
1452			},
1453			required: []string{"value"},
1454		}
1455
1456		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1457
1458		result, err := agent.Generate(context.Background(), AgentCall{
1459			Prompt: "test prompt",
1460		})
1461
1462		require.NoError(t, err)
1463		require.NotNil(t, result)
1464		require.Len(t, result.Steps, 1) // Only one step
1465
1466		// Check that tool call was marked as invalid
1467		toolCalls := result.Steps[0].Content.ToolCalls()
1468		require.Len(t, toolCalls, 1)
1469		require.True(t, toolCalls[0].Invalid) // Should be invalid
1470		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1471	})
1472
1473	t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1474		t.Parallel()
1475		model := &mockLanguageModel{
1476			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1477				return &Response{
1478					Content: ResponseContent{
1479						TextContent{Text: "Response"},
1480						ToolCallContent{
1481							ToolCallID: "call1",
1482							ToolName:   "test_tool",
1483							Input:      `{"wrong_field": "test"}`, // Missing required field
1484						},
1485					},
1486					Usage:        Usage{TotalTokens: 10},
1487					FinishReason: FinishReasonStop, // Changed to stop
1488				}, nil
1489			},
1490		}
1491
1492		tool := &mockTool{
1493			name:        "test_tool",
1494			description: "Test tool",
1495			parameters: map[string]any{
1496				"value": map[string]any{"type": "string"},
1497			},
1498			required: []string{"value"},
1499			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1500				return ToolResponse{Content: "repaired_success", IsError: false}, nil
1501			},
1502		}
1503
1504		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1505			// Simple repair: add the missing required field
1506			repairedToolCall := options.OriginalToolCall
1507			repairedToolCall.Input = `{"value": "repaired"}`
1508			return &repairedToolCall, nil
1509		}
1510
1511		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1512
1513		result, err := agent.Generate(context.Background(), AgentCall{
1514			Prompt: "test prompt",
1515		})
1516
1517		require.NoError(t, err)
1518		require.NotNil(t, result)
1519		require.Len(t, result.Steps, 1) // Only one step
1520
1521		// Check that tool call was repaired and is now valid
1522		toolCalls := result.Steps[0].Content.ToolCalls()
1523		require.Len(t, toolCalls, 1)
1524		require.False(t, toolCalls[0].Invalid)                        // Should be valid after repair
1525		require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1526	})
1527
1528	t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1529		t.Parallel()
1530		model := &mockLanguageModel{
1531			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1532				return &Response{
1533					Content: ResponseContent{
1534						TextContent{Text: "Response"},
1535						ToolCallContent{
1536							ToolCallID: "call1",
1537							ToolName:   "test_tool",
1538							Input:      `{"wrong_field": "test"}`, // Missing required field
1539						},
1540					},
1541					Usage:        Usage{TotalTokens: 10},
1542					FinishReason: FinishReasonStop, // Changed to stop
1543				}, nil
1544			},
1545		}
1546
1547		tool := &mockTool{
1548			name:        "test_tool",
1549			description: "Test tool",
1550			parameters: map[string]any{
1551				"value": map[string]any{"type": "string"},
1552			},
1553			required: []string{"value"},
1554		}
1555
1556		repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1557			// Repair function fails
1558			return nil, errors.New("repair failed")
1559		}
1560
1561		agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1562
1563		result, err := agent.Generate(context.Background(), AgentCall{
1564			Prompt: "test prompt",
1565		})
1566
1567		require.NoError(t, err)
1568		require.NotNil(t, result)
1569		require.Len(t, result.Steps, 1) // Only one step
1570
1571		// Check that tool call was marked as invalid since repair failed
1572		toolCalls := result.Steps[0].Content.ToolCalls()
1573		require.Len(t, toolCalls, 1)
1574		require.True(t, toolCalls[0].Invalid) // Should be invalid
1575		require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1576	})
1577
1578	t.Run("Nonexistent tool call", func(t *testing.T) {
1579		t.Parallel()
1580		model := &mockLanguageModel{
1581			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1582				return &Response{
1583					Content: ResponseContent{
1584						TextContent{Text: "Response"},
1585						ToolCallContent{
1586							ToolCallID: "call1",
1587							ToolName:   "nonexistent_tool",
1588							Input:      `{"value": "test"}`,
1589						},
1590					},
1591					Usage:        Usage{TotalTokens: 10},
1592					FinishReason: FinishReasonStop, // Changed to stop
1593				}, nil
1594			},
1595		}
1596
1597		tool := &mockTool{name: "test_tool", description: "Test tool"}
1598
1599		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1600
1601		result, err := agent.Generate(context.Background(), AgentCall{
1602			Prompt: "test prompt",
1603		})
1604
1605		require.NoError(t, err)
1606		require.NotNil(t, result)
1607		require.Len(t, result.Steps, 1) // Only one step
1608
1609		// Check that tool call was marked as invalid due to nonexistent tool
1610		toolCalls := result.Steps[0].Content.ToolCalls()
1611		require.Len(t, toolCalls, 1)
1612		require.True(t, toolCalls[0].Invalid) // Should be invalid
1613		require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1614	})
1615
1616	t.Run("Invalid JSON in tool call", func(t *testing.T) {
1617		t.Parallel()
1618		model := &mockLanguageModel{
1619			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1620				return &Response{
1621					Content: ResponseContent{
1622						TextContent{Text: "Response"},
1623						ToolCallContent{
1624							ToolCallID: "call1",
1625							ToolName:   "test_tool",
1626							Input:      `{invalid json}`, // Invalid JSON
1627						},
1628					},
1629					Usage:        Usage{TotalTokens: 10},
1630					FinishReason: FinishReasonStop, // Changed to stop
1631				}, nil
1632			},
1633		}
1634
1635		tool := &mockTool{
1636			name:        "test_tool",
1637			description: "Test tool",
1638			parameters: map[string]any{
1639				"value": map[string]any{"type": "string"},
1640			},
1641			required: []string{"value"},
1642		}
1643
1644		agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1645
1646		result, err := agent.Generate(context.Background(), AgentCall{
1647			Prompt: "test prompt",
1648		})
1649
1650		require.NoError(t, err)
1651		require.NotNil(t, result)
1652		require.Len(t, result.Steps, 1) // Only one step
1653
1654		// Check that tool call was marked as invalid due to invalid JSON
1655		toolCalls := result.Steps[0].Content.ToolCalls()
1656		require.Len(t, toolCalls, 1)
1657		require.True(t, toolCalls[0].Invalid) // Should be invalid
1658		require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1659	})
1660}
1661
1662// Test media and image tool responses
1663func TestAgent_MediaToolResponses(t *testing.T) {
1664	t.Parallel()
1665
1666	imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
1667	audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
1668
1669	t.Run("Image tool response", func(t *testing.T) {
1670		t.Parallel()
1671
1672		imageTool := &mockTool{
1673			name:        "generate_image",
1674			description: "Generates an image",
1675			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1676				return NewImageResponse(imageData, "image/png"), nil
1677			},
1678		}
1679
1680		model := &mockLanguageModel{
1681			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1682				if len(call.Prompt) == 1 {
1683					// First call - request image tool
1684					return &Response{
1685						Content: []Content{
1686							ToolCallContent{
1687								ToolCallID: "img-1",
1688								ToolName:   "generate_image",
1689								Input:      `{}`,
1690							},
1691						},
1692						Usage:        Usage{TotalTokens: 10},
1693						FinishReason: FinishReasonToolCalls,
1694					}, nil
1695				}
1696				// Second call - after tool execution
1697				return &Response{
1698					Content:      []Content{TextContent{Text: "Image generated"}},
1699					Usage:        Usage{TotalTokens: 20},
1700					FinishReason: FinishReasonStop,
1701				}, nil
1702			},
1703		}
1704
1705		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1706
1707		result, err := agent.Generate(context.Background(), AgentCall{
1708			Prompt: "Generate an image",
1709		})
1710
1711		require.NoError(t, err)
1712		require.NotNil(t, result)
1713		require.Len(t, result.Steps, 2) // Tool call step + final response
1714
1715		// Check tool results in first step
1716		toolResults := result.Steps[0].Content.ToolResults()
1717		require.Len(t, toolResults, 1)
1718
1719		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1720		require.True(t, ok, "Expected media result")
1721		require.Equal(t, base64.StdEncoding.EncodeToString(imageData), mediaResult.Data)
1722		require.Equal(t, "image/png", mediaResult.MediaType)
1723	})
1724
1725	t.Run("Media tool response (audio)", func(t *testing.T) {
1726		t.Parallel()
1727
1728		audioTool := &mockTool{
1729			name:        "generate_audio",
1730			description: "Generates audio",
1731			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1732				return NewMediaResponse(audioData, "audio/wav"), nil
1733			},
1734		}
1735
1736		model := &mockLanguageModel{
1737			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1738				if len(call.Prompt) == 1 {
1739					return &Response{
1740						Content: []Content{
1741							ToolCallContent{
1742								ToolCallID: "audio-1",
1743								ToolName:   "generate_audio",
1744								Input:      `{}`,
1745							},
1746						},
1747						Usage:        Usage{TotalTokens: 10},
1748						FinishReason: FinishReasonToolCalls,
1749					}, nil
1750				}
1751				return &Response{
1752					Content:      []Content{TextContent{Text: "Audio generated"}},
1753					Usage:        Usage{TotalTokens: 20},
1754					FinishReason: FinishReasonStop,
1755				}, nil
1756			},
1757		}
1758
1759		agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
1760
1761		result, err := agent.Generate(context.Background(), AgentCall{
1762			Prompt: "Generate audio",
1763		})
1764
1765		require.NoError(t, err)
1766		require.NotNil(t, result)
1767
1768		toolResults := result.Steps[0].Content.ToolResults()
1769		require.Len(t, toolResults, 1)
1770
1771		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1772		require.True(t, ok, "Expected media result")
1773		require.Equal(t, base64.StdEncoding.EncodeToString(audioData), mediaResult.Data)
1774		require.Equal(t, "audio/wav", mediaResult.MediaType)
1775	})
1776
1777	t.Run("Media response with text", func(t *testing.T) {
1778		t.Parallel()
1779
1780		imageTool := &mockTool{
1781			name:        "screenshot",
1782			description: "Takes a screenshot",
1783			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1784				resp := NewImageResponse(imageData, "image/png")
1785				resp.Content = "Screenshot captured successfully"
1786				return resp, nil
1787			},
1788		}
1789
1790		model := &mockLanguageModel{
1791			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1792				if len(call.Prompt) == 1 {
1793					return &Response{
1794						Content: []Content{
1795							ToolCallContent{
1796								ToolCallID: "screen-1",
1797								ToolName:   "screenshot",
1798								Input:      `{}`,
1799							},
1800						},
1801						Usage:        Usage{TotalTokens: 10},
1802						FinishReason: FinishReasonToolCalls,
1803					}, nil
1804				}
1805				return &Response{
1806					Content:      []Content{TextContent{Text: "Done"}},
1807					Usage:        Usage{TotalTokens: 20},
1808					FinishReason: FinishReasonStop,
1809				}, nil
1810			},
1811		}
1812
1813		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1814
1815		result, err := agent.Generate(context.Background(), AgentCall{
1816			Prompt: "Take a screenshot",
1817		})
1818
1819		require.NoError(t, err)
1820		require.NotNil(t, result)
1821
1822		toolResults := result.Steps[0].Content.ToolResults()
1823		require.Len(t, toolResults, 1)
1824
1825		mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1826		require.True(t, ok, "Expected media result")
1827		require.Equal(t, base64.StdEncoding.EncodeToString(imageData), mediaResult.Data)
1828		require.Equal(t, "image/png", mediaResult.MediaType)
1829		require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
1830	})
1831
1832	t.Run("Media response preserves metadata", func(t *testing.T) {
1833		t.Parallel()
1834
1835		type ImageMetadata struct {
1836			Width  int `json:"width"`
1837			Height int `json:"height"`
1838		}
1839
1840		imageTool := &mockTool{
1841			name:        "generate_image",
1842			description: "Generates an image",
1843			executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1844				resp := NewImageResponse(imageData, "image/png")
1845				return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
1846			},
1847		}
1848
1849		model := &mockLanguageModel{
1850			generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1851				if len(call.Prompt) == 1 {
1852					return &Response{
1853						Content: []Content{
1854							ToolCallContent{
1855								ToolCallID: "img-1",
1856								ToolName:   "generate_image",
1857								Input:      `{}`,
1858							},
1859						},
1860						Usage:        Usage{TotalTokens: 10},
1861						FinishReason: FinishReasonToolCalls,
1862					}, nil
1863				}
1864				return &Response{
1865					Content:      []Content{TextContent{Text: "Done"}},
1866					Usage:        Usage{TotalTokens: 20},
1867					FinishReason: FinishReasonStop,
1868				}, nil
1869			},
1870		}
1871
1872		agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1873
1874		result, err := agent.Generate(context.Background(), AgentCall{
1875			Prompt: "Generate image",
1876		})
1877
1878		require.NoError(t, err)
1879		require.NotNil(t, result)
1880
1881		toolResults := result.Steps[0].Content.ToolResults()
1882		require.Len(t, toolResults, 1)
1883
1884		// Check metadata was preserved
1885		require.NotEmpty(t, toolResults[0].ClientMetadata)
1886
1887		var metadata ImageMetadata
1888		err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
1889		require.NoError(t, err)
1890		require.Equal(t, 800, metadata.Width)
1891		require.Equal(t, 600, metadata.Height)
1892	})
1893}
1894
1895func TestToResponseMessages_ProviderExecutedRouting(t *testing.T) {
1896	t.Parallel()
1897
1898	// Build step content that mixes a provider-executed tool call/result
1899	// (e.g. web search) with a regular local tool call/result.
1900	content := []Content{
1901		// Provider-executed tool call.
1902		&ToolCallContent{
1903			ToolCallID:       "srvtoolu_01",
1904			ToolName:         "web_search",
1905			Input:            `{"query":"test"}`,
1906			ProviderExecuted: true,
1907		},
1908		// Provider-executed tool result.
1909		&ToolResultContent{
1910			ToolCallID:       "srvtoolu_01",
1911			ProviderExecuted: true,
1912		},
1913		// Regular (locally-executed) tool call.
1914		&ToolCallContent{
1915			ToolCallID: "toolu_02",
1916			ToolName:   "calculator",
1917			Input:      `{"expr":"1+1"}`,
1918		},
1919		// Regular tool result.
1920		&ToolResultContent{
1921			ToolCallID: "toolu_02",
1922			Result:     ToolResultOutputContentText{Text: "2"},
1923		},
1924		// Some trailing text.
1925		&TextContent{Text: "Done."},
1926	}
1927
1928	msgs := toResponseMessages(content)
1929
1930	// Expect two messages: assistant + tool.
1931	require.Len(t, msgs, 2)
1932
1933	// Assistant message should contain:
1934	//   1. provider-executed ToolCallPart
1935	//   2. provider-executed ToolResultPart
1936	//   3. regular ToolCallPart
1937	//   4. TextPart
1938	assistant := msgs[0]
1939	require.Equal(t, MessageRoleAssistant, assistant.Role)
1940	require.Len(t, assistant.Content, 4)
1941
1942	// Verify provider-executed tool call is in assistant.
1943	tc1, ok := AsMessagePart[ToolCallPart](assistant.Content[0])
1944	require.True(t, ok)
1945	require.Equal(t, "srvtoolu_01", tc1.ToolCallID)
1946	require.True(t, tc1.ProviderExecuted)
1947
1948	// Verify provider-executed tool result is in assistant.
1949	tr1, ok := AsMessagePart[ToolResultPart](assistant.Content[1])
1950	require.True(t, ok)
1951	require.Equal(t, "srvtoolu_01", tr1.ToolCallID)
1952	require.True(t, tr1.ProviderExecuted)
1953
1954	// Verify regular tool call is in assistant.
1955	tc2, ok := AsMessagePart[ToolCallPart](assistant.Content[2])
1956	require.True(t, ok)
1957	require.Equal(t, "toolu_02", tc2.ToolCallID)
1958	require.False(t, tc2.ProviderExecuted)
1959
1960	// Verify text part is in assistant.
1961	text, ok := AsMessagePart[TextPart](assistant.Content[3])
1962	require.True(t, ok)
1963	require.Equal(t, "Done.", text.Text)
1964
1965	// Tool message should contain only the regular tool result.
1966	toolMsg := msgs[1]
1967	require.Equal(t, MessageRoleTool, toolMsg.Role)
1968	require.Len(t, toolMsg.Content, 1)
1969
1970	tr2, ok := AsMessagePart[ToolResultPart](toolMsg.Content[0])
1971	require.True(t, ok)
1972	require.Equal(t, "toolu_02", tr2.ToolCallID)
1973	require.False(t, tr2.ProviderExecuted)
1974}
1975
1976// TestAgent_Generate_ExecutableProviderTool verifies that an
1977// ExecutableProviderTool registered via WithProviderDefinedTools is
1978// executed by the agent when the model returns a matching tool call.
1979func TestAgent_Generate_ExecutableProviderTool(t *testing.T) {
1980	t.Parallel()
1981
1982	runCalled := false
1983	execTool := NewExecutableProviderTool(
1984		ProviderDefinedTool{
1985			ID:   "test.computer",
1986			Name: "computer",
1987			Args: map[string]any{"display_width_px": 1920},
1988		},
1989		func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1990			runCalled = true
1991			return NewTextResponse("screenshot taken"), nil
1992		},
1993	)
1994
1995	model := &mockLanguageModel{
1996		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1997			return &Response{
1998				Content: []Content{
1999					ToolCallContent{
2000						ToolCallID: "call-1",
2001						ToolName:   "computer",
2002						Input:      `{"action":"screenshot"}`,
2003					},
2004				},
2005				Usage:        Usage{TotalTokens: 10},
2006				FinishReason: FinishReasonStop,
2007			}, nil
2008		},
2009	}
2010
2011	agent := NewAgent(model, WithProviderDefinedTools(execTool))
2012	result, err := agent.Generate(context.Background(), AgentCall{
2013		Prompt: "take a screenshot",
2014	})
2015
2016	require.NoError(t, err)
2017	require.NotNil(t, result)
2018	require.True(t, runCalled, "expected Run func to be called")
2019	require.Len(t, result.Steps, 1)
2020
2021	// Verify tool result is in the response.
2022	var toolResults []ToolResultContent
2023	for _, c := range result.Response.Content {
2024		if tr, ok := AsContentType[ToolResultContent](c); ok {
2025			toolResults = append(toolResults, tr)
2026		}
2027	}
2028	require.Len(t, toolResults, 1)
2029	require.Equal(t, "call-1", toolResults[0].ToolCallID)
2030	require.Equal(t, "computer", toolResults[0].ToolName)
2031
2032	textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
2033	require.True(t, ok)
2034	require.Equal(t, "screenshot taken", textResult.Text)
2035}
2036
2037// TestAgent_Generate_ExecutableProviderTool_ActiveTools verifies that
2038// active tool filtering works for ExecutableProviderTool.
2039func TestAgent_Generate_ExecutableProviderTool_ActiveTools(t *testing.T) {
2040	t.Parallel()
2041
2042	execTool := NewExecutableProviderTool(
2043		ProviderDefinedTool{
2044			ID:   "test.computer",
2045			Name: "computer",
2046			Args: map[string]any{"display_width_px": 1920},
2047		},
2048		func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2049			return NewTextResponse("ok"), nil
2050		},
2051	)
2052
2053	model := &mockLanguageModel{
2054		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2055			// With ActiveTools=["other"], computer should be filtered out.
2056			require.Empty(t, call.Tools)
2057
2058			return &Response{
2059				Content:      []Content{TextContent{Text: "no tools"}},
2060				Usage:        Usage{TotalTokens: 5},
2061				FinishReason: FinishReasonStop,
2062			}, nil
2063		},
2064	}
2065
2066	agent := NewAgent(model, WithProviderDefinedTools(execTool))
2067	result, err := agent.Generate(context.Background(), AgentCall{
2068		Prompt:      "test",
2069		ActiveTools: []string{"other"},
2070	})
2071
2072	require.NoError(t, err)
2073	require.NotNil(t, result)
2074}
2075
2076// TestAgent_Generate_ExecutableProviderTool_ActiveTools_Rejected
2077// verifies that a hallucinated tool call for an EPT excluded by
2078// activeTools is rejected at validation and execution time.
2079func TestAgent_Generate_ExecutableProviderTool_ActiveTools_Rejected(t *testing.T) {
2080	t.Parallel()
2081
2082	runCalled := false
2083	execTool := NewExecutableProviderTool(
2084		ProviderDefinedTool{
2085			ID:   "test.computer",
2086			Name: "computer",
2087			Args: map[string]any{"display_width_px": 1920},
2088		},
2089		func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2090			runCalled = true
2091			return NewTextResponse("ok"), nil
2092		},
2093	)
2094
2095	callCount := 0
2096	model := &mockLanguageModel{
2097		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2098			callCount++
2099			if callCount == 1 {
2100				// Model hallucinates a call to the excluded tool.
2101				return &Response{
2102					Content: []Content{ToolCallContent{
2103						ToolCallID: "call-1",
2104						ToolName:   "computer",
2105						Input:      `{"action":"screenshot"}`,
2106					}},
2107					Usage:        Usage{TotalTokens: 5},
2108					FinishReason: FinishReasonToolCalls,
2109				}, nil
2110			}
2111			// Second call: model stops.
2112			return &Response{
2113				Content:      []Content{TextContent{Text: "done"}},
2114				Usage:        Usage{TotalTokens: 3},
2115				FinishReason: FinishReasonStop,
2116			}, nil
2117		},
2118	}
2119
2120	agent := NewAgent(model, WithProviderDefinedTools(execTool))
2121	result, err := agent.Generate(context.Background(), AgentCall{
2122		Prompt:      "test",
2123		ActiveTools: []string{"other"},
2124	})
2125
2126	require.NoError(t, err)
2127	require.NotNil(t, result)
2128	require.False(t, runCalled, "excluded EPT should not have been executed")
2129
2130	// The tool call should have been marked invalid.
2131	var foundInvalidToolResult bool
2132	for _, step := range result.Steps {
2133		for _, content := range step.Content {
2134			if tr, ok := AsContentType[ToolResultContent](content); ok {
2135				if errResult, ok := tr.Result.(ToolResultOutputContentError); ok {
2136					require.Contains(t, errResult.Error.Error(), "tool not found")
2137					foundInvalidToolResult = true
2138				}
2139			}
2140		}
2141	}
2142	require.True(t, foundInvalidToolResult, "expected an error result for the excluded tool call")
2143}
2144
2145// TestAgent_Stream_ExecutableProviderTool verifies that an
2146// ExecutableProviderTool works through the Stream path.
2147func TestAgent_Stream_ExecutableProviderTool(t *testing.T) {
2148	t.Parallel()
2149
2150	runCalled := false
2151	execTool := NewExecutableProviderTool(
2152		ProviderDefinedTool{
2153			ID:   "test.computer",
2154			Name: "computer",
2155			Args: map[string]any{"display_width_px": 1920},
2156		},
2157		func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2158			runCalled = true
2159			return NewTextResponse("screenshot taken"), nil
2160		},
2161	)
2162
2163	model := &mockLanguageModel{
2164		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
2165			return func(yield func(StreamPart) bool) {
2166				if !yield(StreamPart{
2167					Type:          StreamPartTypeToolCall,
2168					ID:            "call-1",
2169					ToolCallName:  "computer",
2170					ToolCallInput: `{"action":"screenshot"}`,
2171				}) {
2172					return
2173				}
2174				yield(StreamPart{
2175					Type:         StreamPartTypeFinish,
2176					FinishReason: FinishReasonStop,
2177					Usage:        Usage{TotalTokens: 10},
2178				})
2179			}, nil
2180		},
2181	}
2182
2183	agent := NewAgent(model, WithProviderDefinedTools(execTool))
2184	result, err := agent.Stream(context.Background(), AgentStreamCall{
2185		Prompt: "take a screenshot",
2186	})
2187
2188	require.NoError(t, err)
2189	require.NotNil(t, result)
2190	require.True(t, runCalled, "expected Run func to be called")
2191	require.Len(t, result.Steps, 1)
2192
2193	// Verify tool result is in the step content.
2194	var toolResults []ToolResultContent
2195	for _, c := range result.Steps[0].Content {
2196		if tr, ok := AsContentType[ToolResultContent](c); ok {
2197			toolResults = append(toolResults, tr)
2198		}
2199	}
2200	require.Len(t, toolResults, 1)
2201	require.Equal(t, "call-1", toolResults[0].ToolCallID)
2202}
2203
2204// TestAgent_PrepareTools_ExecutableProviderTool verifies that
2205// prepareTools emits a ProviderDefinedTool (not a FunctionTool) when
2206// an ExecutableProviderTool is registered via WithProviderDefinedTools.
2207func TestAgent_PrepareTools_ExecutableProviderTool(t *testing.T) {
2208	t.Parallel()
2209
2210	execTool := NewExecutableProviderTool(
2211		ProviderDefinedTool{
2212			ID:   "test.computer",
2213			Name: "computer",
2214			Args: map[string]any{"display_width_px": 1920},
2215		},
2216		func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2217			return NewTextResponse("ok"), nil
2218		},
2219	)
2220
2221	model := &mockLanguageModel{
2222		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2223			// Verify the tool is emitted as a ProviderDefinedTool.
2224			require.Len(t, call.Tools, 1)
2225			pdt, ok := call.Tools[0].(ProviderDefinedTool)
2226			require.True(t, ok, "expected ProviderDefinedTool, got %T", call.Tools[0])
2227			require.Equal(t, "computer", pdt.Name)
2228			require.Equal(t, "test.computer", pdt.ID)
2229
2230			return &Response{
2231				Content:      []Content{TextContent{Text: "done"}},
2232				Usage:        Usage{TotalTokens: 5},
2233				FinishReason: FinishReasonStop,
2234			}, nil
2235		},
2236	}
2237
2238	agent := NewAgent(model, WithProviderDefinedTools(execTool))
2239	_, err := agent.Generate(context.Background(), AgentCall{
2240		Prompt: "test",
2241	})
2242	require.NoError(t, err)
2243}
2244
2245// TestAgent_ValidateToolCall_ExecutableProviderTool verifies that
2246// schema validation is skipped for executable provider tools, but
2247// JSON parsing is still checked.
2248func TestAgent_ValidateToolCall_ExecutableProviderTool(t *testing.T) {
2249	t.Parallel()
2250
2251	execTool := NewExecutableProviderTool(
2252		ProviderDefinedTool{
2253			ID:   "test.computer",
2254			Name: "computer",
2255		},
2256		func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2257			return NewTextResponse("ok"), nil
2258		},
2259	)
2260
2261	a := &agent{
2262		settings: agentSettings{
2263			executableProviderTools: []ExecutableProviderTool{execTool},
2264		},
2265	}
2266
2267	// Valid JSON should pass even without required fields.
2268	err := a.validateToolCall(ToolCallContent{
2269		ToolName: "computer",
2270		Input:    `{"action":"screenshot"}`,
2271	}, []AgentTool{}, []ExecutableProviderTool{execTool})
2272	require.NoError(t, err)
2273
2274	// Invalid JSON should still fail.
2275	err = a.validateToolCall(ToolCallContent{
2276		ToolName: "computer",
2277		Input:    `not-json`,
2278	}, []AgentTool{}, []ExecutableProviderTool{execTool})
2279	require.Error(t, err)
2280	require.Contains(t, err.Error(), "invalid JSON")
2281}
2282
2283// TestAgent_WithProviderDefinedTools_BackwardCompat verifies that
2284// passing a plain ProviderDefinedTool to WithProviderDefinedTools
2285// still works (web search path).
2286func TestAgent_WithProviderDefinedTools_BackwardCompat(t *testing.T) {
2287	t.Parallel()
2288
2289	webSearch := ProviderDefinedTool{
2290		ID:   "anthropic.web_search",
2291		Name: "web_search",
2292		Args: map[string]any{"max_results": 5},
2293	}
2294
2295	model := &mockLanguageModel{
2296		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2297			require.Len(t, call.Tools, 1)
2298			pdt, ok := call.Tools[0].(ProviderDefinedTool)
2299			require.True(t, ok, "expected ProviderDefinedTool, got %T", call.Tools[0])
2300			require.Equal(t, "web_search", pdt.Name)
2301			require.Equal(t, "anthropic.web_search", pdt.ID)
2302
2303			return &Response{
2304				Content:      []Content{TextContent{Text: "search results"}},
2305				Usage:        Usage{TotalTokens: 5},
2306				FinishReason: FinishReasonStop,
2307			}, nil
2308		},
2309	}
2310
2311	agent := NewAgent(model, WithProviderDefinedTools(webSearch))
2312	result, err := agent.Generate(context.Background(), AgentCall{
2313		Prompt: "search for something",
2314	})
2315
2316	require.NoError(t, err)
2317	require.NotNil(t, result)
2318	require.Equal(t, "search results", result.Response.Content.Text())
2319}
2320
2321// TestAgent_Generate_ExecutableProviderTool_ImageBase64 verifies that
2322// image data returned by an ExecutableProviderTool's run function is
2323// base64-encoded when stored in ToolResultOutputContentMedia.Data.
2324func TestAgent_Generate_ExecutableProviderTool_ImageBase64(t *testing.T) {
2325	t.Parallel()
2326
2327	rawPNG := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
2328
2329	execTool := NewExecutableProviderTool(
2330		ProviderDefinedTool{
2331			ID:   "test.computer",
2332			Name: "computer",
2333			Args: map[string]any{"display_width_px": 1920},
2334		},
2335		func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2336			return NewImageResponse(rawPNG, "image/png"), nil
2337		},
2338	)
2339
2340	callCount := 0
2341	model := &mockLanguageModel{
2342		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2343			callCount++
2344			if callCount == 1 {
2345				return &Response{
2346					Content: []Content{
2347						ToolCallContent{
2348							ToolCallID: "call-1",
2349							ToolName:   "computer",
2350							Input:      `{"action":"screenshot"}`,
2351						},
2352					},
2353					Usage:        Usage{TotalTokens: 10},
2354					FinishReason: FinishReasonToolCalls,
2355				}, nil
2356			}
2357			return &Response{
2358				Content:      []Content{TextContent{Text: "done"}},
2359				Usage:        Usage{TotalTokens: 5},
2360				FinishReason: FinishReasonStop,
2361			}, nil
2362		},
2363	}
2364
2365	agent := NewAgent(model, WithProviderDefinedTools(execTool))
2366	result, err := agent.Generate(context.Background(), AgentCall{
2367		Prompt: "take a screenshot",
2368	})
2369
2370	require.NoError(t, err)
2371	require.NotNil(t, result)
2372	require.Len(t, result.Steps, 2)
2373
2374	// The tool result in the first step must have base64-encoded data.
2375	toolResults := result.Steps[0].Content.ToolResults()
2376	require.Len(t, toolResults, 1)
2377
2378	mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
2379	require.True(t, ok, "expected media result")
2380	require.Equal(t, base64.StdEncoding.EncodeToString(rawPNG), mediaResult.Data)
2381	require.Equal(t, "image/png", mediaResult.MediaType)
2382}
2383
2384// TestAgent_Generate_ExecutableProviderTool_CriticalError verifies
2385// that a Go error returned from an ExecutableProviderTool's run
2386// function is treated as a critical error, stopping the agent loop.
2387func TestAgent_Generate_ExecutableProviderTool_CriticalError(t *testing.T) {
2388	t.Parallel()
2389
2390	execTool := NewExecutableProviderTool(
2391		ProviderDefinedTool{
2392			ID:   "test.computer",
2393			Name: "computer",
2394			Args: map[string]any{"display_width_px": 1920},
2395		},
2396		func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2397			return ToolResponse{}, fmt.Errorf("vnc connection lost")
2398		},
2399	)
2400
2401	callCount := 0
2402	model := &mockLanguageModel{
2403		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2404			callCount++
2405			return &Response{
2406				Content: []Content{
2407					ToolCallContent{
2408						ToolCallID: "call-1",
2409						ToolName:   "computer",
2410						Input:      `{"action":"screenshot"}`,
2411					},
2412				},
2413				Usage:        Usage{TotalTokens: 10},
2414				FinishReason: FinishReasonToolCalls,
2415			}, nil
2416		},
2417	}
2418
2419	agent := NewAgent(model, WithProviderDefinedTools(execTool), WithStopConditions(StepCountIs(5)))
2420	result, err := agent.Generate(context.Background(), AgentCall{
2421		Prompt: "take a screenshot",
2422	})
2423
2424	require.NoError(t, err)
2425	require.NotNil(t, result)
2426	// The model should only be called once — the critical error stops
2427	// the loop before a second model call.
2428	require.Equal(t, 1, callCount)
2429	require.Len(t, result.Steps, 1)
2430}