context_server.rs

  1pub mod client;
  2pub mod listener;
  3pub mod protocol;
  4#[cfg(any(test, feature = "test-support"))]
  5pub mod test;
  6pub mod transport;
  7pub mod types;
  8
  9use std::path::Path;
 10use std::sync::Arc;
 11use std::{fmt::Display, path::PathBuf};
 12
 13use anyhow::Result;
 14use client::Client;
 15use collections::HashMap;
 16use gpui::AsyncApp;
 17use parking_lot::RwLock;
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use util::redact::should_redact;
 21
 22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 23pub struct ContextServerId(pub Arc<str>);
 24
 25impl Display for ContextServerId {
 26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 27        write!(f, "{}", self.0)
 28    }
 29}
 30
 31#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
 32pub struct ContextServerCommand {
 33    #[serde(rename = "command")]
 34    pub path: PathBuf,
 35    pub args: Vec<String>,
 36    pub env: Option<HashMap<String, String>>,
 37}
 38
 39impl std::fmt::Debug for ContextServerCommand {
 40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 41        let filtered_env = self.env.as_ref().map(|env| {
 42            env.iter()
 43                .map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v }))
 44                .collect::<Vec<_>>()
 45        });
 46
 47        f.debug_struct("ContextServerCommand")
 48            .field("path", &self.path)
 49            .field("args", &self.args)
 50            .field("env", &filtered_env)
 51            .finish()
 52    }
 53}
 54
 55enum ContextServerTransport {
 56    Stdio(ContextServerCommand, Option<PathBuf>),
 57    Custom(Arc<dyn crate::transport::Transport>),
 58}
 59
 60pub struct ContextServer {
 61    id: ContextServerId,
 62    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 63    configuration: ContextServerTransport,
 64}
 65
 66impl ContextServer {
 67    pub fn stdio(
 68        id: ContextServerId,
 69        command: ContextServerCommand,
 70        working_directory: Option<Arc<Path>>,
 71    ) -> Self {
 72        Self {
 73            id,
 74            client: RwLock::new(None),
 75            configuration: ContextServerTransport::Stdio(
 76                command,
 77                working_directory.map(|directory| directory.to_path_buf()),
 78            ),
 79        }
 80    }
 81
 82    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 83        Self {
 84            id,
 85            client: RwLock::new(None),
 86            configuration: ContextServerTransport::Custom(transport),
 87        }
 88    }
 89
 90    pub fn id(&self) -> ContextServerId {
 91        self.id.clone()
 92    }
 93
 94    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 95        self.client.read().clone()
 96    }
 97
 98    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
 99        let client = match &self.configuration {
100            ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
101                client::ContextServerId(self.id.0.clone()),
102                client::ModelContextServerBinary {
103                    executable: Path::new(&command.path).to_path_buf(),
104                    args: command.args.clone(),
105                    env: command.env.clone(),
106                },
107                working_directory,
108                cx.clone(),
109            )?,
110            ContextServerTransport::Custom(transport) => Client::new(
111                client::ContextServerId(self.id.0.clone()),
112                self.id().0,
113                transport.clone(),
114                cx.clone(),
115            )?,
116        };
117        self.initialize(client).await
118    }
119
120    async fn initialize(&self, client: Client) -> Result<()> {
121        log::info!("starting context server {}", self.id);
122        let protocol = crate::protocol::ModelContextProtocol::new(client);
123        let client_info = types::Implementation {
124            name: "Zed".to_string(),
125            version: env!("CARGO_PKG_VERSION").to_string(),
126        };
127        let initialized_protocol = protocol.initialize(client_info).await?;
128
129        log::debug!(
130            "context server {} initialized: {:?}",
131            self.id,
132            initialized_protocol.initialize,
133        );
134
135        *self.client.write() = Some(Arc::new(initialized_protocol));
136        Ok(())
137    }
138
139    pub fn stop(&self) -> Result<()> {
140        let mut client = self.client.write();
141        if let Some(protocol) = client.take() {
142            drop(protocol);
143        }
144        Ok(())
145    }
146}