acp: Support calling tools provided by MCP servers (#36752)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/agent2/src/tests/mod.rs    | 441 ++++++++++++++++++++++++++++++++
crates/agent2/src/thread.rs       | 142 +++++++---
crates/context_server/src/test.rs |  36 +
3 files changed, 558 insertions(+), 61 deletions(-)

Detailed changes

crates/agent2/src/tests/mod.rs 🔗

@@ -4,26 +4,35 @@ use agent_client_protocol::{self as acp};
 use agent_settings::AgentProfileId;
 use anyhow::Result;
 use client::{Client, UserStore};
+use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use fs::{FakeFs, Fs};
-use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
+use futures::{
+    StreamExt,
+    channel::{
+        mpsc::{self, UnboundedReceiver},
+        oneshot,
+    },
+};
 use gpui::{
     App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
 };
 use indoc::indoc;
 use language_model::{
     LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
-    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
-    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
-    fake_provider::FakeLanguageModel,
+    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
+    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
+    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
 };
 use pretty_assertions::assert_eq;
-use project::Project;
+use project::{
+    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
+};
 use prompt_store::ProjectContext;
 use reqwest_client::ReqwestClient;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use serde_json::json;
-use settings::SettingsStore;
+use settings::{Settings, SettingsStore};
 use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
 use util::path;
 
@@ -931,6 +940,334 @@ async fn test_profiles(cx: &mut TestAppContext) {
     assert_eq!(tool_names, vec![InfiniteTool::name()]);
 }
 
+#[gpui::test]
+async fn test_mcp_tools(cx: &mut TestAppContext) {
+    let ThreadTest {
+        model,
+        thread,
+        context_server_store,
+        fs,
+        ..
+    } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    // Override profiles and wait for settings to be loaded.
+    fs.insert_file(
+        paths::settings_file(),
+        json!({
+            "agent": {
+                "profiles": {
+                    "test": {
+                        "name": "Test Profile",
+                        "enable_all_context_servers": true,
+                        "tools": {
+                            EchoTool::name(): true,
+                        }
+                    },
+                }
+            }
+        })
+        .to_string()
+        .into_bytes(),
+    )
+    .await;
+    cx.run_until_parked();
+    thread.update(cx, |thread, _| {
+        thread.set_profile(AgentProfileId("test".into()))
+    });
+
+    let mut mcp_tool_calls = setup_context_server(
+        "test_server",
+        vec![context_server::types::Tool {
+            name: "echo".into(),
+            description: None,
+            input_schema: serde_json::to_value(
+                EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
+            )
+            .unwrap(),
+            output_schema: None,
+            annotations: None,
+        }],
+        &context_server_store,
+        cx,
+    );
+
+    let events = thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
+    });
+    cx.run_until_parked();
+
+    // Simulate the model calling the MCP tool.
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        LanguageModelToolUse {
+            id: "tool_1".into(),
+            name: "echo".into(),
+            raw_input: json!({"text": "test"}).to_string(),
+            input: json!({"text": "test"}),
+            is_input_complete: true,
+        },
+    ));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
+    assert_eq!(tool_call_params.name, "echo");
+    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
+    tool_call_response
+        .send(context_server::types::CallToolResponse {
+            content: vec![context_server::types::ToolResponseContent::Text {
+                text: "test".into(),
+            }],
+            is_error: None,
+            meta: None,
+            structured_content: None,
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
+    fake_model.send_last_completion_stream_text_chunk("Done!");
+    fake_model.end_last_completion_stream();
+    events.collect::<Vec<_>>().await;
+
+    // Send again after adding the echo tool, ensuring the name collision is resolved.
+    let events = thread.update(cx, |thread, cx| {
+        thread.add_tool(EchoTool);
+        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
+    });
+    cx.run_until_parked();
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(
+        tool_names_for_completion(&completion),
+        vec!["echo", "test_server_echo"]
+    );
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        LanguageModelToolUse {
+            id: "tool_2".into(),
+            name: "test_server_echo".into(),
+            raw_input: json!({"text": "mcp"}).to_string(),
+            input: json!({"text": "mcp"}),
+            is_input_complete: true,
+        },
+    ));
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        LanguageModelToolUse {
+            id: "tool_3".into(),
+            name: "echo".into(),
+            raw_input: json!({"text": "native"}).to_string(),
+            input: json!({"text": "native"}),
+            is_input_complete: true,
+        },
+    ));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
+    assert_eq!(tool_call_params.name, "echo");
+    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
+    tool_call_response
+        .send(context_server::types::CallToolResponse {
+            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
+            is_error: None,
+            meta: None,
+            structured_content: None,
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    // Ensure the tool results were inserted with the correct names.
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(
+        completion.messages.last().unwrap().content,
+        vec![
+            MessageContent::ToolResult(LanguageModelToolResult {
+                tool_use_id: "tool_3".into(),
+                tool_name: "echo".into(),
+                is_error: false,
+                content: "native".into(),
+                output: Some("native".into()),
+            },),
+            MessageContent::ToolResult(LanguageModelToolResult {
+                tool_use_id: "tool_2".into(),
+                tool_name: "test_server_echo".into(),
+                is_error: false,
+                content: "mcp".into(),
+                output: Some("mcp".into()),
+            },),
+        ]
+    );
+    fake_model.end_last_completion_stream();
+    events.collect::<Vec<_>>().await;
+}
+
+#[gpui::test]
+async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
+    let ThreadTest {
+        model,
+        thread,
+        context_server_store,
+        fs,
+        ..
+    } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    // Set up a profile with all tools enabled
+    fs.insert_file(
+        paths::settings_file(),
+        json!({
+            "agent": {
+                "profiles": {
+                    "test": {
+                        "name": "Test Profile",
+                        "enable_all_context_servers": true,
+                        "tools": {
+                            EchoTool::name(): true,
+                            DelayTool::name(): true,
+                            WordListTool::name(): true,
+                            ToolRequiringPermission::name(): true,
+                            InfiniteTool::name(): true,
+                        }
+                    },
+                }
+            }
+        })
+        .to_string()
+        .into_bytes(),
+    )
+    .await;
+    cx.run_until_parked();
+
+    thread.update(cx, |thread, _| {
+        thread.set_profile(AgentProfileId("test".into()));
+        thread.add_tool(EchoTool);
+        thread.add_tool(DelayTool);
+        thread.add_tool(WordListTool);
+        thread.add_tool(ToolRequiringPermission);
+        thread.add_tool(InfiniteTool);
+    });
+
+    // Set up multiple context servers with some overlapping tool names
+    let _server1_calls = setup_context_server(
+        "xxx",
+        vec![
+            context_server::types::Tool {
+                name: "echo".into(), // Conflicts with native EchoTool
+                description: None,
+                input_schema: serde_json::to_value(
+                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
+                )
+                .unwrap(),
+                output_schema: None,
+                annotations: None,
+            },
+            context_server::types::Tool {
+                name: "unique_tool_1".into(),
+                description: None,
+                input_schema: json!({"type": "object", "properties": {}}),
+                output_schema: None,
+                annotations: None,
+            },
+        ],
+        &context_server_store,
+        cx,
+    );
+
+    let _server2_calls = setup_context_server(
+        "yyy",
+        vec![
+            context_server::types::Tool {
+                name: "echo".into(), // Also conflicts with native EchoTool
+                description: None,
+                input_schema: serde_json::to_value(
+                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
+                )
+                .unwrap(),
+                output_schema: None,
+                annotations: None,
+            },
+            context_server::types::Tool {
+                name: "unique_tool_2".into(),
+                description: None,
+                input_schema: json!({"type": "object", "properties": {}}),
+                output_schema: None,
+                annotations: None,
+            },
+            context_server::types::Tool {
+                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
+                description: None,
+                input_schema: json!({"type": "object", "properties": {}}),
+                output_schema: None,
+                annotations: None,
+            },
+            context_server::types::Tool {
+                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
+                description: None,
+                input_schema: json!({"type": "object", "properties": {}}),
+                output_schema: None,
+                annotations: None,
+            },
+        ],
+        &context_server_store,
+        cx,
+    );
+    let _server3_calls = setup_context_server(
+        "zzz",
+        vec![
+            context_server::types::Tool {
+                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
+                description: None,
+                input_schema: json!({"type": "object", "properties": {}}),
+                output_schema: None,
+                annotations: None,
+            },
+            context_server::types::Tool {
+                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
+                description: None,
+                input_schema: json!({"type": "object", "properties": {}}),
+                output_schema: None,
+                annotations: None,
+            },
+            context_server::types::Tool {
+                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
+                description: None,
+                input_schema: json!({"type": "object", "properties": {}}),
+                output_schema: None,
+                annotations: None,
+            },
+        ],
+        &context_server_store,
+        cx,
+    );
+
+    thread
+        .update(cx, |thread, cx| {
+            thread.send(UserMessageId::new(), ["Go"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+    let completion = fake_model.pending_completions().pop().unwrap();
+    assert_eq!(
+        tool_names_for_completion(&completion),
+        vec![
+            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
+            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
+            "delay",
+            "echo",
+            "infinite",
+            "tool_requiring_permission",
+            "unique_tool_1",
+            "unique_tool_2",
+            "word_list",
+            "xxx_echo",
+            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+            "yyy_echo",
+            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+        ]
+    );
+}
+
 #[gpui::test]
 #[cfg_attr(not(feature = "e2e"), ignore)]
 async fn test_cancellation(cx: &mut TestAppContext) {
@@ -1806,6 +2143,7 @@ struct ThreadTest {
     model: Arc<dyn LanguageModel>,
     thread: Entity<Thread>,
     project_context: Entity<ProjectContext>,
+    context_server_store: Entity<ContextServerStore>,
     fs: Arc<FakeFs>,
 }
 
@@ -1844,6 +2182,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
                             WordListTool::name(): true,
                             ToolRequiringPermission::name(): true,
                             InfiniteTool::name(): true,
+                            ThinkingTool::name(): true,
                         }
                     }
                 }
@@ -1900,8 +2239,9 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
         .await;
 
     let project_context = cx.new(|_cx| ProjectContext::default());
+    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
     let context_server_registry =
-        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
     let thread = cx.new(|cx| {
         Thread::new(
             project,
@@ -1916,6 +2256,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
         model,
         thread,
         project_context,
+        context_server_store,
         fs,
     }
 }
@@ -1950,3 +2291,89 @@ fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
     })
     .detach();
 }
