Support profiles in agent2 (#36034)

Antonio Scandurra and Ben Brandt created

We still need a profile selector.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

Cargo.lock                                         |   1 
crates/acp_thread/src/acp_thread.rs                |  51 +++
crates/agent2/Cargo.toml                           |   2 
crates/agent2/src/agent.rs                         |  34 +
crates/agent2/src/tests/mod.rs                     | 142 ++++++++
crates/agent2/src/thread.rs                        |  87 ++++-
crates/agent2/src/tools.rs                         |   2 
crates/agent2/src/tools/context_server_registry.rs | 231 ++++++++++++++++
crates/agent2/src/tools/diagnostics_tool.rs        |  18 -
crates/agent2/src/tools/edit_file_tool.rs          |  66 +++
crates/agent2/src/tools/fetch_tool.rs              |   8 
crates/agent2/src/tools/find_path_tool.rs          |   3 
crates/agent2/src/tools/grep_tool.rs               |  25 -
crates/agent2/src/tools/now_tool.rs                |  11 
crates/agent_settings/src/agent_profile.rs         |  14 
15 files changed, 587 insertions(+), 108 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -196,6 +196,7 @@ dependencies = [
  "clock",
  "cloud_llm_client",
  "collections",
+ "context_server",
  "ctor",
  "editor",
  "env_logger 0.11.8",

crates/acp_thread/src/acp_thread.rs 🔗

@@ -254,6 +254,15 @@ impl ToolCall {
         }
 
         if let Some(raw_output) = raw_output {
+            if self.content.is_empty() {
+                if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
+                {
+                    self.content
+                        .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
+                            markdown,
+                        }));
+                }
+            }
             self.raw_output = Some(raw_output);
         }
     }
@@ -1266,6 +1275,48 @@ impl AcpThread {
     }
 }
 
+fn markdown_for_raw_output(
+    raw_output: &serde_json::Value,
+    language_registry: &Arc<LanguageRegistry>,
+    cx: &mut App,
+) -> Option<Entity<Markdown>> {
+    match raw_output {
+        serde_json::Value::Null => None,
+        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
+            Markdown::new(
+                value.to_string().into(),
+                Some(language_registry.clone()),
+                None,
+                cx,
+            )
+        })),
+        serde_json::Value::Number(value) => Some(cx.new(|cx| {
+            Markdown::new(
+                value.to_string().into(),
+                Some(language_registry.clone()),
+                None,
+                cx,
+            )
+        })),
+        serde_json::Value::String(value) => Some(cx.new(|cx| {
+            Markdown::new(
+                value.clone().into(),
+                Some(language_registry.clone()),
+                None,
+                cx,
+            )
+        })),
+        value => Some(cx.new(|cx| {
+            Markdown::new(
+                format!("```json\n{}\n```", value).into(),
+                Some(language_registry.clone()),
+                None,
+                cx,
+            )
+        })),
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

crates/agent2/Cargo.toml 🔗

@@ -23,6 +23,7 @@ assistant_tools.workspace = true
 chrono.workspace = true
 cloud_llm_client.workspace = true
 collections.workspace = true
+context_server.workspace = true
 fs.workspace = true
 futures.workspace = true
 gpui.workspace = true
@@ -60,6 +61,7 @@ workspace-hack.workspace = true
 ctor.workspace = true
 client = { workspace = true, "features" = ["test-support"] }
 clock = { workspace = true, "features" = ["test-support"] }
+context_server = { workspace = true, "features" = ["test-support"] }
 editor = { workspace = true, "features" = ["test-support"] }
 env_logger.workspace = true
 fs = { workspace = true, "features" = ["test-support"] }

crates/agent2/src/agent.rs 🔗

@@ -1,8 +1,8 @@
 use crate::{AgentResponseEvent, Thread, templates::Templates};
 use crate::{
-    CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool,
-    GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool,
-    ThinkingTool, ToolCallAuthorization, WebSearchTool,
+    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
+    FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
+    ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
 };
 use acp_thread::ModelSelector;
 use agent_client_protocol as acp;
