1use anyhow::{Context as _, Result, anyhow};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection, SlashCommandResult,
5};
6use collections::HashMap;
7use context_server::{ContextServerId, types::Prompt};
8use gpui::{App, Entity, Task, WeakEntity, Window};
9use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
10use project::context_server_store::ContextServerStore;
11use std::sync::Arc;
12use std::sync::atomic::AtomicBool;
13use text::LineEnding;
14use ui::{IconName, SharedString};
15use workspace::Workspace;
16
17use crate::create_label_for_command;
18
19pub struct ContextServerSlashCommand {
20 store: Entity<ContextServerStore>,
21 server_id: ContextServerId,
22 prompt: Prompt,
23}
24
25impl ContextServerSlashCommand {
26 pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
27 Self {
28 server_id: id,
29 prompt,
30 store,
31 }
32 }
33}
34
35impl SlashCommand for ContextServerSlashCommand {
36 fn name(&self) -> String {
37 self.prompt.name.clone()
38 }
39
40 fn label(&self, cx: &App) -> language::CodeLabel {
41 let mut parts = vec![self.prompt.name.as_str()];
42 if let Some(args) = &self.prompt.arguments
43 && let Some(arg) = args.first()
44 {
45 parts.push(arg.name.as_str());
46 }
47 create_label_for_command(parts[0], &parts[1..], cx)
48 }
49
50 fn description(&self) -> String {
51 match &self.prompt.description {
52 Some(desc) => desc.clone(),
53 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
54 }
55 }
56
57 fn menu_text(&self) -> String {
58 match &self.prompt.description {
59 Some(desc) => desc.clone(),
60 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
61 }
62 }
63
64 fn requires_argument(&self) -> bool {
65 self.prompt
66 .arguments
67 .as_ref()
68 .is_some_and(|args| args.iter().any(|arg| arg.required == Some(true)))
69 }
70
71 fn complete_argument(
72 self: Arc<Self>,
73 arguments: &[String],
74 _cancel: Arc<AtomicBool>,
75 _workspace: Option<WeakEntity<Workspace>>,
76 _window: &mut Window,
77 cx: &mut App,
78 ) -> Task<Result<Vec<ArgumentCompletion>>> {
79 let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
80 return Task::ready(Err(anyhow!("Failed to complete argument")));
81 };
82
83 let server_id = self.server_id.clone();
84 let prompt_name = self.prompt.name.clone();
85
86 if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
87 cx.foreground_executor().spawn(async move {
88 let protocol = server.client().context("Context server not initialized")?;
89
90 let response = protocol
91 .request::<context_server::types::requests::CompletionComplete>(
92 context_server::types::CompletionCompleteParams {
93 reference: context_server::types::CompletionReference::Prompt(
94 context_server::types::PromptReference {
95 ty: context_server::types::PromptReferenceType::Prompt,
96 name: prompt_name,
97 },
98 ),
99 argument: context_server::types::CompletionArgument {
100 name: arg_name,
101 value: arg_value,
102 },
103 meta: None,
104 },
105 )
106 .await?;
107
108 let completions = response
109 .completion
110 .values
111 .into_iter()
112 .map(|value| ArgumentCompletion {
113 label: CodeLabel::plain(value.clone(), None),
114 new_text: value,
115 after_completion: AfterCompletion::Continue,
116 replace_previous_arguments: false,
117 })
118 .collect();
119 Ok(completions)
120 })
121 } else {
122 Task::ready(Err(anyhow!("Context server not found")))
123 }
124 }
125
126 fn run(
127 self: Arc<Self>,
128 arguments: &[String],
129 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
130 _context_buffer: BufferSnapshot,
131 _workspace: WeakEntity<Workspace>,
132 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
133 _window: &mut Window,
134 cx: &mut App,
135 ) -> Task<SlashCommandResult> {
136 let server_id = self.server_id.clone();
137 let prompt_name = self.prompt.name.clone();
138
139 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
140 Ok(args) => args,
141 Err(e) => return Task::ready(Err(e)),
142 };
143
144 let store = self.store.read(cx);
145 if let Some(server) = store.get_running_server(&server_id) {
146 cx.foreground_executor().spawn(async move {
147 let protocol = server.client().context("Context server not initialized")?;
148 let response = protocol
149 .request::<context_server::types::requests::PromptsGet>(
150 context_server::types::PromptsGetParams {
151 name: prompt_name.clone(),
152 arguments: Some(prompt_args),
153 meta: None,
154 },
155 )
156 .await?;
157
158 anyhow::ensure!(
159 response
160 .messages
161 .iter()
162 .all(|msg| matches!(msg.role, context_server::types::Role::User)),
163 "Prompt contains non-user roles, which is not supported"
164 );
165
166 // Extract text from user messages into a single prompt string
167 let mut prompt = response
168 .messages
169 .into_iter()
170 .filter_map(|msg| match msg.content {
171 context_server::types::MessageContent::Text { text, .. } => Some(text),
172 _ => None,
173 })
174 .collect::<Vec<String>>()
175 .join("\n\n");
176
177 // We must normalize the line endings here, since servers might return CR characters.
178 LineEnding::normalize(&mut prompt);
179
180 Ok(SlashCommandOutput {
181 sections: vec![SlashCommandOutputSection {
182 range: 0..(prompt.len()),
183 icon: IconName::ZedAssistant,
184 label: SharedString::from(
185 response
186 .description
187 .unwrap_or(format!("Result from {}", prompt_name)),
188 ),
189 metadata: None,
190 }],
191 text: prompt,
192 run_commands_in_text: false,
193 }
194 .into_event_stream())
195 })
196 } else {
197 Task::ready(Err(anyhow!("Context server not found")))
198 }
199 }
200}
201
202fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
203 anyhow::ensure!(!arguments.is_empty(), "No arguments given");
204
205 match &prompt.arguments {
206 Some(args) if args.len() == 1 => {
207 let arg_name = args[0].name.clone();
208 let arg_value = arguments.join(" ");
209 Ok((arg_name, arg_value))
210 }
211 Some(_) => anyhow::bail!("Prompt must have exactly one argument"),
212 None => anyhow::bail!("Prompt has no arguments"),
213 }
214}
215
216fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
217 match &prompt.arguments {
218 Some(args) if args.len() > 1 => {
219 anyhow::bail!("Prompt has more than one argument, which is not supported");
220 }
221 Some(args) if args.len() == 1 => {
222 if !arguments.is_empty() {
223 let mut map = HashMap::default();
224 map.insert(args[0].name.clone(), arguments.join(" "));
225 Ok(map)
226 } else if arguments.is_empty() && args[0].required == Some(false) {
227 Ok(HashMap::default())
228 } else {
229 anyhow::bail!("Prompt expects argument but none given");
230 }
231 }
232 Some(_) | None => {
233 anyhow::ensure!(
234 arguments.is_empty(),
235 "Prompt expects no arguments but some were given"
236 );
237 Ok(HashMap::default())
238 }
239 }
240}
241
242/// MCP servers can return prompts with multiple arguments. Since we only
243/// support one argument, we ignore all others. This is the necessary predicate
244/// for this.
245pub fn acceptable_prompt(prompt: &Prompt) -> bool {
246 match &prompt.arguments {
247 None => true,
248 Some(args) if args.len() <= 1 => true,
249 _ => false,
250 }
251}