context_server_tool.rs

 1use anyhow::{anyhow, bail};
 2use assistant_tool::Tool;
 3use context_servers::manager::ContextServerManager;
 4use context_servers::types;
 5use gpui::Task;
 6
 7pub struct ContextServerTool {
 8    server_id: String,
 9    tool: types::Tool,
10}
11
12impl ContextServerTool {
13    pub fn new(server_id: impl Into<String>, tool: types::Tool) -> Self {
14        Self {
15            server_id: server_id.into(),
16            tool,
17        }
18    }
19}
20
21impl Tool for ContextServerTool {
22    fn name(&self) -> String {
23        self.tool.name.clone()
24    }
25
26    fn description(&self) -> String {
27        self.tool.description.clone().unwrap_or_default()
28    }
29
30    fn input_schema(&self) -> serde_json::Value {
31        match &self.tool.input_schema {
32            serde_json::Value::Null => {
33                serde_json::json!({ "type": "object", "properties": [] })
34            }
35            serde_json::Value::Object(map) if map.is_empty() => {
36                serde_json::json!({ "type": "object", "properties": [] })
37            }
38            _ => self.tool.input_schema.clone(),
39        }
40    }
41
42    fn run(
43        self: std::sync::Arc<Self>,
44        input: serde_json::Value,
45        _workspace: gpui::WeakView<workspace::Workspace>,
46        cx: &mut ui::WindowContext,
47    ) -> gpui::Task<gpui::Result<String>> {
48        let manager = ContextServerManager::global(cx);
49        let manager = manager.read(cx);
50        if let Some(server) = manager.get_server(&self.server_id) {
51            cx.foreground_executor().spawn({
52                let tool_name = self.tool.name.clone();
53                async move {
54                    let Some(protocol) = server.client.read().clone() else {
55                        bail!("Context server not initialized");
56                    };
57
58                    let arguments = if let serde_json::Value::Object(map) = input {
59                        Some(map.into_iter().collect())
60                    } else {
61                        None
62                    };
63
64                    log::trace!(
65                        "Running tool: {} with arguments: {:?}",
66                        tool_name,
67                        arguments
68                    );
69                    let response = protocol.run_tool(tool_name, arguments).await?;
70
71                    let tool_result = match response.tool_result {
72                        serde_json::Value::String(s) => s,
73                        _ => serde_json::to_string(&response.tool_result)?,
74                    };
75                    Ok(tool_result)
76                }
77            })
78        } else {
79            Task::ready(Err(anyhow!("Context server not found")))
80        }
81    }
82}