1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection,
5};
6use collections::HashMap;
7use context_servers::{
8 manager::{ContextServer, ContextServerManager},
9 protocol::PromptInfo,
10};
11use gpui::{Task, WeakView, WindowContext};
12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
13use std::sync::atomic::AtomicBool;
14use std::sync::Arc;
15use text::LineEnding;
16use ui::{IconName, SharedString};
17use workspace::Workspace;
18
19pub struct ContextServerSlashCommand {
20 server_id: String,
21 prompt: PromptInfo,
22}
23
24impl ContextServerSlashCommand {
25 pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
26 Self {
27 server_id: server.id.clone(),
28 prompt,
29 }
30 }
31}
32
33impl SlashCommand for ContextServerSlashCommand {
34 fn name(&self) -> String {
35 self.prompt.name.clone()
36 }
37
38 fn description(&self) -> String {
39 format!("Run context server command: {}", self.prompt.name)
40 }
41
42 fn menu_text(&self) -> String {
43 format!("Run '{}' from {}", self.prompt.name, self.server_id)
44 }
45
46 fn requires_argument(&self) -> bool {
47 self.prompt.arguments.as_ref().map_or(false, |args| {
48 args.iter().any(|arg| arg.required == Some(true))
49 })
50 }
51
52 fn complete_argument(
53 self: Arc<Self>,
54 arguments: &[String],
55 _cancel: Arc<AtomicBool>,
56 _workspace: Option<WeakView<Workspace>>,
57 cx: &mut WindowContext,
58 ) -> Task<Result<Vec<ArgumentCompletion>>> {
59 let server_id = self.server_id.clone();
60 let prompt_name = self.prompt.name.clone();
61 let manager = ContextServerManager::global(cx);
62 let manager = manager.read(cx);
63
64 let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
65 Ok(tp) => tp,
66 Err(e) => {
67 return Task::ready(Err(e));
68 }
69 };
70 if let Some(server) = manager.get_server(&server_id) {
71 cx.foreground_executor().spawn(async move {
72 let Some(protocol) = server.client.read().clone() else {
73 return Err(anyhow!("Context server not initialized"));
74 };
75
76 let completion_result = protocol
77 .completion(
78 context_servers::types::CompletionReference::Prompt(
79 context_servers::types::PromptReference {
80 r#type: context_servers::types::PromptReferenceType::Prompt,
81 name: prompt_name,
82 },
83 ),
84 arg_name,
85 arg_val,
86 )
87 .await?;
88
89 let completions = completion_result
90 .values
91 .into_iter()
92 .map(|value| ArgumentCompletion {
93 label: CodeLabel::plain(value.clone(), None),
94 new_text: value,
95 after_completion: AfterCompletion::Continue,
96 replace_previous_arguments: false,
97 })
98 .collect();
99 Ok(completions)
100 })
101 } else {
102 Task::ready(Err(anyhow!("Context server not found")))
103 }
104 }
105
106 fn run(
107 self: Arc<Self>,
108 arguments: &[String],
109 _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
110 _context_buffer: BufferSnapshot,
111 _workspace: WeakView<Workspace>,
112 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
113 cx: &mut WindowContext,
114 ) -> Task<Result<SlashCommandOutput>> {
115 let server_id = self.server_id.clone();
116 let prompt_name = self.prompt.name.clone();
117
118 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
119 Ok(args) => args,
120 Err(e) => return Task::ready(Err(e)),
121 };
122
123 let manager = ContextServerManager::global(cx);
124 let manager = manager.read(cx);
125 if let Some(server) = manager.get_server(&server_id) {
126 cx.foreground_executor().spawn(async move {
127 let Some(protocol) = server.client.read().clone() else {
128 return Err(anyhow!("Context server not initialized"));
129 };
130 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
131 let mut prompt = result.prompt;
132
133 // We must normalize the line endings here, since servers might return CR characters.
134 LineEnding::normalize(&mut prompt);
135
136 Ok(SlashCommandOutput {
137 sections: vec![SlashCommandOutputSection {
138 range: 0..(prompt.len()),
139 icon: IconName::ZedAssistant,
140 label: SharedString::from(
141 result
142 .description
143 .unwrap_or(format!("Result from {}", prompt_name)),
144 ),
145 metadata: None,
146 }],
147 text: prompt,
148 run_commands_in_text: false,
149 })
150 })
151 } else {
152 Task::ready(Err(anyhow!("Context server not found")))
153 }
154 }
155}
156
157fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
158 if arguments.is_empty() {
159 return Err(anyhow!("No arguments given"));
160 }
161
162 match &prompt.arguments {
163 Some(args) if args.len() == 1 => {
164 let arg_name = args[0].name.clone();
165 let arg_value = arguments.join(" ");
166 Ok((arg_name, arg_value))
167 }
168 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
169 None => Err(anyhow!("Prompt has no arguments")),
170 }
171}
172
173fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
174 match &prompt.arguments {
175 Some(args) if args.len() > 1 => Err(anyhow!(
176 "Prompt has more than one argument, which is not supported"
177 )),
178 Some(args) if args.len() == 1 => {
179 if !arguments.is_empty() {
180 let mut map = HashMap::default();
181 map.insert(args[0].name.clone(), arguments.join(" "));
182 Ok(map)
183 } else if arguments.is_empty() && args[0].required == Some(false) {
184 Ok(HashMap::default())
185 } else {
186 Err(anyhow!("Prompt expects argument but none given"))
187 }
188 }
189 Some(_) | None => {
190 if arguments.is_empty() {
191 Ok(HashMap::default())
192 } else {
193 Err(anyhow!("Prompt expects no arguments but some were given"))
194 }
195 }
196 }
197}
198
199/// MCP servers can return prompts with multiple arguments. Since we only
200/// support one argument, we ignore all others. This is the necessary predicate
201/// for this.
202pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
203 match &prompt.arguments {
204 None => true,
205 Some(args) if args.len() <= 1 => true,
206 _ => false,
207 }
208}