1use anyhow::{anyhow, Result};
2use assistant_slash_command::{
3 ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
4};
5use collections::HashMap;
6use context_servers::{
7 manager::{ContextServer, ContextServerManager},
8 protocol::PromptInfo,
9};
10use gpui::{Task, WeakView, WindowContext};
11use language::LspAdapterDelegate;
12use std::sync::atomic::AtomicBool;
13use std::sync::Arc;
14use ui::{IconName, SharedString};
15use workspace::Workspace;
16
17pub struct ContextServerSlashCommand {
18 server_id: String,
19 prompt: PromptInfo,
20}
21
22impl ContextServerSlashCommand {
23 pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
24 Self {
25 server_id: server.id.clone(),
26 prompt,
27 }
28 }
29}
30
31impl SlashCommand for ContextServerSlashCommand {
32 fn name(&self) -> String {
33 self.prompt.name.clone()
34 }
35
36 fn description(&self) -> String {
37 format!("Run context server command: {}", self.prompt.name)
38 }
39
40 fn menu_text(&self) -> String {
41 format!("Run '{}' from {}", self.prompt.name, self.server_id)
42 }
43
44 fn requires_argument(&self) -> bool {
45 self.prompt
46 .arguments
47 .as_ref()
48 .map_or(false, |args| !args.is_empty())
49 }
50
51 fn complete_argument(
52 self: Arc<Self>,
53 _arguments: &[String],
54 _cancel: Arc<AtomicBool>,
55 _workspace: Option<WeakView<Workspace>>,
56 _cx: &mut WindowContext,
57 ) -> Task<Result<Vec<ArgumentCompletion>>> {
58 Task::ready(Ok(Vec::new()))
59 }
60
61 fn run(
62 self: Arc<Self>,
63 arguments: &[String],
64 _workspace: WeakView<Workspace>,
65 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
66 cx: &mut WindowContext,
67 ) -> Task<Result<SlashCommandOutput>> {
68 let server_id = self.server_id.clone();
69 let prompt_name = self.prompt.name.clone();
70
71 let prompt_args = match prompt_arguments(&self.prompt, arguments) {
72 Ok(args) => args,
73 Err(e) => return Task::ready(Err(e)),
74 };
75
76 let manager = ContextServerManager::global(cx);
77 let manager = manager.read(cx);
78 if let Some(server) = manager.get_server(&server_id) {
79 cx.foreground_executor().spawn(async move {
80 let Some(protocol) = server.client.read().clone() else {
81 return Err(anyhow!("Context server not initialized"));
82 };
83 let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
84
85 Ok(SlashCommandOutput {
86 sections: vec![SlashCommandOutputSection {
87 range: 0..result.len(),
88 icon: IconName::ZedAssistant,
89 label: SharedString::from(format!("Result from {}", prompt_name)),
90 }],
91 text: result,
92 run_commands_in_text: false,
93 })
94 })
95 } else {
96 Task::ready(Err(anyhow!("Context server not found")))
97 }
98 }
99}
100
101fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
102 match &prompt.arguments {
103 Some(args) if args.len() > 1 => Err(anyhow!(
104 "Prompt has more than one argument, which is not supported"
105 )),
106 Some(args) if args.len() == 1 => {
107 if !arguments.is_empty() {
108 let mut map = HashMap::default();
109 map.insert(args[0].name.clone(), arguments.join(" "));
110 Ok(map)
111 } else {
112 Err(anyhow!("Prompt expects argument but none given"))
113 }
114 }
115 Some(_) | None => {
116 if arguments.is_empty() {
117 Ok(HashMap::default())
118 } else {
119 Err(anyhow!("Prompt expects no arguments but some were given"))
120 }
121 }
122 }
123}
124
125/// MCP servers can return prompts with multiple arguments. Since we only
126/// support one argument, we ignore all others. This is the necessary predicate
127/// for this.
128pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
129 match &prompt.arguments {
130 None => true,
131 Some(args) if args.len() == 1 => true,
132 _ => false,
133 }
134}