1//! This module implements parts of the Model Context Protocol.
2//!
3//! It handles the lifecycle messages, and provides a general interface to
4//! interacting with an MCP server. It uses the generic JSON-RPC client to
5//! read/write messages and the types from types.rs for serialization/deserialization
6//! of messages.
7
8use anyhow::Result;
9use collections::HashMap;
10
11use crate::client::Client;
12use crate::types;
13
14pub use types::PromptInfo;
15
16const PROTOCOL_VERSION: u32 = 1;
17
18pub struct ModelContextProtocol {
19 inner: Client,
20}
21
22impl ModelContextProtocol {
23 pub fn new(inner: Client) -> Self {
24 Self { inner }
25 }
26
27 pub async fn initialize(
28 self,
29 client_info: types::EntityInfo,
30 ) -> Result<InitializedContextServerProtocol> {
31 let params = types::InitializeParams {
32 protocol_version: PROTOCOL_VERSION,
33 capabilities: types::ClientCapabilities {
34 experimental: None,
35 sampling: None,
36 },
37 client_info,
38 };
39
40 let response: types::InitializeResponse = self
41 .inner
42 .request(types::RequestType::Initialize.as_str(), params)
43 .await?;
44
45 log::trace!("mcp server info {:?}", response.server_info);
46
47 self.inner.notify(
48 types::NotificationType::Initialized.as_str(),
49 serde_json::json!({}),
50 )?;
51
52 let initialized_protocol = InitializedContextServerProtocol {
53 inner: self.inner,
54 initialize: response,
55 };
56
57 Ok(initialized_protocol)
58 }
59}
60
61pub struct InitializedContextServerProtocol {
62 inner: Client,
63 pub initialize: types::InitializeResponse,
64}
65
66#[derive(Debug, PartialEq, Clone, Copy)]
67pub enum ServerCapability {
68 Experimental,
69 Logging,
70 Prompts,
71 Resources,
72 Tools,
73}
74
75impl InitializedContextServerProtocol {
76 /// Check if the server supports a specific capability
77 pub fn capable(&self, capability: ServerCapability) -> bool {
78 match capability {
79 ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
80 ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
81 ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
82 ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
83 ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
84 }
85 }
86
87 fn check_capability(&self, capability: ServerCapability) -> Result<()> {
88 if self.capable(capability) {
89 Ok(())
90 } else {
91 Err(anyhow::anyhow!(
92 "Server does not support {:?} capability",
93 capability
94 ))
95 }
96 }
97
98 /// List the MCP prompts.
99 pub async fn list_prompts(&self) -> Result<Vec<types::PromptInfo>> {
100 self.check_capability(ServerCapability::Prompts)?;
101
102 let response: types::PromptsListResponse = self
103 .inner
104 .request(types::RequestType::PromptsList.as_str(), ())
105 .await?;
106
107 Ok(response.prompts)
108 }
109
110 /// Executes a prompt with the given arguments and returns the result.
111 pub async fn run_prompt<P: AsRef<str>>(
112 &self,
113 prompt: P,
114 arguments: HashMap<String, String>,
115 ) -> Result<String> {
116 self.check_capability(ServerCapability::Prompts)?;
117
118 let params = types::PromptsGetParams {
119 name: prompt.as_ref().to_string(),
120 arguments: Some(arguments),
121 };
122
123 let response: types::PromptsGetResponse = self
124 .inner
125 .request(types::RequestType::PromptsGet.as_str(), params)
126 .await?;
127
128 Ok(response.prompt)
129 }
130}
131
132impl InitializedContextServerProtocol {
133 pub async fn request<R: serde::de::DeserializeOwned>(
134 &self,
135 method: &str,
136 params: impl serde::Serialize,
137 ) -> Result<R> {
138 self.inner.request(method, params).await
139 }
140}