1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
4};
5use collections::HashMap;
6use context_servers::{
7 manager::{ContextServer, ContextServerManager},
8 protocol::PromptInfo,
9};
10use gpui::{Task, WeakView, WindowContext};
11use language::LspAdapterDelegate;
12use std::sync::atomic::AtomicBool;
13use std::sync::Arc;
14use ui::{IconName, SharedString};
15use workspace::Workspace;
16
17pub struct ContextServerSlashCommand {
18 server_id: String,
19 prompt: PromptInfo,
20}
21
22impl ContextServerSlashCommand {
23 pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
24 Self {
25 server_id: server.id.clone(),
26 prompt,
27 }
28 }
29}
30
31impl SlashCommand for ContextServerSlashCommand {
32 fn name(&self) -> String {
33 self.prompt.name.clone()
34 }
35
36 fn description(&self) -> String {
37 format!("Run context server command: {}", self.prompt.name)
38 }
39
40 fn menu_text(&self) -> String {
41 format!("Run '{}' from {}", self.prompt.name, self.server_id)
42 }
43
44 fn requires_argument(&self) -> bool {
45 self.prompt
46 .arguments
47 .as_ref()
48 .map_or(false, |args| !args.is_empty())
49 }
50
51 fn complete_argument(
52 self: Arc<Self>,
53 _arguments: &[String],
54 _cancel: Arc<AtomicBool>,
55 _workspace: Option<WeakView<Workspace>>,
56 _cx: &mut WindowContext,
57 ) -> Task<Result<Vec<ArgumentCompletion>>> {
58 Task::ready(Ok(Vec::new()))
59 }
60
61 fn run(
62 self: Arc<Self>,
63 arguments: &[String],
64 _workspace: WeakView<Workspace>,
65 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
66 cx: &mut WindowContext,
67 ) -> Task<Result<SlashCommandOutput>> {
68 let server_id = self.server_id.clone();
69 let prompt_name = self.prompt.name.clone();
70 let argument = arguments.first().cloned();
71
72 let manager = ContextServerManager::global(cx);
73 let manager = manager.read(cx);
74 if let Some(server) = manager.get_server(&server_id) {
75 cx.foreground_executor().spawn(async move {
76 let Some(protocol) = server.client.read().clone() else {
77 return Err(anyhow!("Context server not initialized"));
78 };
79
80 let result = protocol
81 .run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
82 .await?;
83
84 Ok(SlashCommandOutput {
85 sections: vec![SlashCommandOutputSection {
86 range: 0..result.len(),
87 icon: IconName::ZedAssistant,
88 label: SharedString::from(format!("Result from {}", prompt_name)),
89 }],
90 text: result,
91 run_commands_in_text: false,
92 })
93 })
94 } else {
95 Task::ready(Err(anyhow!("Context server not found")))
96 }
97 }
98}
99
100fn prompt_arguments(
101 prompt: &PromptInfo,
102 argument: Option<String>,
103) -> Result<HashMap<String, String>> {
104 match &prompt.arguments {
105 Some(args) if args.len() >= 2 => Err(anyhow!(
106 "Prompt has more than one argument, which is not supported"
107 )),
108 Some(args) if args.len() == 1 => match argument {
109 Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
110 None => Err(anyhow!("Prompt expects argument but none given")),
111 },
112 Some(_) | None => Ok(HashMap::default()),
113 }
114}
115
116/// MCP servers can return prompts with multiple arguments. Since we only
117/// support one argument, we ignore all others. This is the necessary predicate
118/// for this.
119pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
120 match &prompt.arguments {
121 None => true,
122 Some(args) if args.len() == 1 => true,
123 _ => false,
124 }
125}