1use super::create_label_for_command;
2use anyhow::{anyhow, Result};
3use assistant_slash_command::{
4 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
5 SlashCommandOutputSection,
6};
7use collections::HashMap;
8use context_servers::{
9 manager::{ContextServer, ContextServerManager},
10 types::Prompt,
11};
12use gpui::{AppContext, Task, WeakView, WindowContext};
13use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
14use std::sync::atomic::AtomicBool;
15use std::sync::Arc;
16use text::LineEnding;
17use ui::{IconName, SharedString};
18use workspace::Workspace;
19
20pub struct ContextServerSlashCommand {
21 server_id: String,
22 prompt: Prompt,
23}
24
25impl ContextServerSlashCommand {
26 pub fn new(server: &Arc<ContextServer>, prompt: Prompt) -> Self {
27 Self {
28 server_id: server.id.clone(),
29 prompt,
30 }
31 }
32}
33
34impl SlashCommand for ContextServerSlashCommand {
35 fn name(&self) -> String {
36 self.prompt.name.clone()
37 }
38
39 fn label(&self, cx: &AppContext) -> language::CodeLabel {
40 let mut parts = vec![self.prompt.name.as_str()];
41 if let Some(args) = &self.prompt.arguments {
42 if let Some(arg) = args.first() {
43 parts.push(arg.name.as_str());
44 }
45 }
46 create_label_for_command(&parts[0], &parts[1..], cx)
47 }
48
49 fn description(&self) -> String {
50 match &self.prompt.description {
51 Some(desc) => desc.clone(),
52 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
53 }
54 }
55
56 fn menu_text(&self) -> String {
57 match &self.prompt.description {
58 Some(desc) => desc.clone(),
59 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
60 }
61 }
62
63 fn requires_argument(&self) -> bool {
64 self.prompt.arguments.as_ref().map_or(false, |args| {
65 args.iter().any(|arg| arg.required == Some(true))
66 })
67 }
68
69 fn complete_argument(
70 self: Arc<Self>,
71 arguments: &[String],
72 _cancel: Arc<AtomicBool>,
73 _workspace: Option<WeakView<Workspace>>,
74 cx: &mut WindowContext,
75 ) -> Task<Result<Vec<ArgumentCompletion>>> {
76 let server_id = self.server_id.clone();
77 let prompt_name = self.prompt.name.clone();
78 let manager = ContextServerManager::global(cx);
79 let manager = manager.read(cx);
80
81 let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
82 Ok(tp) => tp,
83 Err(e) => {
84 return Task::ready(Err(e));
85 }
86 };
87 if let Some(server) = manager.get_server(&server_id) {
88 cx.foreground_executor().spawn(async move {
89 let Some(protocol) = server.client.read().clone() else {
90 return Err(anyhow!("Context server not initialized"));
91 };
92
93 let completion_result = protocol
94 .completion(
95 context_servers::types::CompletionReference::Prompt(
96 context_servers::types::PromptReference {
97 r#type: context_servers::types::PromptReferenceType::Prompt,
98 name: prompt_name,
99 },
100 ),
101 arg_name,
102 arg_val,
103 )
104 .await?;
105
106 let completions = completion_result
107 .values
108 .into_iter()
109 .map(|value| ArgumentCompletion {
110 label: CodeLabel::plain(value.clone(), None),
111 new_text: value,
112 after_completion: AfterCompletion::Continue,
113 replace_previous_arguments: false,
114 })
115 .collect();
116 Ok(completions)
117 })
118 } else {
119 Task::ready(Err(anyhow!("Context server not found")))
120 }
121 }
122
123 fn run(
124 self: Arc<Self>,
125 arguments: &[String],
126 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
127 _context_buffer: BufferSnapshot,
128 _workspace: WeakView<Workspace>,
129 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
130 cx: &mut WindowContext,
131 ) -> Task<Result<SlashCommandOutput>> {
132 let server_id = self.server_id.clone();
133 let prompt_name = self.prompt.name.clone();
134
135 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
136 Ok(args) => args,
137 Err(e) => return Task::ready(Err(e)),
138 };
139
140 let manager = ContextServerManager::global(cx);
141 let manager = manager.read(cx);
142 if let Some(server) = manager.get_server(&server_id) {
143 cx.foreground_executor().spawn(async move {
144 let Some(protocol) = server.client.read().clone() else {
145 return Err(anyhow!("Context server not initialized"));
146 };
147 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
148
149 // Check that there are only user roles
150 if result
151 .messages
152 .iter()
153 .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
154 {
155 return Err(anyhow!(
156 "Prompt contains non-user roles, which is not supported"
157 ));
158 }
159
160 // Extract text from user messages into a single prompt string
161 let mut prompt = result
162 .messages
163 .into_iter()
164 .filter_map(|msg| match msg.content {
165 context_servers::types::SamplingContent::Text { text } => Some(text),
166 _ => None,
167 })
168 .collect::<Vec<String>>()
169 .join("\n\n");
170
171 // We must normalize the line endings here, since servers might return CR characters.
172 LineEnding::normalize(&mut prompt);
173
174 Ok(SlashCommandOutput {
175 sections: vec![SlashCommandOutputSection {
176 range: 0..(prompt.len()),
177 icon: IconName::ZedAssistant,
178 label: SharedString::from(
179 result
180 .description
181 .unwrap_or(format!("Result from {}", prompt_name)),
182 ),
183 metadata: None,
184 }],
185 text: prompt,
186 run_commands_in_text: false,
187 })
188 })
189 } else {
190 Task::ready(Err(anyhow!("Context server not found")))
191 }
192 }
193}
194
195fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
196 if arguments.is_empty() {
197 return Err(anyhow!("No arguments given"));
198 }
199
200 match &prompt.arguments {
201 Some(args) if args.len() == 1 => {
202 let arg_name = args[0].name.clone();
203 let arg_value = arguments.join(" ");
204 Ok((arg_name, arg_value))
205 }
206 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
207 None => Err(anyhow!("Prompt has no arguments")),
208 }
209}
210
211fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
212 match &prompt.arguments {
213 Some(args) if args.len() > 1 => Err(anyhow!(
214 "Prompt has more than one argument, which is not supported"
215 )),
216 Some(args) if args.len() == 1 => {
217 if !arguments.is_empty() {
218 let mut map = HashMap::default();
219 map.insert(args[0].name.clone(), arguments.join(" "));
220 Ok(map)
221 } else if arguments.is_empty() && args[0].required == Some(false) {
222 Ok(HashMap::default())
223 } else {
224 Err(anyhow!("Prompt expects argument but none given"))
225 }
226 }
227 Some(_) | None => {
228 if arguments.is_empty() {
229 Ok(HashMap::default())
230 } else {
231 Err(anyhow!("Prompt expects no arguments but some were given"))
232 }
233 }
234 }
235}
236
237/// MCP servers can return prompts with multiple arguments. Since we only
238/// support one argument, we ignore all others. This is the necessary predicate
239/// for this.
240pub fn acceptable_prompt(prompt: &Prompt) -> bool {
241 match &prompt.arguments {
242 None => true,
243 Some(args) if args.len() <= 1 => true,
244 _ => false,
245 }
246}