protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9use gpui::AsyncApp;
 10use serde_json::Value;
 11
 12use crate::client::Client;
 13use crate::types::{self, Notification, Request};
 14
 15pub struct ModelContextProtocol {
 16    inner: Client,
 17}
 18
 19impl ModelContextProtocol {
 20    pub(crate) fn new(inner: Client) -> Self {
 21        Self { inner }
 22    }
 23
 24    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 25        vec![
 26            types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 27            types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
 28        ]
 29    }
 30
 31    pub async fn initialize(
 32        self,
 33        client_info: types::Implementation,
 34    ) -> Result<InitializedContextServerProtocol> {
 35        let params = types::InitializeParams {
 36            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 37            capabilities: types::ClientCapabilities {
 38                experimental: None,
 39                sampling: None,
 40                roots: None,
 41            },
 42            meta: None,
 43            client_info,
 44        };
 45
 46        let response: types::InitializeResponse = self
 47            .inner
 48            .request(types::requests::Initialize::METHOD, params)
 49            .await?;
 50
 51        anyhow::ensure!(
 52            Self::supported_protocols().contains(&response.protocol_version),
 53            "Unsupported protocol version: {:?}",
 54            response.protocol_version
 55        );
 56
 57        log::trace!("mcp server info {:?}", response.server_info);
 58
 59        let initialized_protocol = InitializedContextServerProtocol {
 60            inner: self.inner,
 61            initialize: response,
 62        };
 63
 64        initialized_protocol.notify::<types::notifications::Initialized>(())?;
 65
 66        Ok(initialized_protocol)
 67    }
 68}
 69
 70pub struct InitializedContextServerProtocol {
 71    inner: Client,
 72    pub initialize: types::InitializeResponse,
 73}
 74
 75#[derive(Debug, PartialEq, Clone, Copy)]
 76pub enum ServerCapability {
 77    Experimental,
 78    Logging,
 79    Prompts,
 80    Resources,
 81    Tools,
 82}
 83
 84impl InitializedContextServerProtocol {
 85    /// Check if the server supports a specific capability
 86    pub fn capable(&self, capability: ServerCapability) -> bool {
 87        match capability {
 88            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 89            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 90            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 91            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 92            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 93        }
 94    }
 95
 96    pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
 97        self.inner.request(T::METHOD, params).await
 98    }
 99
100    pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
101        self.inner.notify(T::METHOD, params)
102    }
103
104    pub fn on_notification<F>(&self, method: &'static str, f: F)
105    where
106        F: 'static + Send + FnMut(Value, AsyncApp),
107    {
108        self.inner.on_notification(method, f);
109    }
110}