context_servers: Completion support for context server slash commands (#17085)

David Soria Parra created

This PR adds support for completions via MCP. The protocol now supports
a new request type "completion/complete"
that can either complete a resource URI template (which we currently
don't use in Zed), or a prompt argument.
We use this to add autocompletion to our context server slash commands!


https://github.com/user-attachments/assets/08c9cf04-cbeb-49a7-903f-5049fb3b3d9f



Release Notes:

- context_servers: Added support for argument completions for context
server prompts. These show up as regular completions to slash commands.

Change summary

crates/assistant/src/slash_command/context_server_command.rs | 72 ++++
crates/context_servers/src/protocol.rs                       | 29 ++
crates/context_servers/src/types.rs                          | 83 ++++++
3 files changed, 179 insertions(+), 5 deletions(-)

Detailed changes

crates/assistant/src/slash_command/context_server_command.rs 🔗

@@ -1,6 +1,7 @@
 use anyhow::{anyhow, Result};
 use assistant_slash_command::{
-    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
+    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
+    SlashCommandOutputSection,
 };
 use collections::HashMap;
 use context_servers::{
@@ -8,7 +9,7 @@ use context_servers::{
     protocol::PromptInfo,
 };
 use gpui::{Task, WeakView, WindowContext};
-use language::LspAdapterDelegate;
+use language::{CodeLabel, LspAdapterDelegate};
 use std::sync::atomic::AtomicBool;
 use std::sync::Arc;
 use ui::{IconName, SharedString};
@@ -50,12 +51,57 @@ impl SlashCommand for ContextServerSlashCommand {
 
     fn complete_argument(
         self: Arc<Self>,
-        _arguments: &[String],
+        arguments: &[String],
         _cancel: Arc<AtomicBool>,
         _workspace: Option<WeakView<Workspace>>,
-        _cx: &mut WindowContext,
+        cx: &mut WindowContext,
     ) -> Task<Result<Vec<ArgumentCompletion>>> {
-        Task::ready(Ok(Vec::new()))
+        let server_id = self.server_id.clone();
+        let prompt_name = self.prompt.name.clone();
+        let manager = ContextServerManager::global(cx);
+        let manager = manager.read(cx);
+
+        let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
+            Ok(tp) => tp,
+            Err(e) => {
+                return Task::ready(Err(e));
+            }
+        };
+        if let Some(server) = manager.get_server(&server_id) {
+            cx.foreground_executor().spawn(async move {
+                let Some(protocol) = server.client.read().clone() else {
+                    return Err(anyhow!("Context server not initialized"));
+                };
+
+                let completion_result = protocol
+                    .completion(
+                        context_servers::types::CompletionReference::Prompt(
+                            context_servers::types::PromptReference {
+                                r#type: context_servers::types::PromptReferenceType::Prompt,
+                                name: prompt_name,
+                            },
+                        ),
+                        arg_name,
+                        arg_val,
+                    )
+                    .await?;
+
+                let completions = completion_result
+                    .values
+                    .into_iter()
+                    .map(|value| ArgumentCompletion {
+                        label: CodeLabel::plain(value.clone(), None),
+                        new_text: value,
+                        after_completion: AfterCompletion::Continue,
+                        replace_previous_arguments: false,
+                    })
+                    .collect();
+
+                Ok(completions)
+            })
+        } else {
+            Task::ready(Err(anyhow!("Context server not found")))
+        }
     }
 
     fn run(
@@ -102,6 +148,22 @@ impl SlashCommand for ContextServerSlashCommand {
     }
 }
 
+fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
+    if arguments.is_empty() {
+        return Err(anyhow!("No arguments given"));
+    }
+
+    match &prompt.arguments {
+        Some(args) if args.len() == 1 => {
+            let arg_name = args[0].name.clone();
+            let arg_value = arguments.join(" ");
+            Ok((arg_name, arg_value))
+        }
+        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
+        None => Err(anyhow!("Prompt has no arguments")),
+    }
+}
+
 fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
     match &prompt.arguments {
         Some(args) if args.len() > 1 => Err(anyhow!(

crates/context_servers/src/protocol.rs 🔗

@@ -127,6 +127,35 @@ impl InitializedContextServerProtocol {
 
         Ok(response)
     }
+
+    pub async fn completion<P: Into<String>>(
+        &self,
+        reference: types::CompletionReference,
+        argument: P,
+        value: P,
+    ) -> Result<types::Completion> {
+        let params = types::CompletionCompleteParams {
+            r#ref: reference,
+            argument: types::CompletionArgument {
+                name: argument.into(),
+                value: value.into(),
+            },
+        };
+        let result: types::CompletionCompleteResponse = self
+            .inner
+            .request(types::RequestType::CompletionComplete.as_str(), params)
+            .await?;
+
+        let completion = types::Completion {
+            values: result.completion.values,
+            total: types::CompletionTotal::from_options(
+                result.completion.has_more,
+                result.completion.total,
+            ),
+        };
+
+        Ok(completion)
+    }
 }
 
 impl InitializedContextServerProtocol {

crates/context_servers/src/types.rs 🔗

@@ -14,6 +14,7 @@ pub enum RequestType {
     LoggingSetLevel,
     PromptsGet,
     PromptsList,
+    CompletionComplete,
 }
 
 impl RequestType {
@@ -28,6 +29,7 @@ impl RequestType {
             RequestType::LoggingSetLevel => "logging/setLevel",
             RequestType::PromptsGet => "prompts/get",
             RequestType::PromptsList => "prompts/list",
+            RequestType::CompletionComplete => "completion/complete",
         }
     }
 }
@@ -78,6 +80,50 @@ pub struct PromptsGetParams {
     pub arguments: Option<HashMap<String, String>>,
 }
 
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CompletionCompleteParams {
+    pub r#ref: CompletionReference,
+    pub argument: CompletionArgument,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(untagged)]
+pub enum CompletionReference {
+    Prompt(PromptReference),
+    Resource(ResourceReference),
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptReference {
+    pub r#type: PromptReferenceType,
+    pub name: String,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "snake_case")]
+pub enum PromptReferenceType {
+    #[serde(rename = "ref/prompt")]
+    Prompt,
+    #[serde(rename = "ref/resource")]
+    Resource,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourceReference {
+    pub r#type: String,
+    pub uri: String,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CompletionArgument {
+    pub name: String,
+    pub value: String,
+}
+
 #[derive(Debug, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct InitializeResponse {
@@ -112,6 +158,20 @@ pub struct PromptsListResponse {
     pub prompts: Vec<PromptInfo>,
 }
 
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CompletionCompleteResponse {
+    pub completion: CompletionResult,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CompletionResult {
+    pub values: Vec<String>,
+    pub total: Option<u32>,
+    pub has_more: Option<bool>,
+}
+
 #[derive(Debug, Deserialize, Clone)]
 #[serde(rename_all = "camelCase")]
 pub struct PromptInfo {
@@ -233,3 +293,26 @@ pub struct ProgressParams {
     pub progress: f64,
     pub total: Option<f64>,
 }
+
+// Helper Types that don't map directly to the protocol
+
+pub enum CompletionTotal {
+    Exact(u32),
+    HasMore,
+    Unknown,
+}
+
+impl CompletionTotal {
+    pub fn from_options(has_more: Option<bool>, total: Option<u32>) -> Self {
+        match (has_more, total) {
+            (_, Some(count)) => CompletionTotal::Exact(count),
+            (Some(true), _) => CompletionTotal::HasMore,
+            _ => CompletionTotal::Unknown,
+        }
+    }
+}
+
+pub struct Completion {
+    pub values: Vec<String>,
+    pub total: CompletionTotal,
+}