@@ -145,7 +145,28 @@ impl SlashCommand for ContextServerSlashCommand {
return Err(anyhow!("Context server not initialized"));
};
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
- let mut prompt = result.prompt;
+
+ // Check that there are only user roles
+ if result
+ .messages
+ .iter()
+ .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
+ {
+ return Err(anyhow!(
+ "Prompt contains non-user roles, which is not supported"
+ ));
+ }
+
+ // Extract text from user messages into a single prompt string
+ let mut prompt = result
+ .messages
+ .into_iter()
+ .filter_map(|msg| match msg.content {
+ context_servers::types::SamplingContent::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n\n");
// We must normalize the line endings here, since servers might return CR characters.
LineEnding::normalize(&mut prompt);
@@ -11,7 +11,7 @@ use collections::HashMap;
use crate::client::Client;
use crate::types;
-const PROTOCOL_VERSION: u32 = 1;
+const PROTOCOL_VERSION: &str = "2024-10-07";
pub struct ModelContextProtocol {
inner: Client,
@@ -22,12 +22,19 @@ impl ModelContextProtocol {
Self { inner }
}
+ fn supported_protocols() -> Vec<types::ProtocolVersion> {
+ vec![
+ types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
+ types::ProtocolVersion::VersionNumber(1),
+ ]
+ }
+
pub async fn initialize(
self,
client_info: types::Implementation,
) -> Result<InitializedContextServerProtocol> {
let params = types::InitializeParams {
- protocol_version: PROTOCOL_VERSION,
+ protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
capabilities: types::ClientCapabilities {
experimental: None,
sampling: None,
@@ -40,6 +47,13 @@ impl ModelContextProtocol {
.request(types::RequestType::Initialize.as_str(), params)
.await?;
+ if !Self::supported_protocols().contains(&response.protocol_version) {
+ return Err(anyhow::anyhow!(
+ "Unsupported protocol version: {:?}",
+ response.protocol_version
+ ));
+ }
+
log::trace!("mcp server info {:?}", response.server_info);
self.inner.notify(
@@ -36,10 +36,17 @@ impl RequestType {
}
}
+#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ProtocolVersion {
+ VersionString(String),
+ VersionNumber(u32),
+}
+
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeParams {
- pub protocol_version: u32,
+ pub protocol_version: ProtocolVersion,
pub capabilities: ClientCapabilities,
pub client_info: Implementation,
}
@@ -131,7 +138,7 @@ pub struct CompletionArgument {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
- pub protocol_version: u32,
+ pub protocol_version: ProtocolVersion,
pub capabilities: ServerCapabilities,
pub server_info: Implementation,
}
@@ -145,10 +152,9 @@ pub struct ResourcesReadResponse {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesListResponse {
+ pub resources: Vec<Resource>,
#[serde(skip_serializing_if = "Option::is_none")]
- pub resource_templates: Option<Vec<ResourceTemplate>>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub resources: Option<Vec<Resource>>,
+ pub next_cursor: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -179,13 +185,15 @@ pub enum SamplingContent {
pub struct PromptsGetResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
- pub prompt: String,
+ pub messages: Vec<SamplingMessage>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsListResponse {
pub prompts: Vec<Prompt>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub next_cursor: Option<String>,
}
#[derive(Debug, Deserialize)]