context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection,
  5};
  6use collections::HashMap;
  7use context_servers::{
  8    manager::{ContextServer, ContextServerManager},
  9    protocol::PromptInfo,
 10};
 11use gpui::{Task, WeakView, WindowContext};
 12use language::{CodeLabel, LspAdapterDelegate};
 13use std::sync::atomic::AtomicBool;
 14use std::sync::Arc;
 15use text::LineEnding;
 16use ui::{IconName, SharedString};
 17use workspace::Workspace;
 18
 19pub struct ContextServerSlashCommand {
 20    server_id: String,
 21    prompt: PromptInfo,
 22}
 23
 24impl ContextServerSlashCommand {
 25    pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
 26        Self {
 27            server_id: server.id.clone(),
 28            prompt,
 29        }
 30    }
 31}
 32
 33impl SlashCommand for ContextServerSlashCommand {
 34    fn name(&self) -> String {
 35        self.prompt.name.clone()
 36    }
 37
 38    fn description(&self) -> String {
 39        format!("Run context server command: {}", self.prompt.name)
 40    }
 41
 42    fn menu_text(&self) -> String {
 43        format!("Run '{}' from {}", self.prompt.name, self.server_id)
 44    }
 45
 46    fn requires_argument(&self) -> bool {
 47        self.prompt
 48            .arguments
 49            .as_ref()
 50            .map_or(false, |args| !args.is_empty())
 51    }
 52
 53    fn complete_argument(
 54        self: Arc<Self>,
 55        arguments: &[String],
 56        _cancel: Arc<AtomicBool>,
 57        _workspace: Option<WeakView<Workspace>>,
 58        cx: &mut WindowContext,
 59    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 60        let server_id = self.server_id.clone();
 61        let prompt_name = self.prompt.name.clone();
 62        let manager = ContextServerManager::global(cx);
 63        let manager = manager.read(cx);
 64
 65        let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
 66            Ok(tp) => tp,
 67            Err(e) => {
 68                return Task::ready(Err(e));
 69            }
 70        };
 71        if let Some(server) = manager.get_server(&server_id) {
 72            cx.foreground_executor().spawn(async move {
 73                let Some(protocol) = server.client.read().clone() else {
 74                    return Err(anyhow!("Context server not initialized"));
 75                };
 76
 77                let completion_result = protocol
 78                    .completion(
 79                        context_servers::types::CompletionReference::Prompt(
 80                            context_servers::types::PromptReference {
 81                                r#type: context_servers::types::PromptReferenceType::Prompt,
 82                                name: prompt_name,
 83                            },
 84                        ),
 85                        arg_name,
 86                        arg_val,
 87                    )
 88                    .await?;
 89
 90                let completions = completion_result
 91                    .values
 92                    .into_iter()
 93                    .map(|value| ArgumentCompletion {
 94                        label: CodeLabel::plain(value.clone(), None),
 95                        new_text: value,
 96                        after_completion: AfterCompletion::Continue,
 97                        replace_previous_arguments: false,
 98                    })
 99                    .collect();
100
101                Ok(completions)
102            })
103        } else {
104            Task::ready(Err(anyhow!("Context server not found")))
105        }
106    }
107
108    fn run(
109        self: Arc<Self>,
110        arguments: &[String],
111        _workspace: WeakView<Workspace>,
112        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
113        cx: &mut WindowContext,
114    ) -> Task<Result<SlashCommandOutput>> {
115        let server_id = self.server_id.clone();
116        let prompt_name = self.prompt.name.clone();
117
118        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
119            Ok(args) => args,
120            Err(e) => return Task::ready(Err(e)),
121        };
122
123        let manager = ContextServerManager::global(cx);
124        let manager = manager.read(cx);
125        if let Some(server) = manager.get_server(&server_id) {
126            cx.foreground_executor().spawn(async move {
127                let Some(protocol) = server.client.read().clone() else {
128                    return Err(anyhow!("Context server not initialized"));
129                };
130                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
131                let mut prompt = result.prompt;
132
133                // We must normalize the line endings here, since servers might return CR characters.
134                LineEnding::normalize(&mut prompt);
135
136                Ok(SlashCommandOutput {
137                    sections: vec![SlashCommandOutputSection {
138                        range: 0..(prompt.len()),
139                        icon: IconName::ZedAssistant,
140                        label: SharedString::from(
141                            result
142                                .description
143                                .unwrap_or(format!("Result from {}", prompt_name)),
144                        ),
145                    }],
146                    text: prompt,
147                    run_commands_in_text: false,
148                })
149            })
150        } else {
151            Task::ready(Err(anyhow!("Context server not found")))
152        }
153    }
154}
155
156fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
157    if arguments.is_empty() {
158        return Err(anyhow!("No arguments given"));
159    }
160
161    match &prompt.arguments {
162        Some(args) if args.len() == 1 => {
163            let arg_name = args[0].name.clone();
164            let arg_value = arguments.join(" ");
165            Ok((arg_name, arg_value))
166        }
167        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
168        None => Err(anyhow!("Prompt has no arguments")),
169    }
170}
171
172fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
173    match &prompt.arguments {
174        Some(args) if args.len() > 1 => Err(anyhow!(
175            "Prompt has more than one argument, which is not supported"
176        )),
177        Some(args) if args.len() == 1 => {
178            if !arguments.is_empty() {
179                let mut map = HashMap::default();
180                map.insert(args[0].name.clone(), arguments.join(" "));
181                Ok(map)
182            } else {
183                Err(anyhow!("Prompt expects argument but none given"))
184            }
185        }
186        Some(_) | None => {
187            if arguments.is_empty() {
188                Ok(HashMap::default())
189            } else {
190                Err(anyhow!("Prompt expects no arguments but some were given"))
191            }
192        }
193    }
194}
195
196/// MCP servers can return prompts with multiple arguments. Since we only
197/// support one argument, we ignore all others. This is the necessary predicate
198/// for this.
199pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
200    match &prompt.arguments {
201        None => true,
202        Some(args) if args.len() == 1 => true,
203        _ => false,
204    }
205}