context_server_tool.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Result, anyhow, bail};
  4use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
  5use context_server::{ContextServerId, types};
  6use gpui::{AnyWindowHandle, App, Entity, Task};
  7use icons::IconName;
  8use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  9use project::{Project, context_server_store::ContextServerStore};
 10
 11pub struct ContextServerTool {
 12    store: Entity<ContextServerStore>,
 13    server_id: ContextServerId,
 14    tool: types::Tool,
 15}
 16
 17impl ContextServerTool {
 18    pub fn new(
 19        store: Entity<ContextServerStore>,
 20        server_id: ContextServerId,
 21        tool: types::Tool,
 22    ) -> Self {
 23        Self {
 24            store,
 25            server_id,
 26            tool,
 27        }
 28    }
 29}
 30
 31impl Tool for ContextServerTool {
 32    fn name(&self) -> String {
 33        self.tool.name.clone()
 34    }
 35
 36    fn description(&self) -> String {
 37        self.tool.description.clone().unwrap_or_default()
 38    }
 39
 40    fn icon(&self) -> IconName {
 41        IconName::ToolHammer
 42    }
 43
 44    fn source(&self) -> ToolSource {
 45        ToolSource::ContextServer {
 46            id: self.server_id.clone().0.into(),
 47        }
 48    }
 49
 50    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 51        true
 52    }
 53
 54    fn may_perform_edits(&self) -> bool {
 55        true
 56    }
 57
 58    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 59        let mut schema = self.tool.input_schema.clone();
 60        assistant_tool::adapt_schema_to_format(&mut schema, format)?;
 61        Ok(match schema {
 62            serde_json::Value::Null => {
 63                serde_json::json!({ "type": "object", "properties": [] })
 64            }
 65            serde_json::Value::Object(map) if map.is_empty() => {
 66                serde_json::json!({ "type": "object", "properties": [] })
 67            }
 68            _ => schema,
 69        })
 70    }
 71
 72    fn ui_text(&self, _input: &serde_json::Value) -> String {
 73        format!("Run MCP tool `{}`", self.tool.name)
 74    }
 75
 76    fn run(
 77        self: Arc<Self>,
 78        input: serde_json::Value,
 79        _request: Arc<LanguageModelRequest>,
 80        _project: Entity<Project>,
 81        _action_log: Entity<ActionLog>,
 82        _model: Arc<dyn LanguageModel>,
 83        _window: Option<AnyWindowHandle>,
 84        cx: &mut App,
 85    ) -> ToolResult {
 86        if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
 87            let tool_name = self.tool.name.clone();
 88            let server_clone = server.clone();
 89            let input_clone = input.clone();
 90
 91            cx.spawn(async move |_cx| {
 92                let Some(protocol) = server_clone.client() else {
 93                    bail!("Context server not initialized");
 94                };
 95
 96                let arguments = if let serde_json::Value::Object(map) = input_clone {
 97                    Some(map.into_iter().collect())
 98                } else {
 99                    None
100                };
101
102                log::trace!(
103                    "Running tool: {} with arguments: {:?}",
104                    tool_name,
105                    arguments
106                );
107                let response = protocol
108                    .request::<context_server::types::requests::CallTool>(
109                        context_server::types::CallToolParams {
110                            name: tool_name,
111                            arguments,
112                            meta: None,
113                        },
114                    )
115                    .await?;
116
117                let mut result = String::new();
118                for content in response.content {
119                    match content {
120                        types::ToolResponseContent::Text { text } => {
121                            result.push_str(&text);
122                        }
123                        types::ToolResponseContent::Image { .. } => {
124                            log::warn!("Ignoring image content from tool response");
125                        }
126                        types::ToolResponseContent::Audio { .. } => {
127                            log::warn!("Ignoring audio content from tool response");
128                        }
129                        types::ToolResponseContent::Resource { .. } => {
130                            log::warn!("Ignoring resource content from tool response");
131                        }
132                    }
133                }
134                Ok(result.into())
135            })
136            .into()
137        } else {
138            Task::ready(Err(anyhow!("Context server not found"))).into()
139        }
140    }
141}