diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 678e4cb5d25b75e366dced4ba7fba4cb01b9d44f..1a571e8009673a1bbdf540b27c1048b825086747 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -15,9 +15,9 @@ use futures::{ use gpui::{App, Context, Entity, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, - LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, - LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, + LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, }; use log; use project::Project; @@ -681,10 +681,12 @@ impl Thread { .profiles .get(&self.profile_id) .context("profile not found")?; + let provider_id = self.selected_model.provider_id(); Ok(self .tools .iter() + .filter(move |(_, tool)| tool.supported_provider(&provider_id)) .filter_map(|(tool_name, tool)| { if profile.is_tool_enabled(tool_name) { Some(tool) @@ -782,6 +784,12 @@ where schemars::schema_for!(Self::Input) } + /// Some tools rely on a provider for the underlying billing or other reasons. + /// Allow the tool to check if they are compatible, or should be filtered out. + fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { + true + } + /// Runs the tool with the provided input. fn run( self: Arc, @@ -808,6 +816,9 @@ pub trait AnyAgentTool { fn kind(&self) -> acp::ToolKind; fn initial_title(&self, input: serde_json::Value) -> SharedString; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { + true + } fn run( self: Arc, input: serde_json::Value, @@ -843,6 +854,10 @@ where Ok(json) } + fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool { + self.0.supported_provider(provider) + } + fn run( self: Arc, input: serde_json::Value, diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent2/src/tools/web_search_tool.rs index 12587c2f67eec9e18a906cf9324a249bfa5d0e3e..c1c09707426431bf8a3ad4c59a012a567366d392 100644 --- a/crates/agent2/src/tools/web_search_tool.rs +++ b/crates/agent2/src/tools/web_search_tool.rs @@ -5,7 +5,9 @@ use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use cloud_llm_client::WebSearchResponse; use gpui::{App, AppContext, Task}; -use language_model::LanguageModelToolResultContent; +use language_model::{ + LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID, +}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use ui::prelude::*; @@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool { "Searching the Web".into() } + /// We currently only support Zed Cloud as a provider. + fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool { + provider == &ZED_CLOUD_PROVIDER_ID + } + fn run( self: Arc, input: Self::Input,