context_server: Make notifications type safe (#32396)

Bennet Bo Fenner created

Follow up to #32254 

Release Notes:

- N/A

Change summary

crates/agent/src/context_server_tool.rs                       |  2 
crates/agent/src/thread_store.rs                              |  2 
crates/assistant_context_editor/src/context_store.rs          |  2 
crates/assistant_slash_commands/src/context_server_command.rs |  4 
crates/context_server/src/protocol.rs                         | 15 
crates/context_server/src/test.rs                             |  2 
crates/context_server/src/types.rs                            | 83 +++-
7 files changed, 67 insertions(+), 43 deletions(-)

Detailed changes

crates/agent/src/context_server_tool.rs 🔗

@@ -105,7 +105,7 @@ impl Tool for ContextServerTool {
                     arguments
                 );
                 let response = protocol
-                    .request::<context_server::types::request::CallTool>(
+                    .request::<context_server::types::requests::CallTool>(
                         context_server::types::CallToolParams {
                             name: tool_name,
                             arguments,

crates/agent/src/thread_store.rs 🔗

@@ -562,7 +562,7 @@ impl ThreadStore {
 
             if protocol.capable(context_server::protocol::ServerCapability::Tools) {
                 if let Some(response) = protocol
-                    .request::<context_server::types::request::ListTools>(())
+                    .request::<context_server::types::requests::ListTools>(())
                     .await
                     .log_err()
                 {

crates/assistant_context_editor/src/context_store.rs 🔗

@@ -869,7 +869,7 @@ impl ContextStore {
 
             if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
                 if let Some(response) = protocol
-                    .request::<context_server::types::request::PromptsList>(())
+                    .request::<context_server::types::requests::PromptsList>(())
                     .await
                     .log_err()
                 {

crates/assistant_slash_commands/src/context_server_command.rs 🔗

@@ -87,7 +87,7 @@ impl SlashCommand for ContextServerSlashCommand {
                 let protocol = server.client().context("Context server not initialized")?;
 
                 let response = protocol
-                    .request::<context_server::types::request::CompletionComplete>(
+                    .request::<context_server::types::requests::CompletionComplete>(
                         context_server::types::CompletionCompleteParams {
                             reference: context_server::types::CompletionReference::Prompt(
                                 context_server::types::PromptReference {
@@ -145,7 +145,7 @@ impl SlashCommand for ContextServerSlashCommand {
             cx.foreground_executor().spawn(async move {
                 let protocol = server.client().context("Context server not initialized")?;
                 let response = protocol
-                    .request::<context_server::types::request::PromptsGet>(
+                    .request::<context_server::types::requests::PromptsGet>(
                         context_server::types::PromptsGetParams {
                             name: prompt_name.clone(),
                             arguments: Some(prompt_args),

crates/context_server/src/protocol.rs 🔗

@@ -8,7 +8,7 @@
 use anyhow::Result;
 
 use crate::client::Client;
-use crate::types::{self, Request};
+use crate::types::{self, Notification, Request};
 
 pub struct ModelContextProtocol {
     inner: Client,
@@ -43,7 +43,7 @@ impl ModelContextProtocol {
 
         let response: types::InitializeResponse = self
             .inner
-            .request(types::request::Initialize::METHOD, params)
+            .request(types::requests::Initialize::METHOD, params)
             .await?;
 
         anyhow::ensure!(
@@ -54,16 +54,13 @@ impl ModelContextProtocol {
 
         log::trace!("mcp server info {:?}", response.server_info);
 
-        self.inner.notify(
-            types::NotificationType::Initialized.as_str(),
-            serde_json::json!({}),
-        )?;
-
         let initialized_protocol = InitializedContextServerProtocol {
             inner: self.inner,
             initialize: response,
         };
 
+        initialized_protocol.notify::<types::notifications::Initialized>(())?;
+
         Ok(initialized_protocol)
     }
 }
@@ -97,4 +94,8 @@ impl InitializedContextServerProtocol {
     pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
         self.inner.request(T::METHOD, params).await
     }
+
+    pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
+        self.inner.notify(T::METHOD, params)
+    }
 }

crates/context_server/src/test.rs 🔗

@@ -14,7 +14,7 @@ pub fn create_fake_transport(
     executor: BackgroundExecutor,
 ) -> FakeTransport {
     let name = name.into();
-    FakeTransport::new(executor).on_request::<crate::types::request::Initialize>(move |_params| {
+    FakeTransport::new(executor).on_request::<crate::types::requests::Initialize>(move |_params| {
         create_initialize_response(name.clone())
     })
 }

crates/context_server/src/types.rs 🔗

@@ -6,7 +6,7 @@ use url::Url;
 pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
 pub const VERSION_2024_11_05: &str = "2024-11-05";
 
-pub mod request {
+pub mod requests {
     use super::*;
 
     macro_rules! request {
@@ -83,6 +83,57 @@ pub trait Request {
     const METHOD: &'static str;
 }
 
+pub mod notifications {
+    use super::*;
+
+    macro_rules! notification {
+        ($method:expr, $name:ident, $params:ty) => {
+            pub struct $name;
+
+            impl Notification for $name {
+                type Params = $params;
+                const METHOD: &'static str = $method;
+            }
+        };
+    }
+
+    notification!("notifications/initialized", Initialized, ());
+    notification!("notifications/progress", Progress, ProgressParams);
+    notification!("notifications/message", Message, MessageParams);
+    notification!(
+        "notifications/resources/updated",
+        ResourcesUpdated,
+        ResourcesUpdatedParams
+    );
+    notification!(
+        "notifications/resources/list_changed",
+        ResourcesListChanged,
+        ()
+    );
+    notification!("notifications/tools/list_changed", ToolsListChanged, ());
+    notification!("notifications/prompts/list_changed", PromptsListChanged, ());
+    notification!("notifications/roots/list_changed", RootsListChanged, ());
+}
+
+pub trait Notification {
+    type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
+    const METHOD: &'static str;
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct MessageParams {
+    pub level: LoggingLevel,
+    pub logger: Option<String>,
+    pub data: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesUpdatedParams {
+    pub uri: String,
+}
+
 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
 #[serde(transparent)]
 pub struct ProtocolVersion(pub String);
@@ -560,34 +611,6 @@ pub struct ModelHint {
     pub name: Option<String>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(rename_all = "camelCase")]
-pub enum NotificationType {
-    Initialized,
-    Progress,
-    Message,
-    ResourcesUpdated,
-    ResourcesListChanged,
-    ToolsListChanged,
-    PromptsListChanged,
-    RootsListChanged,
-}
-
-impl NotificationType {
-    pub fn as_str(&self) -> &'static str {
-        match self {
-            NotificationType::Initialized => "notifications/initialized",
-            NotificationType::Progress => "notifications/progress",
-            NotificationType::Message => "notifications/message",
-            NotificationType::ResourcesUpdated => "notifications/resources/updated",
-            NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
-            NotificationType::ToolsListChanged => "notifications/tools/list_changed",
-            NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
-            NotificationType::RootsListChanged => "notifications/roots/list_changed",
-        }
-    }
-}
-
 #[derive(Debug, Serialize)]
 #[serde(untagged)]
 pub enum ClientNotification {
@@ -608,7 +631,7 @@ pub enum ProgressToken {
     Number(f64),
 }
 
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct ProgressParams {
     pub progress_token: ProgressToken,