protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9use collections::HashMap;
 10
 11use crate::client::Client;
 12use crate::types;
 13
 14const PROTOCOL_VERSION: &str = "2024-10-07";
 15
 16pub struct ModelContextProtocol {
 17    inner: Client,
 18}
 19
 20impl ModelContextProtocol {
 21    pub fn new(inner: Client) -> Self {
 22        Self { inner }
 23    }
 24
 25    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 26        vec![
 27            types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
 28            types::ProtocolVersion::VersionNumber(1),
 29        ]
 30    }
 31
 32    pub async fn initialize(
 33        self,
 34        client_info: types::Implementation,
 35    ) -> Result<InitializedContextServerProtocol> {
 36        let params = types::InitializeParams {
 37            protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
 38            capabilities: types::ClientCapabilities {
 39                experimental: None,
 40                sampling: None,
 41            },
 42            client_info,
 43        };
 44
 45        let response: types::InitializeResponse = self
 46            .inner
 47            .request(types::RequestType::Initialize.as_str(), params)
 48            .await?;
 49
 50        if !Self::supported_protocols().contains(&response.protocol_version) {
 51            return Err(anyhow::anyhow!(
 52                "Unsupported protocol version: {:?}",
 53                response.protocol_version
 54            ));
 55        }
 56
 57        log::trace!("mcp server info {:?}", response.server_info);
 58
 59        self.inner.notify(
 60            types::NotificationType::Initialized.as_str(),
 61            serde_json::json!({}),
 62        )?;
 63
 64        let initialized_protocol = InitializedContextServerProtocol {
 65            inner: self.inner,
 66            initialize: response,
 67        };
 68
 69        Ok(initialized_protocol)
 70    }
 71}
 72
 73pub struct InitializedContextServerProtocol {
 74    inner: Client,
 75    pub initialize: types::InitializeResponse,
 76}
 77
 78#[derive(Debug, PartialEq, Clone, Copy)]
 79pub enum ServerCapability {
 80    Experimental,
 81    Logging,
 82    Prompts,
 83    Resources,
 84    Tools,
 85}
 86
 87impl InitializedContextServerProtocol {
 88    /// Check if the server supports a specific capability
 89    pub fn capable(&self, capability: ServerCapability) -> bool {
 90        match capability {
 91            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 92            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 93            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 94            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 95            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 96        }
 97    }
 98
 99    fn check_capability(&self, capability: ServerCapability) -> Result<()> {
100        if self.capable(capability) {
101            Ok(())
102        } else {
103            Err(anyhow::anyhow!(
104                "Server does not support {:?} capability",
105                capability
106            ))
107        }
108    }
109
110    /// List the MCP prompts.
111    pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
112        self.check_capability(ServerCapability::Prompts)?;
113
114        let response: types::PromptsListResponse = self
115            .inner
116            .request(
117                types::RequestType::PromptsList.as_str(),
118                serde_json::json!({}),
119            )
120            .await?;
121
122        Ok(response.prompts)
123    }
124
125    /// List the MCP resources.
126    pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
127        self.check_capability(ServerCapability::Resources)?;
128
129        let response: types::ResourcesListResponse = self
130            .inner
131            .request(
132                types::RequestType::ResourcesList.as_str(),
133                serde_json::json!({}),
134            )
135            .await?;
136
137        Ok(response)
138    }
139
140    /// Executes a prompt with the given arguments and returns the result.
141    pub async fn run_prompt<P: AsRef<str>>(
142        &self,
143        prompt: P,
144        arguments: HashMap<String, String>,
145    ) -> Result<types::PromptsGetResponse> {
146        self.check_capability(ServerCapability::Prompts)?;
147
148        let params = types::PromptsGetParams {
149            name: prompt.as_ref().to_string(),
150            arguments: Some(arguments),
151        };
152
153        let response: types::PromptsGetResponse = self
154            .inner
155            .request(types::RequestType::PromptsGet.as_str(), params)
156            .await?;
157
158        Ok(response)
159    }
160
161    pub async fn completion<P: Into<String>>(
162        &self,
163        reference: types::CompletionReference,
164        argument: P,
165        value: P,
166    ) -> Result<types::Completion> {
167        let params = types::CompletionCompleteParams {
168            r#ref: reference,
169            argument: types::CompletionArgument {
170                name: argument.into(),
171                value: value.into(),
172            },
173        };
174        let result: types::CompletionCompleteResponse = self
175            .inner
176            .request(types::RequestType::CompletionComplete.as_str(), params)
177            .await?;
178
179        let completion = types::Completion {
180            values: result.completion.values,
181            total: types::CompletionTotal::from_options(
182                result.completion.has_more,
183                result.completion.total,
184            ),
185        };
186
187        Ok(completion)
188    }
189
190    /// List MCP tools.
191    pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
192        self.check_capability(ServerCapability::Tools)?;
193
194        let response = self
195            .inner
196            .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
197            .await?;
198
199        Ok(response)
200    }
201
202    /// Executes a tool with the given arguments
203    pub async fn run_tool<P: AsRef<str>>(
204        &self,
205        tool: P,
206        arguments: Option<HashMap<String, serde_json::Value>>,
207    ) -> Result<types::CallToolResponse> {
208        self.check_capability(ServerCapability::Tools)?;
209
210        let params = types::CallToolParams {
211            name: tool.as_ref().to_string(),
212            arguments,
213        };
214
215        let response: types::CallToolResponse = self
216            .inner
217            .request(types::RequestType::CallTool.as_str(), params)
218            .await?;
219
220        Ok(response)
221    }
222}
223
224impl InitializedContextServerProtocol {
225    pub async fn request<R: serde::de::DeserializeOwned>(
226        &self,
227        method: &str,
228        params: impl serde::Serialize,
229    ) -> Result<R> {
230        self.inner.request(method, params).await
231    }
232}