1use std::sync::Arc;
2
3use anyhow::{Result, anyhow, bail};
4use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
5use context_server::{ContextServerId, types};
6use gpui::{AnyWindowHandle, App, Entity, Task};
7use icons::IconName;
8use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
9use project::{Project, context_server_store::ContextServerStore};
10
11pub struct ContextServerTool {
12 store: Entity<ContextServerStore>,
13 server_id: ContextServerId,
14 tool: types::Tool,
15}
16
17impl ContextServerTool {
18 pub fn new(
19 store: Entity<ContextServerStore>,
20 server_id: ContextServerId,
21 tool: types::Tool,
22 ) -> Self {
23 Self {
24 store,
25 server_id,
26 tool,
27 }
28 }
29}
30
31impl Tool for ContextServerTool {
32 type Input = serde_json::Value;
33
34 fn name(&self) -> String {
35 self.tool.name.clone()
36 }
37
38 fn description(&self) -> String {
39 self.tool.description.clone().unwrap_or_default()
40 }
41
42 fn icon(&self) -> IconName {
43 IconName::Cog
44 }
45
46 fn source(&self) -> ToolSource {
47 ToolSource::ContextServer {
48 id: self.server_id.clone().0.into(),
49 }
50 }
51
52 fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
53 true
54 }
55
56 fn may_perform_edits(&self) -> bool {
57 true
58 }
59
60 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
61 let mut schema = self.tool.input_schema.clone();
62 assistant_tool::adapt_schema_to_format(&mut schema, format)?;
63 Ok(match schema {
64 serde_json::Value::Null => {
65 serde_json::json!({ "type": "object", "properties": [] })
66 }
67 serde_json::Value::Object(map) if map.is_empty() => {
68 serde_json::json!({ "type": "object", "properties": [] })
69 }
70 _ => schema,
71 })
72 }
73
74 fn ui_text(&self, _input: &Self::Input) -> String {
75 format!("Run MCP tool `{}`", self.tool.name)
76 }
77
78 fn run(
79 self: Arc<Self>,
80 input: Self::Input,
81 _request: Arc<LanguageModelRequest>,
82 _project: Entity<Project>,
83 _action_log: Entity<ActionLog>,
84 _model: Arc<dyn LanguageModel>,
85 _window: Option<AnyWindowHandle>,
86 cx: &mut App,
87 ) -> ToolResult {
88 if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
89 let tool_name = self.tool.name.clone();
90 let server_clone = server.clone();
91 let input_clone = input.clone();
92
93 cx.spawn(async move |_cx| {
94 let Some(protocol) = server_clone.client() else {
95 bail!("Context server not initialized");
96 };
97
98 let arguments = if let serde_json::Value::Object(map) = input_clone {
99 Some(map.into_iter().collect())
100 } else {
101 None
102 };
103
104 log::trace!(
105 "Running tool: {} with arguments: {:?}",
106 tool_name,
107 arguments
108 );
109 let response = protocol
110 .request::<context_server::types::requests::CallTool>(
111 context_server::types::CallToolParams {
112 name: tool_name,
113 arguments,
114 meta: None,
115 },
116 )
117 .await?;
118
119 let mut result = String::new();
120 for content in response.content {
121 match content {
122 types::ToolResponseContent::Text { text } => {
123 result.push_str(&text);
124 }
125 types::ToolResponseContent::Image { .. } => {
126 log::warn!("Ignoring image content from tool response");
127 }
128 types::ToolResponseContent::Audio { .. } => {
129 log::warn!("Ignoring audio content from tool response");
130 }
131 types::ToolResponseContent::Resource { .. } => {
132 log::warn!("Ignoring resource content from tool response");
133 }
134 }
135 }
136 Ok(result.into())
137 })
138 .into()
139 } else {
140 Task::ready(Err(anyhow!("Context server not found"))).into()
141 }
142 }
143}