1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9use collections::HashMap;
10
11use crate::client::Client;
12use crate::types;
13
14const PROTOCOL_VERSION: u32 = 1;
15
16pub struct ModelContextProtocol {
17 inner: Client,
18}
19
20impl ModelContextProtocol {
21 pub fn new(inner: Client) -> Self {
22 Self { inner }
23 }
24
25 pub async fn initialize(
26 self,
27 client_info: types::Implementation,
28 ) -> Result<InitializedContextServerProtocol> {
29 let params = types::InitializeParams {
30 protocol_version: PROTOCOL_VERSION,
31 capabilities: types::ClientCapabilities {
32 experimental: None,
33 sampling: None,
34 },
35 client_info,
36 };
37
38 let response: types::InitializeResponse = self
39 .inner
40 .request(types::RequestType::Initialize.as_str(), params)
41 .await?;
42
43 log::trace!("mcp server info {:?}", response.server_info);
44
45 self.inner.notify(
46 types::NotificationType::Initialized.as_str(),
47 serde_json::json!({}),
48 )?;
49
50 let initialized_protocol = InitializedContextServerProtocol {
51 inner: self.inner,
52 initialize: response,
53 };
54
55 Ok(initialized_protocol)
56 }
57}
58
59pub struct InitializedContextServerProtocol {
60 inner: Client,
61 pub initialize: types::InitializeResponse,
62}
63
64#[derive(Debug, PartialEq, Clone, Copy)]
65pub enum ServerCapability {
66 Experimental,
67 Logging,
68 Prompts,
69 Resources,
70 Tools,
71}
72
73impl InitializedContextServerProtocol {
74 /// Check if the server supports a specific capability
75 pub fn capable(&self, capability: ServerCapability) -> bool {
76 match capability {
77 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
78 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
79 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
80 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
81 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
82 }
83 }
84
85 fn check_capability(&self, capability: ServerCapability) -> Result<()> {
86 if self.capable(capability) {
87 Ok(())
88 } else {
89 Err(anyhow::anyhow!(
90 "Server does not support {:?} capability",
91 capability
92 ))
93 }
94 }
95
96 /// List the MCP prompts.
97 pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
98 self.check_capability(ServerCapability::Prompts)?;
99
100 let response: types::PromptsListResponse = self
101 .inner
102 .request(types::RequestType::PromptsList.as_str(), ())
103 .await?;
104
105 Ok(response.prompts)
106 }
107
108 /// List the MCP resources.
109 pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
110 self.check_capability(ServerCapability::Resources)?;
111
112 let response: types::ResourcesListResponse = self
113 .inner
114 .request(types::RequestType::ResourcesList.as_str(), ())
115 .await?;
116
117 Ok(response)
118 }
119
120 /// Executes a prompt with the given arguments and returns the result.
121 pub async fn run_prompt<P: AsRef<str>>(
122 &self,
123 prompt: P,
124 arguments: HashMap<String, String>,
125 ) -> Result<types::PromptsGetResponse> {
126 self.check_capability(ServerCapability::Prompts)?;
127
128 let params = types::PromptsGetParams {
129 name: prompt.as_ref().to_string(),
130 arguments: Some(arguments),
131 };
132
133 let response: types::PromptsGetResponse = self
134 .inner
135 .request(types::RequestType::PromptsGet.as_str(), params)
136 .await?;
137
138 Ok(response)
139 }
140
141 pub async fn completion<P: Into<String>>(
142 &self,
143 reference: types::CompletionReference,
144 argument: P,
145 value: P,
146 ) -> Result<types::Completion> {
147 let params = types::CompletionCompleteParams {
148 r#ref: reference,
149 argument: types::CompletionArgument {
150 name: argument.into(),
151 value: value.into(),
152 },
153 };
154 let result: types::CompletionCompleteResponse = self
155 .inner
156 .request(types::RequestType::CompletionComplete.as_str(), params)
157 .await?;
158
159 let completion = types::Completion {
160 values: result.completion.values,
161 total: types::CompletionTotal::from_options(
162 result.completion.has_more,
163 result.completion.total,
164 ),
165 };
166
167 Ok(completion)
168 }
169}
170
171impl InitializedContextServerProtocol {
172 pub async fn request<R: serde::de::DeserializeOwned>(
173 &self,
174 method: &str,
175 params: impl serde::Serialize,
176 ) -> Result<R> {
177 self.inner.request(method, params).await
178 }
179}