context_servers: Upgrade protocol to version 2024-11-05 (#20615)

David Soria Parra and Marshall Bowers created

This updates context servers to the most recent version

Release Notes:

- N/A

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>

Change summary

crates/assistant/src/slash_command/context_server_command.rs |   4 
crates/assistant/src/tools/context_server_tool.rs            |  20 
crates/context_servers/src/client.rs                         |   7 
crates/context_servers/src/protocol.rs                       |  16 
crates/context_servers/src/types.rs                          | 167 +++++
5 files changed, 180 insertions(+), 34 deletions(-)

Detailed changes

crates/assistant/src/slash_command/context_server_command.rs 🔗

@@ -152,7 +152,7 @@ impl SlashCommand for ContextServerSlashCommand {
                 if result
                     .messages
                     .iter()
-                    .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
+                    .any(|msg| !matches!(msg.role, context_servers::types::Role::User))
                 {
                     return Err(anyhow!(
                         "Prompt contains non-user roles, which is not supported"
@@ -164,7 +164,7 @@ impl SlashCommand for ContextServerSlashCommand {
                     .messages
                     .into_iter()
                     .filter_map(|msg| match msg.content {
-                        context_servers::types::SamplingContent::Text { text } => Some(text),
+                        context_servers::types::MessageContent::Text { text } => Some(text),
                         _ => None,
                     })
                     .collect::<Vec<String>>()

crates/assistant/src/tools/context_server_tool.rs 🔗

@@ -74,11 +74,21 @@ impl Tool for ContextServerTool {
                     );
                     let response = protocol.run_tool(tool_name, arguments).await?;
 
-                    let tool_result = match response.tool_result {
-                        serde_json::Value::String(s) => s,
-                        _ => serde_json::to_string(&response.tool_result)?,
-                    };
-                    Ok(tool_result)
+                    let mut result = String::new();
+                    for content in response.content {
+                        match content {
+                            types::ToolResponseContent::Text { text } => {
+                                result.push_str(&text);
+                            }
+                            types::ToolResponseContent::Image { .. } => {
+                                log::warn!("Ignoring image content from tool response");
+                            }
+                            types::ToolResponseContent::Resource { .. } => {
+                                log::warn!("Ignoring resource content from tool response");
+                            }
+                        }
+                    }
+                    Ok(result)
                 }
             })
         } else {

crates/context_servers/src/client.rs 🔗

@@ -25,6 +25,13 @@ use util::TryFutureExt;
 const JSON_RPC_VERSION: &str = "2.0";
 const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
 
+// Standard JSON-RPC error codes
+pub const PARSE_ERROR: i32 = -32700;
+pub const INVALID_REQUEST: i32 = -32600;
+pub const METHOD_NOT_FOUND: i32 = -32601;
+pub const INVALID_PARAMS: i32 = -32602;
+pub const INTERNAL_ERROR: i32 = -32603;
+
 type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
 type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;
 

crates/context_servers/src/protocol.rs 🔗

@@ -11,8 +11,6 @@ use collections::HashMap;
 use crate::client::Client;
 use crate::types;
 
-const PROTOCOL_VERSION: &str = "2024-10-07";
-
 pub struct ModelContextProtocol {
     inner: Client,
 }
@@ -23,10 +21,9 @@ impl ModelContextProtocol {
     }
 
     fn supported_protocols() -> Vec<types::ProtocolVersion> {
-        vec![
-            types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
-            types::ProtocolVersion::VersionNumber(1),
-        ]
+        vec![types::ProtocolVersion(
+            types::LATEST_PROTOCOL_VERSION.to_string(),
+        )]
     }
 
     pub async fn initialize(
@@ -34,11 +31,13 @@ impl ModelContextProtocol {
         client_info: types::Implementation,
     ) -> Result<InitializedContextServerProtocol> {
         let params = types::InitializeParams {
-            protocol_version: types::ProtocolVersion::VersionString(PROTOCOL_VERSION.to_string()),
+            protocol_version: types::ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
             capabilities: types::ClientCapabilities {
                 experimental: None,
                 sampling: None,
+                roots: None,
             },
+            meta: None,
             client_info,
         };
 
@@ -148,6 +147,7 @@ impl InitializedContextServerProtocol {
         let params = types::PromptsGetParams {
             name: prompt.as_ref().to_string(),
             arguments: Some(arguments),
+            meta: None,
         };
 
         let response: types::PromptsGetResponse = self
@@ -170,6 +170,7 @@ impl InitializedContextServerProtocol {
                 name: argument.into(),
                 value: value.into(),
             },
+            meta: None,
         };
         let result: types::CompletionCompleteResponse = self
             .inner
@@ -210,6 +211,7 @@ impl InitializedContextServerProtocol {
         let params = types::CallToolParams {
             name: tool.as_ref().to_string(),
             arguments,
+            meta: None,
         };
 
         let response: types::CallToolResponse = self

crates/context_servers/src/types.rs 🔗

@@ -2,8 +2,8 @@ use collections::HashMap;
 use serde::{Deserialize, Serialize};
 use url::Url;
 
-#[derive(Debug, Serialize)]
-#[serde(rename_all = "camelCase")]
+pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
+
 pub enum RequestType {
     Initialize,
     CallTool,
@@ -18,6 +18,7 @@ pub enum RequestType {
     Ping,
     ListTools,
     ListResourceTemplates,
+    ListRoots,
 }
 
 impl RequestType {
@@ -36,16 +37,14 @@ impl RequestType {
             RequestType::Ping => "ping",
             RequestType::ListTools => "tools/list",
             RequestType::ListResourceTemplates => "resources/templates/list",
+            RequestType::ListRoots => "roots/list",
         }
     }
 }
 
 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
-#[serde(untagged)]
-pub enum ProtocolVersion {
-    VersionString(String),
-    VersionNumber(u32),
-}
+#[serde(transparent)]
+pub struct ProtocolVersion(pub String);
 
 #[derive(Debug, Serialize)]
 #[serde(rename_all = "camelCase")]
@@ -53,6 +52,8 @@ pub struct InitializeParams {
     pub protocol_version: ProtocolVersion,
     pub capabilities: ClientCapabilities,
     pub client_info: Implementation,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Serialize)]
@@ -61,30 +62,40 @@ pub struct CallToolParams {
     pub name: String,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub arguments: Option<HashMap<String, serde_json::Value>>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Serialize)]
 #[serde(rename_all = "camelCase")]
 pub struct ResourcesUnsubscribeParams {
     pub uri: Url,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Serialize)]
 #[serde(rename_all = "camelCase")]
 pub struct ResourcesSubscribeParams {
     pub uri: Url,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Serialize)]
 #[serde(rename_all = "camelCase")]
 pub struct ResourcesReadParams {
     pub uri: Url,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Serialize)]
 #[serde(rename_all = "camelCase")]
 pub struct LoggingSetLevelParams {
     pub level: LoggingLevel,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Serialize)]
