context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
  4};
  5use collections::HashMap;
  6use context_servers::{
  7    manager::{ContextServer, ContextServerManager},
  8    protocol::PromptInfo,
  9};
 10use gpui::{Task, WeakView, WindowContext};
 11use language::LspAdapterDelegate;
 12use std::sync::atomic::AtomicBool;
 13use std::sync::Arc;
 14use ui::{IconName, SharedString};
 15use workspace::Workspace;
 16
 17pub struct ContextServerSlashCommand {
 18    server_id: String,
 19    prompt: PromptInfo,
 20}
 21
 22impl ContextServerSlashCommand {
 23    pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
 24        Self {
 25            server_id: server.id.clone(),
 26            prompt,
 27        }
 28    }
 29}
 30
 31impl SlashCommand for ContextServerSlashCommand {
 32    fn name(&self) -> String {
 33        self.prompt.name.clone()
 34    }
 35
 36    fn description(&self) -> String {
 37        format!("Run context server command: {}", self.prompt.name)
 38    }
 39
 40    fn menu_text(&self) -> String {
 41        format!("Run '{}' from {}", self.prompt.name, self.server_id)
 42    }
 43
 44    fn requires_argument(&self) -> bool {
 45        self.prompt
 46            .arguments
 47            .as_ref()
 48            .map_or(false, |args| !args.is_empty())
 49    }
 50
 51    fn complete_argument(
 52        self: Arc<Self>,
 53        _arguments: &[String],
 54        _cancel: Arc<AtomicBool>,
 55        _workspace: Option<WeakView<Workspace>>,
 56        _cx: &mut WindowContext,
 57    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 58        Task::ready(Ok(Vec::new()))
 59    }
 60
 61    fn run(
 62        self: Arc<Self>,
 63        arguments: &[String],
 64        _workspace: WeakView<Workspace>,
 65        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
 66        cx: &mut WindowContext,
 67    ) -> Task<Result<SlashCommandOutput>> {
 68        let server_id = self.server_id.clone();
 69        let prompt_name = self.prompt.name.clone();
 70
 71        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
 72            Ok(args) => args,
 73            Err(e) => return Task::ready(Err(e)),
 74        };
 75
 76        let manager = ContextServerManager::global(cx);
 77        let manager = manager.read(cx);
 78        if let Some(server) = manager.get_server(&server_id) {
 79            cx.foreground_executor().spawn(async move {
 80                let Some(protocol) = server.client.read().clone() else {
 81                    return Err(anyhow!("Context server not initialized"));
 82                };
 83                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
 84
 85                Ok(SlashCommandOutput {
 86                    sections: vec![SlashCommandOutputSection {
 87                        range: 0..result.len(),
 88                        icon: IconName::ZedAssistant,
 89                        label: SharedString::from(format!("Result from {}", prompt_name)),
 90                    }],
 91                    text: result,
 92                    run_commands_in_text: false,
 93                })
 94            })
 95        } else {
 96            Task::ready(Err(anyhow!("Context server not found")))
 97        }
 98    }
 99}
100
101fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
102    match &prompt.arguments {
103        Some(args) if args.len() > 1 => Err(anyhow!(
104            "Prompt has more than one argument, which is not supported"
105        )),
106        Some(args) if args.len() == 1 => {
107            if !arguments.is_empty() {
108                let mut map = HashMap::default();
109                map.insert(args[0].name.clone(), arguments.join(" "));
110                Ok(map)
111            } else {
112                Err(anyhow!("Prompt expects argument but none given"))
113            }
114        }
115        Some(_) | None => {
116            if arguments.is_empty() {
117                Ok(HashMap::default())
118            } else {
119                Err(anyhow!("Prompt expects no arguments but some were given"))
120            }
121        }
122    }
123}
124
125/// MCP servers can return prompts with multiple arguments. Since we only
126/// support one argument, we ignore all others. This is the necessary predicate
127/// for this.
128pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
129    match &prompt.arguments {
130        None => true,
131        Some(args) if args.len() == 1 => true,
132        _ => false,
133    }
134}