1use anyhow::{Result, anyhow};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection, SlashCommandResult,
5};
6use collections::HashMap;
7use context_server::{ContextServerId, types::Prompt};
8use gpui::{App, Entity, Task, WeakEntity, Window};
9use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
10use project::context_server_store::ContextServerStore;
11use std::sync::Arc;
12use std::sync::atomic::AtomicBool;
13use text::LineEnding;
14use ui::{IconName, SharedString};
15use workspace::Workspace;
16
17use crate::create_label_for_command;
18
19pub struct ContextServerSlashCommand {
20 store: Entity<ContextServerStore>,
21 server_id: ContextServerId,
22 prompt: Prompt,
23}
24
25impl ContextServerSlashCommand {
26 pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
27 Self {
28 server_id: id,
29 prompt,
30 store,
31 }
32 }
33}
34
35impl SlashCommand for ContextServerSlashCommand {
36 fn name(&self) -> String {
37 self.prompt.name.clone()
38 }
39
40 fn label(&self, cx: &App) -> language::CodeLabel {
41 let mut parts = vec![self.prompt.name.as_str()];
42 if let Some(args) = &self.prompt.arguments {
43 if let Some(arg) = args.first() {
44 parts.push(arg.name.as_str());
45 }
46 }
47 create_label_for_command(&parts[0], &parts[1..], cx)
48 }
49
50 fn description(&self) -> String {
51 match &self.prompt.description {
52 Some(desc) => desc.clone(),
53 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
54 }
55 }
56
57 fn menu_text(&self) -> String {
58 match &self.prompt.description {
59 Some(desc) => desc.clone(),
60 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
61 }
62 }
63
64 fn requires_argument(&self) -> bool {
65 self.prompt.arguments.as_ref().map_or(false, |args| {
66 args.iter().any(|arg| arg.required == Some(true))
67 })
68 }
69
70 fn complete_argument(
71 self: Arc<Self>,
72 arguments: &[String],
73 _cancel: Arc<AtomicBool>,
74 _workspace: Option<WeakEntity<Workspace>>,
75 _window: &mut Window,
76 cx: &mut App,
77 ) -> Task<Result<Vec<ArgumentCompletion>>> {
78 let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
79 return Task::ready(Err(anyhow!("Failed to complete argument")));
80 };
81
82 let server_id = self.server_id.clone();
83 let prompt_name = self.prompt.name.clone();
84
85 if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
86 cx.foreground_executor().spawn(async move {
87 let Some(protocol) = server.client() else {
88 return Err(anyhow!("Context server not initialized"));
89 };
90
91 let completion_result = protocol
92 .completion(
93 context_server::types::CompletionReference::Prompt(
94 context_server::types::PromptReference {
95 r#type: context_server::types::PromptReferenceType::Prompt,
96 name: prompt_name,
97 },
98 ),
99 arg_name,
100 arg_value,
101 )
102 .await?;
103
104 let completions = completion_result
105 .values
106 .into_iter()
107 .map(|value| ArgumentCompletion {
108 label: CodeLabel::plain(value.clone(), None),
109 new_text: value,
110 after_completion: AfterCompletion::Continue,
111 replace_previous_arguments: false,
112 })
113 .collect();
114 Ok(completions)
115 })
116 } else {
117 Task::ready(Err(anyhow!("Context server not found")))
118 }
119 }
120
121 fn run(
122 self: Arc<Self>,
123 arguments: &[String],
124 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
125 _context_buffer: BufferSnapshot,
126 _workspace: WeakEntity<Workspace>,
127 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
128 _window: &mut Window,
129 cx: &mut App,
130 ) -> Task<SlashCommandResult> {
131 let server_id = self.server_id.clone();
132 let prompt_name = self.prompt.name.clone();
133
134 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
135 Ok(args) => args,
136 Err(e) => return Task::ready(Err(e)),
137 };
138
139 let store = self.store.read(cx);
140 if let Some(server) = store.get_running_server(&server_id) {
141 cx.foreground_executor().spawn(async move {
142 let Some(protocol) = server.client() else {
143 return Err(anyhow!("Context server not initialized"));
144 };
145 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
146
147 // Check that there are only user roles
148 if result
149 .messages
150 .iter()
151 .any(|msg| !matches!(msg.role, context_server::types::Role::User))
152 {
153 return Err(anyhow!(
154 "Prompt contains non-user roles, which is not supported"
155 ));
156 }
157
158 // Extract text from user messages into a single prompt string
159 let mut prompt = result
160 .messages
161 .into_iter()
162 .filter_map(|msg| match msg.content {
163 context_server::types::MessageContent::Text { text, .. } => Some(text),
164 _ => None,
165 })
166 .collect::<Vec<String>>()
167 .join("\n\n");
168
169 // We must normalize the line endings here, since servers might return CR characters.
170 LineEnding::normalize(&mut prompt);
171
172 Ok(SlashCommandOutput {
173 sections: vec![SlashCommandOutputSection {
174 range: 0..(prompt.len()),
175 icon: IconName::ZedAssistant,
176 label: SharedString::from(
177 result
178 .description
179 .unwrap_or(format!("Result from {}", prompt_name)),
180 ),
181 metadata: None,
182 }],
183 text: prompt,
184 run_commands_in_text: false,
185 }
186 .to_event_stream())
187 })
188 } else {
189 Task::ready(Err(anyhow!("Context server not found")))
190 }
191 }
192}
193
194fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
195 if arguments.is_empty() {
196 return Err(anyhow!("No arguments given"));
197 }
198
199 match &prompt.arguments {
200 Some(args) if args.len() == 1 => {
201 let arg_name = args[0].name.clone();
202 let arg_value = arguments.join(" ");
203 Ok((arg_name, arg_value))
204 }
205 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
206 None => Err(anyhow!("Prompt has no arguments")),
207 }
208}
209
210fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
211 match &prompt.arguments {
212 Some(args) if args.len() > 1 => Err(anyhow!(
213 "Prompt has more than one argument, which is not supported"
214 )),
215 Some(args) if args.len() == 1 => {
216 if !arguments.is_empty() {
217 let mut map = HashMap::default();
218 map.insert(args[0].name.clone(), arguments.join(" "));
219 Ok(map)
220 } else if arguments.is_empty() && args[0].required == Some(false) {
221 Ok(HashMap::default())
222 } else {
223 Err(anyhow!("Prompt expects argument but none given"))
224 }
225 }
226 Some(_) | None => {
227 if arguments.is_empty() {
228 Ok(HashMap::default())
229 } else {
230 Err(anyhow!("Prompt expects no arguments but some were given"))
231 }
232 }
233 }
234}
235
236/// MCP servers can return prompts with multiple arguments. Since we only
237/// support one argument, we ignore all others. This is the necessary predicate
238/// for this.
239pub fn acceptable_prompt(prompt: &Prompt) -> bool {
240 match &prompt.arguments {
241 None => true,
242 Some(args) if args.len() <= 1 => true,
243 _ => false,
244 }
245}