context_server_tool.rs

 1use std::sync::Arc;
 2
 3use anyhow::{anyhow, bail};
 4use assistant_tool::Tool;
 5use context_servers::manager::ContextServerManager;
 6use context_servers::types;
 7use gpui::{Model, Task};
 8
 9pub struct ContextServerTool {
10    server_manager: Model<ContextServerManager>,
11    server_id: Arc<str>,
12    tool: types::Tool,
13}
14
15impl ContextServerTool {
16    pub fn new(
17        server_manager: Model<ContextServerManager>,
18        server_id: impl Into<Arc<str>>,
19        tool: types::Tool,
20    ) -> Self {
21        Self {
22            server_manager,
23            server_id: server_id.into(),
24            tool,
25        }
26    }
27}
28
29impl Tool for ContextServerTool {
30    fn name(&self) -> String {
31        self.tool.name.clone()
32    }
33
34    fn description(&self) -> String {
35        self.tool.description.clone().unwrap_or_default()
36    }
37
38    fn input_schema(&self) -> serde_json::Value {
39        match &self.tool.input_schema {
40            serde_json::Value::Null => {
41                serde_json::json!({ "type": "object", "properties": [] })
42            }
43            serde_json::Value::Object(map) if map.is_empty() => {
44                serde_json::json!({ "type": "object", "properties": [] })
45            }
46            _ => self.tool.input_schema.clone(),
47        }
48    }
49
50    fn run(
51        self: std::sync::Arc<Self>,
52        input: serde_json::Value,
53        _workspace: gpui::WeakView<workspace::Workspace>,
54        cx: &mut ui::WindowContext,
55    ) -> gpui::Task<gpui::Result<String>> {
56        if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
57            cx.foreground_executor().spawn({
58                let tool_name = self.tool.name.clone();
59                async move {
60                    let Some(protocol) = server.client() else {
61                        bail!("Context server not initialized");
62                    };
63
64                    let arguments = if let serde_json::Value::Object(map) = input {
65                        Some(map.into_iter().collect())
66                    } else {
67                        None
68                    };
69
70                    log::trace!(
71                        "Running tool: {} with arguments: {:?}",
72                        tool_name,
73                        arguments
74                    );
75                    let response = protocol.run_tool(tool_name, arguments).await?;
76
77                    let mut result = String::new();
78                    for content in response.content {
79                        match content {
80                            types::ToolResponseContent::Text { text } => {
81                                result.push_str(&text);
82                            }
83                            types::ToolResponseContent::Image { .. } => {
84                                log::warn!("Ignoring image content from tool response");
85                            }
86                            types::ToolResponseContent::Resource { .. } => {
87                                log::warn!("Ignoring resource content from tool response");
88                            }
89                        }
90                    }
91                    Ok(result)
92                }
93            })
94        } else {
95            Task::ready(Err(anyhow!("Context server not found")))
96        }
97    }
98}