1use anyhow::{anyhow, bail};
2use assistant_tool::Tool;
3use context_servers::manager::ContextServerManager;
4use context_servers::types;
5use gpui::Task;
6
7pub struct ContextServerTool {
8 server_id: String,
9 tool: types::Tool,
10}
11
12impl ContextServerTool {
13 pub fn new(server_id: impl Into<String>, tool: types::Tool) -> Self {
14 Self {
15 server_id: server_id.into(),
16 tool,
17 }
18 }
19}
20
21impl Tool for ContextServerTool {
22 fn name(&self) -> String {
23 self.tool.name.clone()
24 }
25
26 fn description(&self) -> String {
27 self.tool.description.clone().unwrap_or_default()
28 }
29
30 fn input_schema(&self) -> serde_json::Value {
31 match &self.tool.input_schema {
32 serde_json::Value::Null => {
33 serde_json::json!({ "type": "object", "properties": [] })
34 }
35 serde_json::Value::Object(map) if map.is_empty() => {
36 serde_json::json!({ "type": "object", "properties": [] })
37 }
38 _ => self.tool.input_schema.clone(),
39 }
40 }
41
42 fn run(
43 self: std::sync::Arc<Self>,
44 input: serde_json::Value,
45 _workspace: gpui::WeakView<workspace::Workspace>,
46 cx: &mut ui::WindowContext,
47 ) -> gpui::Task<gpui::Result<String>> {
48 let manager = ContextServerManager::global(cx);
49 let manager = manager.read(cx);
50 if let Some(server) = manager.get_server(&self.server_id) {
51 cx.foreground_executor().spawn({
52 let tool_name = self.tool.name.clone();
53 async move {
54 let Some(protocol) = server.client.read().clone() else {
55 bail!("Context server not initialized");
56 };
57
58 let arguments = if let serde_json::Value::Object(map) = input {
59 Some(map.into_iter().collect())
60 } else {
61 None
62 };
63
64 log::trace!(
65 "Running tool: {} with arguments: {:?}",
66 tool_name,
67 arguments
68 );
69 let response = protocol.run_tool(tool_name, arguments).await?;
70
71 let tool_result = match response.tool_result {
72 serde_json::Value::String(s) => s,
73 _ => serde_json::to_string(&response.tool_result)?,
74 };
75 Ok(tool_result)
76 }
77 })
78 } else {
79 Task::ready(Err(anyhow!("Context server not found")))
80 }
81 }
82}