1use std::sync::Arc;
2
3use anyhow::{anyhow, bail, Result};
4use assistant_tool::{ActionLog, Tool, ToolSource};
5use gpui::{App, Entity, Task};
6use language_model::LanguageModelRequestMessage;
7use project::Project;
8
9use crate::manager::ContextServerManager;
10use crate::types;
11
12pub struct ContextServerTool {
13 server_manager: Entity<ContextServerManager>,
14 server_id: Arc<str>,
15 tool: types::Tool,
16}
17
18impl ContextServerTool {
19 pub fn new(
20 server_manager: Entity<ContextServerManager>,
21 server_id: impl Into<Arc<str>>,
22 tool: types::Tool,
23 ) -> Self {
24 Self {
25 server_manager,
26 server_id: server_id.into(),
27 tool,
28 }
29 }
30}
31
32impl Tool for ContextServerTool {
33 fn name(&self) -> String {
34 self.tool.name.clone()
35 }
36
37 fn description(&self) -> String {
38 self.tool.description.clone().unwrap_or_default()
39 }
40
41 fn source(&self) -> ToolSource {
42 ToolSource::ContextServer {
43 id: self.server_id.clone().into(),
44 }
45 }
46
47 fn needs_confirmation(&self) -> bool {
48 true
49 }
50
51 fn input_schema(&self) -> serde_json::Value {
52 match &self.tool.input_schema {
53 serde_json::Value::Null => {
54 serde_json::json!({ "type": "object", "properties": [] })
55 }
56 serde_json::Value::Object(map) if map.is_empty() => {
57 serde_json::json!({ "type": "object", "properties": [] })
58 }
59 _ => self.tool.input_schema.clone(),
60 }
61 }
62
63 fn ui_text(&self, _input: &serde_json::Value) -> String {
64 format!("Run MCP tool `{}`", self.tool.name)
65 }
66
67 fn run(
68 self: Arc<Self>,
69 input: serde_json::Value,
70 _messages: &[LanguageModelRequestMessage],
71 _project: Entity<Project>,
72 _action_log: Entity<ActionLog>,
73 cx: &mut App,
74 ) -> Task<Result<String>> {
75 if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
76 let tool_name = self.tool.name.clone();
77 let server_clone = server.clone();
78 let input_clone = input.clone();
79
80 cx.spawn(async move |_cx| {
81 let Some(protocol) = server_clone.client() else {
82 bail!("Context server not initialized");
83 };
84
85 let arguments = if let serde_json::Value::Object(map) = input_clone {
86 Some(map.into_iter().collect())
87 } else {
88 None
89 };
90
91 log::trace!(
92 "Running tool: {} with arguments: {:?}",
93 tool_name,
94 arguments
95 );
96 let response = protocol.run_tool(tool_name, arguments).await?;
97
98 let mut result = String::new();
99 for content in response.content {
100 match content {
101 types::ToolResponseContent::Text { text } => {
102 result.push_str(&text);
103 }
104 types::ToolResponseContent::Image { .. } => {
105 log::warn!("Ignoring image content from tool response");
106 }
107 types::ToolResponseContent::Resource { .. } => {
108 log::warn!("Ignoring resource content from tool response");
109 }
110 }
111 }
112 Ok(result)
113 })
114 } else {
115 Task::ready(Err(anyhow!("Context server not found")))
116 }
117 }
118}