1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9use collections::HashMap;
10
11use crate::client::Client;
12use crate::types;
13
14pub struct ModelContextProtocol {
15 inner: Client,
16}
17
18impl ModelContextProtocol {
19 pub(crate) fn new(inner: Client) -> Self {
20 Self { inner }
21 }
22
23 fn supported_protocols() -> Vec<types::ProtocolVersion> {
24 vec![types::ProtocolVersion(
25 types::LATEST_PROTOCOL_VERSION.to_string(),
26 )]
27 }
28
29 pub async fn initialize(
30 self,
31 client_info: types::Implementation,
32 ) -> Result<InitializedContextServerProtocol> {
33 let params = types::InitializeParams {
34 protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
35 capabilities: types::ClientCapabilities {
36 experimental: None,
37 sampling: None,
38 roots: None,
39 },
40 meta: None,
41 client_info,
42 };
43
44 let response: types::InitializeResponse = self
45 .inner
46 .request(types::RequestType::Initialize.as_str(), params)
47 .await?;
48
49 anyhow::ensure!(
50 Self::supported_protocols().contains(&response.protocol_version),
51 "Unsupported protocol version: {:?}",
52 response.protocol_version
53 );
54
55 log::trace!("mcp server info {:?}", response.server_info);
56
57 self.inner.notify(
58 types::NotificationType::Initialized.as_str(),
59 serde_json::json!({}),
60 )?;
61
62 let initialized_protocol = InitializedContextServerProtocol {
63 inner: self.inner,
64 initialize: response,
65 };
66
67 Ok(initialized_protocol)
68 }
69}
70
71pub struct InitializedContextServerProtocol {
72 inner: Client,
73 pub initialize: types::InitializeResponse,
74}
75
76#[derive(Debug, PartialEq, Clone, Copy)]
77pub enum ServerCapability {
78 Experimental,
79 Logging,
80 Prompts,
81 Resources,
82 Tools,
83}
84
85impl InitializedContextServerProtocol {
86 /// Check if the server supports a specific capability
87 pub fn capable(&self, capability: ServerCapability) -> bool {
88 match capability {
89 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
90 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
91 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
92 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
93 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
94 }
95 }
96
97 fn check_capability(&self, capability: ServerCapability) -> Result<()> {
98 anyhow::ensure!(
99 self.capable(capability),
100 "Server does not support {capability:?} capability"
101 );
102 Ok(())
103 }
104
105 /// List the MCP prompts.
106 pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
107 self.check_capability(ServerCapability::Prompts)?;
108
109 let response: types::PromptsListResponse = self
110 .inner
111 .request(
112 types::RequestType::PromptsList.as_str(),
113 serde_json::json!({}),
114 )
115 .await?;
116
117 Ok(response.prompts)
118 }
119
120 /// List the MCP resources.
121 pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
122 self.check_capability(ServerCapability::Resources)?;
123
124 let response: types::ResourcesListResponse = self
125 .inner
126 .request(
127 types::RequestType::ResourcesList.as_str(),
128 serde_json::json!({}),
129 )
130 .await?;
131
132 Ok(response)
133 }
134
135 /// Executes a prompt with the given arguments and returns the result.
136 pub async fn run_prompt<P: AsRef<str>>(
137 &self,
138 prompt: P,
139 arguments: HashMap<String, String>,
140 ) -> Result<types::PromptsGetResponse> {
141 self.check_capability(ServerCapability::Prompts)?;
142
143 let params = types::PromptsGetParams {
144 name: prompt.as_ref().to_string(),
145 arguments: Some(arguments),
146 meta: None,
147 };
148
149 let response: types::PromptsGetResponse = self
150 .inner
151 .request(types::RequestType::PromptsGet.as_str(), params)
152 .await?;
153
154 Ok(response)
155 }
156
157 pub async fn completion<P: Into<String>>(
158 &self,
159 reference: types::CompletionReference,
160 argument: P,
161 value: P,
162 ) -> Result<types::Completion> {
163 let params = types::CompletionCompleteParams {
164 r#ref: reference,
165 argument: types::CompletionArgument {
166 name: argument.into(),
167 value: value.into(),
168 },
169 meta: None,
170 };
171 let result: types::CompletionCompleteResponse = self
172 .inner
173 .request(types::RequestType::CompletionComplete.as_str(), params)
174 .await?;
175
176 let completion = types::Completion {
177 values: result.completion.values,
178 total: types::CompletionTotal::from_options(
179 result.completion.has_more,
180 result.completion.total,
181 ),
182 };
183
184 Ok(completion)
185 }
186
187 /// List MCP tools.
188 pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
189 self.check_capability(ServerCapability::Tools)?;
190
191 let response = self
192 .inner
193 .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
194 .await?;
195
196 Ok(response)
197 }
198
199 /// Executes a tool with the given arguments
200 pub async fn run_tool<P: AsRef<str>>(
201 &self,
202 tool: P,
203 arguments: Option<HashMap<String, serde_json::Value>>,
204 ) -> Result<types::CallToolResponse> {
205 self.check_capability(ServerCapability::Tools)?;
206
207 let params = types::CallToolParams {
208 name: tool.as_ref().to_string(),
209 arguments,
210 meta: None,
211 };
212
213 let response: types::CallToolResponse = self
214 .inner
215 .request(types::RequestType::CallTool.as_str(), params)
216 .await?;
217
218 Ok(response)
219 }
220}
221
222impl InitializedContextServerProtocol {
223 pub async fn request<R: serde::de::DeserializeOwned>(
224 &self,
225 method: &str,
226 params: impl serde::Serialize,
227 ) -> Result<R> {
228 self.inner.request(method, params).await
229 }
230}