@@ -87,7 +87,7 @@ impl SlashCommand for ContextServerSlashCommand {
let protocol = server.client().context("Context server not initialized")?;
let response = protocol
- .request::<context_server::types::request::CompletionComplete>(
+ .request::<context_server::types::requests::CompletionComplete>(
context_server::types::CompletionCompleteParams {
reference: context_server::types::CompletionReference::Prompt(
context_server::types::PromptReference {
@@ -145,7 +145,7 @@ impl SlashCommand for ContextServerSlashCommand {
cx.foreground_executor().spawn(async move {
let protocol = server.client().context("Context server not initialized")?;
let response = protocol
- .request::<context_server::types::request::PromptsGet>(
+ .request::<context_server::types::requests::PromptsGet>(
context_server::types::PromptsGetParams {
name: prompt_name.clone(),
arguments: Some(prompt_args),
@@ -8,7 +8,7 @@
use anyhow::Result;
use crate::client::Client;
-use crate::types::{self, Request};
+use crate::types::{self, Notification, Request};
pub struct ModelContextProtocol {
inner: Client,
@@ -43,7 +43,7 @@ impl ModelContextProtocol {
let response: types::InitializeResponse = self
.inner
- .request(types::request::Initialize::METHOD, params)
+ .request(types::requests::Initialize::METHOD, params)
.await?;
anyhow::ensure!(
@@ -54,16 +54,13 @@ impl ModelContextProtocol {
log::trace!("mcp server info {:?}", response.server_info);
- self.inner.notify(
- types::NotificationType::Initialized.as_str(),
- serde_json::json!({}),
- )?;
-
let initialized_protocol = InitializedContextServerProtocol {
inner: self.inner,
initialize: response,
};
+ initialized_protocol.notify::<types::notifications::Initialized>(())?;
+
Ok(initialized_protocol)
}
}
@@ -97,4 +94,8 @@ impl InitializedContextServerProtocol {
pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
self.inner.request(T::METHOD, params).await
}
+
+ pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
+ self.inner.notify(T::METHOD, params)
+ }
}
@@ -6,7 +6,7 @@ use url::Url;
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
pub const VERSION_2024_11_05: &str = "2024-11-05";
-pub mod request {
+pub mod requests {
use super::*;
macro_rules! request {
@@ -83,6 +83,57 @@ pub trait Request {
const METHOD: &'static str;
}
+pub mod notifications {
+ use super::*;
+
+ macro_rules! notification {
+ ($method:expr, $name:ident, $params:ty) => {
+ pub struct $name;
+
+ impl Notification for $name {
+ type Params = $params;
+ const METHOD: &'static str = $method;
+ }
+ };
+ }
+
+ notification!("notifications/initialized", Initialized, ());
+ notification!("notifications/progress", Progress, ProgressParams);
+ notification!("notifications/message", Message, MessageParams);
+ notification!(
+ "notifications/resources/updated",
+ ResourcesUpdated,
+ ResourcesUpdatedParams
+ );
+ notification!(
+ "notifications/resources/list_changed",
+ ResourcesListChanged,
+ ()
+ );
+ notification!("notifications/tools/list_changed", ToolsListChanged, ());
+ notification!("notifications/prompts/list_changed", PromptsListChanged, ());
+ notification!("notifications/roots/list_changed", RootsListChanged, ());
+}
+
+pub trait Notification {
+ type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
+ const METHOD: &'static str;
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct MessageParams {
+ pub level: LoggingLevel,
+ pub logger: Option<String>,
+ pub data: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesUpdatedParams {
+ pub uri: String,
+}
+
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ProtocolVersion(pub String);
@@ -560,34 +611,6 @@ pub struct ModelHint {
pub name: Option<String>,
}
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub enum NotificationType {
- Initialized,
- Progress,
- Message,
- ResourcesUpdated,
- ResourcesListChanged,
- ToolsListChanged,
- PromptsListChanged,
- RootsListChanged,
-}
-
-impl NotificationType {
- pub fn as_str(&self) -> &'static str {
- match self {
- NotificationType::Initialized => "notifications/initialized",
- NotificationType::Progress => "notifications/progress",
- NotificationType::Message => "notifications/message",
- NotificationType::ResourcesUpdated => "notifications/resources/updated",
- NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
- NotificationType::ToolsListChanged => "notifications/tools/list_changed",
- NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
- NotificationType::RootsListChanged => "notifications/roots/list_changed",
- }
- }
-}
-
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum ClientNotification {
@@ -608,7 +631,7 @@ pub enum ProgressToken {
Number(f64),
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProgressParams {
pub progress_token: ProgressToken,