1use anyhow::{Context as _, Result, anyhow};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection, SlashCommandResult,
5};
6use collections::HashMap;
7use context_server::{ContextServerId, types::Prompt};
8use gpui::{App, Entity, Task, WeakEntity, Window};
9use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
10use project::context_server_store::ContextServerStore;
11use std::sync::Arc;
12use std::sync::atomic::AtomicBool;
13use text::LineEnding;
14use ui::{IconName, SharedString};
15use workspace::Workspace;
16
17use crate::create_label_for_command;
18
19pub struct ContextServerSlashCommand {
20 store: Entity<ContextServerStore>,
21 server_id: ContextServerId,
22 prompt: Prompt,
23}
24
25impl ContextServerSlashCommand {
26 pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
27 Self {
28 server_id: id,
29 prompt,
30 store,
31 }
32 }
33}
34
35impl SlashCommand for ContextServerSlashCommand {
36 fn name(&self) -> String {
37 self.prompt.name.clone()
38 }
39
40 fn label(&self, cx: &App) -> language::CodeLabel {
41 let mut parts = vec![self.prompt.name.as_str()];
42 if let Some(args) = &self.prompt.arguments
43 && let Some(arg) = args.first() {
44 parts.push(arg.name.as_str());
45 }
46 create_label_for_command(parts[0], &parts[1..], cx)
47 }
48
49 fn description(&self) -> String {
50 match &self.prompt.description {
51 Some(desc) => desc.clone(),
52 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
53 }
54 }
55
56 fn menu_text(&self) -> String {
57 match &self.prompt.description {
58 Some(desc) => desc.clone(),
59 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
60 }
61 }
62
63 fn requires_argument(&self) -> bool {
64 self.prompt.arguments.as_ref().map_or(false, |args| {
65 args.iter().any(|arg| arg.required == Some(true))
66 })
67 }
68
69 fn complete_argument(
70 self: Arc<Self>,
71 arguments: &[String],
72 _cancel: Arc<AtomicBool>,
73 _workspace: Option<WeakEntity<Workspace>>,
74 _window: &mut Window,
75 cx: &mut App,
76 ) -> Task<Result<Vec<ArgumentCompletion>>> {
77 let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
78 return Task::ready(Err(anyhow!("Failed to complete argument")));
79 };
80
81 let server_id = self.server_id.clone();
82 let prompt_name = self.prompt.name.clone();
83
84 if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
85 cx.foreground_executor().spawn(async move {
86 let protocol = server.client().context("Context server not initialized")?;
87
88 let response = protocol
89 .request::<context_server::types::requests::CompletionComplete>(
90 context_server::types::CompletionCompleteParams {
91 reference: context_server::types::CompletionReference::Prompt(
92 context_server::types::PromptReference {
93 ty: context_server::types::PromptReferenceType::Prompt,
94 name: prompt_name,
95 },
96 ),
97 argument: context_server::types::CompletionArgument {
98 name: arg_name,
99 value: arg_value,
100 },
101 meta: None,
102 },
103 )
104 .await?;
105
106 let completions = response
107 .completion
108 .values
109 .into_iter()
110 .map(|value| ArgumentCompletion {
111 label: CodeLabel::plain(value.clone(), None),
112 new_text: value,
113 after_completion: AfterCompletion::Continue,
114 replace_previous_arguments: false,
115 })
116 .collect();
117 Ok(completions)
118 })
119 } else {
120 Task::ready(Err(anyhow!("Context server not found")))
121 }
122 }
123
124 fn run(
125 self: Arc<Self>,
126 arguments: &[String],
127 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
128 _context_buffer: BufferSnapshot,
129 _workspace: WeakEntity<Workspace>,
130 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
131 _window: &mut Window,
132 cx: &mut App,
133 ) -> Task<SlashCommandResult> {
134 let server_id = self.server_id.clone();
135 let prompt_name = self.prompt.name.clone();
136
137 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
138 Ok(args) => args,
139 Err(e) => return Task::ready(Err(e)),
140 };
141
142 let store = self.store.read(cx);
143 if let Some(server) = store.get_running_server(&server_id) {
144 cx.foreground_executor().spawn(async move {
145 let protocol = server.client().context("Context server not initialized")?;
146 let response = protocol
147 .request::<context_server::types::requests::PromptsGet>(
148 context_server::types::PromptsGetParams {
149 name: prompt_name.clone(),
150 arguments: Some(prompt_args),
151 meta: None,
152 },
153 )
154 .await?;
155
156 anyhow::ensure!(
157 response
158 .messages
159 .iter()
160 .all(|msg| matches!(msg.role, context_server::types::Role::User)),
161 "Prompt contains non-user roles, which is not supported"
162 );
163
164 // Extract text from user messages into a single prompt string
165 let mut prompt = response
166 .messages
167 .into_iter()
168 .filter_map(|msg| match msg.content {
169 context_server::types::MessageContent::Text { text, .. } => Some(text),
170 _ => None,
171 })
172 .collect::<Vec<String>>()
173 .join("\n\n");
174
175 // We must normalize the line endings here, since servers might return CR characters.
176 LineEnding::normalize(&mut prompt);
177
178 Ok(SlashCommandOutput {
179 sections: vec![SlashCommandOutputSection {
180 range: 0..(prompt.len()),
181 icon: IconName::ZedAssistant,
182 label: SharedString::from(
183 response
184 .description
185 .unwrap_or(format!("Result from {}", prompt_name)),
186 ),
187 metadata: None,
188 }],
189 text: prompt,
190 run_commands_in_text: false,
191 }
192 .to_event_stream())
193 })
194 } else {
195 Task::ready(Err(anyhow!("Context server not found")))
196 }
197 }
198}
199
200fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
201 anyhow::ensure!(!arguments.is_empty(), "No arguments given");
202
203 match &prompt.arguments {
204 Some(args) if args.len() == 1 => {
205 let arg_name = args[0].name.clone();
206 let arg_value = arguments.join(" ");
207 Ok((arg_name, arg_value))
208 }
209 Some(_) => anyhow::bail!("Prompt must have exactly one argument"),
210 None => anyhow::bail!("Prompt has no arguments"),
211 }
212}
213
214fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
215 match &prompt.arguments {
216 Some(args) if args.len() > 1 => {
217 anyhow::bail!("Prompt has more than one argument, which is not supported");
218 }
219 Some(args) if args.len() == 1 => {
220 if !arguments.is_empty() {
221 let mut map = HashMap::default();
222 map.insert(args[0].name.clone(), arguments.join(" "));
223 Ok(map)
224 } else if arguments.is_empty() && args[0].required == Some(false) {
225 Ok(HashMap::default())
226 } else {
227 anyhow::bail!("Prompt expects argument but none given");
228 }
229 }
230 Some(_) | None => {
231 anyhow::ensure!(
232 arguments.is_empty(),
233 "Prompt expects no arguments but some were given"
234 );
235 Ok(HashMap::default())
236 }
237 }
238}
239
240/// MCP servers can return prompts with multiple arguments. Since we only
241/// support one argument, we ignore all others. This is the necessary predicate
242/// for this.
243pub fn acceptable_prompt(prompt: &Prompt) -> bool {
244 match &prompt.arguments {
245 None => true,
246 Some(args) if args.len() <= 1 => true,
247 _ => false,
248 }
249}