protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9
 10use crate::client::Client;
 11use crate::types::{self, Notification, Request};
 12
 13pub struct ModelContextProtocol {
 14    inner: Client,
 15}
 16
 17impl ModelContextProtocol {
 18    pub(crate) fn new(inner: Client) -> Self {
 19        Self { inner }
 20    }
 21
 22    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 23        vec![
 24            types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 25            types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
 26        ]
 27    }
 28
 29    pub async fn initialize(
 30        self,
 31        client_info: types::Implementation,
 32    ) -> Result<InitializedContextServerProtocol> {
 33        let params = types::InitializeParams {
 34            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 35            capabilities: types::ClientCapabilities {
 36                experimental: None,
 37                sampling: None,
 38                roots: None,
 39            },
 40            meta: None,
 41            client_info,
 42        };
 43
 44        let response: types::InitializeResponse = self
 45            .inner
 46            .request(types::requests::Initialize::METHOD, params)
 47            .await?;
 48
 49        anyhow::ensure!(
 50            Self::supported_protocols().contains(&response.protocol_version),
 51            "Unsupported protocol version: {:?}",
 52            response.protocol_version
 53        );
 54
 55        log::trace!("mcp server info {:?}", response.server_info);
 56
 57        let initialized_protocol = InitializedContextServerProtocol {
 58            inner: self.inner,
 59            initialize: response,
 60        };
 61
 62        initialized_protocol.notify::<types::notifications::Initialized>(())?;
 63
 64        Ok(initialized_protocol)
 65    }
 66}
 67
 68pub struct InitializedContextServerProtocol {
 69    inner: Client,
 70    pub initialize: types::InitializeResponse,
 71}
 72
 73#[derive(Debug, PartialEq, Clone, Copy)]
 74pub enum ServerCapability {
 75    Experimental,
 76    Logging,
 77    Prompts,
 78    Resources,
 79    Tools,
 80}
 81
 82impl InitializedContextServerProtocol {
 83    /// Check if the server supports a specific capability
 84    pub fn capable(&self, capability: ServerCapability) -> bool {
 85        match capability {
 86            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 87            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 88            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 89            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 90            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 91        }
 92    }
 93
 94    pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
 95        self.inner.request(T::METHOD, params).await
 96    }
 97
 98    pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
 99        self.inner.notify(T::METHOD, params)
100    }
101}