1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use std::time::Duration;
  9
 10use anyhow::Result;
 11use futures::channel::oneshot;
 12use gpui::AsyncApp;
 13use serde_json::Value;
 14
 15use crate::client::Client;
 16use crate::types::{self, Notification, Request};
 17
 18pub struct ModelContextProtocol {
 19    inner: Client,
 20}
 21
 22impl ModelContextProtocol {
 23    pub(crate) const fn new(inner: Client) -> Self {
 24        Self { inner }
 25    }
 26
 27    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 28        vec![
 29            types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 30            types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
 31        ]
 32    }
 33
 34    pub async fn initialize(
 35        self,
 36        client_info: types::Implementation,
 37    ) -> Result<InitializedContextServerProtocol> {
 38        let params = types::InitializeParams {
 39            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 40            capabilities: types::ClientCapabilities {
 41                experimental: None,
 42                sampling: None,
 43                roots: None,
 44            },
 45            meta: None,
 46            client_info,
 47        };
 48
 49        let response: types::InitializeResponse = self
 50            .inner
 51            .request(types::requests::Initialize::METHOD, params)
 52            .await?;
 53
 54        anyhow::ensure!(
 55            Self::supported_protocols().contains(&response.protocol_version),
 56            "Unsupported protocol version: {:?}",
 57            response.protocol_version
 58        );
 59
 60        log::trace!("mcp server info {:?}", response.server_info);
 61
 62        let initialized_protocol = InitializedContextServerProtocol {
 63            inner: self.inner,
 64            initialize: response,
 65        };
 66
 67        initialized_protocol.notify::<types::notifications::Initialized>(())?;
 68
 69        Ok(initialized_protocol)
 70    }
 71}
 72
 73pub struct InitializedContextServerProtocol {
 74    inner: Client,
 75    pub initialize: types::InitializeResponse,
 76}
 77
 78#[derive(Debug, PartialEq, Clone, Copy)]
 79pub enum ServerCapability {
 80    Experimental,
 81    Logging,
 82    Prompts,
 83    Resources,
 84    Tools,
 85}
 86
 87impl InitializedContextServerProtocol {
 88    /// Check if the server supports a specific capability
 89    pub const fn capable(&self, capability: ServerCapability) -> bool {
 90        match capability {
 91            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 92            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 93            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 94            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 95            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 96        }
 97    }
 98
 99    pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
100        self.inner.request(T::METHOD, params).await
101    }
102
103    pub async fn request_with<T: Request>(
104        &self,
105        params: T::Params,
106        cancel_rx: Option<oneshot::Receiver<()>>,
107        timeout: Option<Duration>,
108    ) -> Result<T::Response> {
109        self.inner
110            .request_with(T::METHOD, params, cancel_rx, timeout)
111            .await
112    }
113
114    pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
115        self.inner.notify(T::METHOD, params)
116    }
117
118    pub fn on_notification(
119        &self,
120        method: &'static str,
121        f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
122    ) {
123        self.inner.on_notification(method, f);
124    }
125}