1use std::sync::Arc;
2
3use anyhow::{Result, anyhow, bail};
4use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
5use context_server::{ContextServerId, types};
6use gpui::{AnyWindowHandle, App, Entity, Task};
7use icons::IconName;
8use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
9use project::{Project, context_server_store::ContextServerStore};
10
11pub struct ContextServerTool {
12 store: Entity<ContextServerStore>,
13 server_id: ContextServerId,
14 tool: types::Tool,
15}
16
17impl ContextServerTool {
18 pub fn new(
19 store: Entity<ContextServerStore>,
20 server_id: ContextServerId,
21 tool: types::Tool,
22 ) -> Self {
23 Self {
24 store,
25 server_id,
26 tool,
27 }
28 }
29}
30
31impl Tool for ContextServerTool {
32 fn name(&self) -> String {
33 self.tool.name.clone()
34 }
35
36 fn description(&self) -> String {
37 self.tool.description.clone().unwrap_or_default()
38 }
39
40 fn icon(&self) -> IconName {
41 IconName::ToolHammer
42 }
43
44 fn source(&self) -> ToolSource {
45 ToolSource::ContextServer {
46 id: self.server_id.clone().0.into(),
47 }
48 }
49
50 fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
51 true
52 }
53
54 fn may_perform_edits(&self) -> bool {
55 true
56 }
57
58 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
59 let mut schema = self.tool.input_schema.clone();
60 assistant_tool::adapt_schema_to_format(&mut schema, format)?;
61 Ok(match schema {
62 serde_json::Value::Null => {
63 serde_json::json!({ "type": "object", "properties": [] })
64 }
65 serde_json::Value::Object(map) if map.is_empty() => {
66 serde_json::json!({ "type": "object", "properties": [] })
67 }
68 _ => schema,
69 })
70 }
71
72 fn ui_text(&self, _input: &serde_json::Value) -> String {
73 format!("Run MCP tool `{}`", self.tool.name)
74 }
75
76 fn run(
77 self: Arc<Self>,
78 input: serde_json::Value,
79 _request: Arc<LanguageModelRequest>,
80 _project: Entity<Project>,
81 _action_log: Entity<ActionLog>,
82 _model: Arc<dyn LanguageModel>,
83 _window: Option<AnyWindowHandle>,
84 cx: &mut App,
85 ) -> ToolResult {
86 if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
87 let tool_name = self.tool.name.clone();
88 let server_clone = server.clone();
89 let input_clone = input.clone();
90
91 cx.spawn(async move |_cx| {
92 let Some(protocol) = server_clone.client() else {
93 bail!("Context server not initialized");
94 };
95
96 let arguments = if let serde_json::Value::Object(map) = input_clone {
97 Some(map.into_iter().collect())
98 } else {
99 None
100 };
101
102 log::trace!(
103 "Running tool: {} with arguments: {:?}",
104 tool_name,
105 arguments
106 );
107 let response = protocol
108 .request::<context_server::types::requests::CallTool>(
109 context_server::types::CallToolParams {
110 name: tool_name,
111 arguments,
112 meta: None,
113 },
114 )
115 .await?;
116
117 let mut result = String::new();
118 for content in response.content {
119 match content {
120 types::ToolResponseContent::Text { text } => {
121 result.push_str(&text);
122 }
123 types::ToolResponseContent::Image { .. } => {
124 log::warn!("Ignoring image content from tool response");
125 }
126 types::ToolResponseContent::Audio { .. } => {
127 log::warn!("Ignoring audio content from tool response");
128 }
129 types::ToolResponseContent::Resource { .. } => {
130 log::warn!("Ignoring resource content from tool response");
131 }
132 }
133 }
134 Ok(result.into())
135 })
136 .into()
137 } else {
138 Task::ready(Err(anyhow!("Context server not found"))).into()
139 }
140 }
141}