context_server_tool.rs

 1use anyhow::{anyhow, bail};
 2use assistant_tool::Tool;
 3use context_servers::manager::ContextServerManager;
 4use context_servers::types;
 5use gpui::{Model, Task};
 6
 7pub struct ContextServerTool {
 8    server_manager: Model<ContextServerManager>,
 9    server_id: String,
10    tool: types::Tool,
11}
12
13impl ContextServerTool {
14    pub fn new(
15        server_manager: Model<ContextServerManager>,
16        server_id: impl Into<String>,
17        tool: types::Tool,
18    ) -> Self {
19        Self {
20            server_manager,
21            server_id: server_id.into(),
22            tool,
23        }
24    }
25}
26
27impl Tool for ContextServerTool {
28    fn name(&self) -> String {
29        self.tool.name.clone()
30    }
31
32    fn description(&self) -> String {
33        self.tool.description.clone().unwrap_or_default()
34    }
35
36    fn input_schema(&self) -> serde_json::Value {
37        match &self.tool.input_schema {
38            serde_json::Value::Null => {
39                serde_json::json!({ "type": "object", "properties": [] })
40            }
41            serde_json::Value::Object(map) if map.is_empty() => {
42                serde_json::json!({ "type": "object", "properties": [] })
43            }
44            _ => self.tool.input_schema.clone(),
45        }
46    }
47
48    fn run(
49        self: std::sync::Arc<Self>,
50        input: serde_json::Value,
51        _workspace: gpui::WeakView<workspace::Workspace>,
52        cx: &mut ui::WindowContext,
53    ) -> gpui::Task<gpui::Result<String>> {
54        if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
55            cx.foreground_executor().spawn({
56                let tool_name = self.tool.name.clone();
57                async move {
58                    let Some(protocol) = server.client.read().clone() else {
59                        bail!("Context server not initialized");
60                    };
61
62                    let arguments = if let serde_json::Value::Object(map) = input {
63                        Some(map.into_iter().collect())
64                    } else {
65                        None
66                    };
67
68                    log::trace!(
69                        "Running tool: {} with arguments: {:?}",
70                        tool_name,
71                        arguments
72                    );
73                    let response = protocol.run_tool(tool_name, arguments).await?;
74
75                    let tool_result = match response.tool_result {
76                        serde_json::Value::String(s) => s,
77                        _ => serde_json::to_string(&response.tool_result)?,
78                    };
79                    Ok(tool_result)
80                }
81            })
82        } else {
83            Task::ready(Err(anyhow!("Context server not found")))
84        }
85    }
86}