+
+fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
+    completion
+        .tools
+        .iter()
+        .map(|tool| tool.name.clone())
+        .collect()
+}
+
+fn setup_context_server(
+    name: &'static str,
+    tools: Vec<context_server::types::Tool>,
+    context_server_store: &Entity<ContextServerStore>,
+    cx: &mut TestAppContext,
+) -> mpsc::UnboundedReceiver<(
+    context_server::types::CallToolParams,
+    oneshot::Sender<context_server::types::CallToolResponse>,
+)> {
+    cx.update(|cx| {
+        let mut settings = ProjectSettings::get_global(cx).clone();
+        settings.context_servers.insert(
+            name.into(),
+            project::project_settings::ContextServerSettings::Custom {
+                enabled: true,
+                command: ContextServerCommand {
+                    path: "somebinary".into(),
+                    args: Vec::new(),
+                    env: None,
+                },
+            },
+        );
+        ProjectSettings::override_global(settings, cx);
+    });
+
+    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
+    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
+        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
+            context_server::types::InitializeResponse {
+                protocol_version: context_server::types::ProtocolVersion(
+                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
+                ),
+                server_info: context_server::types::Implementation {
+                    name: name.into(),
+                    version: "1.0.0".to_string(),
+                },
+                capabilities: context_server::types::ServerCapabilities {
+                    tools: Some(context_server::types::ToolsCapabilities {
+                        list_changed: Some(true),
+                    }),
+                    ..Default::default()
+                },
+                meta: None,
+            }
+        })
+        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
+            let tools = tools.clone();
+            async move {
+                context_server::types::ListToolsResponse {
+                    tools,
+                    next_cursor: None,
+                    meta: None,
+                }
+            }
+        })
+        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
+            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
+            async move {
+                let (response_tx, response_rx) = oneshot::channel();
+                mcp_tool_calls_tx
+                    .unbounded_send((params, response_tx))
+                    .unwrap();
+                response_rx.await.unwrap()
+            }
+        });
+    context_server_store.update(cx, |store, cx| {
+        store.start_server(
+            Arc::new(ContextServer::new(
+                ContextServerId(name.into()),
+                Arc::new(fake_transport),
+            )),
+            cx,
+        );
+    });
+    cx.run_until_parked();
+    mcp_tool_calls_rx
+}

