context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
  4};
  5use collections::HashMap;
  6use context_servers::{
  7    manager::{ContextServer, ContextServerManager},
  8    protocol::PromptInfo,
  9};
 10use gpui::{Task, WeakView, WindowContext};
 11use language::LspAdapterDelegate;
 12use std::sync::atomic::AtomicBool;
 13use std::sync::Arc;
 14use ui::{IconName, SharedString};
 15use workspace::Workspace;
 16
 17pub struct ContextServerSlashCommand {
 18    server_id: String,
 19    prompt: PromptInfo,
 20}
 21
 22impl ContextServerSlashCommand {
 23    pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
 24        Self {
 25            server_id: server.id.clone(),
 26            prompt,
 27        }
 28    }
 29}
 30
 31impl SlashCommand for ContextServerSlashCommand {
 32    fn name(&self) -> String {
 33        self.prompt.name.clone()
 34    }
 35
 36    fn description(&self) -> String {
 37        format!("Run context server command: {}", self.prompt.name)
 38    }
 39
 40    fn menu_text(&self) -> String {
 41        format!("Run '{}' from {}", self.prompt.name, self.server_id)
 42    }
 43
 44    fn requires_argument(&self) -> bool {
 45        self.prompt
 46            .arguments
 47            .as_ref()
 48            .map_or(false, |args| !args.is_empty())
 49    }
 50
 51    fn complete_argument(
 52        self: Arc<Self>,
 53        _arguments: &[String],
 54        _cancel: Arc<AtomicBool>,
 55        _workspace: Option<WeakView<Workspace>>,
 56        _cx: &mut WindowContext,
 57    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 58        Task::ready(Ok(Vec::new()))
 59    }
 60
 61    fn run(
 62        self: Arc<Self>,
 63        arguments: &[String],
 64        _workspace: WeakView<Workspace>,
 65        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
 66        cx: &mut WindowContext,
 67    ) -> Task<Result<SlashCommandOutput>> {
 68        let server_id = self.server_id.clone();
 69        let prompt_name = self.prompt.name.clone();
 70
 71        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
 72            Ok(args) => args,
 73            Err(e) => return Task::ready(Err(e)),
 74        };
 75
 76        let manager = ContextServerManager::global(cx);
 77        let manager = manager.read(cx);
 78        if let Some(server) = manager.get_server(&server_id) {
 79            cx.foreground_executor().spawn(async move {
 80                let Some(protocol) = server.client.read().clone() else {
 81                    return Err(anyhow!("Context server not initialized"));
 82                };
 83                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
 84
 85                Ok(SlashCommandOutput {
 86                    sections: vec![SlashCommandOutputSection {
 87                        range: 0..(result.prompt.len()),
 88                        icon: IconName::ZedAssistant,
 89                        label: SharedString::from(
 90                            result
 91                                .description
 92                                .unwrap_or(format!("Result from {}", prompt_name)),
 93                        ),
 94                    }],
 95                    text: result.prompt,
 96                    run_commands_in_text: false,
 97                })
 98            })
 99        } else {
100            Task::ready(Err(anyhow!("Context server not found")))
101        }
102    }
103}
104
105fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
106    match &prompt.arguments {
107        Some(args) if args.len() > 1 => Err(anyhow!(
108            "Prompt has more than one argument, which is not supported"
109        )),
110        Some(args) if args.len() == 1 => {
111            if !arguments.is_empty() {
112                let mut map = HashMap::default();
113                map.insert(args[0].name.clone(), arguments.join(" "));
114                Ok(map)
115            } else {
116                Err(anyhow!("Prompt expects argument but none given"))
117            }
118        }
119        Some(_) | None => {
120            if arguments.is_empty() {
121                Ok(HashMap::default())
122            } else {
123                Err(anyhow!("Prompt expects no arguments but some were given"))
124            }
125        }
126    }
127}
128
129/// MCP servers can return prompts with multiple arguments. Since we only
130/// support one argument, we ignore all others. This is the necessary predicate
131/// for this.
132pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
133    match &prompt.arguments {
134        None => true,
135        Some(args) if args.len() == 1 => true,
136        _ => false,
137    }
138}