@@ -4,26 +4,35 @@ use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
+use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use fs::{FakeFs, Fs};
-use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
+use futures::{
+ StreamExt,
+ channel::{
+ mpsc::{self, UnboundedReceiver},
+ oneshot,
+ },
+};
use gpui::{
App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
};
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
- LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
- fake_provider::FakeLanguageModel,
+ LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
+ LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
};
use pretty_assertions::assert_eq;
-use project::Project;
+use project::{
+ Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
+};
use prompt_store::ProjectContext;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
-use settings::SettingsStore;
+use settings::{Settings, SettingsStore};
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
@@ -931,6 +940,334 @@ async fn test_profiles(cx: &mut TestAppContext) {
assert_eq!(tool_names, vec![InfiniteTool::name()]);
}
+#[gpui::test]
+async fn test_mcp_tools(cx: &mut TestAppContext) {
+ let ThreadTest {
+ model,
+ thread,
+ context_server_store,
+ fs,
+ ..
+ } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ // Override profiles and wait for settings to be loaded.
+ fs.insert_file(
+ paths::settings_file(),
+ json!({
+ "agent": {
+ "profiles": {
+ "test": {
+ "name": "Test Profile",
+ "enable_all_context_servers": true,
+ "tools": {
+ EchoTool::name(): true,
+ }
+ },
+ }
+ }
+ })
+ .to_string()
+ .into_bytes(),
+ )
+ .await;
+ cx.run_until_parked();
+ thread.update(cx, |thread, _| {
+ thread.set_profile(AgentProfileId("test".into()))
+ });
+
+ let mut mcp_tool_calls = setup_context_server(
+ "test_server",
+ vec![context_server::types::Tool {
+ name: "echo".into(),
+ description: None,
+ input_schema: serde_json::to_value(
+ EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
+ )
+ .unwrap(),
+ output_schema: None,
+ annotations: None,
+ }],
+ &context_server_store,
+ cx,
+ );
+
+ let events = thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
+ });
+ cx.run_until_parked();
+
+ // Simulate the model calling the MCP tool.
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "tool_1".into(),
+ name: "echo".into(),
+ raw_input: json!({"text": "test"}).to_string(),
+ input: json!({"text": "test"}),
+ is_input_complete: true,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
+ assert_eq!(tool_call_params.name, "echo");
+ assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
+ tool_call_response
+ .send(context_server::types::CallToolResponse {
+ content: vec![context_server::types::ToolResponseContent::Text {
+ text: "test".into(),
+ }],
+ is_error: None,
+ meta: None,
+ structured_content: None,
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
+ fake_model.send_last_completion_stream_text_chunk("Done!");
+ fake_model.end_last_completion_stream();
+ events.collect::<Vec<_>>().await;
+
+ // Send again after adding the echo tool, ensuring the name collision is resolved.
+ let events = thread.update(cx, |thread, cx| {
+ thread.add_tool(EchoTool);
+ thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
+ });
+ cx.run_until_parked();
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(
+ tool_names_for_completion(&completion),
+ vec!["echo", "test_server_echo"]
+ );
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "tool_2".into(),
+ name: "test_server_echo".into(),
+ raw_input: json!({"text": "mcp"}).to_string(),
+ input: json!({"text": "mcp"}),
+ is_input_complete: true,
+ },
+ ));
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "tool_3".into(),
+ name: "echo".into(),
+ raw_input: json!({"text": "native"}).to_string(),
+ input: json!({"text": "native"}),
+ is_input_complete: true,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
+ assert_eq!(tool_call_params.name, "echo");
+ assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
+ tool_call_response
+ .send(context_server::types::CallToolResponse {
+ content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
+ is_error: None,
+ meta: None,
+ structured_content: None,
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // Ensure the tool results were inserted with the correct names.
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(
+ completion.messages.last().unwrap().content,
+ vec![
+ MessageContent::ToolResult(LanguageModelToolResult {
+ tool_use_id: "tool_3".into(),
+ tool_name: "echo".into(),
+ is_error: false,
+ content: "native".into(),
+ output: Some("native".into()),
+ },),
+ MessageContent::ToolResult(LanguageModelToolResult {
+ tool_use_id: "tool_2".into(),
+ tool_name: "test_server_echo".into(),
+ is_error: false,
+ content: "mcp".into(),
+ output: Some("mcp".into()),
+ },),
+ ]
+ );
+ fake_model.end_last_completion_stream();
+ events.collect::<Vec<_>>().await;
+}
+
+#[gpui::test]
+async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
+ let ThreadTest {
+ model,
+ thread,
+ context_server_store,
+ fs,
+ ..
+ } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ // Set up a profile with all tools enabled
+ fs.insert_file(
+ paths::settings_file(),
+ json!({
+ "agent": {
+ "profiles": {
+ "test": {
+ "name": "Test Profile",
+ "enable_all_context_servers": true,
+ "tools": {
+ EchoTool::name(): true,
+ DelayTool::name(): true,
+ WordListTool::name(): true,
+ ToolRequiringPermission::name(): true,
+ InfiniteTool::name(): true,
+ }
+ },
+ }
+ }
+ })
+ .to_string()
+ .into_bytes(),
+ )
+ .await;
+ cx.run_until_parked();
+
+ thread.update(cx, |thread, _| {
+ thread.set_profile(AgentProfileId("test".into()));
+ thread.add_tool(EchoTool);
+ thread.add_tool(DelayTool);
+ thread.add_tool(WordListTool);
+ thread.add_tool(ToolRequiringPermission);
+ thread.add_tool(InfiniteTool);
+ });
+
+ // Set up multiple context servers with some overlapping tool names
+ let _server1_calls = setup_context_server(
+ "xxx",
+ vec![
+ context_server::types::Tool {
+ name: "echo".into(), // Conflicts with native EchoTool
+ description: None,
+ input_schema: serde_json::to_value(
+ EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
+ )
+ .unwrap(),
+ output_schema: None,
+ annotations: None,
+ },
+ context_server::types::Tool {
+ name: "unique_tool_1".into(),
+ description: None,
+ input_schema: json!({"type": "object", "properties": {}}),
+ output_schema: None,
+ annotations: None,
+ },
+ ],
+ &context_server_store,
+ cx,
+ );
+
+ let _server2_calls = setup_context_server(
+ "yyy",
+ vec![
+ context_server::types::Tool {
+ name: "echo".into(), // Also conflicts with native EchoTool
+ description: None,
+ input_schema: serde_json::to_value(
+ EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
+ )
+ .unwrap(),
+ output_schema: None,
+ annotations: None,
+ },
+ context_server::types::Tool {
+ name: "unique_tool_2".into(),
+ description: None,
+ input_schema: json!({"type": "object", "properties": {}}),
+ output_schema: None,
+ annotations: None,
+ },
+ context_server::types::Tool {
+ name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
+ description: None,
+ input_schema: json!({"type": "object", "properties": {}}),
+ output_schema: None,
+ annotations: None,
+ },
+ context_server::types::Tool {
+ name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
+ description: None,
+ input_schema: json!({"type": "object", "properties": {}}),
+ output_schema: None,
+ annotations: None,
+ },
+ ],
+ &context_server_store,
+ cx,
+ );
+ let _server3_calls = setup_context_server(
+ "zzz",
+ vec![
+ context_server::types::Tool {
+ name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
+ description: None,
+ input_schema: json!({"type": "object", "properties": {}}),
+ output_schema: None,
+ annotations: None,
+ },
+ context_server::types::Tool {
+ name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
+ description: None,
+ input_schema: json!({"type": "object", "properties": {}}),
+ output_schema: None,
+ annotations: None,
+ },
+ context_server::types::Tool {
+ name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
+ description: None,
+ input_schema: json!({"type": "object", "properties": {}}),
+ output_schema: None,
+ annotations: None,
+ },
+ ],
+ &context_server_store,
+ cx,
+ );
+
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Go"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(
+ tool_names_for_completion(&completion),
+ vec![
+ "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
+ "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
+ "delay",
+ "echo",
+ "infinite",
+ "tool_requiring_permission",
+ "unique_tool_1",
+ "unique_tool_2",
+ "word_list",
+ "xxx_echo",
+ "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+ "yyy_echo",
+ "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+ ]
+ );
+}
+
#[gpui::test]
#[cfg_attr(not(feature = "e2e"), ignore)]
async fn test_cancellation(cx: &mut TestAppContext) {
@@ -1806,6 +2143,7 @@ struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Entity<ProjectContext>,
+ context_server_store: Entity<ContextServerStore>,
fs: Arc<FakeFs>,
}
@@ -1844,6 +2182,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
WordListTool::name(): true,
ToolRequiringPermission::name(): true,
InfiniteTool::name(): true,
+ ThinkingTool::name(): true,
}
}
}
@@ -1900,8 +2239,9 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
.await;
let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
let thread = cx.new(|cx| {
Thread::new(
project,
@@ -1916,6 +2256,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
model,
thread,
project_context,
+ context_server_store,
fs,
}
}
@@ -1950,3 +2291,89 @@ fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
})
.detach();
}
+
+fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
+ completion
+ .tools
+ .iter()
+ .map(|tool| tool.name.clone())
+ .collect()
+}
+
+fn setup_context_server(
+ name: &'static str,
+ tools: Vec<context_server::types::Tool>,
+ context_server_store: &Entity<ContextServerStore>,
+ cx: &mut TestAppContext,
+) -> mpsc::UnboundedReceiver<(
+ context_server::types::CallToolParams,
+ oneshot::Sender<context_server::types::CallToolResponse>,
+)> {
+ cx.update(|cx| {
+ let mut settings = ProjectSettings::get_global(cx).clone();
+ settings.context_servers.insert(
+ name.into(),
+ project::project_settings::ContextServerSettings::Custom {
+ enabled: true,
+ command: ContextServerCommand {
+ path: "somebinary".into(),
+ args: Vec::new(),
+ env: None,
+ },
+ },
+ );
+ ProjectSettings::override_global(settings, cx);
+ });
+
+ let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
+ let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
+ .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
+ context_server::types::InitializeResponse {
+ protocol_version: context_server::types::ProtocolVersion(
+ context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
+ ),
+ server_info: context_server::types::Implementation {
+ name: name.into(),
+ version: "1.0.0".to_string(),
+ },
+ capabilities: context_server::types::ServerCapabilities {
+ tools: Some(context_server::types::ToolsCapabilities {
+ list_changed: Some(true),
+ }),
+ ..Default::default()
+ },
+ meta: None,
+ }
+ })
+ .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
+ let tools = tools.clone();
+ async move {
+ context_server::types::ListToolsResponse {
+ tools,
+ next_cursor: None,
+ meta: None,
+ }
+ }
+ })
+ .on_request::<context_server::types::requests::CallTool, _>(move |params| {
+ let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
+ async move {
+ let (response_tx, response_rx) = oneshot::channel();
+ mcp_tool_calls_tx
+ .unbounded_send((params, response_tx))
+ .unwrap();
+ response_rx.await.unwrap()
+ }
+ });
+ context_server_store.update(cx, |store, cx| {
+ store.start_server(
+ Arc::new(ContextServer::new(
+ ContextServerId(name.into()),
+ Arc::new(fake_transport),
+ )),
+ cx,
+ );
+ });
+ cx.run_until_parked();
+ mcp_tool_calls_rx
+}
@@ -9,15 +9,15 @@ use action_log::ActionLog;
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
use agent_client_protocol as acp;
use agent_settings::{
- AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
- SUMMARIZE_THREAD_PROMPT,
+ AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
+ SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
-use collections::{HashMap, IndexMap};
+use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
use futures::{
FutureExt,
@@ -56,6 +56,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
+pub const MAX_TOOL_NAME_LENGTH: usize = 64;
/// The ID of the user prompt that initiated a request.
///
@@ -627,7 +628,20 @@ impl Thread {
stream: &ThreadEventStream,
cx: &mut Context<Self>,
) {
- let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
+ let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
+ self.context_server_registry
+ .read(cx)
+ .servers()
+ .find_map(|(_, tools)| {
+ if let Some(tool) = tools.get(tool_use.name.as_ref()) {
+ Some(tool.clone())
+ } else {
+ None
+ }
+ })
+ });
+
+ let Some(tool) = tool else {
stream
.0
.unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
@@ -1079,6 +1093,10 @@ impl Thread {
self.cancel(cx);
let model = self.model.clone().context("No language model configured")?;
+ let profile = AgentSettings::get_global(cx)
+ .profiles
+ .get(&self.profile_id)
+ .context("Profile not found")?;
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
let event_stream = ThreadEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
@@ -1086,6 +1104,7 @@ impl Thread {
self.summary = None;
self.running_turn = Some(RunningTurn {
event_stream: event_stream.clone(),
+ tools: self.enabled_tools(profile, &model, cx),
_task: cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
@@ -1417,7 +1436,7 @@ impl Thread {
) -> Option<Task<LanguageModelToolResult>> {
cx.notify();
- let tool = self.tools.get(tool_use.name.as_ref()).cloned();
+ let tool = self.tool(tool_use.name.as_ref());
let mut title = SharedString::from(&tool_use.name);
let mut kind = acp::ToolKind::Other;
if let Some(tool) = tool.as_ref() {
@@ -1727,30 +1746,28 @@ impl Thread {
cx: &mut App,
) -> Result<LanguageModelRequest> {
let model = self.model().context("No language model configured")?;
-
- log::debug!("Building completion request");
- log::debug!("Completion intent: {:?}", completion_intent);
- log::debug!("Completion mode: {:?}", self.completion_mode);
-
- let messages = self.build_request_messages(cx);
- log::info!("Request will include {} messages", messages.len());
-
- let tools = if let Some(tools) = self.tools(cx).log_err() {
- tools
- .filter_map(|tool| {
- let tool_name = tool.name().to_string();
+ let tools = if let Some(turn) = self.running_turn.as_ref() {
+ turn.tools
+ .iter()
+ .filter_map(|(tool_name, tool)| {
log::trace!("Including tool: {}", tool_name);
Some(LanguageModelRequestTool {
- name: tool_name,
+ name: tool_name.to_string(),
description: tool.description().to_string(),
input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
})
})
- .collect()
+ .collect::<Vec<_>>()
} else {
Vec::new()
};
+ log::debug!("Building completion request");
+ log::debug!("Completion intent: {:?}", completion_intent);
+ log::debug!("Completion mode: {:?}", self.completion_mode);
+
+ let messages = self.build_request_messages(cx);
+ log::info!("Request will include {} messages", messages.len());
log::info!("Request includes {} tools", tools.len());
let request = LanguageModelRequest {
@@ -1770,37 +1787,76 @@ impl Thread {
Ok(request)
}
- fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
- let model = self.model().context("No language model configured")?;
-
- let profile = AgentSettings::get_global(cx)
- .profiles
- .get(&self.profile_id)
- .context("profile not found")?;
- let provider_id = model.provider_id();
+ fn enabled_tools(
+ &self,
+ profile: &AgentProfileSettings,
+ model: &Arc<dyn LanguageModel>,
+ cx: &App,
+ ) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
+ fn truncate(tool_name: &SharedString) -> SharedString {
+ if tool_name.len() > MAX_TOOL_NAME_LENGTH {
+ let mut truncated = tool_name.to_string();
+ truncated.truncate(MAX_TOOL_NAME_LENGTH);
+ truncated.into()
+ } else {
+ tool_name.clone()
+ }
+ }
- Ok(self
+ let mut tools = self
.tools
.iter()
- .filter(move |(_, tool)| tool.supported_provider(&provider_id))
.filter_map(|(tool_name, tool)| {
- if profile.is_tool_enabled(tool_name) {
- Some(tool)
+ if tool.supported_provider(&model.provider_id())
+ && profile.is_tool_enabled(tool_name)
+ {
+ Some((truncate(tool_name), tool.clone()))
} else {
None
}
})
- .chain(self.context_server_registry.read(cx).servers().flat_map(
- |(server_id, tools)| {
- tools.iter().filter_map(|(tool_name, tool)| {
- if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
- Some(tool)
- } else {
- None
- }
- })
- },
- )))
+ .collect::<BTreeMap<_, _>>();
+
+ let mut context_server_tools = Vec::new();
+ let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
+ let mut duplicate_tool_names = HashSet::default();
+ for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
+ for (tool_name, tool) in server_tools {
+ if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
+ let tool_name = truncate(tool_name);
+ if !seen_tools.insert(tool_name.clone()) {
+ duplicate_tool_names.insert(tool_name.clone());
+ }
+ context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
+ }
+ }
+ }
+
+ // When there are duplicate tool names, disambiguate by prefixing them
+ // with the server ID. In the rare case there isn't enough space for the
+ // disambiguated tool name, keep only the last tool with this name.
+ for (server_id, tool_name, tool) in context_server_tools {
+ if duplicate_tool_names.contains(&tool_name) {
+ let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
+ if available >= 2 {
+ let mut disambiguated = server_id.0.to_string();
+ disambiguated.truncate(available - 1);
+ disambiguated.push('_');
+ disambiguated.push_str(&tool_name);
+ tools.insert(disambiguated.into(), tool.clone());
+ } else {
+ tools.insert(tool_name, tool.clone());
+ }
+ } else {
+ tools.insert(tool_name, tool.clone());
+ }
+ }
+
+ tools
+ }
+
+ fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
+ self.running_turn.as_ref()?.tools.get(name).cloned()
}
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
@@ -1965,6 +2021,8 @@ struct RunningTurn {
/// The current event stream for the running turn. Used to report a final
/// cancellation event if we cancel the turn.
event_stream: ThreadEventStream,
+ /// The tools that were enabled for this turn.
+ tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
}
impl RunningTurn {