context_server.rs

  1pub mod client;
  2pub mod listener;
  3pub mod oauth;
  4pub mod protocol;
  5#[cfg(any(test, feature = "test-support"))]
  6pub mod test;
  7pub mod transport;
  8pub mod types;
  9
 10use collections::HashMap;
 11use http_client::HttpClient;
 12use std::path::Path;
 13use std::sync::Arc;
 14use std::time::Duration;
 15use std::{fmt::Display, path::PathBuf};
 16
 17use anyhow::Result;
 18use client::Client;
 19use gpui::AsyncApp;
 20use parking_lot::RwLock;
 21pub use settings::ContextServerCommand;
 22use url::Url;
 23
 24use crate::transport::HttpTransport;
 25
 26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 27pub struct ContextServerId(pub Arc<str>);
 28
 29impl Display for ContextServerId {
 30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 31        write!(f, "{}", self.0)
 32    }
 33}
 34
 35enum ContextServerTransport {
 36    Stdio(ContextServerCommand, Option<PathBuf>),
 37    Custom(Arc<dyn crate::transport::Transport>),
 38}
 39
 40pub struct ContextServer {
 41    id: ContextServerId,
 42    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 43    configuration: ContextServerTransport,
 44    request_timeout: Option<Duration>,
 45}
 46
 47impl ContextServer {
 48    pub fn stdio(
 49        id: ContextServerId,
 50        command: ContextServerCommand,
 51        working_directory: Option<Arc<Path>>,
 52    ) -> Self {
 53        Self {
 54            id,
 55            client: RwLock::new(None),
 56            configuration: ContextServerTransport::Stdio(
 57                command,
 58                working_directory.map(|directory| directory.to_path_buf()),
 59            ),
 60            request_timeout: None,
 61        }
 62    }
 63
 64    pub fn http(
 65        id: ContextServerId,
 66        endpoint: &Url,
 67        headers: HashMap<String, String>,
 68        http_client: Arc<dyn HttpClient>,
 69        executor: gpui::BackgroundExecutor,
 70        request_timeout: Option<Duration>,
 71    ) -> Result<Self> {
 72        let transport = match endpoint.scheme() {
 73            "http" | "https" => {
 74                log::info!("Using HTTP transport for {}", endpoint);
 75                let transport =
 76                    HttpTransport::new(http_client, endpoint.to_string(), headers, executor);
 77                Arc::new(transport) as _
 78            }
 79            _ => anyhow::bail!("unsupported MCP url scheme {}", endpoint.scheme()),
 80        };
 81        Ok(Self::new_with_timeout(id, transport, request_timeout))
 82    }
 83
 84    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 85        Self::new_with_timeout(id, transport, None)
 86    }
 87
 88    pub fn new_with_timeout(
 89        id: ContextServerId,
 90        transport: Arc<dyn crate::transport::Transport>,
 91        request_timeout: Option<Duration>,
 92    ) -> Self {
 93        Self {
 94            id,
 95            client: RwLock::new(None),
 96            configuration: ContextServerTransport::Custom(transport),
 97            request_timeout,
 98        }
 99    }
100
101    pub fn id(&self) -> ContextServerId {
102        self.id.clone()
103    }
104
105    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
106        self.client.read().clone()
107    }
108
109    pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
110        self.initialize(self.new_client(cx)?).await
111    }
112
113    fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
114        Ok(match &self.configuration {
115            ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
116                client::ContextServerId(self.id.0.clone()),
117                client::ModelContextServerBinary {
118                    executable: Path::new(&command.path).to_path_buf(),
119                    args: command.args.clone(),
120                    env: command.env.clone(),
121                    timeout: command.timeout,
122                },
123                working_directory,
124                cx.clone(),
125            )?,
126            ContextServerTransport::Custom(transport) => Client::new(
127                client::ContextServerId(self.id.0.clone()),
128                self.id().0,
129                transport.clone(),
130                self.request_timeout,
131                cx.clone(),
132            )?,
133        })
134    }
135
136    async fn initialize(&self, client: Client) -> Result<()> {
137        log::debug!("starting context server {}", self.id);
138        let protocol = crate::protocol::ModelContextProtocol::new(client);
139        let client_info = types::Implementation {
140            name: "Zed".to_string(),
141            version: env!("CARGO_PKG_VERSION").to_string(),
142        };
143        let initialized_protocol = protocol.initialize(client_info).await?;
144
145        log::debug!(
146            "context server {} initialized: {:?}",
147            self.id,
148            initialized_protocol.initialize,
149        );
150
151        *self.client.write() = Some(Arc::new(initialized_protocol));
152        Ok(())
153    }
154
155    pub fn stop(&self) -> Result<()> {
156        let mut client = self.client.write();
157        if let Some(protocol) = client.take() {
158            drop(protocol);
159        }
160        Ok(())
161    }
162}