@@ -15,9 +15,9 @@ use futures::{
use gpui::{App, Context, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
- LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
- LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
- LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+ LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
+ LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
+ LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use log;
use project::Project;
@@ -681,10 +681,12 @@ impl Thread {
.profiles
.get(&self.profile_id)
.context("profile not found")?;
+ let provider_id = self.selected_model.provider_id();
Ok(self
.tools
.iter()
+ .filter(move |(_, tool)| tool.supported_provider(&provider_id))
.filter_map(|(tool_name, tool)| {
if profile.is_tool_enabled(tool_name) {
Some(tool)
@@ -782,6 +784,12 @@ where
schemars::schema_for!(Self::Input)
}
+ /// Some tools rely on a provider for the underlying billing or other reasons.
+ /// Allow the tool to check if they are compatible, or should be filtered out.
+ fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
+ true
+ }
+
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
@@ -808,6 +816,9 @@ pub trait AnyAgentTool {
fn kind(&self) -> acp::ToolKind;
fn initial_title(&self, input: serde_json::Value) -> SharedString;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
+ fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
+ true
+ }
fn run(
self: Arc<Self>,
input: serde_json::Value,
@@ -843,6 +854,10 @@ where
Ok(json)
}
+ fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
+ self.0.supported_provider(provider)
+ }
+
fn run(
self: Arc<Self>,
input: serde_json::Value,
@@ -5,7 +5,9 @@ use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use cloud_llm_client::WebSearchResponse;
use gpui::{App, AppContext, Task};
-use language_model::LanguageModelToolResultContent;
+use language_model::{
+ LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
+};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::prelude::*;
@@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
"Searching the Web".into()
}
+ /// We currently only support Zed Cloud as a provider.
+ fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
+ provider == &ZED_CLOUD_PROVIDER_ID
+ }
+
fn run(
self: Arc<Self>,
input: Self::Input,