context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection,
  5};
  6use collections::HashMap;
  7use context_servers::{
  8    manager::{ContextServer, ContextServerManager},
  9    protocol::PromptInfo,
 10};
 11use gpui::{Task, WeakView, WindowContext};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
 13use std::sync::atomic::AtomicBool;
 14use std::sync::Arc;
 15use text::LineEnding;
 16use ui::{IconName, SharedString};
 17use workspace::Workspace;
 18
 19pub struct ContextServerSlashCommand {
 20    server_id: String,
 21    prompt: PromptInfo,
 22}
 23
 24impl ContextServerSlashCommand {
 25    pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
 26        Self {
 27            server_id: server.id.clone(),
 28            prompt,
 29        }
 30    }
 31}
 32
 33impl SlashCommand for ContextServerSlashCommand {
 34    fn name(&self) -> String {
 35        self.prompt.name.clone()
 36    }
 37
 38    fn description(&self) -> String {
 39        format!("Run context server command: {}", self.prompt.name)
 40    }
 41
 42    fn menu_text(&self) -> String {
 43        format!("Run '{}' from {}", self.prompt.name, self.server_id)
 44    }
 45
 46    fn requires_argument(&self) -> bool {
 47        self.prompt.arguments.as_ref().map_or(false, |args| {
 48            args.iter().any(|arg| arg.required == Some(true))
 49        })
 50    }
 51
 52    fn complete_argument(
 53        self: Arc<Self>,
 54        arguments: &[String],
 55        _cancel: Arc<AtomicBool>,
 56        _workspace: Option<WeakView<Workspace>>,
 57        cx: &mut WindowContext,
 58    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 59        let server_id = self.server_id.clone();
 60        let prompt_name = self.prompt.name.clone();
 61        let manager = ContextServerManager::global(cx);
 62        let manager = manager.read(cx);
 63
 64        let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
 65            Ok(tp) => tp,
 66            Err(e) => {
 67                return Task::ready(Err(e));
 68            }
 69        };
 70        if let Some(server) = manager.get_server(&server_id) {
 71            cx.foreground_executor().spawn(async move {
 72                let Some(protocol) = server.client.read().clone() else {
 73                    return Err(anyhow!("Context server not initialized"));
 74                };
 75
 76                let completion_result = protocol
 77                    .completion(
 78                        context_servers::types::CompletionReference::Prompt(
 79                            context_servers::types::PromptReference {
 80                                r#type: context_servers::types::PromptReferenceType::Prompt,
 81                                name: prompt_name,
 82                            },
 83                        ),
 84                        arg_name,
 85                        arg_val,
 86                    )
 87                    .await?;
 88
 89                let completions = completion_result
 90                    .values
 91                    .into_iter()
 92                    .map(|value| ArgumentCompletion {
 93                        label: CodeLabel::plain(value.clone(), None),
 94                        new_text: value,
 95                        after_completion: AfterCompletion::Continue,
 96                        replace_previous_arguments: false,
 97                    })
 98                    .collect();
 99                Ok(completions)
100            })
101        } else {
102            Task::ready(Err(anyhow!("Context server not found")))
103        }
104    }
105
106    fn run(
107        self: Arc<Self>,
108        arguments: &[String],
109        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
110        _context_buffer: BufferSnapshot,
111        _workspace: WeakView<Workspace>,
112        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
113        cx: &mut WindowContext,
114    ) -> Task<Result<SlashCommandOutput>> {
115        let server_id = self.server_id.clone();
116        let prompt_name = self.prompt.name.clone();
117
118        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
119            Ok(args) => args,
120            Err(e) => return Task::ready(Err(e)),
121        };
122
123        let manager = ContextServerManager::global(cx);
124        let manager = manager.read(cx);
125        if let Some(server) = manager.get_server(&server_id) {
126            cx.foreground_executor().spawn(async move {
127                let Some(protocol) = server.client.read().clone() else {
128                    return Err(anyhow!("Context server not initialized"));
129                };
130                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
131                let mut prompt = result.prompt;
132
133                // We must normalize the line endings here, since servers might return CR characters.
134                LineEnding::normalize(&mut prompt);
135
136                Ok(SlashCommandOutput {
137                    sections: vec![SlashCommandOutputSection {
138                        range: 0..(prompt.len()),
139                        icon: IconName::ZedAssistant,
140                        label: SharedString::from(
141                            result
142                                .description
143                                .unwrap_or(format!("Result from {}", prompt_name)),
144                        ),
145                        metadata: None,
146                    }],
147                    text: prompt,
148                    run_commands_in_text: false,
149                })
150            })
151        } else {
152            Task::ready(Err(anyhow!("Context server not found")))
153        }
154    }
155}
156
157fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
158    if arguments.is_empty() {
159        return Err(anyhow!("No arguments given"));
160    }
161
162    match &prompt.arguments {
163        Some(args) if args.len() == 1 => {
164            let arg_name = args[0].name.clone();
165            let arg_value = arguments.join(" ");
166            Ok((arg_name, arg_value))
167        }
168        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
169        None => Err(anyhow!("Prompt has no arguments")),
170    }
171}
172
173fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
174    match &prompt.arguments {
175        Some(args) if args.len() > 1 => Err(anyhow!(
176            "Prompt has more than one argument, which is not supported"
177        )),
178        Some(args) if args.len() == 1 => {
179            if !arguments.is_empty() {
180                let mut map = HashMap::default();
181                map.insert(args[0].name.clone(), arguments.join(" "));
182                Ok(map)
183            } else if arguments.is_empty() && args[0].required == Some(false) {
184                Ok(HashMap::default())
185            } else {
186                Err(anyhow!("Prompt expects argument but none given"))
187            }
188        }
189        Some(_) | None => {
190            if arguments.is_empty() {
191                Ok(HashMap::default())
192            } else {
193                Err(anyhow!("Prompt expects no arguments but some were given"))
194            }
195        }
196    }
197}
198
199/// MCP servers can return prompts with multiple arguments. Since we only
200/// support one argument, we ignore all others. This is the necessary predicate
201/// for this.
202pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
203    match &prompt.arguments {
204        None => true,
205        Some(args) if args.len() <= 1 => true,
206        _ => false,
207    }
208}