context_servers: Fix argument handling (#16402)

David Soria Parra created

Change summary

crates/assistant/src/slash_command/context_server_command.rs | 39 +++--
1 file changed, 24 insertions(+), 15 deletions(-)

Detailed changes

crates/assistant/src/slash_command/context_server_command.rs 🔗

@@ -67,7 +67,11 @@ impl SlashCommand for ContextServerSlashCommand {
     ) -> Task<Result<SlashCommandOutput>> {
         let server_id = self.server_id.clone();
         let prompt_name = self.prompt.name.clone();
-        let argument = arguments.first().cloned();
+
+        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
+            Ok(args) => args,
+            Err(e) => return Task::ready(Err(e)),
+        };
 
         let manager = ContextServerManager::global(cx);
         let manager = manager.read(cx);
@@ -76,10 +80,7 @@ impl SlashCommand for ContextServerSlashCommand {
                 let Some(protocol) = server.client.read().clone() else {
                     return Err(anyhow!("Context server not initialized"));
                 };
-
-                let result = protocol
-                    .run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
-                    .await?;
+                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
 
                 Ok(SlashCommandOutput {
                     sections: vec![SlashCommandOutputSection {
@@ -97,19 +98,27 @@ impl SlashCommand for ContextServerSlashCommand {
     }
 }
 
-fn prompt_arguments(
-    prompt: &PromptInfo,
-    argument: Option<String>,
-) -> Result<HashMap<String, String>> {
+fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
     match &prompt.arguments {
-        Some(args) if args.len() >= 2 => Err(anyhow!(
+        Some(args) if args.len() > 1 => Err(anyhow!(
             "Prompt has more than one argument, which is not supported"
         )),
-        Some(args) if args.len() == 1 => match argument {
-            Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
-            None => Err(anyhow!("Prompt expects argument but none given")),
-        },
-        Some(_) | None => Ok(HashMap::default()),
+        Some(args) if args.len() == 1 => {
+            if !arguments.is_empty() {
+                let mut map = HashMap::default();
+                map.insert(args[0].name.clone(), arguments.join(" "));
+                Ok(map)
+            } else {
+                Err(anyhow!("Prompt expects argument but none given"))
+            }
+        }
+        Some(_) | None => {
+            if arguments.is_empty() {
+                Ok(HashMap::default())
+            } else {
+                Err(anyhow!("Prompt expects no arguments but some were given"))
+            }
+        }
     }
 }