context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection,
  5};
  6use collections::HashMap;
  7use context_servers::{
  8    manager::{ContextServer, ContextServerManager},
  9    protocol::PromptInfo,
 10};
 11use gpui::{Task, WeakView, WindowContext};
 12use language::{CodeLabel, LspAdapterDelegate};
 13use std::sync::atomic::AtomicBool;
 14use std::sync::Arc;
 15use ui::{IconName, SharedString};
 16use workspace::Workspace;
 17
 18pub struct ContextServerSlashCommand {
 19    server_id: String,
 20    prompt: PromptInfo,
 21}
 22
 23impl ContextServerSlashCommand {
 24    pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
 25        Self {
 26            server_id: server.id.clone(),
 27            prompt,
 28        }
 29    }
 30}
 31
 32impl SlashCommand for ContextServerSlashCommand {
 33    fn name(&self) -> String {
 34        self.prompt.name.clone()
 35    }
 36
 37    fn description(&self) -> String {
 38        format!("Run context server command: {}", self.prompt.name)
 39    }
 40
 41    fn menu_text(&self) -> String {
 42        format!("Run '{}' from {}", self.prompt.name, self.server_id)
 43    }
 44
 45    fn requires_argument(&self) -> bool {
 46        self.prompt
 47            .arguments
 48            .as_ref()
 49            .map_or(false, |args| !args.is_empty())
 50    }
 51
 52    fn complete_argument(
 53        self: Arc<Self>,
 54        arguments: &[String],
 55        _cancel: Arc<AtomicBool>,
 56        _workspace: Option<WeakView<Workspace>>,
 57        cx: &mut WindowContext,
 58    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 59        let server_id = self.server_id.clone();
 60        let prompt_name = self.prompt.name.clone();
 61        let manager = ContextServerManager::global(cx);
 62        let manager = manager.read(cx);
 63
 64        let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
 65            Ok(tp) => tp,
 66            Err(e) => {
 67                return Task::ready(Err(e));
 68            }
 69        };
 70        if let Some(server) = manager.get_server(&server_id) {
 71            cx.foreground_executor().spawn(async move {
 72                let Some(protocol) = server.client.read().clone() else {
 73                    return Err(anyhow!("Context server not initialized"));
 74                };
 75
 76                let completion_result = protocol
 77                    .completion(
 78                        context_servers::types::CompletionReference::Prompt(
 79                            context_servers::types::PromptReference {
 80                                r#type: context_servers::types::PromptReferenceType::Prompt,
 81                                name: prompt_name,
 82                            },
 83                        ),
 84                        arg_name,
 85                        arg_val,
 86                    )
 87                    .await?;
 88
 89                let completions = completion_result
 90                    .values
 91                    .into_iter()
 92                    .map(|value| ArgumentCompletion {
 93                        label: CodeLabel::plain(value.clone(), None),
 94                        new_text: value,
 95                        after_completion: AfterCompletion::Continue,
 96                        replace_previous_arguments: false,
 97                    })
 98                    .collect();
 99
100                Ok(completions)
101            })
102        } else {
103            Task::ready(Err(anyhow!("Context server not found")))
104        }
105    }
106
107    fn run(
108        self: Arc<Self>,
109        arguments: &[String],
110        _workspace: WeakView<Workspace>,
111        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
112        cx: &mut WindowContext,
113    ) -> Task<Result<SlashCommandOutput>> {
114        let server_id = self.server_id.clone();
115        let prompt_name = self.prompt.name.clone();
116
117        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
118            Ok(args) => args,
119            Err(e) => return Task::ready(Err(e)),
120        };
121
122        let manager = ContextServerManager::global(cx);
123        let manager = manager.read(cx);
124        if let Some(server) = manager.get_server(&server_id) {
125            cx.foreground_executor().spawn(async move {
126                let Some(protocol) = server.client.read().clone() else {
127                    return Err(anyhow!("Context server not initialized"));
128                };
129                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
130
131                Ok(SlashCommandOutput {
132                    sections: vec![SlashCommandOutputSection {
133                        range: 0..(result.prompt.len()),
134                        icon: IconName::ZedAssistant,
135                        label: SharedString::from(
136                            result
137                                .description
138                                .unwrap_or(format!("Result from {}", prompt_name)),
139                        ),
140                    }],
141                    text: result.prompt,
142                    run_commands_in_text: false,
143                })
144            })
145        } else {
146            Task::ready(Err(anyhow!("Context server not found")))
147        }
148    }
149}
150
151fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
152    if arguments.is_empty() {
153        return Err(anyhow!("No arguments given"));
154    }
155
156    match &prompt.arguments {
157        Some(args) if args.len() == 1 => {
158            let arg_name = args[0].name.clone();
159            let arg_value = arguments.join(" ");
160            Ok((arg_name, arg_value))
161        }
162        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
163        None => Err(anyhow!("Prompt has no arguments")),
164    }
165}
166
167fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
168    match &prompt.arguments {
169        Some(args) if args.len() > 1 => Err(anyhow!(
170            "Prompt has more than one argument, which is not supported"
171        )),
172        Some(args) if args.len() == 1 => {
173            if !arguments.is_empty() {
174                let mut map = HashMap::default();
175                map.insert(args[0].name.clone(), arguments.join(" "));
176                Ok(map)
177            } else {
178                Err(anyhow!("Prompt expects argument but none given"))
179            }
180        }
181        Some(_) | None => {
182            if arguments.is_empty() {
183                Ok(HashMap::default())
184            } else {
185                Err(anyhow!("Prompt expects no arguments but some were given"))
186            }
187        }
188    }
189}
190
191/// MCP servers can return prompts with multiple arguments. Since we only
192/// support one argument, we ignore all others. This is the necessary predicate
193/// for this.
194pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
195    match &prompt.arguments {
196        None => true,
197        Some(args) if args.len() == 1 => true,
198        _ => false,
199    }
200}