1use std::sync::Arc;
2
3use action_log::ActionLog;
4use anyhow::{Result, anyhow, bail};
5use assistant_tool::{Tool, ToolResult, ToolSource};
6use context_server::{ContextServerId, types};
7use gpui::{AnyWindowHandle, App, Entity, Task};
8use icons::IconName;
9use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
10use project::{Project, context_server_store::ContextServerStore};
11
12pub struct ContextServerTool {
13 store: Entity<ContextServerStore>,
14 server_id: ContextServerId,
15 tool: types::Tool,
16}
17
18impl ContextServerTool {
19 pub fn new(
20 store: Entity<ContextServerStore>,
21 server_id: ContextServerId,
22 tool: types::Tool,
23 ) -> Self {
24 Self {
25 store,
26 server_id,
27 tool,
28 }
29 }
30}
31
32impl Tool for ContextServerTool {
33 fn name(&self) -> String {
34 self.tool.name.clone()
35 }
36
37 fn description(&self) -> String {
38 self.tool.description.clone().unwrap_or_default()
39 }
40
41 fn icon(&self) -> IconName {
42 IconName::ToolHammer
43 }
44
45 fn source(&self) -> ToolSource {
46 ToolSource::ContextServer {
47 id: self.server_id.clone().0.into(),
48 }
49 }
50
51 fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
52 true
53 }
54
55 fn may_perform_edits(&self) -> bool {
56 true
57 }
58
59 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
60 let mut schema = self.tool.input_schema.clone();
61 assistant_tool::adapt_schema_to_format(&mut schema, format)?;
62 Ok(match schema {
63 serde_json::Value::Null => {
64 serde_json::json!({ "type": "object", "properties": [] })
65 }
66 serde_json::Value::Object(map) if map.is_empty() => {
67 serde_json::json!({ "type": "object", "properties": [] })
68 }
69 _ => schema,
70 })
71 }
72
73 fn ui_text(&self, _input: &serde_json::Value) -> String {
74 format!("Run MCP tool `{}`", self.tool.name)
75 }
76
77 fn run(
78 self: Arc<Self>,
79 input: serde_json::Value,
80 _request: Arc<LanguageModelRequest>,
81 _project: Entity<Project>,
82 _action_log: Entity<ActionLog>,
83 _model: Arc<dyn LanguageModel>,
84 _window: Option<AnyWindowHandle>,
85 cx: &mut App,
86 ) -> ToolResult {
87 if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
88 let tool_name = self.tool.name.clone();
89
90 cx.spawn(async move |_cx| {
91 let Some(protocol) = server.client() else {
92 bail!("Context server not initialized");
93 };
94
95 let arguments = if let serde_json::Value::Object(map) = input {
96 Some(map.into_iter().collect())
97 } else {
98 None
99 };
100
101 log::trace!(
102 "Running tool: {} with arguments: {:?}",
103 tool_name,
104 arguments
105 );
106 let response = protocol
107 .request::<context_server::types::requests::CallTool>(
108 context_server::types::CallToolParams {
109 name: tool_name,
110 arguments,
111 meta: None,
112 },
113 )
114 .await?;
115
116 let mut result = String::new();
117 for content in response.content {
118 match content {
119 types::ToolResponseContent::Text { text } => {
120 result.push_str(&text);
121 }
122 types::ToolResponseContent::Image { .. } => {
123 log::warn!("Ignoring image content from tool response");
124 }
125 types::ToolResponseContent::Audio { .. } => {
126 log::warn!("Ignoring audio content from tool response");
127 }
128 types::ToolResponseContent::Resource { .. } => {
129 log::warn!("Ignoring resource content from tool response");
130 }
131 }
132 }
133 Ok(result.into())
134 })
135 .into()
136 } else {
137 Task::ready(Err(anyhow!("Context server not found"))).into()
138 }
139 }
140}