@@ -93,6 +104,8 @@ pub struct PromptsGetParams {
     pub name: String,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub arguments: Option<HashMap<String, String>>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Serialize)]
@@ -100,6 +113,8 @@ pub struct PromptsGetParams {
 pub struct CompletionCompleteParams {
     pub r#ref: CompletionReference,
     pub argument: CompletionArgument,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Serialize)]
@@ -145,12 +160,16 @@ pub struct InitializeResponse {
     pub protocol_version: ProtocolVersion,
     pub capabilities: ServerCapabilities,
     pub server_info: Implementation,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct ResourcesReadResponse {
-    pub contents: Vec<ResourceContent>,
+    pub contents: Vec<ResourceContents>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Deserialize)]
@@ -159,29 +178,39 @@ pub struct ResourcesListResponse {
     pub resources: Vec<Resource>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub next_cursor: Option<String>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
-
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct SamplingMessage {
-    pub role: SamplingRole,
-    pub content: SamplingContent,
+    pub role: Role,
+    pub content: MessageContent,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptMessage {
+    pub role: Role,
+    pub content: MessageContent,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "lowercase")]
-pub enum SamplingRole {
+pub enum Role {
     User,
     Assistant,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(tag = "type")]
-pub enum SamplingContent {
+pub enum MessageContent {
     #[serde(rename = "text")]
     Text { text: String },
     #[serde(rename = "image")]
     Image { data: String, mime_type: String },
+    #[serde(rename = "resource")]
+    Resource { resource: ResourceContents },
 }
 
 #[derive(Debug, Deserialize)]
@@ -189,7 +218,9 @@ pub enum SamplingContent {
 pub struct PromptsGetResponse {
     #[serde(skip_serializing_if = "Option::is_none")]
     pub description: Option<String>,
-    pub messages: Vec<SamplingMessage>,
+    pub messages: Vec<PromptMessage>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Deserialize)]
@@ -198,12 +229,16 @@ pub struct PromptsListResponse {
     pub prompts: Vec<Prompt>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub next_cursor: Option<String>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct CompletionCompleteResponse {
     pub completion: CompletionResult,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Deserialize)]
@@ -214,6 +249,8 @@ pub struct CompletionResult {
     pub total: Option<u32>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub has_more: Option<bool>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
 #[derive(Debug, Deserialize, Serialize)]
@@ -243,6 +280,8 @@ pub struct ClientCapabilities {
     pub experimental: Option<HashMap<String, serde_json::Value>>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub sampling: Option<serde_json::Value>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub roots: Option<RootsCapabilities>,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -283,6 +322,13 @@ pub struct ToolsCapabilities {
     pub list_changed: Option<bool>,
 }
 
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct RootsCapabilities {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub list_changed: Option<bool>,
+}
+
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct Tool {
@@ -312,14 +358,28 @@ pub struct Resource {
 
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
-pub struct ResourceContent {
+pub struct ResourceContents {
     pub uri: Url,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub mime_type: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct TextResourceContents {
+    pub uri: Url,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub text: Option<String>,
+    pub mime_type: Option<String>,
+    pub text: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct BlobResourceContents {
+    pub uri: Url,
     #[serde(skip_serializing_if = "Option::is_none")]
-    pub blob: Option<String>,
+    pub mime_type: Option<String>,
+    pub blob: String,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -338,8 +398,32 @@ pub struct ResourceTemplate {
 pub enum LoggingLevel {
     Debug,
     Info,
+    Notice,
     Warning,
     Error,
+    Critical,
+    Alert,
+    Emergency,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ModelPreferences {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub hints: Option<Vec<ModelHint>>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub cost_priority: Option<f64>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub speed_priority: Option<f64>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub intelligence_priority: Option<f64>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ModelHint {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub name: Option<String>,
 }
 
 #[derive(Debug, Serialize)]
@@ -352,6 +436,7 @@ pub enum NotificationType {
     ResourcesListChanged,
     ToolsListChanged,
     PromptsListChanged,
+    RootsListChanged,
 }
 
 impl NotificationType {
@@ -364,6 +449,7 @@ impl NotificationType {
             NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
             NotificationType::ToolsListChanged => "notifications/tools/list_changed",
             NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
+            NotificationType::RootsListChanged => "notifications/roots/list_changed",
         }
     }
 }
@@ -373,6 +459,14 @@ impl NotificationType {
 pub enum ClientNotification {
     Initialized,
     Progress(ProgressParams),
+    RootsListChanged,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ProgressToken {
+    String(String),
+    Number(f64),
 }
 
 #[derive(Debug, Serialize)]
@@ -382,10 +476,10 @@ pub struct ProgressParams {
     pub progress: f64,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub total: Option<f64>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
 }
 
-pub type ProgressToken = String;
-
 pub enum CompletionTotal {
     Exact(u32),
     HasMore,
@@ -410,7 +504,22 @@ pub struct Completion {
 #[derive(Debug, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct CallToolResponse {
-    pub tool_result: serde_json::Value,
+    pub content: Vec<ToolResponseContent>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub is_error: Option<bool>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum ToolResponseContent {
+    #[serde(rename = "text")]
+    Text { text: String },
+    #[serde(rename = "image")]
+    Image { data: String, mime_type: String },
+    #[serde(rename = "resource")]
+    Resource { resource: ResourceContents },
 }
 
 #[derive(Debug, Deserialize)]
@@ -419,4 +528,22 @@ pub struct ListToolsResponse {
     pub tools: Vec<Tool>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub next_cursor: Option<String>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ListRootsResponse {
+    pub roots: Vec<Root>,
+    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
+    pub meta: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Root {
+    pub uri: Url,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub name: Option<String>,
 }