protocol.rs

 1//! This module implements parts of the Model Context Protocol.
 2//!
 3//! It handles the lifecycle messages, and provides a general interface to
 4//! interacting with an MCP server. It uses the generic JSON-RPC client to
 5//! read/write messages and the types from types.rs for serialization/deserialization
 6//! of messages.
 7
 8use anyhow::Result;
 9
10use crate::client::Client;
11use crate::types::{self, Request};
12
13pub struct ModelContextProtocol {
14    inner: Client,
15}
16
17impl ModelContextProtocol {
18    pub(crate) fn new(inner: Client) -> Self {
19        Self { inner }
20    }
21
22    fn supported_protocols() -> Vec<types::ProtocolVersion> {
23        vec![types::ProtocolVersion(
24            types::LATEST_PROTOCOL_VERSION.to_string(),
25        )]
26    }
27
28    pub async fn initialize(
29        self,
30        client_info: types::Implementation,
31    ) -> Result<InitializedContextServerProtocol> {
32        let params = types::InitializeParams {
33            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
34            capabilities: types::ClientCapabilities {
35                experimental: None,
36                sampling: None,
37                roots: None,
38            },
39            meta: None,
40            client_info,
41        };
42
43        let response: types::InitializeResponse = self
44            .inner
45            .request(types::request::Initialize::METHOD, params)
46            .await?;
47
48        anyhow::ensure!(
49            Self::supported_protocols().contains(&response.protocol_version),
50            "Unsupported protocol version: {:?}",
51            response.protocol_version
52        );
53
54        log::trace!("mcp server info {:?}", response.server_info);
55
56        self.inner.notify(
57            types::NotificationType::Initialized.as_str(),
58            serde_json::json!({}),
59        )?;
60
61        let initialized_protocol = InitializedContextServerProtocol {
62            inner: self.inner,
63            initialize: response,
64        };
65
66        Ok(initialized_protocol)
67    }
68}
69
70pub struct InitializedContextServerProtocol {
71    inner: Client,
72    pub initialize: types::InitializeResponse,
73}
74
75#[derive(Debug, PartialEq, Clone, Copy)]
76pub enum ServerCapability {
77    Experimental,
78    Logging,
79    Prompts,
80    Resources,
81    Tools,
82}
83
84impl InitializedContextServerProtocol {
85    /// Check if the server supports a specific capability
86    pub fn capable(&self, capability: ServerCapability) -> bool {
87        match capability {
88            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
89            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
90            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
91            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
92            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
93        }
94    }
95
96    pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
97        self.inner.request(T::METHOD, params).await
98    }
99}