1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9use collections::HashMap;
10
11use crate::client::Client;
12use crate::types;
13
14const PROTOCOL_VERSION: &str = "2024-10-07";
15
16pub struct ModelContextProtocol {
17 inner: Client,
18}
19
20impl ModelContextProtocol {
21 pub fn new(inner: Client) -> Self {
22 Self { inner }
23 }
24
25 fn supported_protocols() -> Vec<types::ProtocolVersion> {
26 vec![
27 types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
28 types::ProtocolVersion::VersionNumber(1),
29 ]
30 }
31
32 pub async fn initialize(
33 self,
34 client_info: types::Implementation,
35 ) -> Result<InitializedContextServerProtocol> {
36 let params = types::InitializeParams {
37 protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
38 capabilities: types::ClientCapabilities {
39 experimental: None,
40 sampling: None,
41 },
42 client_info,
43 };
44
45 let response: types::InitializeResponse = self
46 .inner
47 .request(types::RequestType::Initialize.as_str(), params)
48 .await?;
49
50 if !Self::supported_protocols().contains(&response.protocol_version) {
51 return Err(anyhow::anyhow!(
52 "Unsupported protocol version: {:?}",
53 response.protocol_version
54 ));
55 }
56
57 log::trace!("mcp server info {:?}", response.server_info);
58
59 self.inner.notify(
60 types::NotificationType::Initialized.as_str(),
61 serde_json::json!({}),
62 )?;
63
64 let initialized_protocol = InitializedContextServerProtocol {
65 inner: self.inner,
66 initialize: response,
67 };
68
69 Ok(initialized_protocol)
70 }
71}
72
73pub struct InitializedContextServerProtocol {
74 inner: Client,
75 pub initialize: types::InitializeResponse,
76}
77
78#[derive(Debug, PartialEq, Clone, Copy)]
79pub enum ServerCapability {
80 Experimental,
81 Logging,
82 Prompts,
83 Resources,
84 Tools,
85}
86
87impl InitializedContextServerProtocol {
88 /// Check if the server supports a specific capability
89 pub fn capable(&self, capability: ServerCapability) -> bool {
90 match capability {
91 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
92 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
93 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
94 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
95 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
96 }
97 }
98
99 fn check_capability(&self, capability: ServerCapability) -> Result<()> {
100 if self.capable(capability) {
101 Ok(())
102 } else {
103 Err(anyhow::anyhow!(
104 "Server does not support {:?} capability",
105 capability
106 ))
107 }
108 }
109
110 /// List the MCP prompts.
111 pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
112 self.check_capability(ServerCapability::Prompts)?;
113
114 let response: types::PromptsListResponse = self
115 .inner
116 .request(types::RequestType::PromptsList.as_str(), ())
117 .await?;
118
119 Ok(response.prompts)
120 }
121
122 /// List the MCP resources.
123 pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
124 self.check_capability(ServerCapability::Resources)?;
125
126 let response: types::ResourcesListResponse = self
127 .inner
128 .request(types::RequestType::ResourcesList.as_str(), ())
129 .await?;
130
131 Ok(response)
132 }
133
134 /// Executes a prompt with the given arguments and returns the result.
135 pub async fn run_prompt<P: AsRef<str>>(
136 &self,
137 prompt: P,
138 arguments: HashMap<String, String>,
139 ) -> Result<types::PromptsGetResponse> {
140 self.check_capability(ServerCapability::Prompts)?;
141
142 let params = types::PromptsGetParams {
143 name: prompt.as_ref().to_string(),
144 arguments: Some(arguments),
145 };
146
147 let response: types::PromptsGetResponse = self
148 .inner
149 .request(types::RequestType::PromptsGet.as_str(), params)
150 .await?;
151
152 Ok(response)
153 }
154
155 pub async fn completion<P: Into<String>>(
156 &self,
157 reference: types::CompletionReference,
158 argument: P,
159 value: P,
160 ) -> Result<types::Completion> {
161 let params = types::CompletionCompleteParams {
162 r#ref: reference,
163 argument: types::CompletionArgument {
164 name: argument.into(),
165 value: value.into(),
166 },
167 };
168 let result: types::CompletionCompleteResponse = self
169 .inner
170 .request(types::RequestType::CompletionComplete.as_str(), params)
171 .await?;
172
173 let completion = types::Completion {
174 values: result.completion.values,
175 total: types::CompletionTotal::from_options(
176 result.completion.has_more,
177 result.completion.total,
178 ),
179 };
180
181 Ok(completion)
182 }
183
184 /// List MCP tools.
185 pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
186 self.check_capability(ServerCapability::Tools)?;
187
188 let response = self
189 .inner
190 .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
191 .await?;
192
193 Ok(response)
194 }
195
196 /// Executes a tool with the given arguments
197 pub async fn run_tool<P: AsRef<str>>(
198 &self,
199 tool: P,
200 arguments: Option<HashMap<String, serde_json::Value>>,
201 ) -> Result<types::CallToolResponse> {
202 self.check_capability(ServerCapability::Tools)?;
203
204 let params = types::CallToolParams {
205 name: tool.as_ref().to_string(),
206 arguments,
207 };
208
209 let response: types::CallToolResponse = self
210 .inner
211 .request(types::RequestType::CallTool.as_str(), params)
212 .await?;
213
214 Ok(response)
215 }
216}
217
218impl InitializedContextServerProtocol {
219 pub async fn request<R: serde::de::DeserializeOwned>(
220 &self,
221 method: &str,
222 params: impl serde::Serialize,
223 ) -> Result<R> {
224 self.inner.request(method, params).await
225 }
226}