1use anyhow::{Result, anyhow};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection, SlashCommandResult,
5};
6use collections::HashMap;
7use context_server::{
8 manager::{ContextServer, ContextServerManager},
9 types::Prompt,
10};
11use gpui::{App, Entity, Task, WeakEntity, Window};
12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
13use std::sync::Arc;
14use std::sync::atomic::AtomicBool;
15use text::LineEnding;
16use ui::{IconName, SharedString};
17use workspace::Workspace;
18
19use crate::create_label_for_command;
20
21pub struct ContextServerSlashCommand {
22 server_manager: Entity<ContextServerManager>,
23 server_id: Arc<str>,
24 prompt: Prompt,
25}
26
27impl ContextServerSlashCommand {
28 pub fn new(
29 server_manager: Entity<ContextServerManager>,
30 server: &Arc<ContextServer>,
31 prompt: Prompt,
32 ) -> Self {
33 Self {
34 server_id: server.id(),
35 prompt,
36 server_manager,
37 }
38 }
39}
40
41impl SlashCommand for ContextServerSlashCommand {
42 fn name(&self) -> String {
43 self.prompt.name.clone()
44 }
45
46 fn label(&self, cx: &App) -> language::CodeLabel {
47 let mut parts = vec![self.prompt.name.as_str()];
48 if let Some(args) = &self.prompt.arguments {
49 if let Some(arg) = args.first() {
50 parts.push(arg.name.as_str());
51 }
52 }
53 create_label_for_command(&parts[0], &parts[1..], cx)
54 }
55
56 fn description(&self) -> String {
57 match &self.prompt.description {
58 Some(desc) => desc.clone(),
59 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
60 }
61 }
62
63 fn menu_text(&self) -> String {
64 match &self.prompt.description {
65 Some(desc) => desc.clone(),
66 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
67 }
68 }
69
70 fn requires_argument(&self) -> bool {
71 self.prompt.arguments.as_ref().map_or(false, |args| {
72 args.iter().any(|arg| arg.required == Some(true))
73 })
74 }
75
76 fn complete_argument(
77 self: Arc<Self>,
78 arguments: &[String],
79 _cancel: Arc<AtomicBool>,
80 _workspace: Option<WeakEntity<Workspace>>,
81 _window: &mut Window,
82 cx: &mut App,
83 ) -> Task<Result<Vec<ArgumentCompletion>>> {
84 let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
85 return Task::ready(Err(anyhow!("Failed to complete argument")));
86 };
87
88 let server_id = self.server_id.clone();
89 let prompt_name = self.prompt.name.clone();
90
91 if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
92 cx.foreground_executor().spawn(async move {
93 let Some(protocol) = server.client() else {
94 return Err(anyhow!("Context server not initialized"));
95 };
96
97 let completion_result = protocol
98 .completion(
99 context_server::types::CompletionReference::Prompt(
100 context_server::types::PromptReference {
101 r#type: context_server::types::PromptReferenceType::Prompt,
102 name: prompt_name,
103 },
104 ),
105 arg_name,
106 arg_value,
107 )
108 .await?;
109
110 let completions = completion_result
111 .values
112 .into_iter()
113 .map(|value| ArgumentCompletion {
114 label: CodeLabel::plain(value.clone(), None),
115 new_text: value,
116 after_completion: AfterCompletion::Continue,
117 replace_previous_arguments: false,
118 })
119 .collect();
120 Ok(completions)
121 })
122 } else {
123 Task::ready(Err(anyhow!("Context server not found")))
124 }
125 }
126
127 fn run(
128 self: Arc<Self>,
129 arguments: &[String],
130 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
131 _context_buffer: BufferSnapshot,
132 _workspace: WeakEntity<Workspace>,
133 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
134 _window: &mut Window,
135 cx: &mut App,
136 ) -> Task<SlashCommandResult> {
137 let server_id = self.server_id.clone();
138 let prompt_name = self.prompt.name.clone();
139
140 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
141 Ok(args) => args,
142 Err(e) => return Task::ready(Err(e)),
143 };
144
145 let manager = self.server_manager.read(cx);
146 if let Some(server) = manager.get_server(&server_id) {
147 cx.foreground_executor().spawn(async move {
148 let Some(protocol) = server.client() else {
149 return Err(anyhow!("Context server not initialized"));
150 };
151 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
152
153 // Check that there are only user roles
154 if result
155 .messages
156 .iter()
157 .any(|msg| !matches!(msg.role, context_server::types::Role::User))
158 {
159 return Err(anyhow!(
160 "Prompt contains non-user roles, which is not supported"
161 ));
162 }
163
164 // Extract text from user messages into a single prompt string
165 let mut prompt = result
166 .messages
167 .into_iter()
168 .filter_map(|msg| match msg.content {
169 context_server::types::MessageContent::Text { text, .. } => Some(text),
170 _ => None,
171 })
172 .collect::<Vec<String>>()
173 .join("\n\n");
174
175 // We must normalize the line endings here, since servers might return CR characters.
176 LineEnding::normalize(&mut prompt);
177
178 Ok(SlashCommandOutput {
179 sections: vec![SlashCommandOutputSection {
180 range: 0..(prompt.len()),
181 icon: IconName::ZedAssistant,
182 label: SharedString::from(
183 result
184 .description
185 .unwrap_or(format!("Result from {}", prompt_name)),
186 ),
187 metadata: None,
188 }],
189 text: prompt,
190 run_commands_in_text: false,
191 }
192 .to_event_stream())
193 })
194 } else {
195 Task::ready(Err(anyhow!("Context server not found")))
196 }
197 }
198}
199
200fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
201 if arguments.is_empty() {
202 return Err(anyhow!("No arguments given"));
203 }
204
205 match &prompt.arguments {
206 Some(args) if args.len() == 1 => {
207 let arg_name = args[0].name.clone();
208 let arg_value = arguments.join(" ");
209 Ok((arg_name, arg_value))
210 }
211 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
212 None => Err(anyhow!("Prompt has no arguments")),
213 }
214}
215
216fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
217 match &prompt.arguments {
218 Some(args) if args.len() > 1 => Err(anyhow!(
219 "Prompt has more than one argument, which is not supported"
220 )),
221 Some(args) if args.len() == 1 => {
222 if !arguments.is_empty() {
223 let mut map = HashMap::default();
224 map.insert(args[0].name.clone(), arguments.join(" "));
225 Ok(map)
226 } else if arguments.is_empty() && args[0].required == Some(false) {
227 Ok(HashMap::default())
228 } else {
229 Err(anyhow!("Prompt expects argument but none given"))
230 }
231 }
232 Some(_) | None => {
233 if arguments.is_empty() {
234 Ok(HashMap::default())
235 } else {
236 Err(anyhow!("Prompt expects no arguments but some were given"))
237 }
238 }
239 }
240}
241
242/// MCP servers can return prompts with multiple arguments. Since we only
243/// support one argument, we ignore all others. This is the necessary predicate
244/// for this.
245pub fn acceptable_prompt(prompt: &Prompt) -> bool {
246 match &prompt.arguments {
247 None => true,
248 Some(args) if args.len() <= 1 => true,
249 _ => false,
250 }
251}