context_server.rs

  1pub mod client;
  2pub mod listener;
  3pub mod protocol;
  4#[cfg(any(test, feature = "test-support"))]
  5pub mod test;
  6pub mod transport;
  7pub mod types;
  8
  9use std::path::Path;
 10use std::sync::Arc;
 11use std::{fmt::Display, path::PathBuf};
 12
 13use anyhow::Result;
 14use client::Client;
 15use collections::HashMap;
 16use gpui::AsyncApp;
 17use parking_lot::RwLock;
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20pub use settings::ContextServerCommand;
 21use util::redact::should_redact;
 22
 23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 24pub struct ContextServerId(pub Arc<str>);
 25
 26impl Display for ContextServerId {
 27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 28        write!(f, "{}", self.0)
 29    }
 30}
 31
 32enum ContextServerTransport {
 33    Stdio(ContextServerCommand, Option<PathBuf>),
 34    Custom(Arc<dyn crate::transport::Transport>),
 35}
 36
 37pub struct ContextServer {
 38    id: ContextServerId,
 39    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 40    configuration: ContextServerTransport,
 41}
 42
 43impl ContextServer {
 44    pub fn stdio(
 45        id: ContextServerId,
 46        command: ContextServerCommand,
 47        working_directory: Option<Arc<Path>>,
 48    ) -> Self {
 49        Self {
 50            id,
 51            client: RwLock::new(None),
 52            configuration: ContextServerTransport::Stdio(
 53                command,
 54                working_directory.map(|directory| directory.to_path_buf()),
 55            ),
 56        }
 57    }
 58
 59    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 60        Self {
 61            id,
 62            client: RwLock::new(None),
 63            configuration: ContextServerTransport::Custom(transport),
 64        }
 65    }
 66
 67    pub fn id(&self) -> ContextServerId {
 68        self.id.clone()
 69    }
 70
 71    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 72        self.client.read().clone()
 73    }
 74
 75    pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
 76        self.initialize(self.new_client(cx)?).await
 77    }
 78
 79    /// Starts the context server, making sure handlers are registered before initialization happens
 80    pub async fn start_with_handlers(
 81        &self,
 82        notification_handlers: Vec<(
 83            &'static str,
 84            Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
 85        )>,
 86        cx: &AsyncApp,
 87    ) -> Result<()> {
 88        let client = self.new_client(cx)?;
 89        for (method, handler) in notification_handlers {
 90            client.on_notification(method, handler);
 91        }
 92        self.initialize(client).await
 93    }
 94
 95    fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
 96        Ok(match &self.configuration {
 97            ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
 98                client::ContextServerId(self.id.0.clone()),
 99                client::ModelContextServerBinary {
100                    executable: Path::new(&command.path).to_path_buf(),
101                    args: command.args.clone(),
102                    env: command.env.clone(),
103                    timeout: command.timeout,
104                },
105                working_directory,
106                cx.clone(),
107            )?,
108            ContextServerTransport::Custom(transport) => Client::new(
109                client::ContextServerId(self.id.0.clone()),
110                self.id().0,
111                transport.clone(),
112                None,
113                cx.clone(),
114            )?,
115        })
116    }
117
118    async fn initialize(&self, client: Client) -> Result<()> {
119        log::debug!("starting context server {}", self.id);
120        let protocol = crate::protocol::ModelContextProtocol::new(client);
121        let client_info = types::Implementation {
122            name: "Zed".to_string(),
123            version: env!("CARGO_PKG_VERSION").to_string(),
124        };
125        let initialized_protocol = protocol.initialize(client_info).await?;
126
127        log::debug!(
128            "context server {} initialized: {:?}",
129            self.id,
130            initialized_protocol.initialize,
131        );
132
133        *self.client.write() = Some(Arc::new(initialized_protocol));
134        Ok(())
135    }
136
137    pub fn stop(&self) -> Result<()> {
138        let mut client = self.client.write();
139        if let Some(protocol) = client.take() {
140            drop(protocol);
141        }
142        Ok(())
143    }
144}