context_server_command.rs

  1use anyhow::{Result, anyhow};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection, SlashCommandResult,
  5};
  6use collections::HashMap;
  7use context_server::{ContextServerId, types::Prompt};
  8use gpui::{App, Entity, Task, WeakEntity, Window};
  9use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
 10use project::context_server_store::ContextServerStore;
 11use std::sync::Arc;
 12use std::sync::atomic::AtomicBool;
 13use text::LineEnding;
 14use ui::{IconName, SharedString};
 15use workspace::Workspace;
 16
 17use crate::create_label_for_command;
 18
 19pub struct ContextServerSlashCommand {
 20    store: Entity<ContextServerStore>,
 21    server_id: ContextServerId,
 22    prompt: Prompt,
 23}
 24
 25impl ContextServerSlashCommand {
 26    pub fn new(store: Entity<ContextServerStore>, id: ContextServerId, prompt: Prompt) -> Self {
 27        Self {
 28            server_id: id,
 29            prompt,
 30            store,
 31        }
 32    }
 33}
 34
 35impl SlashCommand for ContextServerSlashCommand {
 36    fn name(&self) -> String {
 37        self.prompt.name.clone()
 38    }
 39
 40    fn label(&self, cx: &App) -> language::CodeLabel {
 41        let mut parts = vec![self.prompt.name.as_str()];
 42        if let Some(args) = &self.prompt.arguments {
 43            if let Some(arg) = args.first() {
 44                parts.push(arg.name.as_str());
 45            }
 46        }
 47        create_label_for_command(&parts[0], &parts[1..], cx)
 48    }
 49
 50    fn description(&self) -> String {
 51        match &self.prompt.description {
 52            Some(desc) => desc.clone(),
 53            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 54        }
 55    }
 56
 57    fn menu_text(&self) -> String {
 58        match &self.prompt.description {
 59            Some(desc) => desc.clone(),
 60            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 61        }
 62    }
 63
 64    fn requires_argument(&self) -> bool {
 65        self.prompt.arguments.as_ref().map_or(false, |args| {
 66            args.iter().any(|arg| arg.required == Some(true))
 67        })
 68    }
 69
 70    fn complete_argument(
 71        self: Arc<Self>,
 72        arguments: &[String],
 73        _cancel: Arc<AtomicBool>,
 74        _workspace: Option<WeakEntity<Workspace>>,
 75        _window: &mut Window,
 76        cx: &mut App,
 77    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 78        let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
 79            return Task::ready(Err(anyhow!("Failed to complete argument")));
 80        };
 81
 82        let server_id = self.server_id.clone();
 83        let prompt_name = self.prompt.name.clone();
 84
 85        if let Some(server) = self.store.read(cx).get_running_server(&server_id) {
 86            cx.foreground_executor().spawn(async move {
 87                let Some(protocol) = server.client() else {
 88                    return Err(anyhow!("Context server not initialized"));
 89                };
 90
 91                let completion_result = protocol
 92                    .completion(
 93                        context_server::types::CompletionReference::Prompt(
 94                            context_server::types::PromptReference {
 95                                r#type: context_server::types::PromptReferenceType::Prompt,
 96                                name: prompt_name,
 97                            },
 98                        ),
 99                        arg_name,
100                        arg_value,
101                    )
102                    .await?;
103
104                let completions = completion_result
105                    .values
106                    .into_iter()
107                    .map(|value| ArgumentCompletion {
108                        label: CodeLabel::plain(value.clone(), None),
109                        new_text: value,
110                        after_completion: AfterCompletion::Continue,
111                        replace_previous_arguments: false,
112                    })
113                    .collect();
114                Ok(completions)
115            })
116        } else {
117            Task::ready(Err(anyhow!("Context server not found")))
118        }
119    }
120
121    fn run(
122        self: Arc<Self>,
123        arguments: &[String],
124        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
125        _context_buffer: BufferSnapshot,
126        _workspace: WeakEntity<Workspace>,
127        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
128        _window: &mut Window,
129        cx: &mut App,
130    ) -> Task<SlashCommandResult> {
131        let server_id = self.server_id.clone();
132        let prompt_name = self.prompt.name.clone();
133
134        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
135            Ok(args) => args,
136            Err(e) => return Task::ready(Err(e)),
137        };
138
139        let store = self.store.read(cx);
140        if let Some(server) = store.get_running_server(&server_id) {
141            cx.foreground_executor().spawn(async move {
142                let Some(protocol) = server.client() else {
143                    return Err(anyhow!("Context server not initialized"));
144                };
145                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
146
147                // Check that there are only user roles
148                if result
149                    .messages
150                    .iter()
151                    .any(|msg| !matches!(msg.role, context_server::types::Role::User))
152                {
153                    return Err(anyhow!(
154                        "Prompt contains non-user roles, which is not supported"
155                    ));
156                }
157
158                // Extract text from user messages into a single prompt string
159                let mut prompt = result
160                    .messages
161                    .into_iter()
162                    .filter_map(|msg| match msg.content {
163                        context_server::types::MessageContent::Text { text, .. } => Some(text),
164                        _ => None,
165                    })
166                    .collect::<Vec<String>>()
167                    .join("\n\n");
168
169                // We must normalize the line endings here, since servers might return CR characters.
170                LineEnding::normalize(&mut prompt);
171
172                Ok(SlashCommandOutput {
173                    sections: vec![SlashCommandOutputSection {
174                        range: 0..(prompt.len()),
175                        icon: IconName::ZedAssistant,
176                        label: SharedString::from(
177                            result
178                                .description
179                                .unwrap_or(format!("Result from {}", prompt_name)),
180                        ),
181                        metadata: None,
182                    }],
183                    text: prompt,
184                    run_commands_in_text: false,
185                }
186                .to_event_stream())
187            })
188        } else {
189            Task::ready(Err(anyhow!("Context server not found")))
190        }
191    }
192}
193
194fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
195    if arguments.is_empty() {
196        return Err(anyhow!("No arguments given"));
197    }
198
199    match &prompt.arguments {
200        Some(args) if args.len() == 1 => {
201            let arg_name = args[0].name.clone();
202            let arg_value = arguments.join(" ");
203            Ok((arg_name, arg_value))
204        }
205        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
206        None => Err(anyhow!("Prompt has no arguments")),
207    }
208}
209
210fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
211    match &prompt.arguments {
212        Some(args) if args.len() > 1 => Err(anyhow!(
213            "Prompt has more than one argument, which is not supported"
214        )),
215        Some(args) if args.len() == 1 => {
216            if !arguments.is_empty() {
217                let mut map = HashMap::default();
218                map.insert(args[0].name.clone(), arguments.join(" "));
219                Ok(map)
220            } else if arguments.is_empty() && args[0].required == Some(false) {
221                Ok(HashMap::default())
222            } else {
223                Err(anyhow!("Prompt expects argument but none given"))
224            }
225        }
226        Some(_) | None => {
227            if arguments.is_empty() {
228                Ok(HashMap::default())
229            } else {
230                Err(anyhow!("Prompt expects no arguments but some were given"))
231            }
232        }
233    }
234}
235
236/// MCP servers can return prompts with multiple arguments. Since we only
237/// support one argument, we ignore all others. This is the necessary predicate
238/// for this.
239pub fn acceptable_prompt(prompt: &Prompt) -> bool {
240    match &prompt.arguments {
241        None => true,
242        Some(args) if args.len() <= 1 => true,
243        _ => false,
244    }
245}