1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9
10use crate::client::Client;
11use crate::types::{self, Notification, Request};
12
13pub struct ModelContextProtocol {
14 inner: Client,
15}
16
17impl ModelContextProtocol {
18 pub(crate) fn new(inner: Client) -> Self {
19 Self { inner }
20 }
21
22 fn supported_protocols() -> Vec<types::ProtocolVersion> {
23 vec![
24 types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
25 types::ProtocolVersion(types::VERSION_2024_11_05.to_string()),
26 ]
27 }
28
29 pub async fn initialize(
30 self,
31 client_info: types::Implementation,
32 ) -> Result<InitializedContextServerProtocol> {
33 let params = types::InitializeParams {
34 protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
35 capabilities: types::ClientCapabilities {
36 experimental: None,
37 sampling: None,
38 roots: None,
39 },
40 meta: None,
41 client_info,
42 };
43
44 let response: types::InitializeResponse = self
45 .inner
46 .request(types::requests::Initialize::METHOD, params)
47 .await?;
48
49 anyhow::ensure!(
50 Self::supported_protocols().contains(&response.protocol_version),
51 "Unsupported protocol version: {:?}",
52 response.protocol_version
53 );
54
55 log::trace!("mcp server info {:?}", response.server_info);
56
57 let initialized_protocol = InitializedContextServerProtocol {
58 inner: self.inner,
59 initialize: response,
60 };
61
62 initialized_protocol.notify::<types::notifications::Initialized>(())?;
63
64 Ok(initialized_protocol)
65 }
66}
67
68pub struct InitializedContextServerProtocol {
69 inner: Client,
70 pub initialize: types::InitializeResponse,
71}
72
73#[derive(Debug, PartialEq, Clone, Copy)]
74pub enum ServerCapability {
75 Experimental,
76 Logging,
77 Prompts,
78 Resources,
79 Tools,
80}
81
82impl InitializedContextServerProtocol {
83 /// Check if the server supports a specific capability
84 pub fn capable(&self, capability: ServerCapability) -> bool {
85 match capability {
86 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
87 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
88 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
89 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
90 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
91 }
92 }
93
94 pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
95 self.inner.request(T::METHOD, params).await
96 }
97
98 pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
99 self.inner.notify(T::METHOD, params)
100 }
101}