context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection, SlashCommandResult,
  5};
  6use collections::HashMap;
  7use context_servers::{
  8    manager::{ContextServer, ContextServerManager},
  9    types::Prompt,
 10};
 11use gpui::{AppContext, Task, WeakView, WindowContext};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
 13use std::sync::atomic::AtomicBool;
 14use std::sync::Arc;
 15use text::LineEnding;
 16use ui::{IconName, SharedString};
 17use workspace::Workspace;
 18
 19use crate::slash_command::create_label_for_command;
 20
 21pub struct ContextServerSlashCommand {
 22    server_id: String,
 23    prompt: Prompt,
 24}
 25
 26impl ContextServerSlashCommand {
 27    pub fn new(server: &Arc<ContextServer>, prompt: Prompt) -> Self {
 28        Self {
 29            server_id: server.id.clone(),
 30            prompt,
 31        }
 32    }
 33}
 34
 35impl SlashCommand for ContextServerSlashCommand {
 36    fn name(&self) -> String {
 37        self.prompt.name.clone()
 38    }
 39
 40    fn label(&self, cx: &AppContext) -> language::CodeLabel {
 41        let mut parts = vec![self.prompt.name.as_str()];
 42        if let Some(args) = &self.prompt.arguments {
 43            if let Some(arg) = args.first() {
 44                parts.push(arg.name.as_str());
 45            }
 46        }
 47        create_label_for_command(&parts[0], &parts[1..], cx)
 48    }
 49
 50    fn description(&self) -> String {
 51        match &self.prompt.description {
 52            Some(desc) => desc.clone(),
 53            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 54        }
 55    }
 56
 57    fn menu_text(&self) -> String {
 58        match &self.prompt.description {
 59            Some(desc) => desc.clone(),
 60            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 61        }
 62    }
 63
 64    fn requires_argument(&self) -> bool {
 65        self.prompt.arguments.as_ref().map_or(false, |args| {
 66            args.iter().any(|arg| arg.required == Some(true))
 67        })
 68    }
 69
 70    fn complete_argument(
 71        self: Arc<Self>,
 72        arguments: &[String],
 73        _cancel: Arc<AtomicBool>,
 74        _workspace: Option<WeakView<Workspace>>,
 75        cx: &mut WindowContext,
 76    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 77        let server_id = self.server_id.clone();
 78        let prompt_name = self.prompt.name.clone();
 79        let manager = ContextServerManager::global(cx);
 80        let manager = manager.read(cx);
 81
 82        let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
 83            Ok(tp) => tp,
 84            Err(e) => {
 85                return Task::ready(Err(e));
 86            }
 87        };
 88        if let Some(server) = manager.get_server(&server_id) {
 89            cx.foreground_executor().spawn(async move {
 90                let Some(protocol) = server.client.read().clone() else {
 91                    return Err(anyhow!("Context server not initialized"));
 92                };
 93
 94                let completion_result = protocol
 95                    .completion(
 96                        context_servers::types::CompletionReference::Prompt(
 97                            context_servers::types::PromptReference {
 98                                r#type: context_servers::types::PromptReferenceType::Prompt,
 99                                name: prompt_name,
100                            },
101                        ),
102                        arg_name,
103                        arg_val,
104                    )
105                    .await?;
106
107                let completions = completion_result
108                    .values
109                    .into_iter()
110                    .map(|value| ArgumentCompletion {
111                        label: CodeLabel::plain(value.clone(), None),
112                        new_text: value,
113                        after_completion: AfterCompletion::Continue,
114                        replace_previous_arguments: false,
115                    })
116                    .collect();
117                Ok(completions)
118            })
119        } else {
120            Task::ready(Err(anyhow!("Context server not found")))
121        }
122    }
123
124    fn run(
125        self: Arc<Self>,
126        arguments: &[String],
127        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
128        _context_buffer: BufferSnapshot,
129        _workspace: WeakView<Workspace>,
130        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
131        cx: &mut WindowContext,
132    ) -> Task<SlashCommandResult> {
133        let server_id = self.server_id.clone();
134        let prompt_name = self.prompt.name.clone();
135
136        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
137            Ok(args) => args,
138            Err(e) => return Task::ready(Err(e)),
139        };
140
141        let manager = ContextServerManager::global(cx);
142        let manager = manager.read(cx);
143        if let Some(server) = manager.get_server(&server_id) {
144            cx.foreground_executor().spawn(async move {
145                let Some(protocol) = server.client.read().clone() else {
146                    return Err(anyhow!("Context server not initialized"));
147                };
148                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
149
150                // Check that there are only user roles
151                if result
152                    .messages
153                    .iter()
154                    .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
155                {
156                    return Err(anyhow!(
157                        "Prompt contains non-user roles, which is not supported"
158                    ));
159                }
160
161                // Extract text from user messages into a single prompt string
162                let mut prompt = result
163                    .messages
164                    .into_iter()
165                    .filter_map(|msg| match msg.content {
166                        context_servers::types::SamplingContent::Text { text } => Some(text),
167                        _ => None,
168                    })
169                    .collect::<Vec<String>>()
170                    .join("\n\n");
171
172                // We must normalize the line endings here, since servers might return CR characters.
173                LineEnding::normalize(&mut prompt);
174
175                Ok(SlashCommandOutput {
176                    sections: vec![SlashCommandOutputSection {
177                        range: 0..(prompt.len()),
178                        icon: IconName::ZedAssistant,
179                        label: SharedString::from(
180                            result
181                                .description
182                                .unwrap_or(format!("Result from {}", prompt_name)),
183                        ),
184                        metadata: None,
185                    }],
186                    text: prompt,
187                    run_commands_in_text: false,
188                }
189                .to_event_stream())
190            })
191        } else {
192            Task::ready(Err(anyhow!("Context server not found")))
193        }
194    }
195}
196
197fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
198    if arguments.is_empty() {
199        return Err(anyhow!("No arguments given"));
200    }
201
202    match &prompt.arguments {
203        Some(args) if args.len() == 1 => {
204            let arg_name = args[0].name.clone();
205            let arg_value = arguments.join(" ");
206            Ok((arg_name, arg_value))
207        }
208        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
209        None => Err(anyhow!("Prompt has no arguments")),
210    }
211}
212
213fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
214    match &prompt.arguments {
215        Some(args) if args.len() > 1 => Err(anyhow!(
216            "Prompt has more than one argument, which is not supported"
217        )),
218        Some(args) if args.len() == 1 => {
219            if !arguments.is_empty() {
220                let mut map = HashMap::default();
221                map.insert(args[0].name.clone(), arguments.join(" "));
222                Ok(map)
223            } else if arguments.is_empty() && args[0].required == Some(false) {
224                Ok(HashMap::default())
225            } else {
226                Err(anyhow!("Prompt expects argument but none given"))
227            }
228        }
229        Some(_) | None => {
230            if arguments.is_empty() {
231                Ok(HashMap::default())
232            } else {
233                Err(anyhow!("Prompt expects no arguments but some were given"))
234            }
235        }
236    }
237}
238
239/// MCP servers can return prompts with multiple arguments. Since we only
240/// support one argument, we ignore all others. This is the necessary predicate
241/// for this.
242pub fn acceptable_prompt(prompt: &Prompt) -> bool {
243    match &prompt.arguments {
244        None => true,
245        Some(args) if args.len() <= 1 => true,
246        _ => false,
247    }
248}