context_server_command.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection, SlashCommandResult,
  5};
  6use collections::HashMap;
  7use context_server::{ContextServerId, types::Prompt};
  8use gpui::{App, Entity, Task, WeakEntity, Window};
  9use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
 10use project::context_server_store::ContextServerStore;
 11use std::sync::Arc;
 12use std::sync::atomic::AtomicBool;
 13use text::LineEnding;
 14use ui::{IconName, SharedString};
 15use workspace::Workspace;
 16
 17use crate::create_label_for_command;
 18
 19pub struct ContextServerSlashCommand {
 20    store: Entity<ContextServerStore>,
 21    server_id: ContextServerId,
 22    prompt: Prompt,
 23}
 24
 25impl ContextServerSlashCommand {
 26    pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
 27        Self {
 28            server_id: id,
 29            prompt,
 30            store,
 31        }
 32    }
 33}
 34
 35impl SlashCommand for ContextServerSlashCommand {
 36    fn name(&self) -> String {
 37        self.prompt.name.clone()
 38    }
 39
 40    fn label(&self, cx: &App) -> language::CodeLabel {
 41        let mut parts = vec![self.prompt.name.as_str()];
 42        if let Some(args) = &self.prompt.arguments
 43            && let Some(arg) = args.first() {
 44                parts.push(arg.name.as_str());
 45            }
 46        create_label_for_command(parts[0], &parts[1..], cx)
 47    }
 48
 49    fn description(&self) -> String {
 50        match &self.prompt.description {
 51            Some(desc) => desc.clone(),
 52            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 53        }
 54    }
 55
 56    fn menu_text(&self) -> String {
 57        match &self.prompt.description {
 58            Some(desc) => desc.clone(),
 59            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 60        }
 61    }
 62
 63    fn requires_argument(&self) -> bool {
 64        self.prompt.arguments.as_ref().map_or(false, |args| {
 65            args.iter().any(|arg| arg.required == Some(true))
 66        })
 67    }
 68
 69    fn complete_argument(
 70        self: Arc<Self>,
 71        arguments: &[String],
 72        _cancel: Arc<AtomicBool>,
 73        _workspace: Option<WeakEntity<Workspace>>,
 74        _window: &mut Window,
 75        cx: &mut App,
 76    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 77        let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
 78            return Task::ready(Err(anyhow!("Failed to complete argument")));
 79        };
 80
 81        let server_id = self.server_id.clone();
 82        let prompt_name = self.prompt.name.clone();
 83
 84        if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
 85            cx.foreground_executor().spawn(async move {
 86                let protocol = server.client().context("Context server not initialized")?;
 87
 88                let response = protocol
 89                    .request::<context_server::types::requests::CompletionComplete>(
 90                        context_server::types::CompletionCompleteParams {
 91                            reference: context_server::types::CompletionReference::Prompt(
 92                                context_server::types::PromptReference {
 93                                    ty: context_server::types::PromptReferenceType::Prompt,
 94                                    name: prompt_name,
 95                                },
 96                            ),
 97                            argument: context_server::types::CompletionArgument {
 98                                name: arg_name,
 99                                value: arg_value,
100                            },
101                            meta: None,
102                        },
103                    )
104                    .await?;
105
106                let completions = response
107                    .completion
108                    .values
109                    .into_iter()
110                    .map(|value| ArgumentCompletion {
111                        label: CodeLabel::plain(value.clone(), None),
112                        new_text: value,
113                        after_completion: AfterCompletion::Continue,
114                        replace_previous_arguments: false,
115                    })
116                    .collect();
117                Ok(completions)
118            })
119        } else {
120            Task::ready(Err(anyhow!("Context server not found")))
121        }
122    }
123
124    fn run(
125        self: Arc<Self>,
126        arguments: &[String],
127        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
128        _context_buffer: BufferSnapshot,
129        _workspace: WeakEntity<Workspace>,
130        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
131        _window: &mut Window,
132        cx: &mut App,
133    ) -> Task<SlashCommandResult> {
134        let server_id = self.server_id.clone();
135        let prompt_name = self.prompt.name.clone();
136
137        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
138            Ok(args) => args,
139            Err(e) => return Task::ready(Err(e)),
140        };
141
142        let store = self.store.read(cx);
143        if let Some(server) = store.get_running_server(&server_id) {
144            cx.foreground_executor().spawn(async move {
145                let protocol = server.client().context("Context server not initialized")?;
146                let response = protocol
147                    .request::<context_server::types::requests::PromptsGet>(
148                        context_server::types::PromptsGetParams {
149                            name: prompt_name.clone(),
150                            arguments: Some(prompt_args),
151                            meta: None,
152                        },
153                    )
154                    .await?;
155
156                anyhow::ensure!(
157                    response
158                        .messages
159                        .iter()
160                        .all(|msg| matches!(msg.role, context_server::types::Role::User)),
161                    "Prompt contains non-user roles, which is not supported"
162                );
163
164                // Extract text from user messages into a single prompt string
165                let mut prompt = response
166                    .messages
167                    .into_iter()
168                    .filter_map(|msg| match msg.content {
169                        context_server::types::MessageContent::Text { text, .. } => Some(text),
170                        _ => None,
171                    })
172                    .collect::<Vec<String>>()
173                    .join("\n\n");
174
175                // We must normalize the line endings here, since servers might return CR characters.
176                LineEnding::normalize(&mut prompt);
177
178                Ok(SlashCommandOutput {
179                    sections: vec![SlashCommandOutputSection {
180                        range: 0..(prompt.len()),
181                        icon: IconName::ZedAssistant,
182                        label: SharedString::from(
183                            response
184                                .description
185                                .unwrap_or(format!("Result from {}", prompt_name)),
186                        ),
187                        metadata: None,
188                    }],
189                    text: prompt,
190                    run_commands_in_text: false,
191                }
192                .to_event_stream())
193            })
194        } else {
195            Task::ready(Err(anyhow!("Context server not found")))
196        }
197    }
198}
199
200fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
201    anyhow::ensure!(!arguments.is_empty(), "No arguments given");
202
203    match &prompt.arguments {
204        Some(args) if args.len() == 1 => {
205            let arg_name = args[0].name.clone();
206            let arg_value = arguments.join(" ");
207            Ok((arg_name, arg_value))
208        }
209        Some(_) => anyhow::bail!("Prompt must have exactly one argument"),
210        None => anyhow::bail!("Prompt has no arguments"),
211    }
212}
213
214fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
215    match &prompt.arguments {
216        Some(args) if args.len() > 1 => {
217            anyhow::bail!("Prompt has more than one argument, which is not supported");
218        }
219        Some(args) if args.len() == 1 => {
220            if !arguments.is_empty() {
221                let mut map = HashMap::default();
222                map.insert(args[0].name.clone(), arguments.join(" "));
223                Ok(map)
224            } else if arguments.is_empty() && args[0].required == Some(false) {
225                Ok(HashMap::default())
226            } else {
227                anyhow::bail!("Prompt expects argument but none given");
228            }
229        }
230        Some(_) | None => {
231            anyhow::ensure!(
232                arguments.is_empty(),
233                "Prompt expects no arguments but some were given"
234            );
235            Ok(HashMap::default())
236        }
237    }
238}
239
240/// MCP servers can return prompts with multiple arguments. Since we only
241/// support one argument, we ignore all others. This is the necessary predicate
242/// for this.
243pub fn acceptable_prompt(prompt: &Prompt) -> bool {
244    match &prompt.arguments {
245        None => true,
246        Some(args) if args.len() <= 1 => true,
247        _ => false,
248    }
249}