context_server_tool.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Result, anyhow, bail};
  4use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
  5use context_server::{ContextServerId, types};
  6use gpui::{AnyWindowHandle, App, Entity, Task};
  7use icons::IconName;
  8use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  9use project::{Project, context_server_store::ContextServerStore};
 10
 11pub struct ContextServerTool {
 12    store: Entity<ContextServerStore>,
 13    server_id: ContextServerId,
 14    tool: types::Tool,
 15}
 16
 17impl ContextServerTool {
 18    pub fn new(
 19        store: Entity<ContextServerStore>,
 20        server_id: ContextServerId,
 21        tool: types::Tool,
 22    ) -> Self {
 23        Self {
 24            store,
 25            server_id,
 26            tool,
 27        }
 28    }
 29}
 30
 31impl Tool for ContextServerTool {
 32    type Input = serde_json::Value;
 33
 34    fn name(&self) -> String {
 35        self.tool.name.clone()
 36    }
 37
 38    fn description(&self) -> String {
 39        self.tool.description.clone().unwrap_or_default()
 40    }
 41
 42    fn icon(&self) -> IconName {
 43        IconName::Cog
 44    }
 45
 46    fn source(&self) -> ToolSource {
 47        ToolSource::ContextServer {
 48            id: self.server_id.clone().0.into(),
 49        }
 50    }
 51
 52    fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
 53        true
 54    }
 55
 56    fn may_perform_edits(&self) -> bool {
 57        true
 58    }
 59
 60    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 61        let mut schema = self.tool.input_schema.clone();
 62        assistant_tool::adapt_schema_to_format(&mut schema, format)?;
 63        Ok(match schema {
 64            serde_json::Value::Null => {
 65                serde_json::json!({ "type": "object", "properties": [] })
 66            }
 67            serde_json::Value::Object(map) if map.is_empty() => {
 68                serde_json::json!({ "type": "object", "properties": [] })
 69            }
 70            _ => schema,
 71        })
 72    }
 73
 74    fn ui_text(&self, _input: &Self::Input) -> String {
 75        format!("Run MCP tool `{}`", self.tool.name)
 76    }
 77
 78    fn run(
 79        self: Arc<Self>,
 80        input: Self::Input,
 81        _request: Arc<LanguageModelRequest>,
 82        _project: Entity<Project>,
 83        _action_log: Entity<ActionLog>,
 84        _model: Arc<dyn LanguageModel>,
 85        _window: Option<AnyWindowHandle>,
 86        cx: &mut App,
 87    ) -> ToolResult {
 88        if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) {
 89            let tool_name = self.tool.name.clone();
 90            let server_clone = server.clone();
 91            let input_clone = input.clone();
 92
 93            cx.spawn(async move |_cx| {
 94                let Some(protocol) = server_clone.client() else {
 95                    bail!("Context server not initialized");
 96                };
 97
 98                let arguments = if let serde_json::Value::Object(map) = input_clone {
 99                    Some(map.into_iter().collect())
100                } else {
101                    None
102                };
103
104                log::trace!(
105                    "Running tool: {} with arguments: {:?}",
106                    tool_name,
107                    arguments
108                );
109                let response = protocol
110                    .request::<context_server::types::requests::CallTool>(
111                        context_server::types::CallToolParams {
112                            name: tool_name,
113                            arguments,
114                            meta: None,
115                        },
116                    )
117                    .await?;
118
119                let mut result = String::new();
120                for content in response.content {
121                    match content {
122                        types::ToolResponseContent::Text { text } => {
123                            result.push_str(&text);
124                        }
125                        types::ToolResponseContent::Image { .. } => {
126                            log::warn!("Ignoring image content from tool response");
127                        }
128                        types::ToolResponseContent::Audio { .. } => {
129                            log::warn!("Ignoring audio content from tool response");
130                        }
131                        types::ToolResponseContent::Resource { .. } => {
132                            log::warn!("Ignoring resource content from tool response");
133                        }
134                    }
135                }
136                Ok(result.into())
137            })
138            .into()
139        } else {
140            Task::ready(Err(anyhow!("Context server not found"))).into()
141        }
142    }
143}