context_server_tool.rs

  1use std::sync::Arc;
  2
  3use action_log::ActionLog;
  4use anyhow::{Result, anyhow, bail};
  5use assistant_tool::{Tool, ToolResult, ToolSource};
  6use context_server::{ContextServerId, types};
  7use gpui::{AnyWindowHandle, App, Entity, Task};
  8use icons::IconName;
  9use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 10use project::{Project, context_server_store::ContextServerStore};
 11
 12pub struct ContextServerTool {
 13    store: Entity<ContextServerStore>,
 14    server_id: ContextServerId,
 15    tool: types::Tool,
 16}
 17
 18impl ContextServerTool {
 19    pub fn new(
 20        store: Entity<ContextServerStore>,
 21        server_id: ContextServerId,
 22        tool: types::Tool,
 23    ) -> Self {
 24        Self {
 25            store,
 26            server_id,
 27            tool,
 28        }
 29    }
 30}
 31
 32impl Tool for ContextServerTool {
 33    fn name(&self) -> String {
 34        self.tool.name.clone()
 35    }
 36
 37    fn description(&self) -> String {
 38        self.tool.description.clone().unwrap_or_default()
 39    }
 40
 41    fn icon(&self) -> IconName {
 42        IconName::ToolHammer
 43    }
 44
 45    fn source(&self) -> ToolSource {
 46        ToolSource::ContextServer {
 47            id: self.server_id.clone().0.into(),
 48        }
 49    }
 50
 51    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 52        true
 53    }
 54
 55    fn may_perform_edits(&self) -> bool {
 56        true
 57    }
 58
 59    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 60        let mut schema = self.tool.input_schema.clone();
 61        assistant_tool::adapt_schema_to_format(&mut schema, format)?;
 62        Ok(match schema {
 63            serde_json::Value::Null => {
 64                serde_json::json!({ "type": "object", "properties": [] })
 65            }
 66            serde_json::Value::Object(map) if map.is_empty() => {
 67                serde_json::json!({ "type": "object", "properties": [] })
 68            }
 69            _ => schema,
 70        })
 71    }
 72
 73    fn ui_text(&self, _input: &serde_json::Value) -> String {
 74        format!("Run MCP tool `{}`", self.tool.name)
 75    }
 76
 77    fn run(
 78        self: Arc<Self>,
 79        input: serde_json::Value,
 80        _request: Arc<LanguageModelRequest>,
 81        _project: Entity<Project>,
 82        _action_log: Entity<ActionLog>,
 83        _model: Arc<dyn LanguageModel>,
 84        _window: Option<AnyWindowHandle>,
 85        cx: &mut App,
 86    ) -> ToolResult {
 87        if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
 88            let tool_name = self.tool.name.clone();
 89
 90            cx.spawn(async move |_cx| {
 91                let Some(protocol) = server.client() else {
 92                    bail!("Context server not initialized");
 93                };
 94
 95                let arguments = if let serde_json::Value::Object(map) = input {
 96                    Some(map.into_iter().collect())
 97                } else {
 98                    None
 99                };
100
101                log::trace!(
102                    "Running tool: {} with arguments: {:?}",
103                    tool_name,
104                    arguments
105                );
106                let response = protocol
107                    .request::<context_server::types::requests::CallTool>(
108                        context_server::types::CallToolParams {
109                            name: tool_name,
110                            arguments,
111                            meta: None,
112                        },
113                    )
114                    .await?;
115
116                let mut result = String::new();
117                for content in response.content {
118                    match content {
119                        types::ToolResponseContent::Text { text } => {
120                            result.push_str(&text);
121                        }
122                        types::ToolResponseContent::Image { .. } => {
123                            log::warn!("Ignoring image content from tool response");
124                        }
125                        types::ToolResponseContent::Audio { .. } => {
126                            log::warn!("Ignoring audio content from tool response");
127                        }
128                        types::ToolResponseContent::Resource { .. } => {
129                            log::warn!("Ignoring resource content from tool response");
130                        }
131                    }
132                }
133                Ok(result.into())
134            })
135            .into()
136        } else {
137            Task::ready(Err(anyhow!("Context server not found"))).into()
138        }
139    }
140}