context_server.rs

  1pub mod client;
  2pub mod protocol;
  3#[cfg(any(test, feature = "test-support"))]
  4pub mod test;
  5pub mod transport;
  6pub mod types;
  7
  8use std::fmt::Display;
  9use std::path::Path;
 10use std::sync::Arc;
 11
 12use anyhow::Result;
 13use client::Client;
 14use collections::HashMap;
 15use gpui::AsyncApp;
 16use parking_lot::RwLock;
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use util::redact::should_redact;
 20
 21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 22pub struct ContextServerId(pub Arc<str>);
 23
 24impl Display for ContextServerId {
 25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 26        write!(f, "{}", self.0)
 27    }
 28}
 29
 30#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
 31pub struct ContextServerCommand {
 32    #[serde(rename = "command")]
 33    pub path: String,
 34    pub args: Vec<String>,
 35    pub env: Option<HashMap<String, String>>,
 36}
 37
 38impl std::fmt::Debug for ContextServerCommand {
 39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 40        let filtered_env = self.env.as_ref().map(|env| {
 41            env.iter()
 42                .map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v }))
 43                .collect::<Vec<_>>()
 44        });
 45
 46        f.debug_struct("ContextServerCommand")
 47            .field("path", &self.path)
 48            .field("args", &self.args)
 49            .field("env", &filtered_env)
 50            .finish()
 51    }
 52}
 53
 54enum ContextServerTransport {
 55    Stdio(ContextServerCommand),
 56    Custom(Arc<dyn crate::transport::Transport>),
 57}
 58
 59pub struct ContextServer {
 60    id: ContextServerId,
 61    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 62    configuration: ContextServerTransport,
 63}
 64
 65impl ContextServer {
 66    pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
 67        Self {
 68            id,
 69            client: RwLock::new(None),
 70            configuration: ContextServerTransport::Stdio(command),
 71        }
 72    }
 73
 74    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 75        Self {
 76            id,
 77            client: RwLock::new(None),
 78            configuration: ContextServerTransport::Custom(transport),
 79        }
 80    }
 81
 82    pub fn id(&self) -> ContextServerId {
 83        self.id.clone()
 84    }
 85
 86    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 87        self.client.read().clone()
 88    }
 89
 90    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
 91        let client = match &self.configuration {
 92            ContextServerTransport::Stdio(command) => Client::stdio(
 93                client::ContextServerId(self.id.0.clone()),
 94                client::ModelContextServerBinary {
 95                    executable: Path::new(&command.path).to_path_buf(),
 96                    args: command.args.clone(),
 97                    env: command.env.clone(),
 98                },
 99                cx.clone(),
100            )?,
101            ContextServerTransport::Custom(transport) => Client::new(
102                client::ContextServerId(self.id.0.clone()),
103                self.id().0,
104                transport.clone(),
105                cx.clone(),
106            )?,
107        };
108        self.initialize(client).await
109    }
110
111    async fn initialize(&self, client: Client) -> Result<()> {
112        log::info!("starting context server {}", self.id);
113        let protocol = crate::protocol::ModelContextProtocol::new(client);
114        let client_info = types::Implementation {
115            name: "Zed".to_string(),
116            version: env!("CARGO_PKG_VERSION").to_string(),
117        };
118        let initialized_protocol = protocol.initialize(client_info).await?;
119
120        log::debug!(
121            "context server {} initialized: {:?}",
122            self.id,
123            initialized_protocol.initialize,
124        );
125
126        *self.client.write() = Some(Arc::new(initialized_protocol));
127        Ok(())
128    }
129
130    pub fn stop(&self) -> Result<()> {
131        let mut client = self.client.write();
132        if let Some(protocol) = client.take() {
133            drop(protocol);
134        }
135        Ok(())
136    }
137}