From d8d8c908ed62ec76382e0233e8945b8e12eb774b Mon Sep 17 00:00:00 2001 From: David Soria Parra <167242713+dsp-ant@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:19:32 +0100 Subject: [PATCH] context_servers: Update protocol (#19547) We sadly have to change the underlying protocol once again. This will likely be the last change to the core protocol without correctly handling older versions. From here on out, we want to get better with version handling. To do so, we introduce the notion of a string protocol version to be explicit of when the underlying protocol last changed. The change also changes the return values of prompts. For now we only allow User messages from servers to match the current behaviour. We will change this once #19222 lands which will allow slash commands to insert user and assistant messages. Release Notes: - N/A --- .../slash_command/context_server_command.rs | 23 ++++++++++++++++++- crates/context_servers/src/protocol.rs | 18 +++++++++++++-- crates/context_servers/src/types.rs | 20 +++++++++++----- 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 3db057d07494cd5a67834399eaf48d29dad41f7a..9e6c4b7718889c2bc6fc1a6b97ff2fbafd9d32cc 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -145,7 +145,28 @@ impl SlashCommand for ContextServerSlashCommand { return Err(anyhow!("Context server not initialized")); }; let result = protocol.run_prompt(&prompt_name, prompt_args).await?; - let mut prompt = result.prompt; + + // Check that there are only user roles + if result + .messages + .iter() + .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User)) + { + return Err(anyhow!( + "Prompt contains non-user roles, which is not supported" + )); + } + + // Extract text from user messages into a single prompt string + let mut prompt = result + .messages + .into_iter() + .filter_map(|msg| match msg.content { + context_servers::types::SamplingContent::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join("\n\n"); // We must normalize the line endings here, since servers might return CR characters. LineEnding::normalize(&mut prompt); diff --git a/crates/context_servers/src/protocol.rs b/crates/context_servers/src/protocol.rs index 451db56ef31df0d74c00c5ea2d5bce70c99b9c6c..80a7a7f991a23f5fe963ae54e836b3240b8844c5 100644 --- a/crates/context_servers/src/protocol.rs +++ b/crates/context_servers/src/protocol.rs @@ -11,7 +11,7 @@ use collections::HashMap; use crate::client::Client; use crate::types; -const PROTOCOL_VERSION: u32 = 1; +const PROTOCOL_VERSION: &str = "2024-10-07"; pub struct ModelContextProtocol { inner: Client, @@ -22,12 +22,19 @@ impl ModelContextProtocol { Self { inner } } + fn supported_protocols() -> Vec { + vec![ + types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()), + types::ProtocolVersion::VersionNumber(1), + ] + } + pub async fn initialize( self, client_info: types::Implementation, ) -> Result { let params = types::InitializeParams { - protocol_version: PROTOCOL_VERSION, + protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()), capabilities: types::ClientCapabilities { experimental: None, sampling: None, @@ -40,6 +47,13 @@ impl ModelContextProtocol { .request(types::RequestType::Initialize.as_str(), params) .await?; + if !Self::supported_protocols().contains(&response.protocol_version) { + return Err(anyhow::anyhow!( + "Unsupported protocol version: {:?}", + response.protocol_version + )); + } + log::trace!("mcp server info {:?}", response.server_info); self.inner.notify( diff --git a/crates/context_servers/src/types.rs b/crates/context_servers/src/types.rs index 04ac87c704d06cfec85cdf5e9da36be15f4355fc..2bca0a021a129029b55d6371f9db98332418a7a5 100644 --- a/crates/context_servers/src/types.rs +++ b/crates/context_servers/src/types.rs @@ -36,10 +36,17 @@ impl RequestType { } } +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ProtocolVersion { + VersionString(String), + VersionNumber(u32), +} + #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct InitializeParams { - pub protocol_version: u32, + pub protocol_version: ProtocolVersion, pub capabilities: ClientCapabilities, pub client_info: Implementation, } @@ -131,7 +138,7 @@ pub struct CompletionArgument { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InitializeResponse { - pub protocol_version: u32, + pub protocol_version: ProtocolVersion, pub capabilities: ServerCapabilities, pub server_info: Implementation, } @@ -145,10 +152,9 @@ pub struct ResourcesReadResponse { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesListResponse { + pub resources: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub resource_templates: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub resources: Option>, + pub next_cursor: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -179,13 +185,15 @@ pub enum SamplingContent { pub struct PromptsGetResponse { #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, - pub prompt: String, + pub messages: Vec, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsListResponse { pub prompts: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, } #[derive(Debug, Deserialize)]