@@ -921,7 +921,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
}
func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
- preparedTools := make([]Tool, 0, len(tools))
+ preparedTools := make([]Tool, 0, len(tools)+len(providerDefinedTools))
// If explicitly disabling all tools, return no tools
if disableAllTools {
@@ -949,6 +949,11 @@ func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderD
})
}
for _, tool := range providerDefinedTools {
+ // If activeTools has items, only include tools in the list. If
+ // activeTools is empty, include all tools
+ if len(activeTools) > 0 && !slices.Contains(activeTools, tool.GetName()) {
+ continue
+ }
preparedTools = append(preparedTools, tool)
}
return preparedTools
@@ -631,6 +631,57 @@ func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
require.NotNil(t, result)
}
+func TestAgent_Generate_OptionsActiveTools_WithProviderDefinedTools(t *testing.T) {
+ t.Parallel()
+
+ tool1 := &mockTool{
+ name: "tool1",
+ description: "Test tool 1",
+ parameters: map[string]any{
+ "value": map[string]any{"type": "string"},
+ },
+ required: []string{"value"},
+ }
+
+ providerTool1 := ProviderDefinedTool{ID: "provider.web_search", Name: "web_search"}
+ providerTool2 := ProviderDefinedTool{ID: "provider.code_execution", Name: "code_execution"}
+
+ model := &mockLanguageModel{
+ generateFunc: func(ctx context.Context, call Call) (*Response, error) {
+ require.Len(t, call.Tools, 2)
+
+ functionTool, ok := call.Tools[0].(FunctionTool)
+ require.True(t, ok)
+ require.Equal(t, "tool1", functionTool.Name)
+
+ providerTool, ok := call.Tools[1].(ProviderDefinedTool)
+ require.True(t, ok)
+ require.Equal(t, "web_search", providerTool.Name)
+
+ return &Response{
+ Content: []Content{
+ TextContent{Text: "Hello, world!"},
+ },
+ Usage: Usage{
+ InputTokens: 3,
+ OutputTokens: 10,
+ TotalTokens: 13,
+ },
+ FinishReason: FinishReasonStop,
+ }, nil
+ },
+ }
+
+ agent := NewAgent(model, WithTools(tool1), WithProviderDefinedTools(providerTool1, providerTool2))
+ result, err := agent.Generate(context.Background(), AgentCall{
+ Prompt: "test-input",
+ ActiveTools: []string{"tool1", "web_search"}, // Only tool1 and web_search should be active
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+}
+
func TestResponseContent_Getters(t *testing.T) {
t.Parallel()