context_server.rs

  1pub mod client;
  2pub mod protocol;
  3pub mod transport;
  4pub mod types;
  5
  6use std::fmt::Display;
  7use std::path::Path;
  8use std::sync::Arc;
  9
 10use anyhow::Result;
 11use client::Client;
 12use collections::HashMap;
 13use gpui::AsyncApp;
 14use parking_lot::RwLock;
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17
 18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 19pub struct ContextServerId(pub Arc<str>);
 20
 21impl Display for ContextServerId {
 22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 23        write!(f, "{}", self.0)
 24    }
 25}
 26
 27#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 28pub struct ContextServerCommand {
 29    pub path: String,
 30    pub args: Vec<String>,
 31    pub env: Option<HashMap<String, String>>,
 32}
 33
 34enum ContextServerTransport {
 35    Stdio(ContextServerCommand),
 36    Custom(Arc<dyn crate::transport::Transport>),
 37}
 38
 39pub struct ContextServer {
 40    id: ContextServerId,
 41    client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 42    configuration: ContextServerTransport,
 43}
 44
 45impl ContextServer {
 46    pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
 47        Self {
 48            id,
 49            client: RwLock::new(None),
 50            configuration: ContextServerTransport::Stdio(command),
 51        }
 52    }
 53
 54    pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
 55        Self {
 56            id,
 57            client: RwLock::new(None),
 58            configuration: ContextServerTransport::Custom(transport),
 59        }
 60    }
 61
 62    pub fn id(&self) -> ContextServerId {
 63        self.id.clone()
 64    }
 65
 66    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 67        self.client.read().clone()
 68    }
 69
 70    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
 71        let client = match &self.configuration {
 72            ContextServerTransport::Stdio(command) => Client::stdio(
 73                client::ContextServerId(self.id.0.clone()),
 74                client::ModelContextServerBinary {
 75                    executable: Path::new(&command.path).to_path_buf(),
 76                    args: command.args.clone(),
 77                    env: command.env.clone(),
 78                },
 79                cx.clone(),
 80            )?,
 81            ContextServerTransport::Custom(transport) => Client::new(
 82                client::ContextServerId(self.id.0.clone()),
 83                self.id().0,
 84                transport.clone(),
 85                cx.clone(),
 86            )?,
 87        };
 88        self.initialize(client).await
 89    }
 90
 91    async fn initialize(&self, client: Client) -> Result<()> {
 92        log::info!("starting context server {}", self.id);
 93        let protocol = crate::protocol::ModelContextProtocol::new(client);
 94        let client_info = types::Implementation {
 95            name: "Zed".to_string(),
 96            version: env!("CARGO_PKG_VERSION").to_string(),
 97        };
 98        let initialized_protocol = protocol.initialize(client_info).await?;
 99
100        log::debug!(
101            "context server {} initialized: {:?}",
102            self.id,
103            initialized_protocol.initialize,
104        );
105
106        *self.client.write() = Some(Arc::new(initialized_protocol));
107        Ok(())
108    }
109
110    pub fn stop(&self) -> Result<()> {
111        let mut client = self.client.write();
112        if let Some(protocol) = client.take() {
113            drop(protocol);
114        }
115        Ok(())
116    }
117}