1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9use collections::HashMap;
10
11use crate::client::Client;
12use crate::types;
13
14const PROTOCOL_VERSION: &str = "2024-10-07";
15
16pub struct ModelContextProtocol {
17 inner: Client,
18}
19
20impl ModelContextProtocol {
21 pub fn new(inner: Client) -> Self {
22 Self { inner }
23 }
24
25 fn supported_protocols() -> Vec<types::ProtocolVersion> {
26 vec![
27 types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
28 types::ProtocolVersion::VersionNumber(1),
29 ]
30 }
31
32 pub async fn initialize(
33 self,
34 client_info: types::Implementation,
35 ) -> Result<InitializedContextServerProtocol> {
36 let params = types::InitializeParams {
37 protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
38 capabilities: types::ClientCapabilities {
39 experimental: None,
40 sampling: None,
41 },
42 client_info,
43 };
44
45 let response: types::InitializeResponse = self
46 .inner
47 .request(types::RequestType::Initialize.as_str(), params)
48 .await?;
49
50 if !Self::supported_protocols().contains(&response.protocol_version) {
51 return Err(anyhow::anyhow!(
52 "Unsupported protocol version: {:?}",
53 response.protocol_version
54 ));
55 }
56
57 log::trace!("mcp server info {:?}", response.server_info);
58
59 self.inner.notify(
60 types::NotificationType::Initialized.as_str(),
61 serde_json::json!({}),
62 )?;
63
64 let initialized_protocol = InitializedContextServerProtocol {
65 inner: self.inner,
66 initialize: response,
67 };
68
69 Ok(initialized_protocol)
70 }
71}
72
73pub struct InitializedContextServerProtocol {
74 inner: Client,
75 pub initialize: types::InitializeResponse,
76}
77
78#[derive(Debug, PartialEq, Clone, Copy)]
79pub enum ServerCapability {
80 Experimental,
81 Logging,
82 Prompts,
83 Resources,
84 Tools,
85}
86
87impl InitializedContextServerProtocol {
88 /// Check if the server supports a specific capability
89 pub fn capable(&self, capability: ServerCapability) -> bool {
90 match capability {
91 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
92 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
93 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
94 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
95 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
96 }
97 }
98
99 fn check_capability(&self, capability: ServerCapability) -> Result<()> {
100 if self.capable(capability) {
101 Ok(())
102 } else {
103 Err(anyhow::anyhow!(
104 "Server does not support {:?} capability",
105 capability
106 ))
107 }
108 }
109
110 /// List the MCP prompts.
111 pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
112 self.check_capability(ServerCapability::Prompts)?;
113
114 let response: types::PromptsListResponse = self
115 .inner
116 .request(
117 types::RequestType::PromptsList.as_str(),
118 serde_json::json!({}),
119 )
120 .await?;
121
122 Ok(response.prompts)
123 }
124
125 /// List the MCP resources.
126 pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
127 self.check_capability(ServerCapability::Resources)?;
128
129 let response: types::ResourcesListResponse = self
130 .inner
131 .request(
132 types::RequestType::ResourcesList.as_str(),
133 serde_json::json!({}),
134 )
135 .await?;
136
137 Ok(response)
138 }
139
140 /// Executes a prompt with the given arguments and returns the result.
141 pub async fn run_prompt<P: AsRef<str>>(
142 &self,
143 prompt: P,
144 arguments: HashMap<String, String>,
145 ) -> Result<types::PromptsGetResponse> {
146 self.check_capability(ServerCapability::Prompts)?;
147
148 let params = types::PromptsGetParams {
149 name: prompt.as_ref().to_string(),
150 arguments: Some(arguments),
151 };
152
153 let response: types::PromptsGetResponse = self
154 .inner
155 .request(types::RequestType::PromptsGet.as_str(), params)
156 .await?;
157
158 Ok(response)
159 }
160
161 pub async fn completion<P: Into<String>>(
162 &self,
163 reference: types::CompletionReference,
164 argument: P,
165 value: P,
166 ) -> Result<types::Completion> {
167 let params = types::CompletionCompleteParams {
168 r#ref: reference,
169 argument: types::CompletionArgument {
170 name: argument.into(),
171 value: value.into(),
172 },
173 };
174 let result: types::CompletionCompleteResponse = self
175 .inner
176 .request(types::RequestType::CompletionComplete.as_str(), params)
177 .await?;
178
179 let completion = types::Completion {
180 values: result.completion.values,
181 total: types::CompletionTotal::from_options(
182 result.completion.has_more,
183 result.completion.total,
184 ),
185 };
186
187 Ok(completion)
188 }
189
190 /// List MCP tools.
191 pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
192 self.check_capability(ServerCapability::Tools)?;
193
194 let response = self
195 .inner
196 .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
197 .await?;
198
199 Ok(response)
200 }
201
202 /// Executes a tool with the given arguments
203 pub async fn run_tool<P: AsRef<str>>(
204 &self,
205 tool: P,
206 arguments: Option<HashMap<String, serde_json::Value>>,
207 ) -> Result<types::CallToolResponse> {
208 self.check_capability(ServerCapability::Tools)?;
209
210 let params = types::CallToolParams {
211 name: tool.as_ref().to_string(),
212 arguments,
213 };
214
215 let response: types::CallToolResponse = self
216 .inner
217 .request(types::RequestType::CallTool.as_str(), params)
218 .await?;
219
220 Ok(response)
221 }
222}
223
224impl InitializedContextServerProtocol {
225 pub async fn request<R: serde::de::DeserializeOwned>(
226 &self,
227 method: &str,
228 params: impl serde::Serialize,
229 ) -> Result<R> {
230 self.inner.request(method, params).await
231 }
232}