protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9use futures::channel::oneshot;
 10use gpui::AsyncApp;
 11use serde_json::Value;
 12
 13use crate::client::Client;
 14use crate::types::{self, Notification, Request};
 15
 16pub struct ModelContextProtocol {
 17    inner: Client,
 18}
 19
 20impl ModelContextProtocol {
 21    pub(crate) fn new(inner: Client) -> Self {
 22        Self { inner }
 23    }
 24
 25    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 26        vec![
 27            types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 28            types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
 29        ]
 30    }
 31
 32    pub async fn initialize(
 33        self,
 34        client_info: types::Implementation,
 35    ) -> Result<InitializedContextServerProtocol> {
 36        let params = types::InitializeParams {
 37            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 38            capabilities: types::ClientCapabilities {
 39                experimental: None,
 40                sampling: None,
 41                roots: None,
 42            },
 43            meta: None,
 44            client_info,
 45        };
 46
 47        let response: types::InitializeResponse = self
 48            .inner
 49            .request(types::requests::Initialize::METHOD, params)
 50            .await?;
 51
 52        anyhow::ensure!(
 53            Self::supported_protocols().contains(&response.protocol_version),
 54            "Unsupported protocol version: {:?}",
 55            response.protocol_version
 56        );
 57
 58        log::trace!("mcp server info {:?}", response.server_info);
 59
 60        let initialized_protocol = InitializedContextServerProtocol {
 61            inner: self.inner,
 62            initialize: response,
 63        };
 64
 65        initialized_protocol.notify::<types::notifications::Initialized>(())?;
 66
 67        Ok(initialized_protocol)
 68    }
 69}
 70
 71pub struct InitializedContextServerProtocol {
 72    inner: Client,
 73    pub initialize: types::InitializeResponse,
 74}
 75
 76#[derive(Debug, PartialEq, Clone, Copy)]
 77pub enum ServerCapability {
 78    Experimental,
 79    Logging,
 80    Prompts,
 81    Resources,
 82    Tools,
 83}
 84
 85impl InitializedContextServerProtocol {
 86    /// Check if the server supports a specific capability
 87    pub fn capable(&self, capability: ServerCapability) -> bool {
 88        match capability {
 89            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 90            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 91            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 92            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 93            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 94        }
 95    }
 96
 97    pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
 98        self.inner.request(T::METHOD, params).await
 99    }
100
101    pub async fn cancellable_request<T: Request>(
102        &self,
103        params: T::Params,
104        cancel_rx: oneshot::Receiver<()>,
105    ) -> Result<T::Response> {
106        self.inner
107            .cancellable_request(T::METHOD, params, cancel_rx)
108            .await
109    }
110
111    pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
112        self.inner.notify(T::METHOD, params)
113    }
114
115    pub fn on_notification<F>(&self, method: &'static str, f: F)
116    where
117        F: 'static + Send + FnMut(Value, AsyncApp),
118    {
119        self.inner.on_notification(method, f);
120    }
121}