context_server_tool.rs

  1use std::sync::Arc;
  2
  3use anyhow::{anyhow, bail};
  4use assistant_tool::Tool;
  5use gpui::{App, Entity, Task, Window};
  6
  7use crate::manager::ContextServerManager;
  8use crate::types;
  9
 10pub struct ContextServerTool {
 11    server_manager: Entity<ContextServerManager>,
 12    server_id: Arc<str>,
 13    tool: types::Tool,
 14}
 15
 16impl ContextServerTool {
 17    pub fn new(
 18        server_manager: Entity<ContextServerManager>,
 19        server_id: impl Into<Arc<str>>,
 20        tool: types::Tool,
 21    ) -> Self {
 22        Self {
 23            server_manager,
 24            server_id: server_id.into(),
 25            tool,
 26        }
 27    }
 28}
 29
 30impl Tool for ContextServerTool {
 31    fn name(&self) -> String {
 32        self.tool.name.clone()
 33    }
 34
 35    fn description(&self) -> String {
 36        self.tool.description.clone().unwrap_or_default()
 37    }
 38
 39    fn input_schema(&self) -> serde_json::Value {
 40        match &self.tool.input_schema {
 41            serde_json::Value::Null => {
 42                serde_json::json!({ "type": "object", "properties": [] })
 43            }
 44            serde_json::Value::Object(map) if map.is_empty() => {
 45                serde_json::json!({ "type": "object", "properties": [] })
 46            }
 47            _ => self.tool.input_schema.clone(),
 48        }
 49    }
 50
 51    fn run(
 52        self: std::sync::Arc<Self>,
 53        input: serde_json::Value,
 54        _workspace: gpui::WeakEntity<workspace::Workspace>,
 55        _: &mut Window,
 56        cx: &mut App,
 57    ) -> gpui::Task<gpui::Result<String>> {
 58        if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
 59            cx.foreground_executor().spawn({
 60                let tool_name = self.tool.name.clone();
 61                async move {
 62                    let Some(protocol) = server.client() else {
 63                        bail!("Context server not initialized");
 64                    };
 65
 66                    let arguments = if let serde_json::Value::Object(map) = input {
 67                        Some(map.into_iter().collect())
 68                    } else {
 69                        None
 70                    };
 71
 72                    log::trace!(
 73                        "Running tool: {} with arguments: {:?}",
 74                        tool_name,
 75                        arguments
 76                    );
 77                    let response = protocol.run_tool(tool_name, arguments).await?;
 78
 79                    let mut result = String::new();
 80                    for content in response.content {
 81                        match content {
 82                            types::ToolResponseContent::Text { text } => {
 83                                result.push_str(&text);
 84                            }
 85                            types::ToolResponseContent::Image { .. } => {
 86                                log::warn!("Ignoring image content from tool response");
 87                            }
 88                            types::ToolResponseContent::Resource { .. } => {
 89                                log::warn!("Ignoring resource content from tool response");
 90                            }
 91                        }
 92                    }
 93                    Ok(result)
 94                }
 95            })
 96        } else {
 97            Task::ready(Err(anyhow!("Context server not found")))
 98        }
 99    }
100}