context_server_tool.rs

  1use std::sync::Arc;
  2
  3use anyhow::{anyhow, bail, Result};
  4use assistant_tool::{ActionLog, Tool, ToolSource};
  5use gpui::{App, Entity, Task};
  6use language_model::LanguageModelRequestMessage;
  7use project::Project;
  8
  9use crate::manager::ContextServerManager;
 10use crate::types;
 11
 12pub struct ContextServerTool {
 13    server_manager: Entity<ContextServerManager>,
 14    server_id: Arc<str>,
 15    tool: types::Tool,
 16}
 17
 18impl ContextServerTool {
 19    pub fn new(
 20        server_manager: Entity<ContextServerManager>,
 21        server_id: impl Into<Arc<str>>,
 22        tool: types::Tool,
 23    ) -> Self {
 24        Self {
 25            server_manager,
 26            server_id: server_id.into(),
 27            tool,
 28        }
 29    }
 30}
 31
 32impl Tool for ContextServerTool {
 33    fn name(&self) -> String {
 34        self.tool.name.clone()
 35    }
 36
 37    fn description(&self) -> String {
 38        self.tool.description.clone().unwrap_or_default()
 39    }
 40
 41    fn source(&self) -> ToolSource {
 42        ToolSource::ContextServer {
 43            id: self.server_id.clone().into(),
 44        }
 45    }
 46
 47    fn needs_confirmation(&self) -> bool {
 48        true
 49    }
 50
 51    fn input_schema(&self) -> serde_json::Value {
 52        match &self.tool.input_schema {
 53            serde_json::Value::Null => {
 54                serde_json::json!({ "type": "object", "properties": [] })
 55            }
 56            serde_json::Value::Object(map) if map.is_empty() => {
 57                serde_json::json!({ "type": "object", "properties": [] })
 58            }
 59            _ => self.tool.input_schema.clone(),
 60        }
 61    }
 62
 63    fn ui_text(&self, _input: &serde_json::Value) -> String {
 64        format!("Run MCP tool `{}`", self.tool.name)
 65    }
 66
 67    fn run(
 68        self: Arc<Self>,
 69        input: serde_json::Value,
 70        _messages: &[LanguageModelRequestMessage],
 71        _project: Entity<Project>,
 72        _action_log: Entity<ActionLog>,
 73        cx: &mut App,
 74    ) -> Task<Result<String>> {
 75        if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
 76            let tool_name = self.tool.name.clone();
 77            let server_clone = server.clone();
 78            let input_clone = input.clone();
 79
 80            cx.spawn(async move |_cx| {
 81                let Some(protocol) = server_clone.client() else {
 82                    bail!("Context server not initialized");
 83                };
 84
 85                let arguments = if let serde_json::Value::Object(map) = input_clone {
 86                    Some(map.into_iter().collect())
 87                } else {
 88                    None
 89                };
 90
 91                log::trace!(
 92                    "Running tool: {} with arguments: {:?}",
 93                    tool_name,
 94                    arguments
 95                );
 96                let response = protocol.run_tool(tool_name, arguments).await?;
 97
 98                let mut result = String::new();
 99                for content in response.content {
100                    match content {
101                        types::ToolResponseContent::Text { text } => {
102                            result.push_str(&text);
103                        }
104                        types::ToolResponseContent::Image { .. } => {
105                            log::warn!("Ignoring image content from tool response");
106                        }
107                        types::ToolResponseContent::Resource { .. } => {
108                            log::warn!("Ignoring resource content from tool response");
109                        }
110                    }
111                }
112                Ok(result)
113            })
114        } else {
115            Task::ready(Err(anyhow!("Context server not found")))
116        }
117    }
118}