context_server_command.rs

  1use anyhow::{anyhow, Result};
  2use assistant_slash_command::{
  3    AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
  4    SlashCommandOutputSection, SlashCommandResult,
  5};
  6use collections::HashMap;
  7use context_server::{
  8    manager::{ContextServer, ContextServerManager},
  9    types::Prompt,
 10};
 11use gpui::{AppContext, Model, Task, WeakView, WindowContext};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
 13use std::sync::atomic::AtomicBool;
 14use std::sync::Arc;
 15use text::LineEnding;
 16use ui::{IconName, SharedString};
 17use workspace::Workspace;
 18
 19use crate::slash_command::create_label_for_command;
 20
 21pub struct ContextServerSlashCommand {
 22    server_manager: Model<ContextServerManager>,
 23    server_id: Arc<str>,
 24    prompt: Prompt,
 25}
 26
 27impl ContextServerSlashCommand {
 28    pub fn new(
 29        server_manager: Model<ContextServerManager>,
 30        server: &Arc<ContextServer>,
 31        prompt: Prompt,
 32    ) -> Self {
 33        Self {
 34            server_id: server.id(),
 35            prompt,
 36            server_manager,
 37        }
 38    }
 39}
 40
 41impl SlashCommand for ContextServerSlashCommand {
 42    fn name(&self) -> String {
 43        self.prompt.name.clone()
 44    }
 45
 46    fn label(&self, cx: &AppContext) -> language::CodeLabel {
 47        let mut parts = vec![self.prompt.name.as_str()];
 48        if let Some(args) = &self.prompt.arguments {
 49            if let Some(arg) = args.first() {
 50                parts.push(arg.name.as_str());
 51            }
 52        }
 53        create_label_for_command(&parts[0], &parts[1..], cx)
 54    }
 55
 56    fn description(&self) -> String {
 57        match &self.prompt.description {
 58            Some(desc) => desc.clone(),
 59            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 60        }
 61    }
 62
 63    fn menu_text(&self) -> String {
 64        match &self.prompt.description {
 65            Some(desc) => desc.clone(),
 66            None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
 67        }
 68    }
 69
 70    fn requires_argument(&self) -> bool {
 71        self.prompt.arguments.as_ref().map_or(false, |args| {
 72            args.iter().any(|arg| arg.required == Some(true))
 73        })
 74    }
 75
 76    fn complete_argument(
 77        self: Arc<Self>,
 78        arguments: &[String],
 79        _cancel: Arc<AtomicBool>,
 80        _workspace: Option<WeakView<Workspace>>,
 81        cx: &mut WindowContext,
 82    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 83        let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
 84            return Task::ready(Err(anyhow!("Failed to complete argument")));
 85        };
 86
 87        let server_id = self.server_id.clone();
 88        let prompt_name = self.prompt.name.clone();
 89
 90        if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
 91            cx.foreground_executor().spawn(async move {
 92                let Some(protocol) = server.client() else {
 93                    return Err(anyhow!("Context server not initialized"));
 94                };
 95
 96                let completion_result = protocol
 97                    .completion(
 98                        context_server::types::CompletionReference::Prompt(
 99                            context_server::types::PromptReference {
100                                r#type: context_server::types::PromptReferenceType::Prompt,
101                                name: prompt_name,
102                            },
103                        ),
104                        arg_name,
105                        arg_value,
106                    )
107                    .await?;
108
109                let completions = completion_result
110                    .values
111                    .into_iter()
112                    .map(|value| ArgumentCompletion {
113                        label: CodeLabel::plain(value.clone(), None),
114                        new_text: value,
115                        after_completion: AfterCompletion::Continue,
116                        replace_previous_arguments: false,
117                    })
118                    .collect();
119                Ok(completions)
120            })
121        } else {
122            Task::ready(Err(anyhow!("Context server not found")))
123        }
124    }
125
126    fn run(
127        self: Arc<Self>,
128        arguments: &[String],
129        _context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
130        _context_buffer: BufferSnapshot,
131        _workspace: WeakView<Workspace>,
132        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
133        cx: &mut WindowContext,
134    ) -> Task<SlashCommandResult> {
135        let server_id = self.server_id.clone();
136        let prompt_name = self.prompt.name.clone();
137
138        let prompt_args = match prompt_arguments(&self.prompt, arguments) {
139            Ok(args) => args,
140            Err(e) => return Task::ready(Err(e)),
141        };
142
143        let manager = self.server_manager.read(cx);
144        if let Some(server) = manager.get_server(&server_id) {
145            cx.foreground_executor().spawn(async move {
146                let Some(protocol) = server.client() else {
147                    return Err(anyhow!("Context server not initialized"));
148                };
149                let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
150
151                // Check that there are only user roles
152                if result
153                    .messages
154                    .iter()
155                    .any(|msg| !matches!(msg.role, context_server::types::Role::User))
156                {
157                    return Err(anyhow!(
158                        "Prompt contains non-user roles, which is not supported"
159                    ));
160                }
161
162                // Extract text from user messages into a single prompt string
163                let mut prompt = result
164                    .messages
165                    .into_iter()
166                    .filter_map(|msg| match msg.content {
167                        context_server::types::MessageContent::Text { text, .. } => Some(text),
168                        _ => None,
169                    })
170                    .collect::<Vec<String>>()
171                    .join("\n\n");
172
173                // We must normalize the line endings here, since servers might return CR characters.
174                LineEnding::normalize(&mut prompt);
175
176                Ok(SlashCommandOutput {
177                    sections: vec![SlashCommandOutputSection {
178                        range: 0..(prompt.len()),
179                        icon: IconName::ZedAssistant,
180                        label: SharedString::from(
181                            result
182                                .description
183                                .unwrap_or(format!("Result from {}", prompt_name)),
184                        ),
185                        metadata: None,
186                    }],
187                    text: prompt,
188                    run_commands_in_text: false,
189                }
190                .to_event_stream())
191            })
192        } else {
193            Task::ready(Err(anyhow!("Context server not found")))
194        }
195    }
196}
197
198fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
199    if arguments.is_empty() {
200        return Err(anyhow!("No arguments given"));
201    }
202
203    match &prompt.arguments {
204        Some(args) if args.len() == 1 => {
205            let arg_name = args[0].name.clone();
206            let arg_value = arguments.join(" ");
207            Ok((arg_name, arg_value))
208        }
209        Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
210        None => Err(anyhow!("Prompt has no arguments")),
211    }
212}
213
214fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
215    match &prompt.arguments {
216        Some(args) if args.len() > 1 => Err(anyhow!(
217            "Prompt has more than one argument, which is not supported"
218        )),
219        Some(args) if args.len() == 1 => {
220            if !arguments.is_empty() {
221                let mut map = HashMap::default();
222                map.insert(args[0].name.clone(), arguments.join(" "));
223                Ok(map)
224            } else if arguments.is_empty() && args[0].required == Some(false) {
225                Ok(HashMap::default())
226            } else {
227                Err(anyhow!("Prompt expects argument but none given"))
228            }
229        }
230        Some(_) | None => {
231            if arguments.is_empty() {
232                Ok(HashMap::default())
233            } else {
234                Err(anyhow!("Prompt expects no arguments but some were given"))
235            }
236        }
237    }
238}
239
240/// MCP servers can return prompts with multiple arguments. Since we only
241/// support one argument, we ignore all others. This is the necessary predicate
242/// for this.
243pub fn acceptable_prompt(prompt: &Prompt) -> bool {
244    match &prompt.arguments {
245        None => true,
246        Some(args) if args.len() <= 1 => true,
247        _ => false,
248    }
249}