1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection,
5};
6use collections::HashMap;
7use context_servers::{
8 manager::{ContextServer, ContextServerManager},
9 protocol::PromptInfo,
10};
11use gpui::{Task, WeakView, WindowContext};
12use language::{CodeLabel, LspAdapterDelegate};
13use std::sync::atomic::AtomicBool;
14use std::sync::Arc;
15use ui::{IconName, SharedString};
16use workspace::Workspace;
17
18pub struct ContextServerSlashCommand {
19 server_id: String,
20 prompt: PromptInfo,
21}
22
23impl ContextServerSlashCommand {
24 pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
25 Self {
26 server_id: server.id.clone(),
27 prompt,
28 }
29 }
30}
31
32impl SlashCommand for ContextServerSlashCommand {
33 fn name(&self) -> String {
34 self.prompt.name.clone()
35 }
36
37 fn description(&self) -> String {
38 format!("Run context server command: {}", self.prompt.name)
39 }
40
41 fn menu_text(&self) -> String {
42 format!("Run '{}' from {}", self.prompt.name, self.server_id)
43 }
44
45 fn requires_argument(&self) -> bool {
46 self.prompt
47 .arguments
48 .as_ref()
49 .map_or(false, |args| !args.is_empty())
50 }
51
52 fn complete_argument(
53 self: Arc<Self>,
54 arguments: &[String],
55 _cancel: Arc<AtomicBool>,
56 _workspace: Option<WeakView<Workspace>>,
57 cx: &mut WindowContext,
58 ) -> Task<Result<Vec<ArgumentCompletion>>> {
59 let server_id = self.server_id.clone();
60 let prompt_name = self.prompt.name.clone();
61 let manager = ContextServerManager::global(cx);
62 let manager = manager.read(cx);
63
64 let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
65 Ok(tp) => tp,
66 Err(e) => {
67 return Task::ready(Err(e));
68 }
69 };
70 if let Some(server) = manager.get_server(&server_id) {
71 cx.foreground_executor().spawn(async move {
72 let Some(protocol) = server.client.read().clone() else {
73 return Err(anyhow!("Context server not initialized"));
74 };
75
76 let completion_result = protocol
77 .completion(
78 context_servers::types::CompletionReference::Prompt(
79 context_servers::types::PromptReference {
80 r#type: context_servers::types::PromptReferenceType::Prompt,
81 name: prompt_name,
82 },
83 ),
84 arg_name,
85 arg_val,
86 )
87 .await?;
88
89 let completions = completion_result
90 .values
91 .into_iter()
92 .map(|value| ArgumentCompletion {
93 label: CodeLabel::plain(value.clone(), None),
94 new_text: value,
95 after_completion: AfterCompletion::Continue,
96 replace_previous_arguments: false,
97 })
98 .collect();
99
100 Ok(completions)
101 })
102 } else {
103 Task::ready(Err(anyhow!("Context server not found")))
104 }
105 }
106
107 fn run(
108 self: Arc<Self>,
109 arguments: &[String],
110 _workspace: WeakView<Workspace>,
111 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
112 cx: &mut WindowContext,
113 ) -> Task<Result<SlashCommandOutput>> {
114 let server_id = self.server_id.clone();
115 let prompt_name = self.prompt.name.clone();
116
117 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
118 Ok(args) => args,
119 Err(e) => return Task::ready(Err(e)),
120 };
121
122 let manager = ContextServerManager::global(cx);
123 let manager = manager.read(cx);
124 if let Some(server) = manager.get_server(&server_id) {
125 cx.foreground_executor().spawn(async move {
126 let Some(protocol) = server.client.read().clone() else {
127 return Err(anyhow!("Context server not initialized"));
128 };
129 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
130
131 Ok(SlashCommandOutput {
132 sections: vec![SlashCommandOutputSection {
133 range: 0..(result.prompt.len()),
134 icon: IconName::ZedAssistant,
135 label: SharedString::from(
136 result
137 .description
138 .unwrap_or(format!("Result from {}", prompt_name)),
139 ),
140 }],
141 text: result.prompt,
142 run_commands_in_text: false,
143 })
144 })
145 } else {
146 Task::ready(Err(anyhow!("Context server not found")))
147 }
148 }
149}
150
151fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
152 if arguments.is_empty() {
153 return Err(anyhow!("No arguments given"));
154 }
155
156 match &prompt.arguments {
157 Some(args) if args.len() == 1 => {
158 let arg_name = args[0].name.clone();
159 let arg_value = arguments.join(" ");
160 Ok((arg_name, arg_value))
161 }
162 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
163 None => Err(anyhow!("Prompt has no arguments")),
164 }
165}
166
167fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
168 match &prompt.arguments {
169 Some(args) if args.len() > 1 => Err(anyhow!(
170 "Prompt has more than one argument, which is not supported"
171 )),
172 Some(args) if args.len() == 1 => {
173 if !arguments.is_empty() {
174 let mut map = HashMap::default();
175 map.insert(args[0].name.clone(), arguments.join(" "));
176 Ok(map)
177 } else {
178 Err(anyhow!("Prompt expects argument but none given"))
179 }
180 }
181 Some(_) | None => {
182 if arguments.is_empty() {
183 Ok(HashMap::default())
184 } else {
185 Err(anyhow!("Prompt expects no arguments but some were given"))
186 }
187 }
188 }
189}
190
191/// MCP servers can return prompts with multiple arguments. Since we only
192/// support one argument, we ignore all others. This is the necessary predicate
193/// for this.
194pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
195 match &prompt.arguments {
196 None => true,
197 Some(args) if args.len() == 1 => true,
198 _ => false,
199 }
200}