fix(agent): make per-step tool selection apply to all tool types

Christian Rocha created

Change summary

agent.go      |  7 ++++++-
agent_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 57 insertions(+), 1 deletion(-)

Detailed changes

agent.go 🔗

@@ -921,7 +921,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 }
 
 func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
-	preparedTools := make([]Tool, 0, len(tools))
+	preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
 
 	// If explicitly disabling all tools, return no tools
 	if disableAllTools {
@@ -949,6 +949,11 @@ func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderD
 		})
 	}
 	for _, tool := range providerDefinedTools {
+		// If activeTools has items, only include tools in the list. If
+		// activeTools is empty, include all tools
+		if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
+			continue
+		}
 		preparedTools = append(preparedTools, tool)
 	}
 	return preparedTools

agent_test.go 🔗

@@ -631,6 +631,57 @@ func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
 	require.NotNil(t, result)
 }
 
+func TestAgent_Generate_OptionsActiveTools_WithProviderDefinedTools(t *testing.T) {
+	t.Parallel()
+
+	tool1 := &mockTool{
+		name:        "tool1",
+		description: "Test tool 1",
+		parameters: map[string]any{
+			"value": map[string]any{"type": "string"},
+		},
+		required: []string{"value"},
+	}
+
+	providerTool1 := ProviderDefinedTool{ID: "provider.web_search", Name: "web_search"}
+	providerTool2 := ProviderDefinedTool{ID: "provider.code_execution", Name: "code_execution"}
+
+	model := &mockLanguageModel{
+		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+			require.Len(t, call.Tools, 2)
+
+			functionTool, ok := call.Tools[0].(FunctionTool)
+			require.True(t, ok)
+			require.Equal(t, "tool1", functionTool.Name)
+
+			providerTool, ok := call.Tools[1].(ProviderDefinedTool)
+			require.True(t, ok)
+			require.Equal(t, "web_search", providerTool.Name)
+
+			return &Response{
+				Content: []Content{
+					TextContent{Text: "Hello, world!"},
+				},
+				Usage: Usage{
+					InputTokens:  3,
+					OutputTokens: 10,
+					TotalTokens:  13,
+				},
+				FinishReason: FinishReasonStop,
+			}, nil
+		},
+	}
+
+	agent := NewAgent(model, WithTools(tool1), WithProviderDefinedTools(providerTool1, providerTool2))
+	result, err := agent.Generate(context.Background(), AgentCall{
+		Prompt:      "test-input",
+		ActiveTools: []string{"tool1", "web_search"}, // Only tool1 and web_search should be active
+	})
+
+	require.NoError(t, err)
+	require.NotNil(t, result)
+}
+
 func TestResponseContent_Getters(t *testing.T) {
 	t.Parallel()