context_server_tool.rs

  1use std::sync::Arc;
  2
  3use anyhow::{anyhow, bail, Result};
  4use assistant_tool::{Tool, ToolSource};
  5use gpui::{App, Entity, Task};
  6use project::Project;
  7
  8use crate::manager::ContextServerManager;
  9use crate::types;
 10
 11pub struct ContextServerTool {
 12    server_manager: Entity<ContextServerManager>,
 13    server_id: Arc<str>,
 14    tool: types::Tool,
 15}
 16
 17impl ContextServerTool {
 18    pub fn new(
 19        server_manager: Entity<ContextServerManager>,
 20        server_id: impl Into<Arc<str>>,
 21        tool: types::Tool,
 22    ) -> Self {
 23        Self {
 24            server_manager,
 25            server_id: server_id.into(),
 26            tool,
 27        }
 28    }
 29}
 30
 31impl Tool for ContextServerTool {
 32    fn name(&self) -> String {
 33        self.tool.name.clone()
 34    }
 35
 36    fn description(&self) -> String {
 37        self.tool.description.clone().unwrap_or_default()
 38    }
 39
 40    fn source(&self) -> ToolSource {
 41        ToolSource::ContextServer {
 42            id: self.server_id.clone().into(),
 43        }
 44    }
 45
 46    fn input_schema(&self) -> serde_json::Value {
 47        match &self.tool.input_schema {
 48            serde_json::Value::Null => {
 49                serde_json::json!({ "type": "object", "properties": [] })
 50            }
 51            serde_json::Value::Object(map) if map.is_empty() => {
 52                serde_json::json!({ "type": "object", "properties": [] })
 53            }
 54            _ => self.tool.input_schema.clone(),
 55        }
 56    }
 57
 58    fn run(
 59        self: Arc<Self>,
 60        input: serde_json::Value,
 61        _project: Entity<Project>,
 62        cx: &mut App,
 63    ) -> Task<Result<String>> {
 64        if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
 65            cx.foreground_executor().spawn({
 66                let tool_name = self.tool.name.clone();
 67                async move {
 68                    let Some(protocol) = server.client() else {
 69                        bail!("Context server not initialized");
 70                    };
 71
 72                    let arguments = if let serde_json::Value::Object(map) = input {
 73                        Some(map.into_iter().collect())
 74                    } else {
 75                        None
 76                    };
 77
 78                    log::trace!(
 79                        "Running tool: {} with arguments: {:?}",
 80                        tool_name,
 81                        arguments
 82                    );
 83                    let response = protocol.run_tool(tool_name, arguments).await?;
 84
 85                    let mut result = String::new();
 86                    for content in response.content {
 87                        match content {
 88                            types::ToolResponseContent::Text { text } => {
 89                                result.push_str(&text);
 90                            }
 91                            types::ToolResponseContent::Image { .. } => {
 92                                log::warn!("Ignoring image content from tool response");
 93                            }
 94                            types::ToolResponseContent::Resource { .. } => {
 95                                log::warn!("Ignoring resource content from tool response");
 96                            }
 97                        }
 98                    }
 99                    Ok(result)
100                }
101            })
102        } else {
103            Task::ready(Err(anyhow!("Context server not found")))
104        }
105    }
106}