protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9
 10use crate::client::Client;
 11use crate::types::{self, Request};
 12
 13pub struct ModelContextProtocol {
 14    inner: Client,
 15}
 16
 17impl ModelContextProtocol {
 18    pub(crate) fn new(inner: Client) -> Self {
 19        Self { inner }
 20    }
 21
 22    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 23        vec![
 24            types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 25            types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
 26        ]
 27    }
 28
 29    pub async fn initialize(
 30        self,
 31        client_info: types::Implementation,
 32    ) -> Result<InitializedContextServerProtocol> {
 33        let params = types::InitializeParams {
 34            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 35            capabilities: types::ClientCapabilities {
 36                experimental: None,
 37                sampling: None,
 38                roots: None,
 39            },
 40            meta: None,
 41            client_info,
 42        };
 43
 44        let response: types::InitializeResponse = self
 45            .inner
 46            .request(types::request::Initialize::METHOD, params)
 47            .await?;
 48
 49        anyhow::ensure!(
 50            Self::supported_protocols().contains(&response.protocol_version),
 51            "Unsupported protocol version: {:?}",
 52            response.protocol_version
 53        );
 54
 55        log::trace!("mcp server info {:?}", response.server_info);
 56
 57        self.inner.notify(
 58            types::NotificationType::Initialized.as_str(),
 59            serde_json::json!({}),
 60        )?;
 61
 62        let initialized_protocol = InitializedContextServerProtocol {
 63            inner: self.inner,
 64            initialize: response,
 65        };
 66
 67        Ok(initialized_protocol)
 68    }
 69}
 70
 71pub struct InitializedContextServerProtocol {
 72    inner: Client,
 73    pub initialize: types::InitializeResponse,
 74}
 75
 76#[derive(Debug, PartialEq, Clone, Copy)]
 77pub enum ServerCapability {
 78    Experimental,
 79    Logging,
 80    Prompts,
 81    Resources,
 82    Tools,
 83}
 84
 85impl InitializedContextServerProtocol {
 86    /// Check if the server supports a specific capability
 87    pub fn capable(&self, capability: ServerCapability) -> bool {
 88        match capability {
 89            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 90            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 91            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 92            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 93            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 94        }
 95    }
 96
 97    pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
 98        self.inner.request(T::METHOD, params).await
 99    }
100}