agent2: Allow tools to be provider specific (#36111)

Ben Brandt created

Our WebSearch tool requires access to a Zed provider

Release Notes:

- N/A

Change summary

crates/agent2/src/thread.rs                | 21 ++++++++++++++++++---
crates/agent2/src/tools/web_search_tool.rs |  9 ++++++++-
2 files changed, 26 insertions(+), 4 deletions(-)

Detailed changes

crates/agent2/src/thread.rs 🔗

@@ -15,9 +15,9 @@ use futures::{
 use gpui::{App, Context, Entity, SharedString, Task};
 use language_model::{
     LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
-    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
-    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
-    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+    LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
+    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
+    LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
 };
 use log;
 use project::Project;
@@ -681,10 +681,12 @@ impl Thread {
             .profiles
             .get(&self.profile_id)
             .context("profile not found")?;
+        let provider_id = self.selected_model.provider_id();
 
         Ok(self
             .tools
             .iter()
+            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
             .filter_map(|(tool_name, tool)| {
                 if profile.is_tool_enabled(tool_name) {
                     Some(tool)
@@ -782,6 +784,12 @@ where
         schemars::schema_for!(Self::Input)
     }
 
+    /// Some tools rely on a provider for the underlying billing or other reasons.
+    /// Allow the tool to check if they are compatible, or should be filtered out.
+    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
+        true
+    }
+
     /// Runs the tool with the provided input.
     fn run(
         self: Arc<Self>,
@@ -808,6 +816,9 @@ pub trait AnyAgentTool {
     fn kind(&self) -> acp::ToolKind;
     fn initial_title(&self, input: serde_json::Value) -> SharedString;
     fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
+    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
+        true
+    }
     fn run(
         self: Arc<Self>,
         input: serde_json::Value,
@@ -843,6 +854,10 @@ where
         Ok(json)
     }
 
+    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
+        self.0.supported_provider(provider)
+    }
+
     fn run(
         self: Arc<Self>,
         input: serde_json::Value,

crates/agent2/src/tools/web_search_tool.rs 🔗

@@ -5,7 +5,9 @@ use agent_client_protocol as acp;
 use anyhow::{Result, anyhow};
 use cloud_llm_client::WebSearchResponse;
 use gpui::{App, AppContext, Task};
-use language_model::LanguageModelToolResultContent;
+use language_model::{
+    LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
+};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use ui::prelude::*;
@@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
         "Searching the Web".into()
     }
 
+    /// We currently only support Zed Cloud as a provider.
+    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
+        provider == &ZED_CLOUD_PROVIDER_ID
+    }
+
     fn run(
         self: Arc<Self>,
         input: Self::Input,