1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection, SlashCommandResult,
5};
6use collections::HashMap;
7use context_servers::{
8 manager::{ContextServer, ContextServerManager},
9 types::Prompt,
10};
11use gpui::{AppContext, Task, WeakView, WindowContext};
12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
13use std::sync::atomic::AtomicBool;
14use std::sync::Arc;
15use text::LineEnding;
16use ui::{IconName, SharedString};
17use workspace::Workspace;
18
19use crate::slash_command::create_label_for_command;
20
21pub struct ContextServerSlashCommand {
22 server_id: String,
23 prompt: Prompt,
24}
25
26impl ContextServerSlashCommand {
27 pub fn new(server: &Arc<ContextServer>, prompt: Prompt) -> Self {
28 Self {
29 server_id: server.id.clone(),
30 prompt,
31 }
32 }
33}
34
35impl SlashCommand for ContextServerSlashCommand {
36 fn name(&self) -> String {
37 self.prompt.name.clone()
38 }
39
40 fn label(&self, cx: &AppContext) -> language::CodeLabel {
41 let mut parts = vec![self.prompt.name.as_str()];
42 if let Some(args) = &self.prompt.arguments {
43 if let Some(arg) = args.first() {
44 parts.push(arg.name.as_str());
45 }
46 }
47 create_label_for_command(&parts[0], &parts[1..], cx)
48 }
49
50 fn description(&self) -> String {
51 match &self.prompt.description {
52 Some(desc) => desc.clone(),
53 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
54 }
55 }
56
57 fn menu_text(&self) -> String {
58 match &self.prompt.description {
59 Some(desc) => desc.clone(),
60 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
61 }
62 }
63
64 fn requires_argument(&self) -> bool {
65 self.prompt.arguments.as_ref().map_or(false, |args| {
66 args.iter().any(|arg| arg.required == Some(true))
67 })
68 }
69
70 fn complete_argument(
71 self: Arc<Self>,
72 arguments: &[String],
73 _cancel: Arc<AtomicBool>,
74 _workspace: Option<WeakView<Workspace>>,
75 cx: &mut WindowContext,
76 ) -> Task<Result<Vec<ArgumentCompletion>>> {
77 let server_id = self.server_id.clone();
78 let prompt_name = self.prompt.name.clone();
79 let manager = ContextServerManager::global(cx);
80 let manager = manager.read(cx);
81
82 let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
83 Ok(tp) => tp,
84 Err(e) => {
85 return Task::ready(Err(e));
86 }
87 };
88 if let Some(server) = manager.get_server(&server_id) {
89 cx.foreground_executor().spawn(async move {
90 let Some(protocol) = server.client.read().clone() else {
91 return Err(anyhow!("Context server not initialized"));
92 };
93
94 let completion_result = protocol
95 .completion(
96 context_servers::types::CompletionReference::Prompt(
97 context_servers::types::PromptReference {
98 r#type: context_servers::types::PromptReferenceType::Prompt,
99 name: prompt_name,
100 },
101 ),
102 arg_name,
103 arg_val,
104 )
105 .await?;
106
107 let completions = completion_result
108 .values
109 .into_iter()
110 .map(|value| ArgumentCompletion {
111 label: CodeLabel::plain(value.clone(), None),
112 new_text: value,
113 after_completion: AfterCompletion::Continue,
114 replace_previous_arguments: false,
115 })
116 .collect();
117 Ok(completions)
118 })
119 } else {
120 Task::ready(Err(anyhow!("Context server not found")))
121 }
122 }
123
124 fn run(
125 self: Arc<Self>,
126 arguments: &[String],
127 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
128 _context_buffer: BufferSnapshot,
129 _workspace: WeakView<Workspace>,
130 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
131 cx: &mut WindowContext,
132 ) -> Task<SlashCommandResult> {
133 let server_id = self.server_id.clone();
134 let prompt_name = self.prompt.name.clone();
135
136 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
137 Ok(args) => args,
138 Err(e) => return Task::ready(Err(e)),
139 };
140
141 let manager = ContextServerManager::global(cx);
142 let manager = manager.read(cx);
143 if let Some(server) = manager.get_server(&server_id) {
144 cx.foreground_executor().spawn(async move {
145 let Some(protocol) = server.client.read().clone() else {
146 return Err(anyhow!("Context server not initialized"));
147 };
148 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
149
150 // Check that there are only user roles
151 if result
152 .messages
153 .iter()
154 .any(|msg| !matches!(msg.role, context_servers::types::SamplingRole::User))
155 {
156 return Err(anyhow!(
157 "Prompt contains non-user roles, which is not supported"
158 ));
159 }
160
161 // Extract text from user messages into a single prompt string
162 let mut prompt = result
163 .messages
164 .into_iter()
165 .filter_map(|msg| match msg.content {
166 context_servers::types::SamplingContent::Text { text } => Some(text),
167 _ => None,
168 })
169 .collect::<Vec<String>>()
170 .join("\n\n");
171
172 // We must normalize the line endings here, since servers might return CR characters.
173 LineEnding::normalize(&mut prompt);
174
175 Ok(SlashCommandOutput {
176 sections: vec![SlashCommandOutputSection {
177 range: 0..(prompt.len()),
178 icon: IconName::ZedAssistant,
179 label: SharedString::from(
180 result
181 .description
182 .unwrap_or(format!("Result from {}", prompt_name)),
183 ),
184 metadata: None,
185 }],
186 text: prompt,
187 run_commands_in_text: false,
188 }
189 .to_event_stream())
190 })
191 } else {
192 Task::ready(Err(anyhow!("Context server not found")))
193 }
194 }
195}
196
197fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
198 if arguments.is_empty() {
199 return Err(anyhow!("No arguments given"));
200 }
201
202 match &prompt.arguments {
203 Some(args) if args.len() == 1 => {
204 let arg_name = args[0].name.clone();
205 let arg_value = arguments.join(" ");
206 Ok((arg_name, arg_value))
207 }
208 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
209 None => Err(anyhow!("Prompt has no arguments")),
210 }
211}
212
213fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
214 match &prompt.arguments {
215 Some(args) if args.len() > 1 => Err(anyhow!(
216 "Prompt has more than one argument, which is not supported"
217 )),
218 Some(args) if args.len() == 1 => {
219 if !arguments.is_empty() {
220 let mut map = HashMap::default();
221 map.insert(args[0].name.clone(), arguments.join(" "));
222 Ok(map)
223 } else if arguments.is_empty() && args[0].required == Some(false) {
224 Ok(HashMap::default())
225 } else {
226 Err(anyhow!("Prompt expects argument but none given"))
227 }
228 }
229 Some(_) | None => {
230 if arguments.is_empty() {
231 Ok(HashMap::default())
232 } else {
233 Err(anyhow!("Prompt expects no arguments but some were given"))
234 }
235 }
236 }
237}
238
239/// MCP servers can return prompts with multiple arguments. Since we only
240/// support one argument, we ignore all others. This is the necessary predicate
241/// for this.
242pub fn acceptable_prompt(prompt: &Prompt) -> bool {
243 match &prompt.arguments {
244 None => true,
245 Some(args) if args.len() <= 1 => true,
246 _ => false,
247 }
248}