context_server.rs

  1pub mod client;
  2pub mod listener;
  3pub mod protocol;
  4#[cfg(any(test, feature = "test-support"))]
  5pub mod test;
  6pub mod transport;
  7pub mod types;
  8
  9use std::path::Path;
 10use std::sync::Arc;
 11use std::{fmt::Display, path::PathBuf};
 12
 13use anyhow::Result;
 14use client::Client;
 15use collections::HashMap;
 16use gpui::AsyncApp;
 17use parking_lot::RwLock;
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use util::redact::should_redact;
 21
 22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 23pub struct ContextServerId(pub Arc<str>);
 24
 25impl Display for ContextServerId {
 26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 27        write!(f, "{}", self.0)
 28    }
 29}
 30
 31#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
 32pub struct ContextServerCommand {
 33    #[serde(rename = "command")]
 34    pub path: PathBuf,
 35    pub args: Vec<String>,
 36    pub env: Option<HashMap<String, String>>,
 37    /// Timeout for tool calls in milliseconds. Defaults to 60000 (60 seconds) if not specified.
 38    pub timeout: Option<u64>,
 39}
 40
 41impl std::fmt::Debug for ContextServerCommand {
 42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 43        let filtered_env = self.env.as_ref().map(|env| {
 44            env.iter()
 45                .map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v }))
 46                .collect::<Vec<_>>()
 47        });
 48
 49        f.debug_struct("ContextServerCommand")
 50            .field("path", &self.path)
 51            .field("args", &self.args)
 52            .field("env", &filtered_env)
 53            .finish()
 54    }
 55}
 56
 57enum ContextServerTransport {
 58    Stdio(ContextServerCommand, Option<PathBuf>),
 59    Custom(Arc<dyn crate::transport::Transport>),
 60}
 61
 62pub struct ContextServer {
 63    id: ContextServerId,
 64    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 65    configuration: ContextServerTransport,
 66}
 67
 68impl ContextServer {
 69    pub fn stdio(
 70        id: ContextServerId,
 71        command: ContextServerCommand,
 72        working_directory: Option<Arc<Path>>,
 73    ) -> Self {
 74        Self {
 75            id,
 76            client: RwLock::new(None),
 77            configuration: ContextServerTransport::Stdio(
 78                command,
 79                working_directory.map(|directory| directory.to_path_buf()),
 80            ),
 81        }
 82    }
 83
 84    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 85        Self {
 86            id,
 87            client: RwLock::new(None),
 88            configuration: ContextServerTransport::Custom(transport),
 89        }
 90    }
 91
 92    pub fn id(&self) -> ContextServerId {
 93        self.id.clone()
 94    }
 95
 96    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 97        self.client.read().clone()
 98    }
 99
100    pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
101        self.initialize(self.new_client(cx)?).await
102    }
103
104    /// Starts the context server, making sure handlers are registered before initialization happens
105    pub async fn start_with_handlers(
106        &self,
107        notification_handlers: Vec<(
108            &'static str,
109            Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
110        )>,
111        cx: &AsyncApp,
112    ) -> Result<()> {
113        let client = self.new_client(cx)?;
114        for (method, handler) in notification_handlers {
115            client.on_notification(method, handler);
116        }
117        self.initialize(client).await
118    }
119
120    fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
121        Ok(match &self.configuration {
122            ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
123                client::ContextServerId(self.id.0.clone()),
124                client::ModelContextServerBinary {
125                    executable: Path::new(&command.path).to_path_buf(),
126                    args: command.args.clone(),
127                    env: command.env.clone(),
128                    timeout: command.timeout,
129                },
130                working_directory,
131                cx.clone(),
132            )?,
133            ContextServerTransport::Custom(transport) => Client::new(
134                client::ContextServerId(self.id.0.clone()),
135                self.id().0,
136                transport.clone(),
137                None,
138                cx.clone(),
139            )?,
140        })
141    }
142
143    async fn initialize(&self, client: Client) -> Result<()> {
144        log::debug!("starting context server {}", self.id);
145        let protocol = crate::protocol::ModelContextProtocol::new(client);
146        let client_info = types::Implementation {
147            name: "Zed".to_string(),
148            version: env!("CARGO_PKG_VERSION").to_string(),
149        };
150        let initialized_protocol = protocol.initialize(client_info).await?;
151
152        log::debug!(
153            "context server {} initialized: {:?}",
154            self.id,
155            initialized_protocol.initialize,
156        );
157
158        *self.client.write() = Some(Arc::new(initialized_protocol));
159        Ok(())
160    }
161
162    pub fn stop(&self) -> Result<()> {
163        let mut client = self.client.write();
164        if let Some(protocol) = client.take() {
165            drop(protocol);
166        }
167        Ok(())
168    }
169}