1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection,
5};
6use collections::HashMap;
7use context_servers::{
8 manager::{ContextServer, ContextServerManager},
9 protocol::PromptInfo,
10};
11use gpui::{Task, WeakView, WindowContext};
12use language::{CodeLabel, LspAdapterDelegate};
13use std::sync::atomic::AtomicBool;
14use std::sync::Arc;
15use text::LineEnding;
16use ui::{IconName, SharedString};
17use workspace::Workspace;
18
19pub struct ContextServerSlashCommand {
20 server_id: String,
21 prompt: PromptInfo,
22}
23
24impl ContextServerSlashCommand {
25 pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
26 Self {
27 server_id: server.id.clone(),
28 prompt,
29 }
30 }
31}
32
33impl SlashCommand for ContextServerSlashCommand {
34 fn name(&self) -> String {
35 self.prompt.name.clone()
36 }
37
38 fn description(&self) -> String {
39 format!("Run context server command: {}", self.prompt.name)
40 }
41
42 fn menu_text(&self) -> String {
43 format!("Run '{}' from {}", self.prompt.name, self.server_id)
44 }
45
46 fn requires_argument(&self) -> bool {
47 self.prompt.arguments.as_ref().map_or(false, |args| {
48 args.iter().any(|arg| arg.required == Some(true))
49 })
50 }
51
52 fn complete_argument(
53 self: Arc<Self>,
54 arguments: &[String],
55 _cancel: Arc<AtomicBool>,
56 _workspace: Option<WeakView<Workspace>>,
57 cx: &mut WindowContext,
58 ) -> Task<Result<Vec<ArgumentCompletion>>> {
59 let server_id = self.server_id.clone();
60 let prompt_name = self.prompt.name.clone();
61 let manager = ContextServerManager::global(cx);
62 let manager = manager.read(cx);
63
64 let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
65 Ok(tp) => tp,
66 Err(e) => {
67 return Task::ready(Err(e));
68 }
69 };
70 if let Some(server) = manager.get_server(&server_id) {
71 cx.foreground_executor().spawn(async move {
72 let Some(protocol) = server.client.read().clone() else {
73 return Err(anyhow!("Context server not initialized"));
74 };
75
76 let completion_result = protocol
77 .completion(
78 context_servers::types::CompletionReference::Prompt(
79 context_servers::types::PromptReference {
80 r#type: context_servers::types::PromptReferenceType::Prompt,
81 name: prompt_name,
82 },
83 ),
84 arg_name,
85 arg_val,
86 )
87 .await?;
88
89 let completions = completion_result
90 .values
91 .into_iter()
92 .map(|value| ArgumentCompletion {
93 label: CodeLabel::plain(value.clone(), None),
94 new_text: value,
95 after_completion: AfterCompletion::Continue,
96 replace_previous_arguments: false,
97 })
98 .collect();
99
100 Ok(completions)
101 })
102 } else {
103 Task::ready(Err(anyhow!("Context server not found")))
104 }
105 }
106
107 fn run(
108 self: Arc<Self>,
109 arguments: &[String],
110 _workspace: WeakView<Workspace>,
111 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
112 cx: &mut WindowContext,
113 ) -> Task<Result<SlashCommandOutput>> {
114 let server_id = self.server_id.clone();
115 let prompt_name = self.prompt.name.clone();
116
117 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
118 Ok(args) => args,
119 Err(e) => return Task::ready(Err(e)),
120 };
121
122 let manager = ContextServerManager::global(cx);
123 let manager = manager.read(cx);
124 if let Some(server) = manager.get_server(&server_id) {
125 cx.foreground_executor().spawn(async move {
126 let Some(protocol) = server.client.read().clone() else {
127 return Err(anyhow!("Context server not initialized"));
128 };
129 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
130 let mut prompt = result.prompt;
131
132 // We must normalize the line endings here, since servers might return CR characters.
133 LineEnding::normalize(&mut prompt);
134
135 Ok(SlashCommandOutput {
136 sections: vec![SlashCommandOutputSection {
137 range: 0..(prompt.len()),
138 icon: IconName::ZedAssistant,
139 label: SharedString::from(
140 result
141 .description
142 .unwrap_or(format!("Result from {}", prompt_name)),
143 ),
144 }],
145 text: prompt,
146 run_commands_in_text: false,
147 })
148 })
149 } else {
150 Task::ready(Err(anyhow!("Context server not found")))
151 }
152 }
153}
154
155fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
156 if arguments.is_empty() {
157 return Err(anyhow!("No arguments given"));
158 }
159
160 match &prompt.arguments {
161 Some(args) if args.len() == 1 => {
162 let arg_name = args[0].name.clone();
163 let arg_value = arguments.join(" ");
164 Ok((arg_name, arg_value))
165 }
166 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
167 None => Err(anyhow!("Prompt has no arguments")),
168 }
169}
170
171fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
172 match &prompt.arguments {
173 Some(args) if args.len() > 1 => Err(anyhow!(
174 "Prompt has more than one argument, which is not supported"
175 )),
176 Some(args) if args.len() == 1 => {
177 if !arguments.is_empty() {
178 let mut map = HashMap::default();
179 map.insert(args[0].name.clone(), arguments.join(" "));
180 Ok(map)
181 } else if arguments.is_empty() && args[0].required == Some(false) {
182 Ok(HashMap::default())
183 } else {
184 Err(anyhow!("Prompt expects argument but none given"))
185 }
186 }
187 Some(_) | None => {
188 if arguments.is_empty() {
189 Ok(HashMap::default())
190 } else {
191 Err(anyhow!("Prompt expects no arguments but some were given"))
192 }
193 }
194 }
195}
196
197/// MCP servers can return prompts with multiple arguments. Since we only
198/// support one argument, we ignore all others. This is the necessary predicate
199/// for this.
200pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
201 match &prompt.arguments {
202 None => true,
203 Some(args) if args.len() <= 1 => true,
204 _ => false,
205 }
206}