Cargo.lock 🔗
@@ -196,6 +196,7 @@ dependencies = [
"clock",
"cloud_llm_client",
"collections",
+ "context_server",
"ctor",
"editor",
"env_logger 0.11.8",
Antonio Scandurra and Ben Brandt created
We still need a profile selector.
Release Notes:
- N/A
---------
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Cargo.lock | 1
crates/acp_thread/src/acp_thread.rs | 51 +++
crates/agent2/Cargo.toml | 2
crates/agent2/src/agent.rs | 34 +
crates/agent2/src/tests/mod.rs | 142 ++++++++
crates/agent2/src/thread.rs | 87 ++++-
crates/agent2/src/tools.rs | 2
crates/agent2/src/tools/context_server_registry.rs | 231 ++++++++++++++++
crates/agent2/src/tools/diagnostics_tool.rs | 18 -
crates/agent2/src/tools/edit_file_tool.rs | 66 +++
crates/agent2/src/tools/fetch_tool.rs | 8
crates/agent2/src/tools/find_path_tool.rs | 3
crates/agent2/src/tools/grep_tool.rs | 25 -
crates/agent2/src/tools/now_tool.rs | 11
crates/agent_settings/src/agent_profile.rs | 14
15 files changed, 587 insertions(+), 108 deletions(-)
@@ -196,6 +196,7 @@ dependencies = [
"clock",
"cloud_llm_client",
"collections",
+ "context_server",
"ctor",
"editor",
"env_logger 0.11.8",
@@ -254,6 +254,15 @@ impl ToolCall {
}
if let Some(raw_output) = raw_output {
+ if self.content.is_empty() {
+ if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
+ {
+ self.content
+ .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
+ markdown,
+ }));
+ }
+ }
self.raw_output = Some(raw_output);
}
}
@@ -1266,6 +1275,48 @@ impl AcpThread {
}
}
+fn markdown_for_raw_output(
+ raw_output: &serde_json::Value,
+ language_registry: &Arc<LanguageRegistry>,
+ cx: &mut App,
+) -> Option<Entity<Markdown>> {
+ match raw_output {
+ serde_json::Value::Null => None,
+ serde_json::Value::Bool(value) => Some(cx.new(|cx| {
+ Markdown::new(
+ value.to_string().into(),
+ Some(language_registry.clone()),
+ None,
+ cx,
+ )
+ })),
+ serde_json::Value::Number(value) => Some(cx.new(|cx| {
+ Markdown::new(
+ value.to_string().into(),
+ Some(language_registry.clone()),
+ None,
+ cx,
+ )
+ })),
+ serde_json::Value::String(value) => Some(cx.new(|cx| {
+ Markdown::new(
+ value.clone().into(),
+ Some(language_registry.clone()),
+ None,
+ cx,
+ )
+ })),
+ value => Some(cx.new(|cx| {
+ Markdown::new(
+ format!("```json\n{}\n```", value).into(),
+ Some(language_registry.clone()),
+ None,
+ cx,
+ )
+ })),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -23,6 +23,7 @@ assistant_tools.workspace = true
chrono.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
+context_server.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -60,6 +61,7 @@ workspace-hack.workspace = true
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
+context_server = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
@@ -1,8 +1,8 @@
use crate::{AgentResponseEvent, Thread, templates::Templates};
use crate::{
- CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool,
- GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool,
- ThinkingTool, ToolCallAuthorization, WebSearchTool,
+ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
+ FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
+ ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
};
use acp_thread::ModelSelector;
use agent_client_protocol as acp;
@@ -55,6 +55,7 @@ pub struct NativeAgent {
project_context: Rc<RefCell<ProjectContext>>,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>,
+ context_server_registry: Entity<ContextServerRegistry>,
/// Shared templates for all threads
templates: Arc<Templates>,
project: Entity<Project>,
@@ -90,6 +91,9 @@ impl NativeAgent {
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
}),
+ context_server_registry: cx.new(|cx| {
+ ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
+ }),
templates,
project,
prompt_store,
@@ -385,7 +389,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
- acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
+ acp_thread::AcpThread::new(
+ "agent2",
+ self.clone(),
+ project.clone(),
+ session_id.clone(),
+ cx,
+ )
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
@@ -413,11 +423,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
})
.ok_or_else(|| {
log::warn!("No default model configured in settings");
- anyhow!("No default model configured. Please configure a default model in settings.")
+ anyhow!(
+ "No default model. Please configure a default model in settings."
+ )
})?;
let thread = cx.new(|cx| {
- let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
+ let mut thread = Thread::new(
+ project.clone(),
+ agent.project_context.clone(),
+ agent.context_server_registry.clone(),
+ action_log.clone(),
+ agent.templates.clone(),
+ default_model,
+ cx,
+ );
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
@@ -450,7 +470,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
- })
+ }),
},
);
})?;
@@ -2,6 +2,7 @@ use super::*;
use acp_thread::AgentConnection;
use action_log::ActionLog;
use agent_client_protocol::{self as acp};
+use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use fs::{FakeFs, Fs};
@@ -165,7 +166,9 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
} else {
false
}
- })
+ }),
+ "{}",
+ thread.to_markdown()
);
});
}
@@ -469,6 +472,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_profiles(cx: &mut TestAppContext) {
+ let ThreadTest {
+ model, thread, fs, ..
+ } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ thread.update(cx, |thread, _cx| {
+ thread.add_tool(DelayTool);
+ thread.add_tool(EchoTool);
+ thread.add_tool(InfiniteTool);
+ });
+
+ // Override profiles and wait for settings to be loaded.
+ fs.insert_file(
+ paths::settings_file(),
+ json!({
+ "agent": {
+ "profiles": {
+ "test-1": {
+ "name": "Test Profile 1",
+ "tools": {
+ EchoTool.name(): true,
+ DelayTool.name(): true,
+ }
+ },
+ "test-2": {
+ "name": "Test Profile 2",
+ "tools": {
+ InfiniteTool.name(): true,
+ }
+ }
+ }
+ }
+ })
+ .to_string()
+ .into_bytes(),
+ )
+ .await;
+ cx.run_until_parked();
+
+ // Test that test-1 profile (default) has echo and delay tools
+ thread.update(cx, |thread, cx| {
+ thread.set_profile(AgentProfileId("test-1".into()));
+ thread.send("test", cx);
+ });
+ cx.run_until_parked();
+
+ let mut pending_completions = fake_model.pending_completions();
+ assert_eq!(pending_completions.len(), 1);
+ let completion = pending_completions.pop().unwrap();
+ let tool_names: Vec<String> = completion
+ .tools
+ .iter()
+ .map(|tool| tool.name.clone())
+ .collect();
+ assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
+ fake_model.end_last_completion_stream();
+
+ // Switch to test-2 profile, and verify that it has only the infinite tool.
+ thread.update(cx, |thread, cx| {
+ thread.set_profile(AgentProfileId("test-2".into()));
+ thread.send("test2", cx)
+ });
+ cx.run_until_parked();
+ let mut pending_completions = fake_model.pending_completions();
+ assert_eq!(pending_completions.len(), 1);
+ let completion = pending_completions.pop().unwrap();
+ let tool_names: Vec<String> = completion
+ .tools
+ .iter()
+ .map(|tool| tool.name.clone())
+ .collect();
+ assert_eq!(tool_names, vec![InfiniteTool.name()]);
+}
+
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_cancellation(cx: &mut TestAppContext) {
@@ -595,6 +674,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
language_models::init(user_store.clone(), client.clone(), cx);
Project::init_settings(cx);
LanguageModelRegistry::test(cx);
+ agent_settings::init(cx);
});
cx.executor().forbid_parking();
@@ -790,6 +870,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
+ raw_output: Some("Finished thinking.".into()),
..Default::default()
},
}
@@ -813,6 +894,7 @@ struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>,
+ fs: Arc<FakeFs>,
}
enum TestModel {
@@ -835,30 +917,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.executor().allow_parking();
let fs = FakeFs::new(cx.background_executor.clone());
+ fs.create_dir(paths::settings_file().parent().unwrap())
+ .await
+ .unwrap();
+ fs.insert_file(
+ paths::settings_file(),
+ json!({
+ "agent": {
+ "default_profile": "test-profile",
+ "profiles": {
+ "test-profile": {
+ "name": "Test Profile",
+ "tools": {
+ EchoTool.name(): true,
+ DelayTool.name(): true,
+ WordListTool.name(): true,
+ ToolRequiringPermission.name(): true,
+ InfiniteTool.name(): true,
+ }
+ }
+ }
+ }
+ })
+ .to_string()
+ .into_bytes(),
+ )
+ .await;
cx.update(|cx| {
settings::init(cx);
- watch_settings(fs.clone(), cx);
Project::init_settings(cx);
agent_settings::init(cx);
+ gpui_tokio::init(cx);
+ let http_client = ReqwestClient::user_agent("agent tests").unwrap();
+ cx.set_http_client(Arc::new(http_client));
+
+ client::init_settings(cx);
+ let client = Client::production(cx);
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ language_model::init(client.clone(), cx);
+ language_models::init(user_store.clone(), client.clone(), cx);
+
+ watch_settings(fs.clone(), cx);
});
+
let templates = Templates::new();
fs.insert_tree(path!("/test"), json!({})).await;
- let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+ let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let model = cx
.update(|cx| {
- gpui_tokio::init(cx);
- let http_client = ReqwestClient::user_agent("agent tests").unwrap();
- cx.set_http_client(Arc::new(http_client));
-
- client::init_settings(cx);
- let client = Client::production(cx);
- let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(client.clone(), cx);
- language_models::init(user_store.clone(), client.clone(), cx);
-
if let TestModel::Fake = model {
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
} else {
@@ -881,20 +990,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
.await;
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project,
project_context.clone(),
+ context_server_registry,
action_log,
templates,
model.clone(),
+ cx,
)
});
ThreadTest {
model,
thread,
project_context,
+ fs,
}
}
@@ -1,7 +1,7 @@
-use crate::{SystemPromptTemplate, Template, Templates};
+use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
use action_log::ActionLog;
use agent_client_protocol as acp;
-use agent_settings::AgentSettings;
+use agent_settings::{AgentProfileId, AgentSettings};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use cloud_llm_client::{CompletionIntent, CompletionMode};
@@ -126,6 +126,8 @@ pub struct Thread {
running_turn: Option<Task<()>>,
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+ context_server_registry: Entity<ContextServerRegistry>,
+ profile_id: AgentProfileId,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
@@ -137,16 +139,21 @@ impl Thread {
pub fn new(
project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
+ context_server_registry: Entity<ContextServerRegistry>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
default_model: Arc<dyn LanguageModel>,
+ cx: &mut Context<Self>,
) -> Self {
+ let profile_id = AgentSettings::get_global(cx).default_profile.clone();
Self {
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
running_turn: None,
pending_tool_uses: HashMap::default(),
tools: BTreeMap::default(),
+ context_server_registry,
+ profile_id,
project_context,
templates,
selected_model: default_model,
@@ -179,6 +186,10 @@ impl Thread {
self.tools.remove(name).is_some()
}
+ pub fn set_profile(&mut self, profile_id: AgentProfileId) {
+ self.profile_id = profile_id;
+ }
+
pub fn cancel(&mut self) {
self.running_turn.take();
@@ -298,6 +309,7 @@ impl Thread {
} else {
acp::ToolCallStatus::Completed
}),
+ raw_output: tool_result.output.clone(),
..Default::default()
},
);
@@ -604,21 +616,23 @@ impl Thread {
let messages = self.build_request_messages();
log::info!("Request will include {} messages", messages.len());
- let tools: Vec<LanguageModelRequestTool> = self
- .tools
- .values()
- .filter_map(|tool| {
- let tool_name = tool.name().to_string();
- log::trace!("Including tool: {}", tool_name);
- Some(LanguageModelRequestTool {
- name: tool_name,
- description: tool.description(cx).to_string(),
- input_schema: tool
- .input_schema(self.selected_model.tool_input_format())
- .log_err()?,
+ let tools = if let Some(tools) = self.tools(cx).log_err() {
+ tools
+ .filter_map(|tool| {
+ let tool_name = tool.name().to_string();
+ log::trace!("Including tool: {}", tool_name);
+ Some(LanguageModelRequestTool {
+ name: tool_name,
+ description: tool.description().to_string(),
+ input_schema: tool
+ .input_schema(self.selected_model.tool_input_format())
+ .log_err()?,
+ })
})
- })
- .collect();
+ .collect()
+ } else {
+ Vec::new()
+ };
log::info!("Request includes {} tools", tools.len());
@@ -639,6 +653,35 @@ impl Thread {
request
}
+ fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
+ let profile = AgentSettings::get_global(cx)
+ .profiles
+ .get(&self.profile_id)
+ .context("profile not found")?;
+
+ Ok(self
+ .tools
+ .iter()
+ .filter_map(|(tool_name, tool)| {
+ if profile.is_tool_enabled(tool_name) {
+ Some(tool)
+ } else {
+ None
+ }
+ })
+ .chain(self.context_server_registry.read(cx).servers().flat_map(
+ |(server_id, tools)| {
+ tools.iter().filter_map(|(tool_name, tool)| {
+ if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
+ Some(tool)
+ } else {
+ None
+ }
+ })
+ },
+ )))
+ }
+
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
log::trace!(
"Building request messages from {} thread messages",
@@ -686,7 +729,7 @@ where
fn name(&self) -> SharedString;
- fn description(&self, _cx: &mut App) -> SharedString {
+ fn description(&self) -> SharedString {
let schema = schemars::schema_for!(Self::Input);
SharedString::new(
schema
@@ -722,13 +765,13 @@ where
pub struct Erased<T>(T);
pub struct AgentToolOutput {
- llm_output: LanguageModelToolResultContent,
- raw_output: serde_json::Value,
+ pub llm_output: LanguageModelToolResultContent,
+ pub raw_output: serde_json::Value,
}
pub trait AnyAgentTool {
fn name(&self) -> SharedString;
- fn description(&self, cx: &mut App) -> SharedString;
+ fn description(&self) -> SharedString;
fn kind(&self) -> acp::ToolKind;
fn initial_title(&self, input: serde_json::Value) -> SharedString;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
@@ -748,8 +791,8 @@ where
self.0.name()
}
- fn description(&self, cx: &mut App) -> SharedString {
- self.0.description(cx)
+ fn description(&self) -> SharedString {
+ self.0.description()
}
fn kind(&self) -> agent_client_protocol::ToolKind {
@@ -1,3 +1,4 @@
+mod context_server_registry;
mod copy_path_tool;
mod create_directory_tool;
mod delete_path_tool;
@@ -15,6 +16,7 @@ mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
+pub use context_server_registry::*;
pub use copy_path_tool::*;
pub use create_directory_tool::*;
pub use delete_path_tool::*;
@@ -0,0 +1,231 @@
+use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
+use agent_client_protocol::ToolKind;
+use anyhow::{Result, anyhow, bail};
+use collections::{BTreeMap, HashMap};
+use context_server::ContextServerId;
+use gpui::{App, Context, Entity, SharedString, Task};
+use project::context_server_store::{ContextServerStatus, ContextServerStore};
+use std::sync::Arc;
+use util::ResultExt;
+
+pub struct ContextServerRegistry {
+ server_store: Entity<ContextServerStore>,
+ registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
+ _subscription: gpui::Subscription,
+}
+
+struct RegisteredContextServer {
+ tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+ load_tools: Task<Result<()>>,
+}
+
+impl ContextServerRegistry {
+ pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
+ let mut this = Self {
+ server_store: server_store.clone(),
+ registered_servers: HashMap::default(),
+ _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
+ };
+ for server in server_store.read(cx).running_servers() {
+ this.reload_tools_for_server(server.id(), cx);
+ }
+ this
+ }
+
+ pub fn servers(
+ &self,
+ ) -> impl Iterator<
+ Item = (
+ &ContextServerId,
+ &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+ ),
+ > {
+ self.registered_servers
+ .iter()
+ .map(|(id, server)| (id, &server.tools))
+ }
+
+ fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
+ let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
+ return;
+ };
+ let Some(client) = server.client() else {
+ return;
+ };
+ if !client.capable(context_server::protocol::ServerCapability::Tools) {
+ return;
+ }
+
+ let registered_server =
+ self.registered_servers
+ .entry(server_id.clone())
+ .or_insert(RegisteredContextServer {
+ tools: BTreeMap::default(),
+ load_tools: Task::ready(Ok(())),
+ });
+ registered_server.load_tools = cx.spawn(async move |this, cx| {
+ let response = client
+ .request::<context_server::types::requests::ListTools>(())
+ .await;
+
+ this.update(cx, |this, cx| {
+ let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
+ return;
+ };
+
+ registered_server.tools.clear();
+ if let Some(response) = response.log_err() {
+ for tool in response.tools {
+ let tool = Arc::new(ContextServerTool::new(
+ this.server_store.clone(),
+ server.id(),
+ tool,
+ ));
+ registered_server.tools.insert(tool.name(), tool);
+ }
+ cx.notify();
+ }
+ })
+ });
+ }
+
+ fn handle_context_server_store_event(
+ &mut self,
+ _: Entity<ContextServerStore>,
+ event: &project::context_server_store::Event,
+ cx: &mut Context<Self>,
+ ) {
+ match event {
+ project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
+ match status {
+ ContextServerStatus::Starting => {}
+ ContextServerStatus::Running => {
+ self.reload_tools_for_server(server_id.clone(), cx);
+ }
+ ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
+ self.registered_servers.remove(&server_id);
+ cx.notify();
+ }
+ }
+ }
+ }
+ }
+}
+
+struct ContextServerTool {
+ store: Entity<ContextServerStore>,
+ server_id: ContextServerId,
+ tool: context_server::types::Tool,
+}
+
+impl ContextServerTool {
+ fn new(
+ store: Entity<ContextServerStore>,
+ server_id: ContextServerId,
+ tool: context_server::types::Tool,
+ ) -> Self {
+ Self {
+ store,
+ server_id,
+ tool,
+ }
+ }
+}
+
+impl AnyAgentTool for ContextServerTool {
+ fn name(&self) -> SharedString {
+ self.tool.name.clone().into()
+ }
+
+ fn description(&self) -> SharedString {
+ self.tool.description.clone().unwrap_or_default().into()
+ }
+
+ fn kind(&self) -> ToolKind {
+ ToolKind::Other
+ }
+
+ fn initial_title(&self, _input: serde_json::Value) -> SharedString {
+ format!("Run MCP tool `{}`", self.tool.name).into()
+ }
+
+ fn input_schema(
+ &self,
+ format: language_model::LanguageModelToolSchemaFormat,
+ ) -> Result<serde_json::Value> {
+ let mut schema = self.tool.input_schema.clone();
+ assistant_tool::adapt_schema_to_format(&mut schema, format)?;
+ Ok(match schema {
+ serde_json::Value::Null => {
+ serde_json::json!({ "type": "object", "properties": [] })
+ }
+ serde_json::Value::Object(map) if map.is_empty() => {
+ serde_json::json!({ "type": "object", "properties": [] })
+ }
+ _ => schema,
+ })
+ }
+
+ fn run(
+ self: Arc<Self>,
+ input: serde_json::Value,
+ _event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<AgentToolOutput>> {
+ let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
+ return Task::ready(Err(anyhow!("Context server not found")));
+ };
+ let tool_name = self.tool.name.clone();
+ let server_clone = server.clone();
+ let input_clone = input.clone();
+
+ cx.spawn(async move |_cx| {
+ let Some(protocol) = server_clone.client() else {
+ bail!("Context server not initialized");
+ };
+
+ let arguments = if let serde_json::Value::Object(map) = input_clone {
+ Some(map.into_iter().collect())
+ } else {
+ None
+ };
+
+ log::trace!(
+ "Running tool: {} with arguments: {:?}",
+ tool_name,
+ arguments
+ );
+ let response = protocol
+ .request::<context_server::types::requests::CallTool>(
+ context_server::types::CallToolParams {
+ name: tool_name,
+ arguments,
+ meta: None,
+ },
+ )
+ .await?;
+
+ let mut result = String::new();
+ for content in response.content {
+ match content {
+ context_server::types::ToolResponseContent::Text { text } => {
+ result.push_str(&text);
+ }
+ context_server::types::ToolResponseContent::Image { .. } => {
+ log::warn!("Ignoring image content from tool response");
+ }
+ context_server::types::ToolResponseContent::Audio { .. } => {
+ log::warn!("Ignoring audio content from tool response");
+ }
+ context_server::types::ToolResponseContent::Resource { .. } => {
+ log::warn!("Ignoring resource content from tool response");
+ }
+ }
+ }
+ Ok(AgentToolOutput {
+ raw_output: result.clone().into(),
+ llm_output: result.into(),
+ })
+ })
+ }
+}
@@ -85,7 +85,7 @@ impl AgentTool for DiagnosticsTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- event_stream: ToolCallEventStream,
+ _event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
match input.path {
@@ -119,11 +119,6 @@ impl AgentTool for DiagnosticsTool {
range.start.row + 1,
entry.diagnostic.message
)?;
-
- event_stream.update_fields(acp::ToolCallUpdateFields {
- content: Some(vec![output.clone().into()]),
- ..Default::default()
- });
}
if output.is_empty() {
@@ -158,18 +153,9 @@ impl AgentTool for DiagnosticsTool {
}
if has_diagnostics {
- event_stream.update_fields(acp::ToolCallUpdateFields {
- content: Some(vec![output.clone().into()]),
- ..Default::default()
- });
Task::ready(Ok(output))
} else {
- let text = "No errors or warnings found in the project.";
- event_stream.update_fields(acp::ToolCallUpdateFields {
- content: Some(vec![text.into()]),
- ..Default::default()
- });
- Task::ready(Ok(text.into()))
+ Task::ready(Ok("No errors or warnings found in the project.".into()))
}
}
}
@@ -454,9 +454,8 @@ fn resolve_path(
#[cfg(test)]
mod tests {
- use crate::Templates;
-
use super::*;
+ use crate::{ContextServerRegistry, Templates};
use action_log::ActionLog;
use client::TelemetrySettings;
use fs::Fs;
@@ -475,9 +474,20 @@ mod tests {
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
- let thread =
- cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
+ let thread = cx.new(|cx| {
+ Thread::new(
+ project,
+ Rc::default(),
+ context_server_registry,
+ action_log,
+ Templates::new(),
+ model,
+ cx,
+ )
+ });
let result = cx
.update(|cx| {
let input = EditFileToolInput {
@@ -661,14 +671,18 @@ mod tests {
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
+ context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
+ cx,
)
});
@@ -792,15 +806,19 @@ mod tests {
.unwrap();
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
+ context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
+ cx,
)
});
@@ -914,15 +932,19 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
+ context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
+ cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1041,15 +1063,19 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
+ context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
+ cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1148,14 +1174,18 @@ mod tests {
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
+ context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
+ cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1225,14 +1255,18 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
+ context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
+ cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1305,14 +1339,18 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
+ context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
+ cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1382,14 +1420,18 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
Thread::new(
project.clone(),
Rc::default(),
+ context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
+ cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -136,7 +136,7 @@ impl AgentTool for FetchTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- event_stream: ToolCallEventStream,
+ _event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let text = cx.background_spawn({
@@ -149,12 +149,6 @@ impl AgentTool for FetchTool {
if text.trim().is_empty() {
bail!("no textual content found");
}
-
- event_stream.update_fields(acp::ToolCallUpdateFields {
- content: Some(vec![text.clone().into()]),
- ..Default::default()
- });
-
Ok(text)
})
}
@@ -139,9 +139,6 @@ impl AgentTool for FindPathTool {
})
.collect(),
),
- raw_output: Some(serde_json::json!({
- "paths": &matches,
- })),
..Default::default()
});
@@ -101,7 +101,7 @@ impl AgentTool for GrepTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- event_stream: ToolCallEventStream,
+ _event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
const CONTEXT_LINES: u32 = 2;
@@ -282,33 +282,22 @@ impl AgentTool for GrepTool {
}
}
- event_stream.update_fields(acp::ToolCallUpdateFields {
- content: Some(vec![output.clone().into()]),
- ..Default::default()
- });
matches_found += 1;
}
}
- let output = if matches_found == 0 {
- "No matches found".to_string()
+ if matches_found == 0 {
+ Ok("No matches found".into())
} else if has_more_matches {
- format!(
+ Ok(format!(
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
input.offset + 1,
input.offset + matches_found,
input.offset + RESULTS_PER_PAGE,
- )
+ ))
} else {
- format!("Found {matches_found} matches:\n{output}")
- };
-
- event_stream.update_fields(acp::ToolCallUpdateFields {
- content: Some(vec![output.clone().into()]),
- ..Default::default()
- });
-
- Ok(output)
+ Ok(format!("Found {matches_found} matches:\n{output}"))
+ }
})
}
}
@@ -47,20 +47,13 @@ impl AgentTool for NowTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- event_stream: ToolCallEventStream,
+ _event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
let now = match input.timezone {
Timezone::Utc => Utc::now().to_rfc3339(),
Timezone::Local => Local::now().to_rfc3339(),
};
- let content = format!("The current datetime is {now}.");
-
- event_stream.update_fields(acp::ToolCallUpdateFields {
- content: Some(vec![content.clone().into()]),
- ..Default::default()
- });
-
- Task::ready(Ok(content))
+ Task::ready(Ok(format!("The current datetime is {now}.")))
}
}
@@ -48,6 +48,20 @@ pub struct AgentProfileSettings {
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
}
+impl AgentProfileSettings {
+ pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
+ self.tools.get(tool_name) == Some(&true)
+ }
+
+ pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool {
+ self.enable_all_context_servers
+ || self
+ .context_servers
+ .get(server_id)
+ .map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
+ }
+}
+
#[derive(Debug, Clone, Default)]
pub struct ContextServerPreset {
pub tools: IndexMap<Arc<str>, bool>,