context_server_command.rs

  1use anyhow::{Result, anyhow};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection, SlashCommandResult,
  5};
  6use collections::HashMap;
  7use context_server::{
  8    manager::{ContextServer, ContextServerManager},
  9    types::Prompt,
 10};
 11use gpui::{App, Entity, Task, WeakEntity, Window};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
 13use std::sync::Arc;
 14use std::sync::atomic::AtomicBool;
 15use text::LineEnding;
 16use ui::{IconName, SharedString};
 17use workspace::Workspace;
 18
 19use crate::create_label_for_command;
 20
 21pub struct ContextServerSlashCommand {
 22    server_manager: Entity<ContextServerManager>,
 23    server_id: Arc<str>,
 24    prompt: Prompt,
 25}
 26
 27impl ContextServerSlashCommand {
 28    pub fn new(
 29        server_manager: Entity<ContextServerManager>,
 30        server: &Arc<ContextServer>,
 31        prompt: Prompt,
 32    ) -> Self {
 33        Self {
 34            server_id: server.id(),
 35            prompt,
 36            server_manager,
 37        }
 38    }
 39}
 40
 41impl SlashCommand for ContextServerSlashCommand {
 42    fn name(&self) -> String {
 43        self.prompt.name.clone()
 44    }
 45
 46    fn label(&self, cx: &App) -> language::CodeLabel {
 47        let mut parts = vec![self.prompt.name.as_str()];
 48        if let Some(args) = &self.prompt.arguments {
 49            if let Some(arg) = args.first() {
 50                parts.push(arg.name.as_str());
 51            }
 52        }
 53        create_label_for_command(&parts[0], &parts[1..], cx)
 54    }
 55
 56    fn description(&self) -> String {
 57        match &self.prompt.description {
 58            Some(desc) => desc.clone(),
 59            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 60        }
 61    }
 62
 63    fn menu_text(&self) -> String {
 64        match &self.prompt.description {
 65            Some(desc) => desc.clone(),
 66            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 67        }
 68    }
 69
 70    fn requires_argument(&self) -> bool {
 71        self.prompt.arguments.as_ref().map_or(false, |args| {
 72            args.iter().any(|arg| arg.required == Some(true))
 73        })
 74    }
 75
 76    fn complete_argument(
 77        self: Arc<Self>,
 78        arguments: &[String],
 79        _cancel: Arc<AtomicBool>,
 80        _workspace: Option<WeakEntity<Workspace>>,
 81        _window: &mut Window,
 82        cx: &mut App,
 83    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 84        let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
 85            return Task::ready(Err(anyhow!("Failed to complete argument")));
 86        };
 87
 88        let server_id = self.server_id.clone();
 89        let prompt_name = self.prompt.name.clone();
 90
 91        if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
 92            cx.foreground_executor().spawn(async move {
 93                let Some(protocol) = server.client() else {
 94                    return Err(anyhow!("Context server not initialized"));
 95                };
 96
 97                let completion_result = protocol
 98                    .completion(
 99                        context_server::types::CompletionReference::Prompt(
100                            context_server::types::PromptReference {
101                                r#type: context_server::types::PromptReferenceType::Prompt,
102                                name: prompt_name,
103                            },
104                        ),
105                        arg_name,
106                        arg_value,
107                    )
108                    .await?;
109
110                let completions = completion_result
111                    .values
112                    .into_iter()
113                    .map(|value| ArgumentCompletion {
114                        label: CodeLabel::plain(value.clone(), None),
115                        new_text: value,
116                        after_completion: AfterCompletion::Continue,
117                        replace_previous_arguments: false,
118                    })
119                    .collect();
120                Ok(completions)
121            })
122        } else {
123            Task::ready(Err(anyhow!("Context server not found")))
124        }
125    }
126
127    fn run(
128        self: Arc<Self>,
129        arguments: &[String],
130        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
131        _context_buffer: BufferSnapshot,
132        _workspace: WeakEntity<Workspace>,
133        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
134        _window: &mut Window,
135        cx: &mut App,
136    ) -> Task<SlashCommandResult> {
137        let server_id = self.server_id.clone();
138        let prompt_name = self.prompt.name.clone();
139
140        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
141            Ok(args) => args,
142            Err(e) => return Task::ready(Err(e)),
143        };
144
145        let manager = self.server_manager.read(cx);
146        if let Some(server) = manager.get_server(&server_id) {
147            cx.foreground_executor().spawn(async move {
148                let Some(protocol) = server.client() else {
149                    return Err(anyhow!("Context server not initialized"));
150                };
151                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
152
153                // Check that there are only user roles
154                if result
155                    .messages
156                    .iter()
157                    .any(|msg| !matches!(msg.role, context_server::types::Role::User))
158                {
159                    return Err(anyhow!(
160                        "Prompt contains non-user roles, which is not supported"
161                    ));
162                }
163
164                // Extract text from user messages into a single prompt string
165                let mut prompt = result
166                    .messages
167                    .into_iter()
168                    .filter_map(|msg| match msg.content {
169                        context_server::types::MessageContent::Text { text, .. } => Some(text),
170                        _ => None,
171                    })
172                    .collect::<Vec<String>>()
173                    .join("\n\n");
174
175                // We must normalize the line endings here, since servers might return CR characters.
176                LineEnding::normalize(&mut prompt);
177
178                Ok(SlashCommandOutput {
179                    sections: vec![SlashCommandOutputSection {
180                        range: 0..(prompt.len()),
181                        icon: IconName::ZedAssistant,
182                        label: SharedString::from(
183                            result
184                                .description
185                                .unwrap_or(format!("Result from {}", prompt_name)),
186                        ),
187                        metadata: None,
188                    }],
189                    text: prompt,
190                    run_commands_in_text: false,
191                }
192                .to_event_stream())
193            })
194        } else {
195            Task::ready(Err(anyhow!("Context server not found")))
196        }
197    }
198}
199
200fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
201    if arguments.is_empty() {
202        return Err(anyhow!("No arguments given"));
203    }
204
205    match &prompt.arguments {
206        Some(args) if args.len() == 1 => {
207            let arg_name = args[0].name.clone();
208            let arg_value = arguments.join(" ");
209            Ok((arg_name, arg_value))
210        }
211        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
212        None => Err(anyhow!("Prompt has no arguments")),
213    }
214}
215
216fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
217    match &prompt.arguments {
218        Some(args) if args.len() > 1 => Err(anyhow!(
219            "Prompt has more than one argument, which is not supported"
220        )),
221        Some(args) if args.len() == 1 => {
222            if !arguments.is_empty() {
223                let mut map = HashMap::default();
224                map.insert(args[0].name.clone(), arguments.join(" "));
225                Ok(map)
226            } else if arguments.is_empty() && args[0].required == Some(false) {
227                Ok(HashMap::default())
228            } else {
229                Err(anyhow!("Prompt expects argument but none given"))
230            }
231        }
232        Some(_) | None => {
233            if arguments.is_empty() {
234                Ok(HashMap::default())
235            } else {
236                Err(anyhow!("Prompt expects no arguments but some were given"))
237            }
238        }
239    }
240}
241
242/// MCP servers can return prompts with multiple arguments. Since we only
243/// support one argument, we ignore all others. This is the necessary predicate
244/// for this.
245pub fn acceptable_prompt(prompt: &Prompt) -> bool {
246    match &prompt.arguments {
247        None => true,
248        Some(args) if args.len() <= 1 => true,
249        _ => false,
250    }
251}