crates/agent2/src/thread.rs 🔗

@@ -9,15 +9,15 @@ use action_log::ActionLog;
 use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
 use agent_client_protocol as acp;
 use agent_settings::{
-    AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
-    SUMMARIZE_THREAD_PROMPT,
+    AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
+    SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
 };
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
 use chrono::{DateTime, Utc};
 use client::{ModelRequestUsage, RequestUsage};
 use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
-use collections::{HashMap, IndexMap};
+use collections::{HashMap, HashSet, IndexMap};
 use fs::Fs;
 use futures::{
     FutureExt,
@@ -56,6 +56,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
 use uuid::Uuid;
 
 const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
+pub const MAX_TOOL_NAME_LENGTH: usize = 64;
 
 /// The ID of the user prompt that initiated a request.
 ///
@@ -627,7 +628,20 @@ impl Thread {
         stream: &ThreadEventStream,
         cx: &mut Context<Self>,
     ) {
-        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
+        let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
+            self.context_server_registry
+                .read(cx)
+                .servers()
+                .find_map(|(_, tools)| {
+                    if let Some(tool) = tools.get(tool_use.name.as_ref()) {
+                        Some(tool.clone())
+                    } else {
+                        None
+                    }
+                })
+        });
+
+        let Some(tool) = tool else {
             stream
                 .0
                 .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
@@ -1079,6 +1093,10 @@ impl Thread {
         self.cancel(cx);
 
         let model = self.model.clone().context("No language model configured")?;
+        let profile = AgentSettings::get_global(cx)
+            .profiles
+            .get(&self.profile_id)
+            .context("Profile not found")?;
         let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
         let event_stream = ThreadEventStream(events_tx);
         let message_ix = self.messages.len().saturating_sub(1);
@@ -1086,6 +1104,7 @@ impl Thread {
         self.summary = None;
         self.running_turn = Some(RunningTurn {
             event_stream: event_stream.clone(),
+            tools: self.enabled_tools(profile, &model, cx),
             _task: cx.spawn(async move |this, cx| {
                 log::info!("Starting agent turn execution");
 
@@ -1417,7 +1436,7 @@ impl Thread {
     ) -> Option<Task<LanguageModelToolResult>> {
         cx.notify();
 
-        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
+        let tool = self.tool(tool_use.name.as_ref());
         let mut title = SharedString::from(&tool_use.name);
         let mut kind = acp::ToolKind::Other;
         if let Some(tool) = tool.as_ref() {
@@ -1727,30 +1746,28 @@ impl Thread {
         cx: &mut App,
     ) -> Result<LanguageModelRequest> {
         let model = self.model().context("No language model configured")?;
-
-        log::debug!("Building completion request");
-        log::debug!("Completion intent: {:?}", completion_intent);
-        log::debug!("Completion mode: {:?}", self.completion_mode);
-
-        let messages = self.build_request_messages(cx);
-        log::info!("Request will include {} messages", messages.len());
-
-        let tools = if let Some(tools) = self.tools(cx).log_err() {
-            tools
-                .filter_map(|tool| {
-                    let tool_name = tool.name().to_string();
+        let tools = if let Some(turn) = self.running_turn.as_ref() {
+            turn.tools
+                .iter()
+                .filter_map(|(tool_name, tool)| {
                     log::trace!("Including tool: {}", tool_name);
                     Some(LanguageModelRequestTool {
-                        name: tool_name,
+                        name: tool_name.to_string(),
                         description: tool.description().to_string(),
                         input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
                     })
                 })
-                .collect()
+                .collect::<Vec<_>>()
         } else {
             Vec::new()
         };
 
+        log::debug!("Building completion request");
+        log::debug!("Completion intent: {:?}", completion_intent);
+        log::debug!("Completion mode: {:?}", self.completion_mode);
+
+        let messages = self.build_request_messages(cx);
+        log::info!("Request will include {} messages", messages.len());
         log::info!("Request includes {} tools", tools.len());
 
         let request = LanguageModelRequest {
@@ -1770,37 +1787,76 @@ impl Thread {
         Ok(request)
     }
 
-    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
-        let model = self.model().context("No language model configured")?;
-
-        let profile = AgentSettings::get_global(cx)
-            .profiles
-            .get(&self.profile_id)
-            .context("profile not found")?;
-        let provider_id = model.provider_id();
+    fn enabled_tools(
+        &self,
+        profile: &AgentProfileSettings,
+        model: &Arc<dyn LanguageModel>,
+        cx: &App,
+    ) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
+        fn truncate(tool_name: &SharedString) -> SharedString {
+            if tool_name.len() > MAX_TOOL_NAME_LENGTH {
+                let mut truncated = tool_name.to_string();
+                truncated.truncate(MAX_TOOL_NAME_LENGTH);
+                truncated.into()
+            } else {
+                tool_name.clone()
+            }
+        }
 
-        Ok(self
+        let mut tools = self
             .tools
             .iter()
-            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
             .filter_map(|(tool_name, tool)| {
-                if profile.is_tool_enabled(tool_name) {
-                    Some(tool)
+                if tool.supported_provider(&model.provider_id())
+                    && profile.is_tool_enabled(tool_name)
+                {
+                    Some((truncate(tool_name), tool.clone()))
                 } else {
                     None
                 }
             })
-            .chain(self.context_server_registry.read(cx).servers().flat_map(
-                |(server_id, tools)| {
-                    tools.iter().filter_map(|(tool_name, tool)| {
-                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
-                            Some(tool)
-                        } else {
-                            None
-                        }
-                    })
-                },
-            )))
+            .collect::<BTreeMap<_, _>>();
+
+        let mut context_server_tools = Vec::new();
+        let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
+        let mut duplicate_tool_names = HashSet::default();
+        for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
+            for (tool_name, tool) in server_tools {
+                if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
+                    let tool_name = truncate(tool_name);
+                    if !seen_tools.insert(tool_name.clone()) {
+                        duplicate_tool_names.insert(tool_name.clone());
+                    }
+                    context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
+                }
+            }
+        }
+
+        // When there are duplicate tool names, disambiguate by prefixing them
+        // with the server ID. In the rare case there isn't enough space for the
+        // disambiguated tool name, keep only the last tool with this name.
+        for (server_id, tool_name, tool) in context_server_tools {
+            if duplicate_tool_names.contains(&tool_name) {
+                let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
+                if available >= 2 {
+                    let mut disambiguated = server_id.0.to_string();
+                    disambiguated.truncate(available - 1);
+                    disambiguated.push('_');
+                    disambiguated.push_str(&tool_name);
+                    tools.insert(disambiguated.into(), tool.clone());
+                } else {
+                    tools.insert(tool_name, tool.clone());
+                }
+            } else {
+                tools.insert(tool_name, tool.clone());
+            }
+        }
+
+        tools
+    }
+
+    fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
+        self.running_turn.as_ref()?.tools.get(name).cloned()
     }
 
     fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
@@ -1965,6 +2021,8 @@ struct RunningTurn {
     /// The current event stream for the running turn. Used to report a final
     /// cancellation event if we cancel the turn.
     event_stream: ThreadEventStream,
+    /// The tools that were enabled for this turn.
+    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 }
 
 impl RunningTurn {

crates/context_server/src/test.rs 🔗

@@ -1,6 +1,6 @@
 use anyhow::Context as _;
 use collections::HashMap;
-use futures::{Stream, StreamExt as _, lock::Mutex};
+use futures::{FutureExt, Stream, StreamExt as _, future::BoxFuture, lock::Mutex};
 use gpui::BackgroundExecutor;
 use std::{pin::Pin, sync::Arc};
 
@@ -14,9 +14,12 @@ pub fn create_fake_transport(
     executor: BackgroundExecutor,
 ) -> FakeTransport {
     let name = name.into();
-    FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
-        create_initialize_response(name.clone())
-    })
+    FakeTransport::new(executor).on_request::<crate::types::requests::Initialize, _>(
+        move |_params| {
+            let name = name.clone();
+            async move { create_initialize_response(name.clone()) }
+        },
+    )
 }
 
 fn create_initialize_response(server_name: String) -> InitializeResponse {
@@ -32,8 +35,10 @@ fn create_initialize_response(server_name: String) -> InitializeResponse {
 }
 
 pub struct FakeTransport {
-    request_handlers:
-        HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>,
+    request_handlers: HashMap<
+        &'static str,
+        Arc<dyn Send + Sync + Fn(serde_json::Value) -> BoxFuture<'static, serde_json::Value>>,
+    >,
     tx: futures::channel::mpsc::UnboundedSender<String>,
     rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
     executor: BackgroundExecutor,
@@ -50,18 +55,25 @@ impl FakeTransport {
         }
     }
 
-    pub fn on_request<T: crate::types::Request>(
+    pub fn on_request<T, Fut>(
         mut self,
-        handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static,
-    ) -> Self {
+        handler: impl 'static + Send + Sync + Fn(T::Params) -> Fut,
+    ) -> Self
+    where
+        T: crate::types::Request,
+        Fut: 'static + Send + Future<Output = T::Response>,
+    {
         self.request_handlers.insert(
             T::METHOD,
             Arc::new(move |value| {
-                let params = value.get("params").expect("Missing parameters").clone();
+                let params = value
+                    .get("params")
+                    .cloned()
+                    .unwrap_or(serde_json::Value::Null);
                 let params: T::Params =
                     serde_json::from_value(params).expect("Invalid parameters received");
                 let response = handler(params);
-                serde_json::to_value(response).unwrap()
+                async move { serde_json::to_value(response.await).unwrap() }.boxed()
             }),
         );
         self
@@ -77,7 +89,7 @@ impl Transport for FakeTransport {
             if let Some(method) = msg.get("method") {
                 let method = method.as_str().expect("Invalid method received");
                 if let Some(handler) = self.request_handlers.get(method) {
-                    let payload = handler(msg);
+                    let payload = handler(msg).await;
                     let response = serde_json::json!({
                         "jsonrpc": "2.0",
                         "id": id,