context_server.rs

  1pub mod client;
  2pub mod listener;
  3pub mod protocol;
  4#[cfg(any(test, feature = "test-support"))]
  5pub mod test;
  6pub mod transport;
  7pub mod types;
  8
  9use collections::HashMap;
 10use http_client::HttpClient;
 11use std::path::Path;
 12use std::sync::Arc;
 13use std::{fmt::Display, path::PathBuf};
 14
 15use anyhow::Result;
 16use client::Client;
 17use gpui::AsyncApp;
 18use parking_lot::RwLock;
 19pub use settings::ContextServerCommand;
 20use url::Url;
 21
 22use crate::transport::HttpTransport;
 23
 24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 25pub struct ContextServerId(pub Arc<str>);
 26
 27impl Display for ContextServerId {
 28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 29        write!(f, "{}", self.0)
 30    }
 31}
 32
 33enum ContextServerTransport {
 34    Stdio(ContextServerCommand, Option<PathBuf>),
 35    Custom(Arc<dyn crate::transport::Transport>),
 36}
 37
 38pub struct ContextServer {
 39    id: ContextServerId,
 40    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 41    configuration: ContextServerTransport,
 42}
 43
 44impl ContextServer {
 45    pub fn stdio(
 46        id: ContextServerId,
 47        command: ContextServerCommand,
 48        working_directory: Option<Arc<Path>>,
 49    ) -> Self {
 50        Self {
 51            id,
 52            client: RwLock::new(None),
 53            configuration: ContextServerTransport::Stdio(
 54                command,
 55                working_directory.map(|directory| directory.to_path_buf()),
 56            ),
 57        }
 58    }
 59
 60    pub fn http(
 61        id: ContextServerId,
 62        endpoint: &Url,
 63        headers: HashMap<String, String>,
 64        http_client: Arc<dyn HttpClient>,
 65        executor: gpui::BackgroundExecutor,
 66    ) -> Result<Self> {
 67        let transport = match endpoint.scheme() {
 68            "http" | "https" => {
 69                log::info!("Using HTTP transport for {}", endpoint);
 70                let transport =
 71                    HttpTransport::new(http_client, endpoint.to_string(), headers, executor);
 72                Arc::new(transport) as _
 73            }
 74            _ => anyhow::bail!("unsupported MCP url scheme {}", endpoint.scheme()),
 75        };
 76        Ok(Self::new(id, transport))
 77    }
 78
 79    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 80        Self {
 81            id,
 82            client: RwLock::new(None),
 83            configuration: ContextServerTransport::Custom(transport),
 84        }
 85    }
 86
 87    pub fn id(&self) -> ContextServerId {
 88        self.id.clone()
 89    }
 90
 91    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 92        self.client.read().clone()
 93    }
 94
 95    pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
 96        self.initialize(self.new_client(cx)?).await
 97    }
 98
 99    /// Starts the context server, making sure handlers are registered before initialization happens
100    pub async fn start_with_handlers(
101        &self,
102        notification_handlers: Vec<(
103            &'static str,
104            Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
105        )>,
106        cx: &AsyncApp,
107    ) -> Result<()> {
108        let client = self.new_client(cx)?;
109        for (method, handler) in notification_handlers {
110            client.on_notification(method, handler);
111        }
112        self.initialize(client).await
113    }
114
115    fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
116        Ok(match &self.configuration {
117            ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
118                client::ContextServerId(self.id.0.clone()),
119                client::ModelContextServerBinary {
120                    executable: Path::new(&command.path).to_path_buf(),
121                    args: command.args.clone(),
122                    env: command.env.clone(),
123                    timeout: command.timeout,
124                },
125                working_directory,
126                cx.clone(),
127            )?,
128            ContextServerTransport::Custom(transport) => Client::new(
129                client::ContextServerId(self.id.0.clone()),
130                self.id().0,
131                transport.clone(),
132                None,
133                cx.clone(),
134            )?,
135        })
136    }
137
138    async fn initialize(&self, client: Client) -> Result<()> {
139        log::debug!("starting context server {}", self.id);
140        let protocol = crate::protocol::ModelContextProtocol::new(client);
141        let client_info = types::Implementation {
142            name: "Zed".to_string(),
143            version: env!("CARGO_PKG_VERSION").to_string(),
144        };
145        let initialized_protocol = protocol.initialize(client_info).await?;
146
147        log::debug!(
148            "context server {} initialized: {:?}",
149            self.id,
150            initialized_protocol.initialize,
151        );
152
153        *self.client.write() = Some(Arc::new(initialized_protocol));
154        Ok(())
155    }
156
157    pub fn stop(&self) -> Result<()> {
158        let mut client = self.client.write();
159        if let Some(protocol) = client.take() {
160            drop(protocol);
161        }
162        Ok(())
163    }
164}