context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection,
  5};
  6use collections::HashMap;
  7use context_servers::{
  8    manager::{ContextServer, ContextServerManager},
  9    protocol::PromptInfo,
 10};
 11use gpui::{Task, WeakView, WindowContext};
 12use language::{CodeLabel, LspAdapterDelegate};
 13use std::sync::atomic::AtomicBool;
 14use std::sync::Arc;
 15use text::LineEnding;
 16use ui::{IconName, SharedString};
 17use workspace::Workspace;
 18
 19pub struct ContextServerSlashCommand {
 20    server_id: String,
 21    prompt: PromptInfo,
 22}
 23
 24impl ContextServerSlashCommand {
 25    pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
 26        Self {
 27            server_id: server.id.clone(),
 28            prompt,
 29        }
 30    }
 31}
 32
 33impl SlashCommand for ContextServerSlashCommand {
 34    fn name(&self) -> String {
 35        self.prompt.name.clone()
 36    }
 37
 38    fn description(&self) -> String {
 39        format!("Run context server command: {}", self.prompt.name)
 40    }
 41
 42    fn menu_text(&self) -> String {
 43        format!("Run '{}' from {}", self.prompt.name, self.server_id)
 44    }
 45
 46    fn requires_argument(&self) -> bool {
 47        self.prompt.arguments.as_ref().map_or(false, |args| {
 48            args.iter().any(|arg| arg.required == Some(true))
 49        })
 50    }
 51
 52    fn complete_argument(
 53        self: Arc<Self>,
 54        arguments: &[String],
 55        _cancel: Arc<AtomicBool>,
 56        _workspace: Option<WeakView<Workspace>>,
 57        cx: &mut WindowContext,
 58    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 59        let server_id = self.server_id.clone();
 60        let prompt_name = self.prompt.name.clone();
 61        let manager = ContextServerManager::global(cx);
 62        let manager = manager.read(cx);
 63
 64        let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
 65            Ok(tp) => tp,
 66            Err(e) => {
 67                return Task::ready(Err(e));
 68            }
 69        };
 70        if let Some(server) = manager.get_server(&server_id) {
 71            cx.foreground_executor().spawn(async move {
 72                let Some(protocol) = server.client.read().clone() else {
 73                    return Err(anyhow!("Context server not initialized"));
 74                };
 75
 76                let completion_result = protocol
 77                    .completion(
 78                        context_servers::types::CompletionReference::Prompt(
 79                            context_servers::types::PromptReference {
 80                                r#type: context_servers::types::PromptReferenceType::Prompt,
 81                                name: prompt_name,
 82                            },
 83                        ),
 84                        arg_name,
 85                        arg_val,
 86                    )
 87                    .await?;
 88
 89                let completions = completion_result
 90                    .values
 91                    .into_iter()
 92                    .map(|value| ArgumentCompletion {
 93                        label: CodeLabel::plain(value.clone(), None),
 94                        new_text: value,
 95                        after_completion: AfterCompletion::Continue,
 96                        replace_previous_arguments: false,
 97                    })
 98                    .collect();
 99
100                Ok(completions)
101            })
102        } else {
103            Task::ready(Err(anyhow!("Context server not found")))
104        }
105    }
106
107    fn run(
108        self: Arc<Self>,
109        arguments: &[String],
110        _workspace: WeakView<Workspace>,
111        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
112        cx: &mut WindowContext,
113    ) -> Task<Result<SlashCommandOutput>> {
114        let server_id = self.server_id.clone();
115        let prompt_name = self.prompt.name.clone();
116
117        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
118            Ok(args) => args,
119            Err(e) => return Task::ready(Err(e)),
120        };
121
122        let manager = ContextServerManager::global(cx);
123        let manager = manager.read(cx);
124        if let Some(server) = manager.get_server(&server_id) {
125            cx.foreground_executor().spawn(async move {
126                let Some(protocol) = server.client.read().clone() else {
127                    return Err(anyhow!("Context server not initialized"));
128                };
129                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
130                let mut prompt = result.prompt;
131
132                // We must normalize the line endings here, since servers might return CR characters.
133                LineEnding::normalize(&mut prompt);
134
135                Ok(SlashCommandOutput {
136                    sections: vec![SlashCommandOutputSection {
137                        range: 0..(prompt.len()),
138                        icon: IconName::ZedAssistant,
139                        label: SharedString::from(
140                            result
141                                .description
142                                .unwrap_or(format!("Result from {}", prompt_name)),
143                        ),
144                    }],
145                    text: prompt,
146                    run_commands_in_text: false,
147                })
148            })
149        } else {
150            Task::ready(Err(anyhow!("Context server not found")))
151        }
152    }
153}
154
155fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
156    if arguments.is_empty() {
157        return Err(anyhow!("No arguments given"));
158    }
159
160    match &prompt.arguments {
161        Some(args) if args.len() == 1 => {
162            let arg_name = args[0].name.clone();
163            let arg_value = arguments.join(" ");
164            Ok((arg_name, arg_value))
165        }
166        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
167        None => Err(anyhow!("Prompt has no arguments")),
168    }
169}
170
171fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
172    match &prompt.arguments {
173        Some(args) if args.len() > 1 => Err(anyhow!(
174            "Prompt has more than one argument, which is not supported"
175        )),
176        Some(args) if args.len() == 1 => {
177            if !arguments.is_empty() {
178                let mut map = HashMap::default();
179                map.insert(args[0].name.clone(), arguments.join(" "));
180                Ok(map)
181            } else if arguments.is_empty() && args[0].required == Some(false) {
182                Ok(HashMap::default())
183            } else {
184                Err(anyhow!("Prompt expects argument but none given"))
185            }
186        }
187        Some(_) | None => {
188            if arguments.is_empty() {
189                Ok(HashMap::default())
190            } else {
191                Err(anyhow!("Prompt expects no arguments but some were given"))
192            }
193        }
194    }
195}
196
197/// MCP servers can return prompts with multiple arguments. Since we only
198/// support one argument, we ignore all others. This is the necessary predicate
199/// for this.
200pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
201    match &prompt.arguments {
202        None => true,
203        Some(args) if args.len() <= 1 => true,
204        _ => false,
205    }
206}