diff --git a/agent.go b/agent.go index 011f5edd9e4bdf2257aed7b39921007b687a0cf6..3f90ea5ee824112966e6e304aee4dc05a11cc756 100644 --- a/agent.go +++ b/agent.go @@ -921,7 +921,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, } func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool { - preparedTools := make([]Tool, 0, len(tools)) + preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools)) // If explicitly disabling all tools, return no tools if disableAllTools { @@ -949,6 +949,11 @@ func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderD }) } for _, tool := range providerDefinedTools { + // If activeTools has items, only include tools in the list. If + // activeTools is empty, include all tools + if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) { + continue + } preparedTools = append(preparedTools, tool) } return preparedTools diff --git a/agent_test.go b/agent_test.go index 2d207610892c8a26522080aaf598738c3c07e5ca..f6df57987d124b65ba7bb6b4964eddd7cc5ee188 100644 --- a/agent_test.go +++ b/agent_test.go @@ -631,6 +631,57 @@ func TestAgent_Generate_OptionsActiveTools(t *testing.T) { require.NotNil(t, result) } +func TestAgent_Generate_OptionsActiveTools_WithProviderDefinedTools(t *testing.T) { + t.Parallel() + + tool1 := &mockTool{ + name: "tool1", + description: "Test tool 1", + parameters: map[string]any{ + "value": map[string]any{"type": "string"}, + }, + required: []string{"value"}, + } + + providerTool1 := ProviderDefinedTool{ID: "provider.web_search", Name: "web_search"} + providerTool2 := ProviderDefinedTool{ID: "provider.code_execution", Name: "code_execution"} + + model := &mockLanguageModel{ + generateFunc: func(ctx context.Context, call Call) (*Response, error) { + require.Len(t, call.Tools, 2) + + functionTool, ok := call.Tools[0].(FunctionTool) + require.True(t, ok) + require.Equal(t, "tool1", functionTool.Name) + + providerTool, ok := call.Tools[1].(ProviderDefinedTool) + require.True(t, ok) + require.Equal(t, "web_search", providerTool.Name) + + return &Response{ + Content: []Content{ + TextContent{Text: "Hello, world!"}, + }, + Usage: Usage{ + InputTokens: 3, + OutputTokens: 10, + TotalTokens: 13, + }, + FinishReason: FinishReasonStop, + }, nil + }, + } + + agent := NewAgent(model, WithTools(tool1), WithProviderDefinedTools(providerTool1, providerTool2)) + result, err := agent.Generate(context.Background(), AgentCall{ + Prompt: "test-input", + ActiveTools: []string{"tool1", "web_search"}, // Only tool1 and web_search should be active + }) + + require.NoError(t, err) + require.NotNil(t, result) +} + func TestResponseContent_Getters(t *testing.T) { t.Parallel()