context_server_command.rs

  1use super::create_label_for_command;
  2use anyhow::{anyhow, Result};
  3use assistant_slash_command::{
  4    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  5    SlashCommandOutputSection,
  6};
  7use collections::HashMap;
  8use context_servers::{
  9    manager::{ContextServer, ContextServerManager},
 10    types::Prompt,
 11};
 12use gpui::{AppContext, Task, WeakView, WindowContext};
 13use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
 14use std::sync::atomic::AtomicBool;
 15use std::sync::Arc;
 16use text::LineEnding;
 17use ui::{IconName, SharedString};
 18use workspace::Workspace;
 19
 20pub struct ContextServerSlashCommand {
 21    server_id: String,
 22    prompt: Prompt,
 23}
 24
 25impl ContextServerSlashCommand {
 26    pub fn new(server: &Arc<ContextServer>, prompt: Prompt) -> Self {
 27        Self {
 28            server_id: server.id.clone(),
 29            prompt,
 30        }
 31    }
 32}
 33
 34impl SlashCommand for ContextServerSlashCommand {
 35    fn name(&self) -> String {
 36        self.prompt.name.clone()
 37    }
 38
 39    fn label(&self, cx: &AppContext) -> language::CodeLabel {
 40        let mut parts = vec![self.prompt.name.as_str()];
 41        if let Some(args) = &self.prompt.arguments {
 42            if let Some(arg) = args.first() {
 43                parts.push(arg.name.as_str());
 44            }
 45        }
 46        create_label_for_command(&parts[0], &parts[1..], cx)
 47    }
 48
 49    fn description(&self) -> String {
 50        match &self.prompt.description {
 51            Some(desc) => desc.clone(),
 52            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 53        }
 54    }
 55
 56    fn menu_text(&self) -> String {
 57        match &self.prompt.description {
 58            Some(desc) => desc.clone(),
 59            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 60        }
 61    }
 62
 63    fn requires_argument(&self) -> bool {
 64        self.prompt.arguments.as_ref().map_or(false, |args| {
 65            args.iter().any(|arg| arg.required == Some(true))
 66        })
 67    }
 68
 69    fn complete_argument(
 70        self: Arc<Self>,
 71        arguments: &[String],
 72        _cancel: Arc<AtomicBool>,
 73        _workspace: Option<WeakView<Workspace>>,
 74        cx: &mut WindowContext,
 75    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 76        let server_id = self.server_id.clone();
 77        let prompt_name = self.prompt.name.clone();
 78        let manager = ContextServerManager::global(cx);
 79        let manager = manager.read(cx);
 80
 81        let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
 82            Ok(tp) => tp,
 83            Err(e) => {
 84                return Task::ready(Err(e));
 85            }
 86        };
 87        if let Some(server) = manager.get_server(&server_id) {
 88            cx.foreground_executor().spawn(async move {
 89                let Some(protocol) = server.client.read().clone() else {
 90                    return Err(anyhow!("Context server not initialized"));
 91                };
 92
 93                let completion_result = protocol
 94                    .completion(
 95                        context_servers::types::CompletionReference::Prompt(
 96                            context_servers::types::PromptReference {
 97                                r#type: context_servers::types::PromptReferenceType::Prompt,
 98                                name: prompt_name,
 99                            },
100                        ),
101                        arg_name,
102                        arg_val,
103                    )
104                    .await?;
105
106                let completions = completion_result
107                    .values
108                    .into_iter()
109                    .map(|value| ArgumentCompletion {
110                        label: CodeLabel::plain(value.clone(), None),
111                        new_text: value,
112                        after_completion: AfterCompletion::Continue,
113                        replace_previous_arguments: false,
114                    })
115                    .collect();
116                Ok(completions)
117            })
118        } else {
119            Task::ready(Err(anyhow!("Context server not found")))
120        }
121    }
122
123    fn run(
124        self: Arc<Self>,
125        arguments: &[String],
126        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
127        _context_buffer: BufferSnapshot,
128        _workspace: WeakView<Workspace>,
129        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
130        cx: &mut WindowContext,
131    ) -> Task<Result<SlashCommandOutput>> {
132        let server_id = self.server_id.clone();
133        let prompt_name = self.prompt.name.clone();
134
135        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
136            Ok(args) => args,
137            Err(e) => return Task::ready(Err(e)),
138        };
139
140        let manager = ContextServerManager::global(cx);
141        let manager = manager.read(cx);
142        if let Some(server) = manager.get_server(&server_id) {
143            cx.foreground_executor().spawn(async move {
144                let Some(protocol) = server.client.read().clone() else {
145                    return Err(anyhow!("Context server not initialized"));
146                };
147                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
148
149                // Check that there are only user roles
150                if result
151                    .messages
152                    .iter()
153                    .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
154                {
155                    return Err(anyhow!(
156                        "Prompt contains non-user roles, which is not supported"
157                    ));
158                }
159
160                // Extract text from user messages into a single prompt string
161                let mut prompt = result
162                    .messages
163                    .into_iter()
164                    .filter_map(|msg| match msg.content {
165                        context_servers::types::SamplingContent::Text { text } => Some(text),
166                        _ => None,
167                    })
168                    .collect::<Vec<String>>()
169                    .join("\n\n");
170
171                // We must normalize the line endings here, since servers might return CR characters.
172                LineEnding::normalize(&mut prompt);
173
174                Ok(SlashCommandOutput {
175                    sections: vec![SlashCommandOutputSection {
176                        range: 0..(prompt.len()),
177                        icon: IconName::ZedAssistant,
178                        label: SharedString::from(
179                            result
180                                .description
181                                .unwrap_or(format!("Result from {}", prompt_name)),
182                        ),
183                        metadata: None,
184                    }],
185                    text: prompt,
186                    run_commands_in_text: false,
187                })
188            })
189        } else {
190            Task::ready(Err(anyhow!("Context server not found")))
191        }
192    }
193}
194
195fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
196    if arguments.is_empty() {
197        return Err(anyhow!("No arguments given"));
198    }
199
200    match &prompt.arguments {
201        Some(args) if args.len() == 1 => {
202            let arg_name = args[0].name.clone();
203            let arg_value = arguments.join(" ");
204            Ok((arg_name, arg_value))
205        }
206        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
207        None => Err(anyhow!("Prompt has no arguments")),
208    }
209}
210
211fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
212    match &prompt.arguments {
213        Some(args) if args.len() > 1 => Err(anyhow!(
214            "Prompt has more than one argument, which is not supported"
215        )),
216        Some(args) if args.len() == 1 => {
217            if !arguments.is_empty() {
218                let mut map = HashMap::default();
219                map.insert(args[0].name.clone(), arguments.join(" "));
220                Ok(map)
221            } else if arguments.is_empty() && args[0].required == Some(false) {
222                Ok(HashMap::default())
223            } else {
224                Err(anyhow!("Prompt expects argument but none given"))
225            }
226        }
227        Some(_) | None => {
228            if arguments.is_empty() {
229                Ok(HashMap::default())
230            } else {
231                Err(anyhow!("Prompt expects no arguments but some were given"))
232            }
233        }
234    }
235}
236
237/// MCP servers can return prompts with multiple arguments. Since we only
238/// support one argument, we ignore all others. This is the necessary predicate
239/// for this.
240pub fn acceptable_prompt(prompt: &Prompt) -> bool {
241    match &prompt.arguments {
242        None => true,
243        Some(args) if args.len() <= 1 => true,
244        _ => false,
245    }
246}