protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9use collections::HashMap;
 10
 11use crate::client::Client;
 12use crate::types;
 13
 14pub struct ModelContextProtocol {
 15    inner: Client,
 16}
 17
 18impl ModelContextProtocol {
 19    pub(crate) fn new(inner: Client) -> Self {
 20        Self { inner }
 21    }
 22
 23    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 24        vec![types::ProtocolVersion(
 25            types::LATEST_PROTOCOL_VERSION.to_string(),
 26        )]
 27    }
 28
 29    pub async fn initialize(
 30        self,
 31        client_info: types::Implementation,
 32    ) -> Result<InitializedContextServerProtocol> {
 33        let params = types::InitializeParams {
 34            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 35            capabilities: types::ClientCapabilities {
 36                experimental: None,
 37                sampling: None,
 38                roots: None,
 39            },
 40            meta: None,
 41            client_info,
 42        };
 43
 44        let response: types::InitializeResponse = self
 45            .inner
 46            .request(types::RequestType::Initialize.as_str(), params)
 47            .await?;
 48
 49        anyhow::ensure!(
 50            Self::supported_protocols().contains(&response.protocol_version),
 51            "Unsupported protocol version: {:?}",
 52            response.protocol_version
 53        );
 54
 55        log::trace!("mcp server info {:?}", response.server_info);
 56
 57        self.inner.notify(
 58            types::NotificationType::Initialized.as_str(),
 59            serde_json::json!({}),
 60        )?;
 61
 62        let initialized_protocol = InitializedContextServerProtocol {
 63            inner: self.inner,
 64            initialize: response,
 65        };
 66
 67        Ok(initialized_protocol)
 68    }
 69}
 70
 71pub struct InitializedContextServerProtocol {
 72    inner: Client,
 73    pub initialize: types::InitializeResponse,
 74}
 75
 76#[derive(Debug, PartialEq, Clone, Copy)]
 77pub enum ServerCapability {
 78    Experimental,
 79    Logging,
 80    Prompts,
 81    Resources,
 82    Tools,
 83}
 84
 85impl InitializedContextServerProtocol {
 86    /// Check if the server supports a specific capability
 87    pub fn capable(&self, capability: ServerCapability) -> bool {
 88        match capability {
 89            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 90            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 91            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 92            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 93            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 94        }
 95    }
 96
 97    fn check_capability(&self, capability: ServerCapability) -> Result<()> {
 98        anyhow::ensure!(
 99            self.capable(capability),
100            "Server does not support {capability:?} capability"
101        );
102        Ok(())
103    }
104
105    /// List the MCP prompts.
106    pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
107        self.check_capability(ServerCapability::Prompts)?;
108
109        let response: types::PromptsListResponse = self
110            .inner
111            .request(
112                types::RequestType::PromptsList.as_str(),
113                serde_json::json!({}),
114            )
115            .await?;
116
117        Ok(response.prompts)
118    }
119
120    /// List the MCP resources.
121    pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
122        self.check_capability(ServerCapability::Resources)?;
123
124        let response: types::ResourcesListResponse = self
125            .inner
126            .request(
127                types::RequestType::ResourcesList.as_str(),
128                serde_json::json!({}),
129            )
130            .await?;
131
132        Ok(response)
133    }
134
135    /// Executes a prompt with the given arguments and returns the result.
136    pub async fn run_prompt<P: AsRef<str>>(
137        &self,
138        prompt: P,
139        arguments: HashMap<String, String>,
140    ) -> Result<types::PromptsGetResponse> {
141        self.check_capability(ServerCapability::Prompts)?;
142
143        let params = types::PromptsGetParams {
144            name: prompt.as_ref().to_string(),
145            arguments: Some(arguments),
146            meta: None,
147        };
148
149        let response: types::PromptsGetResponse = self
150            .inner
151            .request(types::RequestType::PromptsGet.as_str(), params)
152            .await?;
153
154        Ok(response)
155    }
156
157    pub async fn completion<P: Into<String>>(
158        &self,
159        reference: types::CompletionReference,
160        argument: P,
161        value: P,
162    ) -> Result<types::Completion> {
163        let params = types::CompletionCompleteParams {
164            r#ref: reference,
165            argument: types::CompletionArgument {
166                name: argument.into(),
167                value: value.into(),
168            },
169            meta: None,
170        };
171        let result: types::CompletionCompleteResponse = self
172            .inner
173            .request(types::RequestType::CompletionComplete.as_str(), params)
174            .await?;
175
176        let completion = types::Completion {
177            values: result.completion.values,
178            total: types::CompletionTotal::from_options(
179                result.completion.has_more,
180                result.completion.total,
181            ),
182        };
183
184        Ok(completion)
185    }
186
187    /// List MCP tools.
188    pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
189        self.check_capability(ServerCapability::Tools)?;
190
191        let response = self
192            .inner
193            .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
194            .await?;
195
196        Ok(response)
197    }
198
199    /// Executes a tool with the given arguments
200    pub async fn run_tool<P: AsRef<str>>(
201        &self,
202        tool: P,
203        arguments: Option<HashMap<String, serde_json::Value>>,
204    ) -> Result<types::CallToolResponse> {
205        self.check_capability(ServerCapability::Tools)?;
206
207        let params = types::CallToolParams {
208            name: tool.as_ref().to_string(),
209            arguments,
210            meta: None,
211        };
212
213        let response: types::CallToolResponse = self
214            .inner
215            .request(types::RequestType::CallTool.as_str(), params)
216            .await?;
217
218        Ok(response)
219    }
220}
221
222impl InitializedContextServerProtocol {
223    pub async fn request<R: serde::de::DeserializeOwned>(
224        &self,
225        method: &str,
226        params: impl serde::Serialize,
227    ) -> Result<R> {
228        self.inner.request(method, params).await
229    }
230}