1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9use collections::HashMap;
10
11use crate::client::Client;
12use crate::types;
13
14pub struct ModelContextProtocol {
15 inner: Client,
16}
17
18impl ModelContextProtocol {
19 pub(crate) fn new(inner: Client) -> Self {
20 Self { inner }
21 }
22
23 fn supported_protocols() -> Vec<types::ProtocolVersion> {
24 vec![types::ProtocolVersion(
25 types::LATEST_PROTOCOL_VERSION.to_string(),
26 )]
27 }
28
29 pub async fn initialize(
30 self,
31 client_info: types::Implementation,
32 ) -> Result<InitializedContextServerProtocol> {
33 let params = types::InitializeParams {
34 protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
35 capabilities: types::ClientCapabilities {
36 experimental: None,
37 sampling: None,
38 roots: None,
39 },
40 meta: None,
41 client_info,
42 };
43
44 let response: types::InitializeResponse = self
45 .inner
46 .request(types::RequestType::Initialize.as_str(), params)
47 .await?;
48
49 if !Self::supported_protocols().contains(&response.protocol_version) {
50 return Err(anyhow::anyhow!(
51 "Unsupported protocol version: {:?}",
52 response.protocol_version
53 ));
54 }
55
56 log::trace!("mcp server info {:?}", response.server_info);
57
58 self.inner.notify(
59 types::NotificationType::Initialized.as_str(),
60 serde_json::json!({}),
61 )?;
62
63 let initialized_protocol = InitializedContextServerProtocol {
64 inner: self.inner,
65 initialize: response,
66 };
67
68 Ok(initialized_protocol)
69 }
70}
71
72pub struct InitializedContextServerProtocol {
73 inner: Client,
74 pub initialize: types::InitializeResponse,
75}
76
77#[derive(Debug, PartialEq, Clone, Copy)]
78pub enum ServerCapability {
79 Experimental,
80 Logging,
81 Prompts,
82 Resources,
83 Tools,
84}
85
86impl InitializedContextServerProtocol {
87 /// Check if the server supports a specific capability
88 pub fn capable(&self, capability: ServerCapability) -> bool {
89 match capability {
90 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
91 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
92 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
93 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
94 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
95 }
96 }
97
98 fn check_capability(&self, capability: ServerCapability) -> Result<()> {
99 if self.capable(capability) {
100 Ok(())
101 } else {
102 Err(anyhow::anyhow!(
103 "Server does not support {:?} capability",
104 capability
105 ))
106 }
107 }
108
109 /// List the MCP prompts.
110 pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
111 self.check_capability(ServerCapability::Prompts)?;
112
113 let response: types::PromptsListResponse = self
114 .inner
115 .request(
116 types::RequestType::PromptsList.as_str(),
117 serde_json::json!({}),
118 )
119 .await?;
120
121 Ok(response.prompts)
122 }
123
124 /// List the MCP resources.
125 pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
126 self.check_capability(ServerCapability::Resources)?;
127
128 let response: types::ResourcesListResponse = self
129 .inner
130 .request(
131 types::RequestType::ResourcesList.as_str(),
132 serde_json::json!({}),
133 )
134 .await?;
135
136 Ok(response)
137 }
138
139 /// Executes a prompt with the given arguments and returns the result.
140 pub async fn run_prompt<P: AsRef<str>>(
141 &self,
142 prompt: P,
143 arguments: HashMap<String, String>,
144 ) -> Result<types::PromptsGetResponse> {
145 self.check_capability(ServerCapability::Prompts)?;
146
147 let params = types::PromptsGetParams {
148 name: prompt.as_ref().to_string(),
149 arguments: Some(arguments),
150 meta: None,
151 };
152
153 let response: types::PromptsGetResponse = self
154 .inner
155 .request(types::RequestType::PromptsGet.as_str(), params)
156 .await?;
157
158 Ok(response)
159 }
160
161 pub async fn completion<P: Into<String>>(
162 &self,
163 reference: types::CompletionReference,
164 argument: P,
165 value: P,
166 ) -> Result<types::Completion> {
167 let params = types::CompletionCompleteParams {
168 r#ref: reference,
169 argument: types::CompletionArgument {
170 name: argument.into(),
171 value: value.into(),
172 },
173 meta: None,
174 };
175 let result: types::CompletionCompleteResponse = self
176 .inner
177 .request(types::RequestType::CompletionComplete.as_str(), params)
178 .await?;
179
180 let completion = types::Completion {
181 values: result.completion.values,
182 total: types::CompletionTotal::from_options(
183 result.completion.has_more,
184 result.completion.total,
185 ),
186 };
187
188 Ok(completion)
189 }
190
191 /// List MCP tools.
192 pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
193 self.check_capability(ServerCapability::Tools)?;
194
195 let response = self
196 .inner
197 .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
198 .await?;
199
200 Ok(response)
201 }
202
203 /// Executes a tool with the given arguments
204 pub async fn run_tool<P: AsRef<str>>(
205 &self,
206 tool: P,
207 arguments: Option<HashMap<String, serde_json::Value>>,
208 ) -> Result<types::CallToolResponse> {
209 self.check_capability(ServerCapability::Tools)?;
210
211 let params = types::CallToolParams {
212 name: tool.as_ref().to_string(),
213 arguments,
214 meta: None,
215 };
216
217 let response: types::CallToolResponse = self
218 .inner
219 .request(types::RequestType::CallTool.as_str(), params)
220 .await?;
221
222 Ok(response)
223 }
224}
225
226impl InitializedContextServerProtocol {
227 pub async fn request<R: serde::de::DeserializeOwned>(
228 &self,
229 method: &str,
230 params: impl serde::Serialize,
231 ) -> Result<R> {
232 self.inner.request(method, params).await
233 }
234}