protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9use collections::HashMap;
 10
 11use crate::client::Client;
 12use crate::types;
 13
 14const PROTOCOL_VERSION: &str = "2024-10-07";
 15
 16pub struct ModelContextProtocol {
 17    inner: Client,
 18}
 19
 20impl ModelContextProtocol {
 21    pub fn new(inner: Client) -> Self {
 22        Self { inner }
 23    }
 24
 25    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 26        vec![
 27            types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
 28            types::ProtocolVersion::VersionNumber(1),
 29        ]
 30    }
 31
 32    pub async fn initialize(
 33        self,
 34        client_info: types::Implementation,
 35    ) -> Result<InitializedContextServerProtocol> {
 36        let params = types::InitializeParams {
 37            protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
 38            capabilities: types::ClientCapabilities {
 39                experimental: None,
 40                sampling: None,
 41            },
 42            client_info,
 43        };
 44
 45        let response: types::InitializeResponse = self
 46            .inner
 47            .request(types::RequestType::Initialize.as_str(), params)
 48            .await?;
 49
 50        if !Self::supported_protocols().contains(&response.protocol_version) {
 51            return Err(anyhow::anyhow!(
 52                "Unsupported protocol version: {:?}",
 53                response.protocol_version
 54            ));
 55        }
 56
 57        log::trace!("mcp server info {:?}", response.server_info);
 58
 59        self.inner.notify(
 60            types::NotificationType::Initialized.as_str(),
 61            serde_json::json!({}),
 62        )?;
 63
 64        let initialized_protocol = InitializedContextServerProtocol {
 65            inner: self.inner,
 66            initialize: response,
 67        };
 68
 69        Ok(initialized_protocol)
 70    }
 71}
 72
 73pub struct InitializedContextServerProtocol {
 74    inner: Client,
 75    pub initialize: types::InitializeResponse,
 76}
 77
 78#[derive(Debug, PartialEq, Clone, Copy)]
 79pub enum ServerCapability {
 80    Experimental,
 81    Logging,
 82    Prompts,
 83    Resources,
 84    Tools,
 85}
 86
 87impl InitializedContextServerProtocol {
 88    /// Check if the server supports a specific capability
 89    pub fn capable(&self, capability: ServerCapability) -> bool {
 90        match capability {
 91            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 92            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 93            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 94            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 95            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 96        }
 97    }
 98
 99    fn check_capability(&self, capability: ServerCapability) -> Result<()> {
100        if self.capable(capability) {
101            Ok(())
102        } else {
103            Err(anyhow::anyhow!(
104                "Server does not support {:?} capability",
105                capability
106            ))
107        }
108    }
109
110    /// List the MCP prompts.
111    pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
112        self.check_capability(ServerCapability::Prompts)?;
113
114        let response: types::PromptsListResponse = self
115            .inner
116            .request(types::RequestType::PromptsList.as_str(), ())
117            .await?;
118
119        Ok(response.prompts)
120    }
121
122    /// List the MCP resources.
123    pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
124        self.check_capability(ServerCapability::Resources)?;
125
126        let response: types::ResourcesListResponse = self
127            .inner
128            .request(types::RequestType::ResourcesList.as_str(), ())
129            .await?;
130
131        Ok(response)
132    }
133
134    /// Executes a prompt with the given arguments and returns the result.
135    pub async fn run_prompt<P: AsRef<str>>(
136        &self,
137        prompt: P,
138        arguments: HashMap<String, String>,
139    ) -> Result<types::PromptsGetResponse> {
140        self.check_capability(ServerCapability::Prompts)?;
141
142        let params = types::PromptsGetParams {
143            name: prompt.as_ref().to_string(),
144            arguments: Some(arguments),
145        };
146
147        let response: types::PromptsGetResponse = self
148            .inner
149            .request(types::RequestType::PromptsGet.as_str(), params)
150            .await?;
151
152        Ok(response)
153    }
154
155    pub async fn completion<P: Into<String>>(
156        &self,
157        reference: types::CompletionReference,
158        argument: P,
159        value: P,
160    ) -> Result<types::Completion> {
161        let params = types::CompletionCompleteParams {
162            r#ref: reference,
163            argument: types::CompletionArgument {
164                name: argument.into(),
165                value: value.into(),
166            },
167        };
168        let result: types::CompletionCompleteResponse = self
169            .inner
170            .request(types::RequestType::CompletionComplete.as_str(), params)
171            .await?;
172
173        let completion = types::Completion {
174            values: result.completion.values,
175            total: types::CompletionTotal::from_options(
176                result.completion.has_more,
177                result.completion.total,
178            ),
179        };
180
181        Ok(completion)
182    }
183
184    /// List MCP tools.
185    pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
186        self.check_capability(ServerCapability::Tools)?;
187
188        let response = self
189            .inner
190            .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
191            .await?;
192
193        Ok(response)
194    }
195
196    /// Executes a tool with the given arguments
197    pub async fn run_tool<P: AsRef<str>>(
198        &self,
199        tool: P,
200        arguments: Option<HashMap<String, serde_json::Value>>,
201    ) -> Result<types::CallToolResponse> {
202        self.check_capability(ServerCapability::Tools)?;
203
204        let params = types::CallToolParams {
205            name: tool.as_ref().to_string(),
206            arguments,
207        };
208
209        let response: types::CallToolResponse = self
210            .inner
211            .request(types::RequestType::CallTool.as_str(), params)
212            .await?;
213
214        Ok(response)
215    }
216}
217
218impl InitializedContextServerProtocol {
219    pub async fn request<R: serde::de::DeserializeOwned>(
220        &self,
221        method: &str,
222        params: impl serde::Serialize,
223    ) -> Result<R> {
224        self.inner.request(method, params).await
225    }
226}