@@ -55,6 +55,7 @@ pub struct NativeAgent {
     project_context: Rc<RefCell<ProjectContext>>,
     project_context_needs_refresh: watch::Sender<()>,
     _maintain_project_context: Task<Result<()>>,
+    context_server_registry: Entity<ContextServerRegistry>,
     /// Shared templates for all threads
     templates: Arc<Templates>,
     project: Entity<Project>,
@@ -90,6 +91,9 @@ impl NativeAgent {
                 _maintain_project_context: cx.spawn(async move |this, cx| {
                     Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
                 }),
+                context_server_registry: cx.new(|cx| {
+                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
+                }),
                 templates,
                 project,
                 prompt_store,
@@ -385,7 +389,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
             // Create AcpThread
             let acp_thread = cx.update(|cx| {
                 cx.new(|cx| {
-                    acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
+                    acp_thread::AcpThread::new(
+                        "agent2",
+                        self.clone(),
+                        project.clone(),
+                        session_id.clone(),
+                        cx,
+                    )
                 })
             })?;
             let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
@@ -413,11 +423,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                         })
                         .ok_or_else(|| {
                             log::warn!("No default model configured in settings");
-                            anyhow!("No default model configured. Please configure a default model in settings.")
+                            anyhow!(
+                                "No default model. Please configure a default model in settings."
+                            )
                         })?;
 
                     let thread = cx.new(|cx| {
-                        let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
+                        let mut thread = Thread::new(
+                            project.clone(),
+                            agent.project_context.clone(),
+                            agent.context_server_registry.clone(),
+                            action_log.clone(),
+                            agent.templates.clone(),
+                            default_model,
+                            cx,
+                        );
                         thread.add_tool(CreateDirectoryTool::new(project.clone()));
                         thread.add_tool(CopyPathTool::new(project.clone()));
                         thread.add_tool(DiagnosticsTool::new(project.clone()));
@@ -450,7 +470,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                         acp_thread: acp_thread.downgrade(),
                         _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
                             this.sessions.remove(acp_thread.session_id());
-                        })
+                        }),
                     },
                 );
             })?;

crates/agent2/src/tests/mod.rs 🔗

@@ -2,6 +2,7 @@ use super::*;
 use acp_thread::AgentConnection;
 use action_log::ActionLog;
 use agent_client_protocol::{self as acp};
+use agent_settings::AgentProfileId;
 use anyhow::Result;
 use client::{Client, UserStore};
 use fs::{FakeFs, Fs};
@@ -165,7 +166,9 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
                     } else {
                         false
                     }
-                })
+                }),
+            "{}",
+            thread.to_markdown()
         );
     });
 }
@@ -469,6 +472,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
     });
 }
 
+#[gpui::test]
+async fn test_profiles(cx: &mut TestAppContext) {
+    let ThreadTest {
+        model, thread, fs, ..
+    } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    thread.update(cx, |thread, _cx| {
+        thread.add_tool(DelayTool);
+        thread.add_tool(EchoTool);
+        thread.add_tool(InfiniteTool);
+    });
+
+    // Override profiles and wait for settings to be loaded.
+    fs.insert_file(
+        paths::settings_file(),
+        json!({
+            "agent": {
+                "profiles": {
+                    "test-1": {
+                        "name": "Test Profile 1",
+                        "tools": {
+                            EchoTool.name(): true,
+                            DelayTool.name(): true,
+                        }
+                    },
+                    "test-2": {
+                        "name": "Test Profile 2",
+                        "tools": {
+                            InfiniteTool.name(): true,
+                        }
+                    }
+                }
+            }
+        })
+        .to_string()
+        .into_bytes(),
+    )
+    .await;
+    cx.run_until_parked();
+
+    // Test that test-1 profile (default) has echo and delay tools
+    thread.update(cx, |thread, cx| {
+        thread.set_profile(AgentProfileId("test-1".into()));
+        thread.send("test", cx);
+    });
+    cx.run_until_parked();
+
+    let mut pending_completions = fake_model.pending_completions();
+    assert_eq!(pending_completions.len(), 1);
+    let completion = pending_completions.pop().unwrap();
+    let tool_names: Vec<String> = completion
+        .tools
+        .iter()
+        .map(|tool| tool.name.clone())
+        .collect();
+    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
+    fake_model.end_last_completion_stream();
+
+    // Switch to test-2 profile, and verify that it has only the infinite tool.
+    thread.update(cx, |thread, cx| {
+        thread.set_profile(AgentProfileId("test-2".into()));
+        thread.send("test2", cx)
+    });
+    cx.run_until_parked();
+    let mut pending_completions = fake_model.pending_completions();
+    assert_eq!(pending_completions.len(), 1);
+    let completion = pending_completions.pop().unwrap();
+    let tool_names: Vec<String> = completion
+        .tools
+        .iter()
+        .map(|tool| tool.name.clone())
+        .collect();
+    assert_eq!(tool_names, vec![InfiniteTool.name()]);
+}
+
 #[gpui::test]
 #[ignore = "can't run on CI yet"]
 async fn test_cancellation(cx: &mut TestAppContext) {
@@ -595,6 +674,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
         language_models::init(user_store.clone(), client.clone(), cx);
         Project::init_settings(cx);
         LanguageModelRegistry::test(cx);
+        agent_settings::init(cx);
     });
     cx.executor().forbid_parking();
 
