1pub mod client;
  2pub mod listener;
  3pub mod protocol;
  4#[cfg(any(test, feature = "test-support"))]
  5pub mod test;
  6pub mod transport;
  7pub mod types;
  8
  9use std::path::Path;
 10use std::sync::Arc;
 11use std::{fmt::Display, path::PathBuf};
 12
 13use anyhow::Result;
 14use client::Client;
 15use gpui::AsyncApp;
 16use parking_lot::RwLock;
 17pub use settings::ContextServerCommand;
 18
 19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 20pub struct ContextServerId(pub Arc<str>);
 21
 22impl Display for ContextServerId {
 23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 24        write!(f, "{}", self.0)
 25    }
 26}
 27
 28enum ContextServerTransport {
 29    Stdio(ContextServerCommand, Option<PathBuf>),
 30    Custom(Arc<dyn crate::transport::Transport>),
 31}
 32
 33pub struct ContextServer {
 34    id: ContextServerId,
 35    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 36    configuration: ContextServerTransport,
 37}
 38
 39impl ContextServer {
 40    pub fn stdio(
 41        id: ContextServerId,
 42        command: ContextServerCommand,
 43        working_directory: Option<Arc<Path>>,
 44    ) -> Self {
 45        Self {
 46            id,
 47            client: RwLock::new(None),
 48            configuration: ContextServerTransport::Stdio(
 49                command,
 50                working_directory.map(|directory| directory.to_path_buf()),
 51            ),
 52        }
 53    }
 54
 55    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 56        Self {
 57            id,
 58            client: RwLock::new(None),
 59            configuration: ContextServerTransport::Custom(transport),
 60        }
 61    }
 62
 63    pub fn id(&self) -> ContextServerId {
 64        self.id.clone()
 65    }
 66
 67    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 68        self.client.read().clone()
 69    }
 70
 71    pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
 72        self.initialize(self.new_client(cx)?).await
 73    }
 74
 75    /// Starts the context server, making sure handlers are registered before initialization happens
 76    pub async fn start_with_handlers(
 77        &self,
 78        notification_handlers: Vec<(
 79            &'static str,
 80            Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
 81        )>,
 82        cx: &AsyncApp,
 83    ) -> Result<()> {
 84        let client = self.new_client(cx)?;
 85        for (method, handler) in notification_handlers {
 86            client.on_notification(method, handler);
 87        }
 88        self.initialize(client).await
 89    }
 90
 91    fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
 92        Ok(match &self.configuration {
 93            ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
 94                client::ContextServerId(self.id.0.clone()),
 95                client::ModelContextServerBinary {
 96                    executable: Path::new(&command.path).to_path_buf(),
 97                    args: command.args.clone(),
 98                    env: command.env.clone(),
 99                    timeout: command.timeout,
100                },
101                working_directory,
102                cx.clone(),
103            )?,
104            ContextServerTransport::Custom(transport) => Client::new(
105                client::ContextServerId(self.id.0.clone()),
106                self.id().0,
107                transport.clone(),
108                None,
109                cx.clone(),
110            )?,
111        })
112    }
113
114    async fn initialize(&self, client: Client) -> Result<()> {
115        log::debug!("starting context server {}", self.id);
116        let protocol = crate::protocol::ModelContextProtocol::new(client);
117        let client_info = types::Implementation {
118            name: "Zed".to_string(),
119            version: env!("CARGO_PKG_VERSION").to_string(),
120        };
121        let initialized_protocol = protocol.initialize(client_info).await?;
122
123        log::debug!(
124            "context server {} initialized: {:?}",
125            self.id,
126            initialized_protocol.initialize,
127        );
128
129        *self.client.write() = Some(Arc::new(initialized_protocol));
130        Ok(())
131    }
132
133    pub fn stop(&self) -> Result<()> {
134        let mut client = self.client.write();
135        if let Some(protocol) = client.take() {
136            drop(protocol);
137        }
138        Ok(())
139    }
140}