protocol.rs

  1//! This module implements parts of the Model Context Protocol.
  2//!
  3//! It handles the lifecycle messages, and provides a general interface to
  4//! interacting with an MCP server. It uses the generic JSON-RPC client to
  5//! read/write messages and the types from types.rs for serialization/deserialization
  6//! of messages.
  7
  8use anyhow::Result;
  9use collections::HashMap;
 10
 11use crate::client::Client;
 12use crate::types;
 13
 14pub struct ModelContextProtocol {
 15    inner: Client,
 16}
 17
 18impl ModelContextProtocol {
 19    pub(crate) fn new(inner: Client) -> Self {
 20        Self { inner }
 21    }
 22
 23    fn supported_protocols() -> Vec<types::ProtocolVersion> {
 24        vec![types::ProtocolVersion(
 25            types::LATEST_PROTOCOL_VERSION.to_string(),
 26        )]
 27    }
 28
 29    pub async fn initialize(
 30        self,
 31        client_info: types::Implementation,
 32    ) -> Result<InitializedContextServerProtocol> {
 33        let params = types::InitializeParams {
 34            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
 35            capabilities: types::ClientCapabilities {
 36                experimental: None,
 37                sampling: None,
 38                roots: None,
 39            },
 40            meta: None,
 41            client_info,
 42        };
 43
 44        let response: types::InitializeResponse = self
 45            .inner
 46            .request(types::RequestType::Initialize.as_str(), params)
 47            .await?;
 48
 49        if !Self::supported_protocols().contains(&response.protocol_version) {
 50            return Err(anyhow::anyhow!(
 51                "Unsupported protocol version: {:?}",
 52                response.protocol_version
 53            ));
 54        }
 55
 56        log::trace!("mcp server info {:?}", response.server_info);
 57
 58        self.inner.notify(
 59            types::NotificationType::Initialized.as_str(),
 60            serde_json::json!({}),
 61        )?;
 62
 63        let initialized_protocol = InitializedContextServerProtocol {
 64            inner: self.inner,
 65            initialize: response,
 66        };
 67
 68        Ok(initialized_protocol)
 69    }
 70}
 71
 72pub struct InitializedContextServerProtocol {
 73    inner: Client,
 74    pub initialize: types::InitializeResponse,
 75}
 76
 77#[derive(Debug, PartialEq, Clone, Copy)]
 78pub enum ServerCapability {
 79    Experimental,
 80    Logging,
 81    Prompts,
 82    Resources,
 83    Tools,
 84}
 85
 86impl InitializedContextServerProtocol {
 87    /// Check if the server supports a specific capability
 88    pub fn capable(&self, capability: ServerCapability) -> bool {
 89        match capability {
 90            ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
 91            ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
 92            ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
 93            ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
 94            ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
 95        }
 96    }
 97
 98    fn check_capability(&self, capability: ServerCapability) -> Result<()> {
 99        if self.capable(capability) {
100            Ok(())
101        } else {
102            Err(anyhow::anyhow!(
103                "Server does not support {:?} capability",
104                capability
105            ))
106        }
107    }
108
109    /// List the MCP prompts.
110    pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
111        self.check_capability(ServerCapability::Prompts)?;
112
113        let response: types::PromptsListResponse = self
114            .inner
115            .request(
116                types::RequestType::PromptsList.as_str(),
117                serde_json::json!({}),
118            )
119            .await?;
120
121        Ok(response.prompts)
122    }
123
124    /// List the MCP resources.
125    pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
126        self.check_capability(ServerCapability::Resources)?;
127
128        let response: types::ResourcesListResponse = self
129            .inner
130            .request(
131                types::RequestType::ResourcesList.as_str(),
132                serde_json::json!({}),
133            )
134            .await?;
135
136        Ok(response)
137    }
138
139    /// Executes a prompt with the given arguments and returns the result.
140    pub async fn run_prompt<P: AsRef<str>>(
141        &self,
142        prompt: P,
143        arguments: HashMap<String, String>,
144    ) -> Result<types::PromptsGetResponse> {
145        self.check_capability(ServerCapability::Prompts)?;
146
147        let params = types::PromptsGetParams {
148            name: prompt.as_ref().to_string(),
149            arguments: Some(arguments),
150            meta: None,
151        };
152
153        let response: types::PromptsGetResponse = self
154            .inner
155            .request(types::RequestType::PromptsGet.as_str(), params)
156            .await?;
157
158        Ok(response)
159    }
160
161    pub async fn completion<P: Into<String>>(
162        &self,
163        reference: types::CompletionReference,
164        argument: P,
165        value: P,
166    ) -> Result<types::Completion> {
167        let params = types::CompletionCompleteParams {
168            r#ref: reference,
169            argument: types::CompletionArgument {
170                name: argument.into(),
171                value: value.into(),
172            },
173            meta: None,
174        };
175        let result: types::CompletionCompleteResponse = self
176            .inner
177            .request(types::RequestType::CompletionComplete.as_str(), params)
178            .await?;
179
180        let completion = types::Completion {
181            values: result.completion.values,
182            total: types::CompletionTotal::from_options(
183                result.completion.has_more,
184                result.completion.total,
185            ),
186        };
187
188        Ok(completion)
189    }
190
191    /// List MCP tools.
192    pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
193        self.check_capability(ServerCapability::Tools)?;
194
195        let response = self
196            .inner
197            .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
198            .await?;
199
200        Ok(response)
201    }
202
203    /// Executes a tool with the given arguments
204    pub async fn run_tool<P: AsRef<str>>(
205        &self,
206        tool: P,
207        arguments: Option<HashMap<String, serde_json::Value>>,
208    ) -> Result<types::CallToolResponse> {
209        self.check_capability(ServerCapability::Tools)?;
210
211        let params = types::CallToolParams {
212            name: tool.as_ref().to_string(),
213            arguments,
214            meta: None,
215        };
216
217        let response: types::CallToolResponse = self
218            .inner
219            .request(types::RequestType::CallTool.as_str(), params)
220            .await?;
221
222        Ok(response)
223    }
224}
225
226impl InitializedContextServerProtocol {
227    pub async fn request<R: serde::de::DeserializeOwned>(
228        &self,
229        method: &str,
230        params: impl serde::Serialize,
231    ) -> Result<R> {
232        self.inner.request(method, params).await
233    }
234}