assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3mod slash_command_working_set;
  4
  5pub use crate::extension_slash_command::*;
  6pub use crate::slash_command_registry::*;
  7pub use crate::slash_command_working_set::*;
  8use anyhow::Result;
  9use futures::stream::{self, BoxStream};
 10use futures::StreamExt;
 11use gpui::{App, SharedString, Task, WeakEntity, Window};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 13pub use language_model::Role;
 14use serde::{Deserialize, Serialize};
 15use std::{
 16    ops::Range,
 17    sync::{atomic::AtomicBool, Arc},
 18};
 19use workspace::{ui::IconName, Workspace};
 20
 21pub fn init(cx: &mut App) {
 22    SlashCommandRegistry::default_global(cx);
 23    extension_slash_command::init(cx);
 24}
 25
 26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 27pub enum AfterCompletion {
 28    /// Run the command
 29    Run,
 30    /// Continue composing the current argument, doesn't add a space
 31    Compose,
 32    /// Continue the command composition, adds a space
 33    Continue,
 34}
 35
 36impl From<bool> for AfterCompletion {
 37    fn from(value: bool) -> Self {
 38        if value {
 39            AfterCompletion::Run
 40        } else {
 41            AfterCompletion::Continue
 42        }
 43    }
 44}
 45
 46impl AfterCompletion {
 47    pub fn run(&self) -> bool {
 48        match self {
 49            AfterCompletion::Run => true,
 50            AfterCompletion::Compose | AfterCompletion::Continue => false,
 51        }
 52    }
 53}
 54
 55#[derive(Debug)]
 56pub struct ArgumentCompletion {
 57    /// The label to display for this completion.
 58    pub label: CodeLabel,
 59    /// The new text that should be inserted into the command when this completion is accepted.
 60    pub new_text: String,
 61    /// Whether the command should be run when accepting this completion.
 62    pub after_completion: AfterCompletion,
 63    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 64    pub replace_previous_arguments: bool,
 65}
 66
 67pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 68
 69pub trait SlashCommand: 'static + Send + Sync {
 70    fn name(&self) -> String;
 71    fn icon(&self) -> IconName {
 72        IconName::Slash
 73    }
 74    fn label(&self, _cx: &App) -> CodeLabel {
 75        CodeLabel::plain(self.name(), None)
 76    }
 77    fn description(&self) -> String;
 78    fn menu_text(&self) -> String;
 79    fn complete_argument(
 80        self: Arc<Self>,
 81        arguments: &[String],
 82        cancel: Arc<AtomicBool>,
 83        workspace: Option<WeakEntity<Workspace>>,
 84        window: &mut Window,
 85        cx: &mut App,
 86    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 87    fn requires_argument(&self) -> bool;
 88    fn accepts_arguments(&self) -> bool {
 89        self.requires_argument()
 90    }
 91    #[allow(clippy::too_many_arguments)]
 92    fn run(
 93        self: Arc<Self>,
 94        arguments: &[String],
 95        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 96        context_buffer: BufferSnapshot,
 97        workspace: WeakEntity<Workspace>,
 98        // TODO: We're just using the `LspAdapterDelegate` here because that is
 99        // what the extension API is already expecting.
100        //
101        // It may be that `LspAdapterDelegate` needs a more general name, or
102        // perhaps another kind of delegate is needed here.
103        delegate: Option<Arc<dyn LspAdapterDelegate>>,
104        window: &mut Window,
105        cx: &mut App,
106    ) -> Task<SlashCommandResult>;
107}
108
109#[derive(Debug, PartialEq)]
110pub enum SlashCommandContent {
111    Text {
112        text: String,
113        run_commands_in_text: bool,
114    },
115}
116
117impl<'a> From<&'a str> for SlashCommandContent {
118    fn from(text: &'a str) -> Self {
119        Self::Text {
120            text: text.into(),
121            run_commands_in_text: false,
122        }
123    }
124}
125
126#[derive(Debug, PartialEq)]
127pub enum SlashCommandEvent {
128    StartMessage {
129        role: Role,
130        merge_same_roles: bool,
131    },
132    StartSection {
133        icon: IconName,
134        label: SharedString,
135        metadata: Option<serde_json::Value>,
136    },
137    Content(SlashCommandContent),
138    EndSection,
139}
140
141#[derive(Debug, Default, PartialEq, Clone)]
142pub struct SlashCommandOutput {
143    pub text: String,
144    pub sections: Vec<SlashCommandOutputSection<usize>>,
145    pub run_commands_in_text: bool,
146}
147
148impl SlashCommandOutput {
149    pub fn ensure_valid_section_ranges(&mut self) {
150        for section in &mut self.sections {
151            section.range.start = section.range.start.min(self.text.len());
152            section.range.end = section.range.end.min(self.text.len());
153            while !self.text.is_char_boundary(section.range.start) {
154                section.range.start -= 1;
155            }
156            while !self.text.is_char_boundary(section.range.end) {
157                section.range.end += 1;
158            }
159        }
160    }
161
162    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
163    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
164        self.ensure_valid_section_ranges();
165
166        let mut events = Vec::new();
167
168        let mut section_endpoints = Vec::new();
169        for section in self.sections {
170            section_endpoints.push((
171                section.range.start,
172                SlashCommandEvent::StartSection {
173                    icon: section.icon,
174                    label: section.label,
175                    metadata: section.metadata,
176                },
177            ));
178            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
179        }
180        section_endpoints.sort_by_key(|(offset, _)| *offset);
181
182        let mut content_offset = 0;
183        for (endpoint_offset, endpoint) in section_endpoints {
184            if content_offset < endpoint_offset {
185                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
186                    text: self.text[content_offset..endpoint_offset].to_string(),
187                    run_commands_in_text: self.run_commands_in_text,
188                })));
189                content_offset = endpoint_offset;
190            }
191
192            events.push(Ok(endpoint));
193        }
194
195        if content_offset < self.text.len() {
196            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
197                text: self.text[content_offset..].to_string(),
198                run_commands_in_text: self.run_commands_in_text,
199            })));
200        }
201
202        stream::iter(events).boxed()
203    }
204
205    pub async fn from_event_stream(
206        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
207    ) -> Result<SlashCommandOutput> {
208        let mut output = SlashCommandOutput::default();
209        let mut section_stack = Vec::new();
210
211        while let Some(event) = events.next().await {
212            match event? {
213                SlashCommandEvent::StartSection {
214                    icon,
215                    label,
216                    metadata,
217                } => {
218                    let start = output.text.len();
219                    section_stack.push(SlashCommandOutputSection {
220                        range: start..start,
221                        icon,
222                        label,
223                        metadata,
224                    });
225                }
226                SlashCommandEvent::Content(SlashCommandContent::Text {
227                    text,
228                    run_commands_in_text,
229                }) => {
230                    output.text.push_str(&text);
231                    output.run_commands_in_text = run_commands_in_text;
232
233                    if let Some(section) = section_stack.last_mut() {
234                        section.range.end = output.text.len();
235                    }
236                }
237                SlashCommandEvent::EndSection => {
238                    if let Some(section) = section_stack.pop() {
239                        output.sections.push(section);
240                    }
241                }
242                SlashCommandEvent::StartMessage { .. } => {}
243            }
244        }
245
246        while let Some(section) = section_stack.pop() {
247            output.sections.push(section);
248        }
249
250        Ok(output)
251    }
252}
253
254#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
255pub struct SlashCommandOutputSection<T> {
256    pub range: Range<T>,
257    pub icon: IconName,
258    pub label: SharedString,
259    pub metadata: Option<serde_json::Value>,
260}
261
262impl SlashCommandOutputSection<language::Anchor> {
263    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
264        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
265    }
266}
267
268pub struct SlashCommandLine {
269    /// The range within the line containing the command name.
270    pub name: Range<usize>,
271    /// Ranges within the line containing the command arguments.
272    pub arguments: Vec<Range<usize>>,
273}
274
275impl SlashCommandLine {
276    pub fn parse(line: &str) -> Option<Self> {
277        let mut call: Option<Self> = None;
278        let mut ix = 0;
279        for c in line.chars() {
280            let next_ix = ix + c.len_utf8();
281            if let Some(call) = &mut call {
282                // The command arguments start at the first non-whitespace character
283                // after the command name, and continue until the end of the line.
284                if let Some(argument) = call.arguments.last_mut() {
285                    if c.is_whitespace() {
286                        if (*argument).is_empty() {
287                            argument.start = next_ix;
288                            argument.end = next_ix;
289                        } else {
290                            argument.end = ix;
291                            call.arguments.push(next_ix..next_ix);
292                        }
293                    } else {
294                        argument.end = next_ix;
295                    }
296                }
297                // The command name ends at the first whitespace character.
298                else if !call.name.is_empty() {
299                    if c.is_whitespace() {
300                        call.arguments = vec![next_ix..next_ix];
301                    } else {
302                        call.name.end = next_ix;
303                    }
304                }
305                // The command name must begin with a letter.
306                else if c.is_alphabetic() {
307                    call.name.end = next_ix;
308                } else {
309                    return None;
310                }
311            }
312            // Commands start with a slash.
313            else if c == '/' {
314                call = Some(SlashCommandLine {
315                    name: next_ix..next_ix,
316                    arguments: Vec::new(),
317                });
318            }
319            // The line can't contain anything before the slash except for whitespace.
320            else if !c.is_whitespace() {
321                return None;
322            }
323            ix = next_ix;
324        }
325        call
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use pretty_assertions::assert_eq;
332    use serde_json::json;
333
334    use super::*;
335
336    #[gpui::test]
337    async fn test_slash_command_output_to_events_round_trip() {
338        // Test basic output consisting of a single section.
339        {
340            let text = "Hello, world!".to_string();
341            let range = 0..text.len();
342            let output = SlashCommandOutput {
343                text,
344                sections: vec![SlashCommandOutputSection {
345                    range,
346                    icon: IconName::Code,
347                    label: "Section 1".into(),
348                    metadata: None,
349                }],
350                run_commands_in_text: false,
351            };
352
353            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
354            let events = events
355                .into_iter()
356                .filter_map(|event| event.ok())
357                .collect::<Vec<_>>();
358
359            assert_eq!(
360                events,
361                vec![
362                    SlashCommandEvent::StartSection {
363                        icon: IconName::Code,
364                        label: "Section 1".into(),
365                        metadata: None
366                    },
367                    SlashCommandEvent::Content(SlashCommandContent::Text {
368                        text: "Hello, world!".into(),
369                        run_commands_in_text: false
370                    }),
371                    SlashCommandEvent::EndSection
372                ]
373            );
374
375            let new_output =
376                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
377                    .await
378                    .unwrap();
379
380            assert_eq!(new_output, output);
381        }
382
383        // Test output where the sections do not comprise all of the text.
384        {
385            let text = "Apple\nCucumber\nBanana\n".to_string();
386            let output = SlashCommandOutput {
387                text,
388                sections: vec![
389                    SlashCommandOutputSection {
390                        range: 0..6,
391                        icon: IconName::Check,
392                        label: "Fruit".into(),
393                        metadata: None,
394                    },
395                    SlashCommandOutputSection {
396                        range: 15..22,
397                        icon: IconName::Check,
398                        label: "Fruit".into(),
399                        metadata: None,
400                    },
401                ],
402                run_commands_in_text: false,
403            };
404
405            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
406            let events = events
407                .into_iter()
408                .filter_map(|event| event.ok())
409                .collect::<Vec<_>>();
410
411            assert_eq!(
412                events,
413                vec![
414                    SlashCommandEvent::StartSection {
415                        icon: IconName::Check,
416                        label: "Fruit".into(),
417                        metadata: None
418                    },
419                    SlashCommandEvent::Content(SlashCommandContent::Text {
420                        text: "Apple\n".into(),
421                        run_commands_in_text: false
422                    }),
423                    SlashCommandEvent::EndSection,
424                    SlashCommandEvent::Content(SlashCommandContent::Text {
425                        text: "Cucumber\n".into(),
426                        run_commands_in_text: false
427                    }),
428                    SlashCommandEvent::StartSection {
429                        icon: IconName::Check,
430                        label: "Fruit".into(),
431                        metadata: None
432                    },
433                    SlashCommandEvent::Content(SlashCommandContent::Text {
434                        text: "Banana\n".into(),
435                        run_commands_in_text: false
436                    }),
437                    SlashCommandEvent::EndSection
438                ]
439            );
440
441            let new_output =
442                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
443                    .await
444                    .unwrap();
445
446            assert_eq!(new_output, output);
447        }
448
449        // Test output consisting of multiple sections.
450        {
451            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
452            let output = SlashCommandOutput {
453                text,
454                sections: vec![
455                    SlashCommandOutputSection {
456                        range: 0..6,
457                        icon: IconName::FileCode,
458                        label: "Section 1".into(),
459                        metadata: Some(json!({ "a": true })),
460                    },
461                    SlashCommandOutputSection {
462                        range: 7..13,
463                        icon: IconName::FileDoc,
464                        label: "Section 2".into(),
465                        metadata: Some(json!({ "b": true })),
466                    },
467                    SlashCommandOutputSection {
468                        range: 14..20,
469                        icon: IconName::FileGit,
470                        label: "Section 3".into(),
471                        metadata: Some(json!({ "c": true })),
472                    },
473                    SlashCommandOutputSection {
474                        range: 21..27,
475                        icon: IconName::FileToml,
476                        label: "Section 4".into(),
477                        metadata: Some(json!({ "d": true })),
478                    },
479                ],
480                run_commands_in_text: false,
481            };
482
483            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
484            let events = events
485                .into_iter()
486                .filter_map(|event| event.ok())
487                .collect::<Vec<_>>();
488
489            assert_eq!(
490                events,
491                vec![
492                    SlashCommandEvent::StartSection {
493                        icon: IconName::FileCode,
494                        label: "Section 1".into(),
495                        metadata: Some(json!({ "a": true }))
496                    },
497                    SlashCommandEvent::Content(SlashCommandContent::Text {
498                        text: "Line 1".into(),
499                        run_commands_in_text: false
500                    }),
501                    SlashCommandEvent::EndSection,
502                    SlashCommandEvent::Content(SlashCommandContent::Text {
503                        text: "\n".into(),
504                        run_commands_in_text: false
505                    }),
506                    SlashCommandEvent::StartSection {
507                        icon: IconName::FileDoc,
508                        label: "Section 2".into(),
509                        metadata: Some(json!({ "b": true }))
510                    },
511                    SlashCommandEvent::Content(SlashCommandContent::Text {
512                        text: "Line 2".into(),
513                        run_commands_in_text: false
514                    }),
515                    SlashCommandEvent::EndSection,
516                    SlashCommandEvent::Content(SlashCommandContent::Text {
517                        text: "\n".into(),
518                        run_commands_in_text: false
519                    }),
520                    SlashCommandEvent::StartSection {
521                        icon: IconName::FileGit,
522                        label: "Section 3".into(),
523                        metadata: Some(json!({ "c": true }))
524                    },
525                    SlashCommandEvent::Content(SlashCommandContent::Text {
526                        text: "Line 3".into(),
527                        run_commands_in_text: false
528                    }),
529                    SlashCommandEvent::EndSection,
530                    SlashCommandEvent::Content(SlashCommandContent::Text {
531                        text: "\n".into(),
532                        run_commands_in_text: false
533                    }),
534                    SlashCommandEvent::StartSection {
535                        icon: IconName::FileToml,
536                        label: "Section 4".into(),
537                        metadata: Some(json!({ "d": true }))
538                    },
539                    SlashCommandEvent::Content(SlashCommandContent::Text {
540                        text: "Line 4".into(),
541                        run_commands_in_text: false
542                    }),
543                    SlashCommandEvent::EndSection,
544                    SlashCommandEvent::Content(SlashCommandContent::Text {
545                        text: "\n".into(),
546                        run_commands_in_text: false
547                    }),
548                ]
549            );
550
551            let new_output =
552                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
553                    .await
554                    .unwrap();
555
556            assert_eq!(new_output, output);
557        }
558    }
559}