1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
4 SlashCommandOutputSection,
5};
6use collections::HashMap;
7use context_servers::{
8 manager::{ContextServer, ContextServerManager},
9 protocol::PromptInfo,
10};
11use gpui::{Task, WeakView, WindowContext};
12use language::{CodeLabel, LspAdapterDelegate};
13use std::sync::atomic::AtomicBool;
14use std::sync::Arc;
15use text::LineEnding;
16use ui::{IconName, SharedString};
17use workspace::Workspace;
18
19pub struct ContextServerSlashCommand {
20 server_id: String,
21 prompt: PromptInfo,
22}
23
24impl ContextServerSlashCommand {
25 pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
26 Self {
27 server_id: server.id.clone(),
28 prompt,
29 }
30 }
31}
32
33impl SlashCommand for ContextServerSlashCommand {
34 fn name(&self) -> String {
35 self.prompt.name.clone()
36 }
37
38 fn description(&self) -> String {
39 format!("Run context server command: {}", self.prompt.name)
40 }
41
42 fn menu_text(&self) -> String {
43 format!("Run '{}' from {}", self.prompt.name, self.server_id)
44 }
45
46 fn requires_argument(&self) -> bool {
47 self.prompt
48 .arguments
49 .as_ref()
50 .map_or(false, |args| !args.is_empty())
51 }
52
53 fn complete_argument(
54 self: Arc<Self>,
55 arguments: &[String],
56 _cancel: Arc<AtomicBool>,
57 _workspace: Option<WeakView<Workspace>>,
58 cx: &mut WindowContext,
59 ) -> Task<Result<Vec<ArgumentCompletion>>> {
60 let server_id = self.server_id.clone();
61 let prompt_name = self.prompt.name.clone();
62 let manager = ContextServerManager::global(cx);
63 let manager = manager.read(cx);
64
65 let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
66 Ok(tp) => tp,
67 Err(e) => {
68 return Task::ready(Err(e));
69 }
70 };
71 if let Some(server) = manager.get_server(&server_id) {
72 cx.foreground_executor().spawn(async move {
73 let Some(protocol) = server.client.read().clone() else {
74 return Err(anyhow!("Context server not initialized"));
75 };
76
77 let completion_result = protocol
78 .completion(
79 context_servers::types::CompletionReference::Prompt(
80 context_servers::types::PromptReference {
81 r#type: context_servers::types::PromptReferenceType::Prompt,
82 name: prompt_name,
83 },
84 ),
85 arg_name,
86 arg_val,
87 )
88 .await?;
89
90 let completions = completion_result
91 .values
92 .into_iter()
93 .map(|value| ArgumentCompletion {
94 label: CodeLabel::plain(value.clone(), None),
95 new_text: value,
96 after_completion: AfterCompletion::Continue,
97 replace_previous_arguments: false,
98 })
99 .collect();
100
101 Ok(completions)
102 })
103 } else {
104 Task::ready(Err(anyhow!("Context server not found")))
105 }
106 }
107
108 fn run(
109 self: Arc<Self>,
110 arguments: &[String],
111 _workspace: WeakView<Workspace>,
112 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
113 cx: &mut WindowContext,
114 ) -> Task<Result<SlashCommandOutput>> {
115 let server_id = self.server_id.clone();
116 let prompt_name = self.prompt.name.clone();
117
118 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
119 Ok(args) => args,
120 Err(e) => return Task::ready(Err(e)),
121 };
122
123 let manager = ContextServerManager::global(cx);
124 let manager = manager.read(cx);
125 if let Some(server) = manager.get_server(&server_id) {
126 cx.foreground_executor().spawn(async move {
127 let Some(protocol) = server.client.read().clone() else {
128 return Err(anyhow!("Context server not initialized"));
129 };
130 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
131 let mut prompt = result.prompt;
132
133 // We must normalize the line endings here, since servers might return CR characters.
134 LineEnding::normalize(&mut prompt);
135
136 Ok(SlashCommandOutput {
137 sections: vec![SlashCommandOutputSection {
138 range: 0..(prompt.len()),
139 icon: IconName::ZedAssistant,
140 label: SharedString::from(
141 result
142 .description
143 .unwrap_or(format!("Result from {}", prompt_name)),
144 ),
145 }],
146 text: prompt,
147 run_commands_in_text: false,
148 })
149 })
150 } else {
151 Task::ready(Err(anyhow!("Context server not found")))
152 }
153 }
154}
155
156fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
157 if arguments.is_empty() {
158 return Err(anyhow!("No arguments given"));
159 }
160
161 match &prompt.arguments {
162 Some(args) if args.len() == 1 => {
163 let arg_name = args[0].name.clone();
164 let arg_value = arguments.join(" ");
165 Ok((arg_name, arg_value))
166 }
167 Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
168 None => Err(anyhow!("Prompt has no arguments")),
169 }
170}
171
172fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
173 match &prompt.arguments {
174 Some(args) if args.len() > 1 => Err(anyhow!(
175 "Prompt has more than one argument, which is not supported"
176 )),
177 Some(args) if args.len() == 1 => {
178 if !arguments.is_empty() {
179 let mut map = HashMap::default();
180 map.insert(args[0].name.clone(), arguments.join(" "));
181 Ok(map)
182 } else {
183 Err(anyhow!("Prompt expects argument but none given"))
184 }
185 }
186 Some(_) | None => {
187 if arguments.is_empty() {
188 Ok(HashMap::default())
189 } else {
190 Err(anyhow!("Prompt expects no arguments but some were given"))
191 }
192 }
193 }
194}
195
196/// MCP servers can return prompts with multiple arguments. Since we only
197/// support one argument, we ignore all others. This is the necessary predicate
198/// for this.
199pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
200 match &prompt.arguments {
201 None => true,
202 Some(args) if args.len() == 1 => true,
203 _ => false,
204 }
205}