1use anyhow::{Context as _, Result, anyhow};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection, SlashCommandResult,
  5};
  6use collections::HashMap;
  7use context_server::{ContextServerId, types::Prompt};
  8use gpui::{App, Entity, Task, WeakEntity, Window};
  9use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
 10use project::context_server_store::ContextServerStore;
 11use std::sync::Arc;
 12use std::sync::atomic::AtomicBool;
 13use text::LineEnding;
 14use ui::{IconName, SharedString};
 15use workspace::Workspace;
 16
 17use crate::create_label_for_command;
 18
 19pub struct ContextServerSlashCommand {
 20    store: Entity<ContextServerStore>,
 21    server_id: ContextServerId,
 22    prompt: Prompt,
 23}
 24
 25impl ContextServerSlashCommand {
 26    pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
 27        Self {
 28            server_id: id,
 29            prompt,
 30            store,
 31        }
 32    }
 33}
 34
 35impl SlashCommand for ContextServerSlashCommand {
 36    fn name(&self) -> String {
 37        self.prompt.name.clone()
 38    }
 39
 40    fn label(&self, cx: &App) -> language::CodeLabel {
 41        let mut parts = vec![self.prompt.name.as_str()];
 42        if let Some(args) = &self.prompt.arguments
 43            && let Some(arg) = args.first()
 44        {
 45            parts.push(arg.name.as_str());
 46        }
 47        create_label_for_command(parts[0], &parts[1..], cx)
 48    }
 49
 50    fn description(&self) -> String {
 51        match &self.prompt.description {
 52            Some(desc) => desc.clone(),
 53            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 54        }
 55    }
 56
 57    fn menu_text(&self) -> String {
 58        match &self.prompt.description {
 59            Some(desc) => desc.clone(),
 60            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 61        }
 62    }
 63
 64    fn requires_argument(&self) -> bool {
 65        self.prompt
 66            .arguments
 67            .as_ref()
 68            .is_some_and(|args| args.iter().any(|arg| arg.required == Some(true)))
 69    }
 70
 71    fn complete_argument(
 72        self: Arc<Self>,
 73        arguments: &[String],
 74        _cancel: Arc<AtomicBool>,
 75        _workspace: Option<WeakEntity<Workspace>>,
 76        _window: &mut Window,
 77        cx: &mut App,
 78    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 79        let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
 80            return Task::ready(Err(anyhow!("Failed to complete argument")));
 81        };
 82
 83        let server_id = self.server_id.clone();
 84        let prompt_name = self.prompt.name.clone();
 85
 86        if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
 87            cx.foreground_executor().spawn(async move {
 88                let protocol = server.client().context("Context server not initialized")?;
 89
 90                let response = protocol
 91                    .request::<context_server::types::requests::CompletionComplete>(
 92                        context_server::types::CompletionCompleteParams {
 93                            reference: context_server::types::CompletionReference::Prompt(
 94                                context_server::types::PromptReference {
 95                                    ty: context_server::types::PromptReferenceType::Prompt,
 96                                    name: prompt_name,
 97                                },
 98                            ),
 99                            argument: context_server::types::CompletionArgument {
100                                name: arg_name,
101                                value: arg_value,
102                            },
103                            meta: None,
104                        },
105                    )
106                    .await?;
107
108                let completions = response
109                    .completion
110                    .values
111                    .into_iter()
112                    .map(|value| ArgumentCompletion {
113                        label: CodeLabel::plain(value.clone(), None),
114                        new_text: value,
115                        after_completion: AfterCompletion::Continue,
116                        replace_previous_arguments: false,
117                    })
118                    .collect();
119                Ok(completions)
120            })
121        } else {
122            Task::ready(Err(anyhow!("Context server not found")))
123        }
124    }
125
126    fn run(
127        self: Arc<Self>,
128        arguments: &[String],
129        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
130        _context_buffer: BufferSnapshot,
131        _workspace: WeakEntity<Workspace>,
132        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
133        _window: &mut Window,
134        cx: &mut App,
135    ) -> Task<SlashCommandResult> {
136        let server_id = self.server_id.clone();
137        let prompt_name = self.prompt.name.clone();
138
139        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
140            Ok(args) => args,
141            Err(e) => return Task::ready(Err(e)),
142        };
143
144        let store = self.store.read(cx);
145        if let Some(server) = store.get_running_server(&server_id) {
146            cx.foreground_executor().spawn(async move {
147                let protocol = server.client().context("Context server not initialized")?;
148                let response = protocol
149                    .request::<context_server::types::requests::PromptsGet>(
150                        context_server::types::PromptsGetParams {
151                            name: prompt_name.clone(),
152                            arguments: Some(prompt_args),
153                            meta: None,
154                        },
155                    )
156                    .await?;
157
158                anyhow::ensure!(
159                    response
160                        .messages
161                        .iter()
162                        .all(|msg| matches!(msg.role, context_server::types::Role::User)),
163                    "Prompt contains non-user roles, which is not supported"
164                );
165
166                // Extract text from user messages into a single prompt string
167                let mut prompt = response
168                    .messages
169                    .into_iter()
170                    .filter_map(|msg| match msg.content {
171                        context_server::types::MessageContent::Text { text, .. } => Some(text),
172                        _ => None,
173                    })
174                    .collect::<Vec<String>>()
175                    .join("\n\n");
176
177                // We must normalize the line endings here, since servers might return CR characters.
178                LineEnding::normalize(&mut prompt);
179
180                Ok(SlashCommandOutput {
181                    sections: vec![SlashCommandOutputSection {
182                        range: 0..(prompt.len()),
183                        icon: IconName::ZedAssistant,
184                        label: SharedString::from(
185                            response
186                                .description
187                                .unwrap_or(format!("Result from {}", prompt_name)),
188                        ),
189                        metadata: None,
190                    }],
191                    text: prompt,
192                    run_commands_in_text: false,
193                }
194                .into_event_stream())
195            })
196        } else {
197            Task::ready(Err(anyhow!("Context server not found")))
198        }
199    }
200}
201
202fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
203    anyhow::ensure!(!arguments.is_empty(), "No arguments given");
204
205    match &prompt.arguments {
206        Some(args) if args.len() == 1 => {
207            let arg_name = args[0].name.clone();
208            let arg_value = arguments.join(" ");
209            Ok((arg_name, arg_value))
210        }
211        Some(_) => anyhow::bail!("Prompt must have exactly one argument"),
212        None => anyhow::bail!("Prompt has no arguments"),
213    }
214}
215
216fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
217    match &prompt.arguments {
218        Some(args) if args.len() > 1 => {
219            anyhow::bail!("Prompt has more than one argument, which is not supported");
220        }
221        Some(args) if args.len() == 1 => {
222            if !arguments.is_empty() {
223                let mut map = HashMap::default();
224                map.insert(args[0].name.clone(), arguments.join(" "));
225                Ok(map)
226            } else if arguments.is_empty() && args[0].required == Some(false) {
227                Ok(HashMap::default())
228            } else {
229                anyhow::bail!("Prompt expects argument but none given");
230            }
231        }
232        Some(_) | None => {
233            anyhow::ensure!(
234                arguments.is_empty(),
235                "Prompt expects no arguments but some were given"
236            );
237            Ok(HashMap::default())
238        }
239    }
240}
241
242/// MCP servers can return prompts with multiple arguments. Since we only
243/// support one argument, we ignore all others. This is the necessary predicate
244/// for this.
245pub fn acceptable_prompt(prompt: &Prompt) -> bool {
246    match &prompt.arguments {
247        None => true,
248        Some(args) if args.len() <= 1 => true,
249        _ => false,
250    }
251}