1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection, SlashCommandResult,
5};
6use collections::HashMap;
7use context_server::{
8 manager::{ContextServer, ContextServerManager},
9 types::Prompt,
10};
11use gpui::{AppContext, Model, Task, WeakView, WindowContext};
12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
13use std::sync::atomic::AtomicBool;
14use std::sync::Arc;
15use text::LineEnding;
16use ui::{IconName, SharedString};
17use workspace::Workspace;
18
19use crate::slash_command::create_label_for_command;
20
21pub struct ContextServerSlashCommand {
22 server_manager: Model<ContextServerManager>,
23 server_id: Arc<str>,
24 prompt: Prompt,
25}
26
27impl ContextServerSlashCommand {
28 pub fn new(
29 server_manager: Model<ContextServerManager>,
30 server: &Arc<ContextServer>,
31 prompt: Prompt,
32 ) -> Self {
33 Self {
34 server_id: server.id(),
35 prompt,
36 server_manager,
37 }
38 }
39}
40
41impl SlashCommand for ContextServerSlashCommand {
42 fn name(&self) -> String {
43 self.prompt.name.clone()
44 }
45
46 fn label(&self, cx: &AppContext) -> language::CodeLabel {
47 let mut parts = vec![self.prompt.name.as_str()];
48 if let Some(args) = &self.prompt.arguments {
49 if let Some(arg) = args.first() {
50 parts.push(arg.name.as_str());
51 }
52 }
53 create_label_for_command(&parts[0], &parts[1..], cx)
54 }
55
56 fn description(&self) -> String {
57 match &self.prompt.description {
58 Some(desc) => desc.clone(),
59 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
60 }
61 }
62
63 fn menu_text(&self) -> String {
64 match &self.prompt.description {
65 Some(desc) => desc.clone(),
66 None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
67 }
68 }
69
70 fn requires_argument(&self) -> bool {
71 self.prompt.arguments.as_ref().map_or(false, |args| {
72 args.iter().any(|arg| arg.required == Some(true))
73 })
74 }
75
76 fn complete_argument(
77 self: Arc<Self>,
78 arguments: &[String],
79 _cancel: Arc<AtomicBool>,
80 _workspace: Option<WeakView<Workspace>>,
81 cx: &mut WindowContext,
82 ) -> Task<Result<Vec<ArgumentCompletion>>> {
83 let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
84 return Task::ready(Err(anyhow!("Failed to complete argument")));
85 };
86
87 let server_id = self.server_id.clone();
88 let prompt_name = self.prompt.name.clone();
89
90 if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
91 cx.foreground_executor().spawn(async move {
92 let Some(protocol) = server.client() else {
93 return Err(anyhow!("Context server not initialized"));
94 };
95
96 let completion_result = protocol
97 .completion(
98 context_server::types::CompletionReference::Prompt(
99 context_server::types::PromptReference {
100 r#type: context_server::types::PromptReferenceType::Prompt,
101 name: prompt_name,
102 },
103 ),
104 arg_name,
105 arg_value,
106 )
107 .await?;
108
109 let completions = completion_result
110 .values
111 .into_iter()
112 .map(|value| ArgumentCompletion {
113 label: CodeLabel::plain(value.clone(), None),
114 new_text: value,
115 after_completion: AfterCompletion::Continue,
116 replace_previous_arguments: false,
117 })
118 .collect();
119 Ok(completions)
120 })
121 } else {
122 Task::ready(Err(anyhow!("Context server not found")))
123 }
124 }
125
126 fn run(
127 self: Arc<Self>,
128 arguments: &[String],
129 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
130 _context_buffer: BufferSnapshot,
131 _workspace: WeakView<Workspace>,
132 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
133 cx: &mut WindowContext,
134 ) -> Task<SlashCommandResult> {
135 let server_id = self.server_id.clone();
136 let prompt_name = self.prompt.name.clone();
137
138 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
139 Ok(args) => args,
140 Err(e) => return Task::ready(Err(e)),
141 };
142
143 let manager = self.server_manager.read(cx);
144 if let Some(server) = manager.get_server(&server_id) {
145 cx.foreground_executor().spawn(async move {
146 let Some(protocol) = server.client() else {
147 return Err(anyhow!("Context server not initialized"));
148 };
149 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
150
151 // Check that there are only user roles
152 if result
153 .messages
154 .iter()
155 .any(|msg| !matches!(msg.role, context_server::types::Role::User))
156 {
157 return Err(anyhow!(
158 "Prompt contains non-user roles, which is not supported"
159 ));
160 }
161
162 // Extract text from user messages into a single prompt string
163 let mut prompt = result
164 .messages
165 .into_iter()
166 .filter_map(|msg| match msg.content {
167 context_server::types::MessageContent::Text { text, .. } => Some(text),
168 _ => None,
169 })
170 .collect::<Vec<String>>()
171 .join("\n\n");
172
173 // We must normalize the line endings here, since servers might return CR characters.
174 LineEnding::normalize(&mut prompt);
175
176 Ok(SlashCommandOutput {
177 sections: vec![SlashCommandOutputSection {
178 range: 0..(prompt.len()),
179 icon: IconName::ZedAssistant,
180 label: SharedString::from(
181 result
182 .description
183 .unwrap_or(format!("Result from {}", prompt_name)),
184 ),
185 metadata: None,
186 }],
187 text: prompt,
188 run_commands_in_text: false,
189 }
190 .to_event_stream())
191 })
192 } else {
193 Task::ready(Err(anyhow!("Context server not found")))
194 }
195 }
196}
197
198fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
199 if arguments.is_empty() {
200 return Err(anyhow!("No arguments given"));
201 }
202
203 match &prompt.arguments {
204 Some(args) if args.len() == 1 => {
205 let arg_name = args[0].name.clone();
206 let arg_value = arguments.join(" ");
207 Ok((arg_name, arg_value))
208 }
209 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
210 None => Err(anyhow!("Prompt has no arguments")),
211 }
212}
213
214fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
215 match &prompt.arguments {
216 Some(args) if args.len() > 1 => Err(anyhow!(
217 "Prompt has more than one argument, which is not supported"
218 )),
219 Some(args) if args.len() == 1 => {
220 if !arguments.is_empty() {
221 let mut map = HashMap::default();
222 map.insert(args[0].name.clone(), arguments.join(" "));
223 Ok(map)
224 } else if arguments.is_empty() && args[0].required == Some(false) {
225 Ok(HashMap::default())
226 } else {
227 Err(anyhow!("Prompt expects argument but none given"))
228 }
229 }
230 Some(_) | None => {
231 if arguments.is_empty() {
232 Ok(HashMap::default())
233 } else {
234 Err(anyhow!("Prompt expects no arguments but some were given"))
235 }
236 }
237 }
238}
239
240/// MCP servers can return prompts with multiple arguments. Since we only
241/// support one argument, we ignore all others. This is the necessary predicate
242/// for this.
243pub fn acceptable_prompt(prompt: &Prompt) -> bool {
244 match &prompt.arguments {
245 None => true,
246 Some(args) if args.len() <= 1 => true,
247 _ => false,
248 }
249}