context_server.rs

  1pub mod client;
  2pub mod protocol;
  3#[cfg(any(test, feature = "test-support"))]
  4pub mod test;
  5pub mod transport;
  6pub mod types;
  7
  8use std::fmt::Display;
  9use std::path::Path;
 10use std::sync::Arc;
 11
 12use anyhow::Result;
 13use client::Client;
 14use collections::HashMap;
 15use gpui::AsyncApp;
 16use parking_lot::RwLock;
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19
 20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 21pub struct ContextServerId(pub Arc<str>);
 22
 23impl Display for ContextServerId {
 24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 25        write!(f, "{}", self.0)
 26    }
 27}
 28
 29#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 30pub struct ContextServerCommand {
 31    pub path: String,
 32    pub args: Vec<String>,
 33    pub env: Option<HashMap<String, String>>,
 34}
 35
 36enum ContextServerTransport {
 37    Stdio(ContextServerCommand),
 38    Custom(Arc<dyn crate::transport::Transport>),
 39}
 40
 41pub struct ContextServer {
 42    id: ContextServerId,
 43    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 44    configuration: ContextServerTransport,
 45}
 46
 47impl ContextServer {
 48    pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
 49        Self {
 50            id,
 51            client: RwLock::new(None),
 52            configuration: ContextServerTransport::Stdio(command),
 53        }
 54    }
 55
 56    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 57        Self {
 58            id,
 59            client: RwLock::new(None),
 60            configuration: ContextServerTransport::Custom(transport),
 61        }
 62    }
 63
 64    pub fn id(&self) -> ContextServerId {
 65        self.id.clone()
 66    }
 67
 68    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 69        self.client.read().clone()
 70    }
 71
 72    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
 73        let client = match &self.configuration {
 74            ContextServerTransport::Stdio(command) => Client::stdio(
 75                client::ContextServerId(self.id.0.clone()),
 76                client::ModelContextServerBinary {
 77                    executable: Path::new(&command.path).to_path_buf(),
 78                    args: command.args.clone(),
 79                    env: command.env.clone(),
 80                },
 81                cx.clone(),
 82            )?,
 83            ContextServerTransport::Custom(transport) => Client::new(
 84                client::ContextServerId(self.id.0.clone()),
 85                self.id().0,
 86                transport.clone(),
 87                cx.clone(),
 88            )?,
 89        };
 90        self.initialize(client).await
 91    }
 92
 93    async fn initialize(&self, client: Client) -> Result<()> {
 94        log::info!("starting context server {}", self.id);
 95        let protocol = crate::protocol::ModelContextProtocol::new(client);
 96        let client_info = types::Implementation {
 97            name: "Zed".to_string(),
 98            version: env!("CARGO_PKG_VERSION").to_string(),
 99        };
100        let initialized_protocol = protocol.initialize(client_info).await?;
101
102        log::debug!(
103            "context server {} initialized: {:?}",
104            self.id,
105            initialized_protocol.initialize,
106        );
107
108        *self.client.write() = Some(Arc::new(initialized_protocol));
109        Ok(())
110    }
111
112    pub fn stop(&self) -> Result<()> {
113        let mut client = self.client.write();
114        if let Some(protocol) = client.take() {
115            drop(protocol);
116        }
117        Ok(())
118    }
119}