1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
4};
5use collections::HashMap;
6use context_servers::{
7 manager::{ContextServer, ContextServerManager},
8 protocol::PromptInfo,
9};
10use gpui::{Task, WeakView, WindowContext};
11use language::LspAdapterDelegate;
12use std::sync::atomic::AtomicBool;
13use std::sync::Arc;
14use ui::{IconName, SharedString};
15use workspace::Workspace;
16
17pub struct ContextServerSlashCommand {
18 server_id: String,
19 prompt: PromptInfo,
20}
21
22impl ContextServerSlashCommand {
23 pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
24 Self {
25 server_id: server.id.clone(),
26 prompt,
27 }
28 }
29}
30
31impl SlashCommand for ContextServerSlashCommand {
32 fn name(&self) -> String {
33 self.prompt.name.clone()
34 }
35
36 fn description(&self) -> String {
37 format!("Run context server command: {}", self.prompt.name)
38 }
39
40 fn menu_text(&self) -> String {
41 format!("Run '{}' from {}", self.prompt.name, self.server_id)
42 }
43
44 fn requires_argument(&self) -> bool {
45 self.prompt
46 .arguments
47 .as_ref()
48 .map_or(false, |args| !args.is_empty())
49 }
50
51 fn complete_argument(
52 self: Arc<Self>,
53 _arguments: &[String],
54 _cancel: Arc<AtomicBool>,
55 _workspace: Option<WeakView<Workspace>>,
56 _cx: &mut WindowContext,
57 ) -> Task<Result<Vec<ArgumentCompletion>>> {
58 Task::ready(Ok(Vec::new()))
59 }
60
61 fn run(
62 self: Arc<Self>,
63 arguments: &[String],
64 _workspace: WeakView<Workspace>,
65 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
66 cx: &mut WindowContext,
67 ) -> Task<Result<SlashCommandOutput>> {
68 let server_id = self.server_id.clone();
69 let prompt_name = self.prompt.name.clone();
70
71 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
72 Ok(args) => args,
73 Err(e) => return Task::ready(Err(e)),
74 };
75
76 let manager = ContextServerManager::global(cx);
77 let manager = manager.read(cx);
78 if let Some(server) = manager.get_server(&server_id) {
79 cx.foreground_executor().spawn(async move {
80 let Some(protocol) = server.client.read().clone() else {
81 return Err(anyhow!("Context server not initialized"));
82 };
83 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
84
85 Ok(SlashCommandOutput {
86 sections: vec![SlashCommandOutputSection {
87 range: 0..(result.prompt.len()),
88 icon: IconName::ZedAssistant,
89 label: SharedString::from(
90 result
91 .description
92 .unwrap_or(format!("Result from {}", prompt_name)),
93 ),
94 }],
95 text: result.prompt,
96 run_commands_in_text: false,
97 })
98 })
99 } else {
100 Task::ready(Err(anyhow!("Context server not found")))
101 }
102 }
103}
104
105fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
106 match &prompt.arguments {
107 Some(args) if args.len() > 1 => Err(anyhow!(
108 "Prompt has more than one argument, which is not supported"
109 )),
110 Some(args) if args.len() == 1 => {
111 if !arguments.is_empty() {
112 let mut map = HashMap::default();
113 map.insert(args[0].name.clone(), arguments.join(" "));
114 Ok(map)
115 } else {
116 Err(anyhow!("Prompt expects argument but none given"))
117 }
118 }
119 Some(_) | None => {
120 if arguments.is_empty() {
121 Ok(HashMap::default())
122 } else {
123 Err(anyhow!("Prompt expects no arguments but some were given"))
124 }
125 }
126 }
127}
128
129/// MCP servers can return prompts with multiple arguments. Since we only
130/// support one argument, we ignore all others. This is the necessary predicate
131/// for this.
132pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
133 match &prompt.arguments {
134 None => true,
135 Some(args) if args.len() == 1 => true,
136 _ => false,
137 }
138}