assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3mod slash_command_working_set;
  4
  5pub use crate::extension_slash_command::*;
  6pub use crate::slash_command_registry::*;
  7pub use crate::slash_command_working_set::*;
  8use anyhow::Result;
  9use futures::StreamExt;
 10use futures::stream::{self, BoxStream};
 11use gpui::{App, SharedString, Task, WeakEntity, Window};
 12use language::HighlightId;
 13use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 14pub use language_model::Role;
 15use serde::{Deserialize, Serialize};
 16use std::{
 17    ops::Range,
 18    sync::{Arc, atomic::AtomicBool},
 19};
 20use ui::ActiveTheme;
 21use workspace::{Workspace, ui::IconName};
 22
 23pub fn init(cx: &mut App) {
 24    SlashCommandRegistry::default_global(cx);
 25    extension_slash_command::init(cx);
 26}
 27
 28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 29pub enum AfterCompletion {
 30    /// Run the command
 31    Run,
 32    /// Continue composing the current argument, doesn't add a space
 33    Compose,
 34    /// Continue the command composition, adds a space
 35    Continue,
 36}
 37
 38impl From<bool> for AfterCompletion {
 39    fn from(value: bool) -> Self {
 40        if value {
 41            AfterCompletion::Run
 42        } else {
 43            AfterCompletion::Continue
 44        }
 45    }
 46}
 47
 48impl AfterCompletion {
 49    pub fn run(&self) -> bool {
 50        match self {
 51            AfterCompletion::Run => true,
 52            AfterCompletion::Compose | AfterCompletion::Continue => false,
 53        }
 54    }
 55}
 56
 57#[derive(Debug)]
 58pub struct ArgumentCompletion {
 59    /// The label to display for this completion.
 60    pub label: CodeLabel,
 61    /// The new text that should be inserted into the command when this completion is accepted.
 62    pub new_text: String,
 63    /// Whether the command should be run when accepting this completion.
 64    pub after_completion: AfterCompletion,
 65    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 66    pub replace_previous_arguments: bool,
 67}
 68
 69pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 70
 71pub trait SlashCommand: 'static + Send + Sync {
 72    fn name(&self) -> String;
 73    fn icon(&self) -> IconName {
 74        IconName::Slash
 75    }
 76    fn label(&self, _cx: &App) -> CodeLabel {
 77        CodeLabel::plain(self.name(), None)
 78    }
 79    fn description(&self) -> String;
 80    fn menu_text(&self) -> String;
 81    fn complete_argument(
 82        self: Arc<Self>,
 83        arguments: &[String],
 84        cancel: Arc<AtomicBool>,
 85        workspace: Option<WeakEntity<Workspace>>,
 86        window: &mut Window,
 87        cx: &mut App,
 88    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 89    fn requires_argument(&self) -> bool;
 90    fn accepts_arguments(&self) -> bool {
 91        self.requires_argument()
 92    }
 93    fn run(
 94        self: Arc<Self>,
 95        arguments: &[String],
 96        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 97        context_buffer: BufferSnapshot,
 98        workspace: WeakEntity<Workspace>,
 99        // TODO: We're just using the `LspAdapterDelegate` here because that is
100        // what the extension API is already expecting.
101        //
102        // It may be that `LspAdapterDelegate` needs a more general name, or
103        // perhaps another kind of delegate is needed here.
104        delegate: Option<Arc<dyn LspAdapterDelegate>>,
105        window: &mut Window,
106        cx: &mut App,
107    ) -> Task<SlashCommandResult>;
108}
109
110#[derive(Debug, PartialEq)]
111pub enum SlashCommandContent {
112    Text {
113        text: String,
114        run_commands_in_text: bool,
115    },
116}
117
118impl<'a> From<&'a str> for SlashCommandContent {
119    fn from(text: &'a str) -> Self {
120        Self::Text {
121            text: text.into(),
122            run_commands_in_text: false,
123        }
124    }
125}
126
127#[derive(Debug, PartialEq)]
128pub enum SlashCommandEvent {
129    StartMessage {
130        role: Role,
131        merge_same_roles: bool,
132    },
133    StartSection {
134        icon: IconName,
135        label: SharedString,
136        metadata: Option<serde_json::Value>,
137    },
138    Content(SlashCommandContent),
139    EndSection,
140}
141
142#[derive(Debug, Default, PartialEq, Clone)]
143pub struct SlashCommandOutput {
144    pub text: String,
145    pub sections: Vec<SlashCommandOutputSection<usize>>,
146    pub run_commands_in_text: bool,
147}
148
149impl SlashCommandOutput {
150    pub fn ensure_valid_section_ranges(&mut self) {
151        for section in &mut self.sections {
152            section.range.start = section.range.start.min(self.text.len());
153            section.range.end = section.range.end.min(self.text.len());
154            while !self.text.is_char_boundary(section.range.start) {
155                section.range.start -= 1;
156            }
157            while !self.text.is_char_boundary(section.range.end) {
158                section.range.end += 1;
159            }
160        }
161    }
162
163    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
164    pub fn into_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
165        self.ensure_valid_section_ranges();
166
167        let mut events = Vec::new();
168
169        let mut section_endpoints = Vec::new();
170        for section in self.sections {
171            section_endpoints.push((
172                section.range.start,
173                SlashCommandEvent::StartSection {
174                    icon: section.icon,
175                    label: section.label,
176                    metadata: section.metadata,
177                },
178            ));
179            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
180        }
181        section_endpoints.sort_by_key(|(offset, _)| *offset);
182
183        let mut content_offset = 0;
184        for (endpoint_offset, endpoint) in section_endpoints {
185            if content_offset < endpoint_offset {
186                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
187                    text: self.text[content_offset..endpoint_offset].to_string(),
188                    run_commands_in_text: self.run_commands_in_text,
189                })));
190                content_offset = endpoint_offset;
191            }
192
193            events.push(Ok(endpoint));
194        }
195
196        if content_offset < self.text.len() {
197            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
198                text: self.text[content_offset..].to_string(),
199                run_commands_in_text: self.run_commands_in_text,
200            })));
201        }
202
203        stream::iter(events).boxed()
204    }
205
206    pub async fn from_event_stream(
207        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
208    ) -> Result<SlashCommandOutput> {
209        let mut output = SlashCommandOutput::default();
210        let mut section_stack = Vec::new();
211
212        while let Some(event) = events.next().await {
213            match event? {
214                SlashCommandEvent::StartSection {
215                    icon,
216                    label,
217                    metadata,
218                } => {
219                    let start = output.text.len();
220                    section_stack.push(SlashCommandOutputSection {
221                        range: start..start,
222                        icon,
223                        label,
224                        metadata,
225                    });
226                }
227                SlashCommandEvent::Content(SlashCommandContent::Text {
228                    text,
229                    run_commands_in_text,
230                }) => {
231                    output.text.push_str(&text);
232                    output.run_commands_in_text = run_commands_in_text;
233
234                    if let Some(section) = section_stack.last_mut() {
235                        section.range.end = output.text.len();
236                    }
237                }
238                SlashCommandEvent::EndSection => {
239                    if let Some(section) = section_stack.pop() {
240                        output.sections.push(section);
241                    }
242                }
243                SlashCommandEvent::StartMessage { .. } => {}
244            }
245        }
246
247        while let Some(section) = section_stack.pop() {
248            output.sections.push(section);
249        }
250
251        Ok(output)
252    }
253}
254
255#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
256pub struct SlashCommandOutputSection<T> {
257    pub range: Range<T>,
258    pub icon: IconName,
259    pub label: SharedString,
260    pub metadata: Option<serde_json::Value>,
261}
262
263impl SlashCommandOutputSection<language::Anchor> {
264    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
265        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
266    }
267}
268
269pub struct SlashCommandLine {
270    /// The range within the line containing the command name.
271    pub name: Range<usize>,
272    /// Ranges within the line containing the command arguments.
273    pub arguments: Vec<Range<usize>>,
274}
275
276impl SlashCommandLine {
277    pub fn parse(line: &str) -> Option<Self> {
278        let mut call: Option<Self> = None;
279        let mut ix = 0;
280        for c in line.chars() {
281            let next_ix = ix + c.len_utf8();
282            if let Some(call) = &mut call {
283                // The command arguments start at the first non-whitespace character
284                // after the command name, and continue until the end of the line.
285                if let Some(argument) = call.arguments.last_mut() {
286                    if c.is_whitespace() {
287                        if (*argument).is_empty() {
288                            argument.start = next_ix;
289                            argument.end = next_ix;
290                        } else {
291                            argument.end = ix;
292                            call.arguments.push(next_ix..next_ix);
293                        }
294                    } else {
295                        argument.end = next_ix;
296                    }
297                }
298                // The command name ends at the first whitespace character.
299                else if !call.name.is_empty() {
300                    if c.is_whitespace() {
301                        call.arguments = vec![next_ix..next_ix];
302                    } else {
303                        call.name.end = next_ix;
304                    }
305                }
306                // The command name must begin with a letter.
307                else if c.is_alphabetic() {
308                    call.name.end = next_ix;
309                } else {
310                    return None;
311                }
312            }
313            // Commands start with a slash.
314            else if c == '/' {
315                call = Some(SlashCommandLine {
316                    name: next_ix..next_ix,
317                    arguments: Vec::new(),
318                });
319            }
320            // The line can't contain anything before the slash except for whitespace.
321            else if !c.is_whitespace() {
322                return None;
323            }
324            ix = next_ix;
325        }
326        call
327    }
328}
329
330pub fn create_label_for_command(command_name: &str, arguments: &[&str], cx: &App) -> CodeLabel {
331    let mut label = CodeLabel::default();
332    label.push_str(command_name, None);
333    label.push_str(" ", None);
334    label.push_str(
335        &arguments.join(" "),
336        cx.theme().syntax().highlight_id("comment").map(HighlightId),
337    );
338    label.filter_range = 0..command_name.len();
339    label
340}
341
342#[cfg(test)]
343mod tests {
344    use pretty_assertions::assert_eq;
345    use serde_json::json;
346
347    use super::*;
348
349    #[gpui::test]
350    async fn test_slash_command_output_to_events_round_trip() {
351        // Test basic output consisting of a single section.
352        {
353            let text = "Hello, world!".to_string();
354            let range = 0..text.len();
355            let output = SlashCommandOutput {
356                text,
357                sections: vec![SlashCommandOutputSection {
358                    range,
359                    icon: IconName::Code,
360                    label: "Section 1".into(),
361                    metadata: None,
362                }],
363                run_commands_in_text: false,
364            };
365
366            let events = output.clone().into_event_stream().collect::<Vec<_>>().await;
367            let events = events
368                .into_iter()
369                .filter_map(|event| event.ok())
370                .collect::<Vec<_>>();
371
372            assert_eq!(
373                events,
374                vec![
375                    SlashCommandEvent::StartSection {
376                        icon: IconName::Code,
377                        label: "Section 1".into(),
378                        metadata: None
379                    },
380                    SlashCommandEvent::Content(SlashCommandContent::Text {
381                        text: "Hello, world!".into(),
382                        run_commands_in_text: false
383                    }),
384                    SlashCommandEvent::EndSection
385                ]
386            );
387
388            let new_output =
389                SlashCommandOutput::from_event_stream(output.clone().into_event_stream())
390                    .await
391                    .unwrap();
392
393            assert_eq!(new_output, output);
394        }
395
396        // Test output where the sections do not comprise all of the text.
397        {
398            let text = "Apple\nCucumber\nBanana\n".to_string();
399            let output = SlashCommandOutput {
400                text,
401                sections: vec![
402                    SlashCommandOutputSection {
403                        range: 0..6,
404                        icon: IconName::Check,
405                        label: "Fruit".into(),
406                        metadata: None,
407                    },
408                    SlashCommandOutputSection {
409                        range: 15..22,
410                        icon: IconName::Check,
411                        label: "Fruit".into(),
412                        metadata: None,
413                    },
414                ],
415                run_commands_in_text: false,
416            };
417
418            let events = output.clone().into_event_stream().collect::<Vec<_>>().await;
419            let events = events
420                .into_iter()
421                .filter_map(|event| event.ok())
422                .collect::<Vec<_>>();
423
424            assert_eq!(
425                events,
426                vec![
427                    SlashCommandEvent::StartSection {
428                        icon: IconName::Check,
429                        label: "Fruit".into(),
430                        metadata: None
431                    },
432                    SlashCommandEvent::Content(SlashCommandContent::Text {
433                        text: "Apple\n".into(),
434                        run_commands_in_text: false
435                    }),
436                    SlashCommandEvent::EndSection,
437                    SlashCommandEvent::Content(SlashCommandContent::Text {
438                        text: "Cucumber\n".into(),
439                        run_commands_in_text: false
440                    }),
441                    SlashCommandEvent::StartSection {
442                        icon: IconName::Check,
443                        label: "Fruit".into(),
444                        metadata: None
445                    },
446                    SlashCommandEvent::Content(SlashCommandContent::Text {
447                        text: "Banana\n".into(),
448                        run_commands_in_text: false
449                    }),
450                    SlashCommandEvent::EndSection
451                ]
452            );
453
454            let new_output =
455                SlashCommandOutput::from_event_stream(output.clone().into_event_stream())
456                    .await
457                    .unwrap();
458
459            assert_eq!(new_output, output);
460        }
461
462        // Test output consisting of multiple sections.
463        {
464            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
465            let output = SlashCommandOutput {
466                text,
467                sections: vec![
468                    SlashCommandOutputSection {
469                        range: 0..6,
470                        icon: IconName::FileCode,
471                        label: "Section 1".into(),
472                        metadata: Some(json!({ "a": true })),
473                    },
474                    SlashCommandOutputSection {
475                        range: 7..13,
476                        icon: IconName::FileDoc,
477                        label: "Section 2".into(),
478                        metadata: Some(json!({ "b": true })),
479                    },
480                    SlashCommandOutputSection {
481                        range: 14..20,
482                        icon: IconName::FileGit,
483                        label: "Section 3".into(),
484                        metadata: Some(json!({ "c": true })),
485                    },
486                    SlashCommandOutputSection {
487                        range: 21..27,
488                        icon: IconName::FileToml,
489                        label: "Section 4".into(),
490                        metadata: Some(json!({ "d": true })),
491                    },
492                ],
493                run_commands_in_text: false,
494            };
495
496            let events = output.clone().into_event_stream().collect::<Vec<_>>().await;
497            let events = events
498                .into_iter()
499                .filter_map(|event| event.ok())
500                .collect::<Vec<_>>();
501
502            assert_eq!(
503                events,
504                vec![
505                    SlashCommandEvent::StartSection {
506                        icon: IconName::FileCode,
507                        label: "Section 1".into(),
508                        metadata: Some(json!({ "a": true }))
509                    },
510                    SlashCommandEvent::Content(SlashCommandContent::Text {
511                        text: "Line 1".into(),
512                        run_commands_in_text: false
513                    }),
514                    SlashCommandEvent::EndSection,
515                    SlashCommandEvent::Content(SlashCommandContent::Text {
516                        text: "\n".into(),
517                        run_commands_in_text: false
518                    }),
519                    SlashCommandEvent::StartSection {
520                        icon: IconName::FileDoc,
521                        label: "Section 2".into(),
522                        metadata: Some(json!({ "b": true }))
523                    },
524                    SlashCommandEvent::Content(SlashCommandContent::Text {
525                        text: "Line 2".into(),
526                        run_commands_in_text: false
527                    }),
528                    SlashCommandEvent::EndSection,
529                    SlashCommandEvent::Content(SlashCommandContent::Text {
530                        text: "\n".into(),
531                        run_commands_in_text: false
532                    }),
533                    SlashCommandEvent::StartSection {
534                        icon: IconName::FileGit,
535                        label: "Section 3".into(),
536                        metadata: Some(json!({ "c": true }))
537                    },
538                    SlashCommandEvent::Content(SlashCommandContent::Text {
539                        text: "Line 3".into(),
540                        run_commands_in_text: false
541                    }),
542                    SlashCommandEvent::EndSection,
543                    SlashCommandEvent::Content(SlashCommandContent::Text {
544                        text: "\n".into(),
545                        run_commands_in_text: false
546                    }),
547                    SlashCommandEvent::StartSection {
548                        icon: IconName::FileToml,
549                        label: "Section 4".into(),
550                        metadata: Some(json!({ "d": true }))
551                    },
552                    SlashCommandEvent::Content(SlashCommandContent::Text {
553                        text: "Line 4".into(),
554                        run_commands_in_text: false
555                    }),
556                    SlashCommandEvent::EndSection,
557                    SlashCommandEvent::Content(SlashCommandContent::Text {
558                        text: "\n".into(),
559                        run_commands_in_text: false
560                    }),
561                ]
562            );
563
564            let new_output =
565                SlashCommandOutput::from_event_stream(output.clone().into_event_stream())
566                    .await
567                    .unwrap();
568
569            assert_eq!(new_output, output);
570        }
571    }
572}