@@ -790,6 +870,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
             id: acp::ToolCallId("1".into()),
             fields: acp::ToolCallUpdateFields {
                 status: Some(acp::ToolCallStatus::Completed),
+                raw_output: Some("Finished thinking.".into()),
                 ..Default::default()
             },
         }
@@ -813,6 +894,7 @@ struct ThreadTest {
     model: Arc<dyn LanguageModel>,
     thread: Entity<Thread>,
     project_context: Rc<RefCell<ProjectContext>>,
+    fs: Arc<FakeFs>,
 }
 
 enum TestModel {
@@ -835,30 +917,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
     cx.executor().allow_parking();
 
     let fs = FakeFs::new(cx.background_executor.clone());
+    fs.create_dir(paths::settings_file().parent().unwrap())
+        .await
+        .unwrap();
+    fs.insert_file(
+        paths::settings_file(),
+        json!({
+            "agent": {
+                "default_profile": "test-profile",
+                "profiles": {
+                    "test-profile": {
+                        "name": "Test Profile",
+                        "tools": {
+                            EchoTool.name(): true,
+                            DelayTool.name(): true,
+                            WordListTool.name(): true,
+                            ToolRequiringPermission.name(): true,
+                            InfiniteTool.name(): true,
+                        }
+                    }
+                }
+            }
+        })
+        .to_string()
+        .into_bytes(),
+    )
+    .await;
 
     cx.update(|cx| {
         settings::init(cx);
-        watch_settings(fs.clone(), cx);
         Project::init_settings(cx);
         agent_settings::init(cx);
+        gpui_tokio::init(cx);
+        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
+        cx.set_http_client(Arc::new(http_client));
+
+        client::init_settings(cx);
+        let client = Client::production(cx);
+        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+        language_model::init(client.clone(), cx);
+        language_models::init(user_store.clone(), client.clone(), cx);
+
+        watch_settings(fs.clone(), cx);
     });
+
     let templates = Templates::new();
 
     fs.insert_tree(path!("/test"), json!({})).await;
-    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
 
     let model = cx
         .update(|cx| {
-            gpui_tokio::init(cx);
-            let http_client = ReqwestClient::user_agent("agent tests").unwrap();
-            cx.set_http_client(Arc::new(http_client));
-
-            client::init_settings(cx);
-            let client = Client::production(cx);
-            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-            language_model::init(client.clone(), cx);
-            language_models::init(user_store.clone(), client.clone(), cx);
-
             if let TestModel::Fake = model {
                 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
             } else {
@@ -881,20 +990,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
         .await;
 
     let project_context = Rc::new(RefCell::new(ProjectContext::default()));
+    let context_server_registry =
+        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
     let action_log = cx.new(|_| ActionLog::new(project.clone()));
-    let thread = cx.new(|_| {
+    let thread = cx.new(|cx| {
         Thread::new(
             project,
             project_context.clone(),
+            context_server_registry,
             action_log,
             templates,
             model.clone(),
+            cx,
         )
     });
     ThreadTest {
         model,
         thread,
         project_context,
+        fs,
     }
 }
 

crates/agent2/src/thread.rs 🔗

@@ -1,7 +1,7 @@
-use crate::{SystemPromptTemplate, Template, Templates};
+use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
 use action_log::ActionLog;
 use agent_client_protocol as acp;
-use agent_settings::AgentSettings;
+use agent_settings::{AgentProfileId, AgentSettings};
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
 use cloud_llm_client::{CompletionIntent, CompletionMode};
@@ -126,6 +126,8 @@ pub struct Thread {
     running_turn: Option<Task<()>>,
     pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
     tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+    context_server_registry: Entity<ContextServerRegistry>,
+    profile_id: AgentProfileId,
     project_context: Rc<RefCell<ProjectContext>>,
     templates: Arc<Templates>,
     pub selected_model: Arc<dyn LanguageModel>,
@@ -137,16 +139,21 @@ impl Thread {
     pub fn new(
         project: Entity<Project>,
         project_context: Rc<RefCell<ProjectContext>>,
+        context_server_registry: Entity<ContextServerRegistry>,
         action_log: Entity<ActionLog>,
         templates: Arc<Templates>,
         default_model: Arc<dyn LanguageModel>,
+        cx: &mut Context<Self>,
     ) -> Self {
+        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
         Self {
             messages: Vec::new(),
             completion_mode: CompletionMode::Normal,
             running_turn: None,
             pending_tool_uses: HashMap::default(),
             tools: BTreeMap::default(),
+            context_server_registry,
+            profile_id,
             project_context,
             templates,
             selected_model: default_model,
@@ -179,6 +186,10 @@ impl Thread {
         self.tools.remove(name).is_some()
     }
 
+    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
+        self.profile_id = profile_id;
+    }
+
     pub fn cancel(&mut self) {
         self.running_turn.take();
 
@@ -298,6 +309,7 @@ impl Thread {
                                 } else {
                                     acp::ToolCallStatus::Completed
                                 }),
+                                raw_output: tool_result.output.clone(),
                                 ..Default::default()
                             },
                         );
@@ -604,21 +616,23 @@ impl Thread {
         let messages = self.build_request_messages();
         log::info!("Request will include {} messages", messages.len());
 
-        let tools: Vec<LanguageModelRequestTool> = self
-            .tools
-            .values()
-            .filter_map(|tool| {
-                let tool_name = tool.name().to_string();
-                log::trace!("Including tool: {}", tool_name);
-                Some(LanguageModelRequestTool {
-                    name: tool_name,
-                    description: tool.description(cx).to_string(),
-                    input_schema: tool
-                        .input_schema(self.selected_model.tool_input_format())
-                        .log_err()?,
+        let tools = if let Some(tools) = self.tools(cx).log_err() {
+            tools
+                .filter_map(|tool| {
+                    let tool_name = tool.name().to_string();
+                    log::trace!("Including tool: {}", tool_name);
+                    Some(LanguageModelRequestTool {
+                        name: tool_name,
+                        description: tool.description().to_string(),
+                        input_schema: tool
+                            .input_schema(self.selected_model.tool_input_format())
+                            .log_err()?,
+                    })
                 })
-            })
-            .collect();
+                .collect()
+        } else {
+            Vec::new()
+        };
 
         log::info!("Request includes {} tools", tools.len());
 
@@ -639,6 +653,35 @@ impl Thread {
         request
     }
 
+    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
+        let profile = AgentSettings::get_global(cx)
+            .profiles
+            .get(&self.profile_id)
+            .context("profile not found")?;
+
+        Ok(self
+            .tools
+            .iter()
+            .filter_map(|(tool_name, tool)| {
+                if profile.is_tool_enabled(tool_name) {
+                    Some(tool)
+                } else {
+                    None
+                }
+            })
+            .chain(self.context_server_registry.read(cx).servers().flat_map(
+                |(server_id, tools)| {
+                    tools.iter().filter_map(|(tool_name, tool)| {
+                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
+                            Some(tool)
+                        } else {
+                            None
+                        }
+                    })
+                },
+            )))
+    }
+
     fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
         log::trace!(
             "Building request messages from {} thread messages",
@@ -686,7 +729,7 @@ where
 
     fn name(&self) -> SharedString;
 
-    fn description(&self, _cx: &mut App) -> SharedString {
+    fn description(&self) -> SharedString {
         let schema = schemars::schema_for!(Self::Input);
         SharedString::new(
             schema
@@ -722,13 +765,13 @@ where
 pub struct Erased<T>(T);
 
 pub struct AgentToolOutput {
-    llm_output: LanguageModelToolResultContent,
-    raw_output: serde_json::Value,
+    pub llm_output: LanguageModelToolResultContent,
+    pub raw_output: serde_json::Value,
 }
 
 pub trait AnyAgentTool {
     fn name(&self) -> SharedString;
-    fn description(&self, cx: &mut App) -> SharedString;
+    fn description(&self) -> SharedString;
     fn kind(&self) -> acp::ToolKind;
     fn initial_title(&self, input: serde_json::Value) -> SharedString;
     fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
@@ -748,8 +791,8 @@ where
         self.0.name()
     }
 
-    fn description(&self, cx: &mut App) -> SharedString {
-        self.0.description(cx)
+    fn description(&self) -> SharedString {
+        self.0.description()
     }
 
     fn kind(&self) -> agent_client_protocol::ToolKind {

crates/agent2/src/tools.rs 🔗

@@ -1,3 +1,4 @@
+mod context_server_registry;
 mod copy_path_tool;
 mod create_directory_tool;
 mod delete_path_tool;
@@ -15,6 +16,7 @@ mod terminal_tool;
 mod thinking_tool;
 mod web_search_tool;
 
+pub use context_server_registry::*;
 pub use copy_path_tool::*;
 pub use create_directory_tool::*;
 pub use delete_path_tool::*;

crates/agent2/src/tools/context_server_registry.rs 🔗

@@ -0,0 +1,231 @@
+use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
+use agent_client_protocol::ToolKind;
+use anyhow::{Result, anyhow, bail};
+use collections::{BTreeMap, HashMap};
+use context_server::ContextServerId;
+use gpui::{App, Context, Entity, SharedString, Task};
+use project::context_server_store::{ContextServerStatus, ContextServerStore};
+use std::sync::Arc;
+use util::ResultExt;
+
+pub struct ContextServerRegistry {
+    server_store: Entity<ContextServerStore>,
+    registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
+    _subscription: gpui::Subscription,
+}
+
+struct RegisteredContextServer {
+    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+    load_tools: Task<Result<()>>,
+}
+
+impl ContextServerRegistry {
+    pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
+        let mut this = Self {
+            server_store: server_store.clone(),
+            registered_servers: HashMap::default(),
+            _subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
+        };
+        for server in server_store.read(cx).running_servers() {
+            this.reload_tools_for_server(server.id(), cx);
+        }
+        this
+    }
+
+    pub fn servers(
+        &self,
+    ) -> impl Iterator<
+        Item = (
+            &ContextServerId,
+            &BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+        ),
+    > {
+        self.registered_servers
+            .iter()
+            .map(|(id, server)| (id, &server.tools))
+    }
+
+    fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
+        let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
+            return;
+        };
+        let Some(client) = server.client() else {
+            return;
+        };
+        if !client.capable(context_server::protocol::ServerCapability::Tools) {
+            return;
+        }
+
+        let registered_server =
+            self.registered_servers
+                .entry(server_id.clone())
+                .or_insert(RegisteredContextServer {
+                    tools: BTreeMap::default(),
+                    load_tools: Task::ready(Ok(())),
+                });
+        registered_server.load_tools = cx.spawn(async move |this, cx| {
+            let response = client
+                .request::<context_server::types::requests::ListTools>(())
+                .await;
+
+            this.update(cx, |this, cx| {
+                let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
+                    return;
+                };
+
+                registered_server.tools.clear();
+                if let Some(response) = response.log_err() {
+                    for tool in response.tools {
+                        let tool = Arc::new(ContextServerTool::new(
+                            this.server_store.clone(),
+                            server.id(),
+                            tool,
+                        ));
+                        registered_server.tools.insert(tool.name(), tool);
+                    }
+                    cx.notify();
+                }
+            })
+        });
+    }
+
+    fn handle_context_server_store_event(
+        &mut self,
+        _: Entity<ContextServerStore>,
+        event: &project::context_server_store::Event,
+        cx: &mut Context<Self>,
+    ) {
+        match event {
+            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
+                match status {
+                    ContextServerStatus::Starting => {}
+                    ContextServerStatus::Running => {
+                        self.reload_tools_for_server(server_id.clone(), cx);
+                    }
+                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
+                        self.registered_servers.remove(&server_id);
+                        cx.notify();
+                    }
+                }
+            }
+        }
+    }
+}
+
+struct ContextServerTool {
+    store: Entity<ContextServerStore>,
+    server_id: ContextServerId,
+    tool: context_server::types::Tool,
+}
+
+impl ContextServerTool {
+    fn new(
+        store: Entity<ContextServerStore>,
+        server_id: ContextServerId,
+        tool: context_server::types::Tool,
+    ) -> Self {
+        Self {
+            store,
+            server_id,
+            tool,
+        }
+    }
+}
+
+impl AnyAgentTool for ContextServerTool {
+    fn name(&self) -> SharedString {
+        self.tool.name.clone().into()
+    }
+
+    fn description(&self) -> SharedString {
+        self.tool.description.clone().unwrap_or_default().into()
+    }
+
+    fn kind(&self) -> ToolKind {
+        ToolKind::Other
+    }
+
+    fn initial_title(&self, _input: serde_json::Value) -> SharedString {
+        format!("Run MCP tool `{}`", self.tool.name).into()
+    }
+
+    fn input_schema(
+        &self,
+        format: language_model::LanguageModelToolSchemaFormat,
+    ) -> Result<serde_json::Value> {
+        let mut schema = self.tool.input_schema.clone();
+        assistant_tool::adapt_schema_to_format(&mut schema, format)?;
+        Ok(match schema {
+            serde_json::Value::Null => {
+                serde_json::json!({ "type": "object", "properties": [] })
+            }
+            serde_json::Value::Object(map) if map.is_empty() => {
+                serde_json::json!({ "type": "object", "properties": [] })
+            }
+            _ => schema,
+        })
+    }
+
+    fn run(
+        self: Arc<Self>,
+        input: serde_json::Value,
+        _event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Task<Result<AgentToolOutput>> {
+        let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
+            return Task::ready(Err(anyhow!("Context server not found")));
+        };
+        let tool_name = self.tool.name.clone();
+        let server_clone = server.clone();
+        let input_clone = input.clone();
+
+        cx.spawn(async move |_cx| {
+            let Some(protocol) = server_clone.client() else {
+                bail!("Context server not initialized");
+            };
+
+            let arguments = if let serde_json::Value::Object(map) = input_clone {
+                Some(map.into_iter().collect())
+            } else {
+                None
+            };
+
+            log::trace!(
+                "Running tool: {} with arguments: {:?}",
+                tool_name,
+                arguments
+            );
+            let response = protocol
+                .request::<context_server::types::requests::CallTool>(
+                    context_server::types::CallToolParams {
+                        name: tool_name,
+                        arguments,
+                        meta: None,
+                    },
+                )
+                .await?;
+
+            let mut result = String::new();
+            for content in response.content {
+                match content {
+                    context_server::types::ToolResponseContent::Text { text } => {
+                        result.push_str(&text);
+                    }
+                    context_server::types::ToolResponseContent::Image { .. } => {
+                        log::warn!("Ignoring image content from tool response");
+                    }
+                    context_server::types::ToolResponseContent::Audio { .. } => {
+                        log::warn!("Ignoring audio content from tool response");
+                    }
+                    context_server::types::ToolResponseContent::Resource { .. } => {
+                        log::warn!("Ignoring resource content from tool response");
+                    }
+                }
+            }
+            Ok(AgentToolOutput {
+                raw_output: result.clone().into(),
+                llm_output: result.into(),
+            })
+        })
+    }
+}

crates/agent2/src/tools/diagnostics_tool.rs 🔗

@@ -85,7 +85,7 @@ impl AgentTool for DiagnosticsTool {
     fn run(
         self: Arc<Self>,
         input: Self::Input,
-        event_stream: ToolCallEventStream,
+        _event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<Self::Output>> {
         match input.path {
@@ -119,11 +119,6 @@ impl AgentTool for DiagnosticsTool {
                             range.start.row + 1,
                             entry.diagnostic.message
                         )?;
-
-                        event_stream.update_fields(acp::ToolCallUpdateFields {
-                            content: Some(vec![output.clone().into()]),
-                            ..Default::default()
-                        });
                     }
 
                     if output.is_empty() {
@@ -158,18 +153,9 @@ impl AgentTool for DiagnosticsTool {
                 }
 
                 if has_diagnostics {
-                    event_stream.update_fields(acp::ToolCallUpdateFields {
-                        content: Some(vec![output.clone().into()]),
-                        ..Default::default()
-                    });
                     Task::ready(Ok(output))
                 } else {
-                    let text = "No errors or warnings found in the project.";
-                    event_stream.update_fields(acp::ToolCallUpdateFields {
-                        content: Some(vec![text.into()]),
-                        ..Default::default()
-                    });
-                    Task::ready(Ok(text.into()))
+                    Task::ready(Ok("No errors or warnings found in the project.".into()))
                 }
             }
         }

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -454,9 +454,8 @@ fn resolve_path(
 
 #[cfg(test)]
 mod tests {
-    use crate::Templates;
-
     use super::*;
+    use crate::{ContextServerRegistry, Templates};
     use action_log::ActionLog;
     use client::TelemetrySettings;
     use fs::Fs;
@@ -475,9 +474,20 @@ mod tests {
         fs.insert_tree("/root", json!({})).await;
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread =
-            cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
+        let thread = cx.new(|cx| {
+            Thread::new(
+                project,
+                Rc::default(),
+                context_server_registry,
+                action_log,
+                Templates::new(),
+                model,
+                cx,
+            )
+        });
         let result = cx
             .update(|cx| {
                 let input = EditFileToolInput {
@@ -661,14 +671,18 @@ mod tests {
         });
 
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|_| {
+        let thread = cx.new(|cx| {
             Thread::new(
                 project,
                 Rc::default(),
+                context_server_registry,
                 action_log.clone(),
                 Templates::new(),
                 model.clone(),
+                cx,
             )
         });
 
@@ -792,15 +806,19 @@ mod tests {
         .unwrap();
 
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|_| {
+        let thread = cx.new(|cx| {
             Thread::new(
                 project,
                 Rc::default(),
+                context_server_registry,
                 action_log.clone(),
                 Templates::new(),
                 model.clone(),
+                cx,
             )
         });
 
@@ -914,15 +932,19 @@ mod tests {
         init_test(cx);
         let fs = project::FakeFs::new(cx.executor());
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|_| {
+        let thread = cx.new(|cx| {
             Thread::new(
                 project,
                 Rc::default(),
+                context_server_registry,
                 action_log.clone(),
                 Templates::new(),
                 model.clone(),
+                cx,
             )
         });
         let tool = Arc::new(EditFileTool { thread });
@@ -1041,15 +1063,19 @@ mod tests {
         let fs = project::FakeFs::new(cx.executor());
         fs.insert_tree("/project", json!({})).await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|_| {
+        let thread = cx.new(|cx| {
             Thread::new(
                 project,
                 Rc::default(),
+                context_server_registry,
                 action_log.clone(),
                 Templates::new(),
                 model.clone(),
+                cx,
             )
         });
         let tool = Arc::new(EditFileTool { thread });
@@ -1148,14 +1174,18 @@ mod tests {
         .await;
 
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|_| {
+        let thread = cx.new(|cx| {
             Thread::new(
                 project.clone(),
                 Rc::default(),
+                context_server_registry.clone(),
                 action_log.clone(),
                 Templates::new(),
                 model.clone(),
+                cx,
             )
         });
         let tool = Arc::new(EditFileTool { thread });
@@ -1225,14 +1255,18 @@ mod tests {
         .await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|_| {
+        let thread = cx.new(|cx| {
             Thread::new(
                 project.clone(),
                 Rc::default(),
+                context_server_registry.clone(),
                 action_log.clone(),
                 Templates::new(),
                 model.clone(),
+                cx,
             )
         });
         let tool = Arc::new(EditFileTool { thread });
@@ -1305,14 +1339,18 @@ mod tests {
         .await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|_| {
+        let thread = cx.new(|cx| {
             Thread::new(
                 project.clone(),
                 Rc::default(),
+                context_server_registry.clone(),
                 action_log.clone(),
                 Templates::new(),
                 model.clone(),
+                cx,
             )
         });
         let tool = Arc::new(EditFileTool { thread });
@@ -1382,14 +1420,18 @@ mod tests {
         let fs = project::FakeFs::new(cx.executor());
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|_| {
+        let thread = cx.new(|cx| {
             Thread::new(
                 project.clone(),
                 Rc::default(),
+                context_server_registry,
                 action_log.clone(),
                 Templates::new(),
                 model.clone(),
+                cx,
             )
         });
         let tool = Arc::new(EditFileTool { thread });

crates/agent2/src/tools/fetch_tool.rs 🔗

@@ -136,7 +136,7 @@ impl AgentTool for FetchTool {
     fn run(
         self: Arc<Self>,
         input: Self::Input,
-        event_stream: ToolCallEventStream,
+        _event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<Self::Output>> {
         let text = cx.background_spawn({
@@ -149,12 +149,6 @@ impl AgentTool for FetchTool {
             if text.trim().is_empty() {
                 bail!("no textual content found");
             }
-
-            event_stream.update_fields(acp::ToolCallUpdateFields {
-                content: Some(vec![text.clone().into()]),
-                ..Default::default()
-            });
-
             Ok(text)
         })
     }

crates/agent2/src/tools/grep_tool.rs 🔗

@@ -101,7 +101,7 @@ impl AgentTool for GrepTool {
     fn run(
         self: Arc<Self>,
         input: Self::Input,
-        event_stream: ToolCallEventStream,
+        _event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<Self::Output>> {
         const CONTEXT_LINES: u32 = 2;
@@ -282,33 +282,22 @@ impl AgentTool for GrepTool {
                         }
                     }
 
-                    event_stream.update_fields(acp::ToolCallUpdateFields {
-                        content: Some(vec![output.clone().into()]),
-                        ..Default::default()
-                    });
                     matches_found += 1;
                 }
             }
 
-            let output = if matches_found == 0 {
-                "No matches found".to_string()
+            if matches_found == 0 {
+                Ok("No matches found".into())
             } else if has_more_matches {
-                format!(
+                Ok(format!(
                     "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
                     input.offset + 1,
                     input.offset + matches_found,
                     input.offset + RESULTS_PER_PAGE,
-                )
+                ))
             } else {
-                format!("Found {matches_found} matches:\n{output}")
-            };
-
-            event_stream.update_fields(acp::ToolCallUpdateFields {
-                content: Some(vec![output.clone().into()]),
-                ..Default::default()
-            });
-
-            Ok(output)
+                Ok(format!("Found {matches_found} matches:\n{output}"))
+            }
         })
     }
 }

crates/agent2/src/tools/now_tool.rs 🔗

@@ -47,20 +47,13 @@ impl AgentTool for NowTool {
     fn run(
         self: Arc<Self>,
         input: Self::Input,
-        event_stream: ToolCallEventStream,
+        _event_stream: ToolCallEventStream,
         _cx: &mut App,
     ) -> Task<Result<String>> {
         let now = match input.timezone {
             Timezone::Utc => Utc::now().to_rfc3339(),
             Timezone::Local => Local::now().to_rfc3339(),
         };
-        let content = format!("The current datetime is {now}.");
-
-        event_stream.update_fields(acp::ToolCallUpdateFields {
-            content: Some(vec![content.clone().into()]),
-            ..Default::default()
-        });
-
-        Task::ready(Ok(content))
+        Task::ready(Ok(format!("The current datetime is {now}.")))
     }
 }

crates/agent_settings/src/agent_profile.rs 🔗

@@ -48,6 +48,20 @@ pub struct AgentProfileSettings {
     pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
 }
 
+impl AgentProfileSettings {
+    pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
+        self.tools.get(tool_name) == Some(&true)
+    }
+
+    pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool {
+        self.enable_all_context_servers
+            || self
+                .context_servers
+                .get(server_id)
+                .map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
+    }
+}
+
 #[derive(Debug, Clone, Default)]
 pub struct ContextServerPreset {
     pub tools: IndexMap<Arc<str>, bool>,