context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
  4};
  5use collections::HashMap;
  6use context_servers::{
  7    manager::{ContextServer, ContextServerManager},
  8    protocol::PromptInfo,
  9};
 10use gpui::{Task, WeakView, WindowContext};
 11use language::LspAdapterDelegate;
 12use std::sync::atomic::AtomicBool;
 13use std::sync::Arc;
 14use ui::{IconName, SharedString};
 15use workspace::Workspace;
 16
 17pub struct ContextServerSlashCommand {
 18    server_id: String,
 19    prompt: PromptInfo,
 20}
 21
 22impl ContextServerSlashCommand {
 23    pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
 24        Self {
 25            server_id: server.id.clone(),
 26            prompt,
 27        }
 28    }
 29}
 30
 31impl SlashCommand for ContextServerSlashCommand {
 32    fn name(&self) -> String {
 33        self.prompt.name.clone()
 34    }
 35
 36    fn description(&self) -> String {
 37        format!("Run context server command: {}", self.prompt.name)
 38    }
 39
 40    fn menu_text(&self) -> String {
 41        format!("Run '{}' from {}", self.prompt.name, self.server_id)
 42    }
 43
 44    fn requires_argument(&self) -> bool {
 45        self.prompt
 46            .arguments
 47            .as_ref()
 48            .map_or(false, |args| !args.is_empty())
 49    }
 50
 51    fn complete_argument(
 52        self: Arc<Self>,
 53        _arguments: &[String],
 54        _cancel: Arc<AtomicBool>,
 55        _workspace: Option<WeakView<Workspace>>,
 56        _cx: &mut WindowContext,
 57    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 58        Task::ready(Ok(Vec::new()))
 59    }
 60
 61    fn run(
 62        self: Arc<Self>,
 63        arguments: &[String],
 64        _workspace: WeakView<Workspace>,
 65        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
 66        cx: &mut WindowContext,
 67    ) -> Task<Result<SlashCommandOutput>> {
 68        let server_id = self.server_id.clone();
 69        let prompt_name = self.prompt.name.clone();
 70        let argument = arguments.first().cloned();
 71
 72        let manager = ContextServerManager::global(cx);
 73        let manager = manager.read(cx);
 74        if let Some(server) = manager.get_server(&server_id) {
 75            cx.foreground_executor().spawn(async move {
 76                let Some(protocol) = server.client.read().clone() else {
 77                    return Err(anyhow!("Context server not initialized"));
 78                };
 79
 80                let result = protocol
 81                    .run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
 82                    .await?;
 83
 84                Ok(SlashCommandOutput {
 85                    sections: vec![SlashCommandOutputSection {
 86                        range: 0..result.len(),
 87                        icon: IconName::ZedAssistant,
 88                        label: SharedString::from(format!("Result from {}", prompt_name)),
 89                    }],
 90                    text: result,
 91                    run_commands_in_text: false,
 92                })
 93            })
 94        } else {
 95            Task::ready(Err(anyhow!("Context server not found")))
 96        }
 97    }
 98}
 99
100fn prompt_arguments(
101    prompt: &PromptInfo,
102    argument: Option<String>,
103) -> Result<HashMap<String, String>> {
104    match &prompt.arguments {
105        Some(args) if args.len() >= 2 => Err(anyhow!(
106            "Prompt has more than one argument, which is not supported"
107        )),
108        Some(args) if args.len() == 1 => match argument {
109            Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
110            None => Err(anyhow!("Prompt expects argument but none given")),
111        },
112        Some(_) | None => Ok(HashMap::default()),
113    }
114}
115
116/// MCP servers can return prompts with multiple arguments. Since we only
117/// support one argument, we ignore all others. This is the necessary predicate
118/// for this.
119pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
120    match &prompt.arguments {
121        None => true,
122        Some(args) if args.len() == 1 => true,
123        _ => false,
124    }
125}