context_server.rs

  1pub mod client;
  2pub mod listener;
  3pub mod protocol;
  4#[cfg(any(test, feature = "test-support"))]
  5pub mod test;
  6pub mod transport;
  7pub mod types;
  8
  9use collections::HashMap;
 10use http_client::HttpClient;
 11use std::path::Path;
 12use std::sync::Arc;
 13use std::time::Duration;
 14use std::{fmt::Display, path::PathBuf};
 15
 16use anyhow::Result;
 17use client::Client;
 18use gpui::AsyncApp;
 19use parking_lot::RwLock;
 20pub use settings::ContextServerCommand;
 21use url::Url;
 22
 23use crate::transport::HttpTransport;
 24
 25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 26pub struct ContextServerId(pub Arc<str>);
 27
 28impl Display for ContextServerId {
 29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 30        write!(f, "{}", self.0)
 31    }
 32}
 33
 34enum ContextServerTransport {
 35    Stdio(ContextServerCommand, Option<PathBuf>),
 36    Custom(Arc<dyn crate::transport::Transport>),
 37}
 38
 39pub struct ContextServer {
 40    id: ContextServerId,
 41    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 42    configuration: ContextServerTransport,
 43    request_timeout: Option<Duration>,
 44}
 45
 46impl ContextServer {
 47    pub fn stdio(
 48        id: ContextServerId,
 49        command: ContextServerCommand,
 50        working_directory: Option<Arc<Path>>,
 51    ) -> Self {
 52        Self {
 53            id,
 54            client: RwLock::new(None),
 55            configuration: ContextServerTransport::Stdio(
 56                command,
 57                working_directory.map(|directory| directory.to_path_buf()),
 58            ),
 59            request_timeout: None, // Stdio handles timeout through command
 60        }
 61    }
 62
 63    pub fn http(
 64        id: ContextServerId,
 65        endpoint: &Url,
 66        headers: HashMap<String, String>,
 67        http_client: Arc<dyn HttpClient>,
 68        executor: gpui::BackgroundExecutor,
 69        request_timeout: Option<Duration>,
 70    ) -> Result<Self> {
 71        let transport = match endpoint.scheme() {
 72            "http" | "https" => {
 73                log::info!("Using HTTP transport for {}", endpoint);
 74                let transport =
 75                    HttpTransport::new(http_client, endpoint.to_string(), headers, executor);
 76                Arc::new(transport) as _
 77            }
 78            _ => anyhow::bail!("unsupported MCP url scheme {}", endpoint.scheme()),
 79        };
 80        Ok(Self::new_with_timeout(id, transport, request_timeout))
 81    }
 82
 83    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 84        Self::new_with_timeout(id, transport, None)
 85    }
 86
 87    pub fn new_with_timeout(
 88        id: ContextServerId,
 89        transport: Arc<dyn crate::transport::Transport>,
 90        request_timeout: Option<Duration>,
 91    ) -> Self {
 92        Self {
 93            id,
 94            client: RwLock::new(None),
 95            configuration: ContextServerTransport::Custom(transport),
 96            request_timeout,
 97        }
 98    }
 99
100    pub fn id(&self) -> ContextServerId {
101        self.id.clone()
102    }
103
104    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
105        self.client.read().clone()
106    }
107
108    pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
109        self.initialize(self.new_client(cx)?).await
110    }
111
112    fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
113        Ok(match &self.configuration {
114            ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
115                client::ContextServerId(self.id.0.clone()),
116                client::ModelContextServerBinary {
117                    executable: Path::new(&command.path).to_path_buf(),
118                    args: command.args.clone(),
119                    env: command.env.clone(),
120                    timeout: command.timeout,
121                },
122                working_directory,
123                cx.clone(),
124            )?,
125            ContextServerTransport::Custom(transport) => Client::new(
126                client::ContextServerId(self.id.0.clone()),
127                self.id().0,
128                transport.clone(),
129                self.request_timeout,
130                cx.clone(),
131            )?,
132        })
133    }
134
135    async fn initialize(&self, client: Client) -> Result<()> {
136        log::debug!("starting context server {}", self.id);
137        let protocol = crate::protocol::ModelContextProtocol::new(client);
138        let client_info = types::Implementation {
139            name: "Zed".to_string(),
140            version: env!("CARGO_PKG_VERSION").to_string(),
141        };
142        let initialized_protocol = protocol.initialize(client_info).await?;
143
144        log::debug!(
145            "context server {} initialized: {:?}",
146            self.id,
147            initialized_protocol.initialize,
148        );
149
150        *self.client.write() = Some(Arc::new(initialized_protocol));
151        Ok(())
152    }
153
154    pub fn stop(&self) -> Result<()> {
155        let mut client = self.client.write();
156        if let Some(protocol) = client.take() {
157            drop(protocol);
158        }
159        Ok(())
160    }
161}