assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3mod slash_command_working_set;
  4
  5pub use crate::extension_slash_command::*;
  6pub use crate::slash_command_registry::*;
  7pub use crate::slash_command_working_set::*;
  8use anyhow::Result;
  9use futures::StreamExt;
 10use futures::stream::{self, BoxStream};
 11use gpui::{App, SharedString, Task, WeakEntity, Window};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 13pub use language_model::Role;
 14use serde::{Deserialize, Serialize};
 15use std::{
 16    ops::Range,
 17    sync::{Arc, atomic::AtomicBool},
 18};
 19use workspace::{Workspace, ui::IconName};
 20
 21pub fn init(cx: &mut App) {
 22    SlashCommandRegistry::default_global(cx);
 23    extension_slash_command::init(cx);
 24}
 25
 26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 27pub enum AfterCompletion {
 28    /// Run the command
 29    Run,
 30    /// Continue composing the current argument, doesn't add a space
 31    Compose,
 32    /// Continue the command composition, adds a space
 33    Continue,
 34}
 35
 36impl From<bool> for AfterCompletion {
 37    fn from(value: bool) -> Self {
 38        if value {
 39            AfterCompletion::Run
 40        } else {
 41            AfterCompletion::Continue
 42        }
 43    }
 44}
 45
 46impl AfterCompletion {
 47    pub fn run(&self) -> bool {
 48        match self {
 49            AfterCompletion::Run => true,
 50            AfterCompletion::Compose | AfterCompletion::Continue => false,
 51        }
 52    }
 53}
 54
 55#[derive(Debug)]
 56pub struct ArgumentCompletion {
 57    /// The label to display for this completion.
 58    pub label: CodeLabel,
 59    /// The new text that should be inserted into the command when this completion is accepted.
 60    pub new_text: String,
 61    /// Whether the command should be run when accepting this completion.
 62    pub after_completion: AfterCompletion,
 63    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 64    pub replace_previous_arguments: bool,
 65}
 66
 67pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 68
 69pub trait SlashCommand: 'static + Send + Sync {
 70    fn name(&self) -> String;
 71    fn icon(&self) -> IconName {
 72        IconName::Slash
 73    }
 74    fn label(&self, _cx: &App) -> CodeLabel {
 75        CodeLabel::plain(self.name(), None)
 76    }
 77    fn description(&self) -> String;
 78    fn menu_text(&self) -> String;
 79    fn complete_argument(
 80        self: Arc<Self>,
 81        arguments: &[String],
 82        cancel: Arc<AtomicBool>,
 83        workspace: Option<WeakEntity<Workspace>>,
 84        window: &mut Window,
 85        cx: &mut App,
 86    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 87    fn requires_argument(&self) -> bool;
 88    fn accepts_arguments(&self) -> bool {
 89        self.requires_argument()
 90    }
 91    fn run(
 92        self: Arc<Self>,
 93        arguments: &[String],
 94        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 95        context_buffer: BufferSnapshot,
 96        workspace: WeakEntity<Workspace>,
 97        // TODO: We're just using the `LspAdapterDelegate` here because that is
 98        // what the extension API is already expecting.
 99        //
100        // It may be that `LspAdapterDelegate` needs a more general name, or
101        // perhaps another kind of delegate is needed here.
102        delegate: Option<Arc<dyn LspAdapterDelegate>>,
103        window: &mut Window,
104        cx: &mut App,
105    ) -> Task<SlashCommandResult>;
106}
107
108#[derive(Debug, PartialEq)]
109pub enum SlashCommandContent {
110    Text {
111        text: String,
112        run_commands_in_text: bool,
113    },
114}
115
116impl<'a> From<&'a str> for SlashCommandContent {
117    fn from(text: &'a str) -> Self {
118        Self::Text {
119            text: text.into(),
120            run_commands_in_text: false,
121        }
122    }
123}
124
125#[derive(Debug, PartialEq)]
126pub enum SlashCommandEvent {
127    StartMessage {
128        role: Role,
129        merge_same_roles: bool,
130    },
131    StartSection {
132        icon: IconName,
133        label: SharedString,
134        metadata: Option<serde_json::Value>,
135    },
136    Content(SlashCommandContent),
137    EndSection,
138}
139
140#[derive(Debug, Default, PartialEq, Clone)]
141pub struct SlashCommandOutput {
142    pub text: String,
143    pub sections: Vec<SlashCommandOutputSection<usize>>,
144    pub run_commands_in_text: bool,
145}
146
147impl SlashCommandOutput {
148    pub fn ensure_valid_section_ranges(&mut self) {
149        for section in &mut self.sections {
150            section.range.start = section.range.start.min(self.text.len());
151            section.range.end = section.range.end.min(self.text.len());
152            while !self.text.is_char_boundary(section.range.start) {
153                section.range.start -= 1;
154            }
155            while !self.text.is_char_boundary(section.range.end) {
156                section.range.end += 1;
157            }
158        }
159    }
160
161    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
162    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
163        self.ensure_valid_section_ranges();
164
165        let mut events = Vec::new();
166
167        let mut section_endpoints = Vec::new();
168        for section in self.sections {
169            section_endpoints.push((
170                section.range.start,
171                SlashCommandEvent::StartSection {
172                    icon: section.icon,
173                    label: section.label,
174                    metadata: section.metadata,
175                },
176            ));
177            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
178        }
179        section_endpoints.sort_by_key(|(offset, _)| *offset);
180
181        let mut content_offset = 0;
182        for (endpoint_offset, endpoint) in section_endpoints {
183            if content_offset < endpoint_offset {
184                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
185                    text: self.text[content_offset..endpoint_offset].to_string(),
186                    run_commands_in_text: self.run_commands_in_text,
187                })));
188                content_offset = endpoint_offset;
189            }
190
191            events.push(Ok(endpoint));
192        }
193
194        if content_offset < self.text.len() {
195            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
196                text: self.text[content_offset..].to_string(),
197                run_commands_in_text: self.run_commands_in_text,
198            })));
199        }
200
201        stream::iter(events).boxed()
202    }
203
204    pub async fn from_event_stream(
205        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
206    ) -> Result<SlashCommandOutput> {
207        let mut output = SlashCommandOutput::default();
208        let mut section_stack = Vec::new();
209
210        while let Some(event) = events.next().await {
211            match event? {
212                SlashCommandEvent::StartSection {
213                    icon,
214                    label,
215                    metadata,
216                } => {
217                    let start = output.text.len();
218                    section_stack.push(SlashCommandOutputSection {
219                        range: start..start,
220                        icon,
221                        label,
222                        metadata,
223                    });
224                }
225                SlashCommandEvent::Content(SlashCommandContent::Text {
226                    text,
227                    run_commands_in_text,
228                }) => {
229                    output.text.push_str(&text);
230                    output.run_commands_in_text = run_commands_in_text;
231
232                    if let Some(section) = section_stack.last_mut() {
233                        section.range.end = output.text.len();
234                    }
235                }
236                SlashCommandEvent::EndSection => {
237                    if let Some(section) = section_stack.pop() {
238                        output.sections.push(section);
239                    }
240                }
241                SlashCommandEvent::StartMessage { .. } => {}
242            }
243        }
244
245        while let Some(section) = section_stack.pop() {
246            output.sections.push(section);
247        }
248
249        Ok(output)
250    }
251}
252
253#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
254pub struct SlashCommandOutputSection<T> {
255    pub range: Range<T>,
256    pub icon: IconName,
257    pub label: SharedString,
258    pub metadata: Option<serde_json::Value>,
259}
260
261impl SlashCommandOutputSection<language::Anchor> {
262    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
263        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
264    }
265}
266
267pub struct SlashCommandLine {
268    /// The range within the line containing the command name.
269    pub name: Range<usize>,
270    /// Ranges within the line containing the command arguments.
271    pub arguments: Vec<Range<usize>>,
272}
273
274impl SlashCommandLine {
275    pub fn parse(line: &str) -> Option<Self> {
276        let mut call: Option<Self> = None;
277        let mut ix = 0;
278        for c in line.chars() {
279            let next_ix = ix + c.len_utf8();
280            if let Some(call) = &mut call {
281                // The command arguments start at the first non-whitespace character
282                // after the command name, and continue until the end of the line.
283                if let Some(argument) = call.arguments.last_mut() {
284                    if c.is_whitespace() {
285                        if (*argument).is_empty() {
286                            argument.start = next_ix;
287                            argument.end = next_ix;
288                        } else {
289                            argument.end = ix;
290                            call.arguments.push(next_ix..next_ix);
291                        }
292                    } else {
293                        argument.end = next_ix;
294                    }
295                }
296                // The command name ends at the first whitespace character.
297                else if !call.name.is_empty() {
298                    if c.is_whitespace() {
299                        call.arguments = vec![next_ix..next_ix];
300                    } else {
301                        call.name.end = next_ix;
302                    }
303                }
304                // The command name must begin with a letter.
305                else if c.is_alphabetic() {
306                    call.name.end = next_ix;
307                } else {
308                    return None;
309                }
310            }
311            // Commands start with a slash.
312            else if c == '/' {
313                call = Some(SlashCommandLine {
314                    name: next_ix..next_ix,
315                    arguments: Vec::new(),
316                });
317            }
318            // The line can't contain anything before the slash except for whitespace.
319            else if !c.is_whitespace() {
320                return None;
321            }
322            ix = next_ix;
323        }
324        call
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use pretty_assertions::assert_eq;
331    use serde_json::json;
332
333    use super::*;
334
335    #[gpui::test]
336    async fn test_slash_command_output_to_events_round_trip() {
337        // Test basic output consisting of a single section.
338        {
339            let text = "Hello, world!".to_string();
340            let range = 0..text.len();
341            let output = SlashCommandOutput {
342                text,
343                sections: vec![SlashCommandOutputSection {
344                    range,
345                    icon: IconName::Code,
346                    label: "Section 1".into(),
347                    metadata: None,
348                }],
349                run_commands_in_text: false,
350            };
351
352            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
353            let events = events
354                .into_iter()
355                .filter_map(|event| event.ok())
356                .collect::<Vec<_>>();
357
358            assert_eq!(
359                events,
360                vec![
361                    SlashCommandEvent::StartSection {
362                        icon: IconName::Code,
363                        label: "Section 1".into(),
364                        metadata: None
365                    },
366                    SlashCommandEvent::Content(SlashCommandContent::Text {
367                        text: "Hello, world!".into(),
368                        run_commands_in_text: false
369                    }),
370                    SlashCommandEvent::EndSection
371                ]
372            );
373
374            let new_output =
375                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
376                    .await
377                    .unwrap();
378
379            assert_eq!(new_output, output);
380        }
381
382        // Test output where the sections do not comprise all of the text.
383        {
384            let text = "Apple\nCucumber\nBanana\n".to_string();
385            let output = SlashCommandOutput {
386                text,
387                sections: vec![
388                    SlashCommandOutputSection {
389                        range: 0..6,
390                        icon: IconName::Check,
391                        label: "Fruit".into(),
392                        metadata: None,
393                    },
394                    SlashCommandOutputSection {
395                        range: 15..22,
396                        icon: IconName::Check,
397                        label: "Fruit".into(),
398                        metadata: None,
399                    },
400                ],
401                run_commands_in_text: false,
402            };
403
404            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
405            let events = events
406                .into_iter()
407                .filter_map(|event| event.ok())
408                .collect::<Vec<_>>();
409
410            assert_eq!(
411                events,
412                vec![
413                    SlashCommandEvent::StartSection {
414                        icon: IconName::Check,
415                        label: "Fruit".into(),
416                        metadata: None
417                    },
418                    SlashCommandEvent::Content(SlashCommandContent::Text {
419                        text: "Apple\n".into(),
420                        run_commands_in_text: false
421                    }),
422                    SlashCommandEvent::EndSection,
423                    SlashCommandEvent::Content(SlashCommandContent::Text {
424                        text: "Cucumber\n".into(),
425                        run_commands_in_text: false
426                    }),
427                    SlashCommandEvent::StartSection {
428                        icon: IconName::Check,
429                        label: "Fruit".into(),
430                        metadata: None
431                    },
432                    SlashCommandEvent::Content(SlashCommandContent::Text {
433                        text: "Banana\n".into(),
434                        run_commands_in_text: false
435                    }),
436                    SlashCommandEvent::EndSection
437                ]
438            );
439
440            let new_output =
441                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
442                    .await
443                    .unwrap();
444
445            assert_eq!(new_output, output);
446        }
447
448        // Test output consisting of multiple sections.
449        {
450            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
451            let output = SlashCommandOutput {
452                text,
453                sections: vec![
454                    SlashCommandOutputSection {
455                        range: 0..6,
456                        icon: IconName::FileCode,
457                        label: "Section 1".into(),
458                        metadata: Some(json!({ "a": true })),
459                    },
460                    SlashCommandOutputSection {
461                        range: 7..13,
462                        icon: IconName::FileDoc,
463                        label: "Section 2".into(),
464                        metadata: Some(json!({ "b": true })),
465                    },
466                    SlashCommandOutputSection {
467                        range: 14..20,
468                        icon: IconName::FileGit,
469                        label: "Section 3".into(),
470                        metadata: Some(json!({ "c": true })),
471                    },
472                    SlashCommandOutputSection {
473                        range: 21..27,
474                        icon: IconName::FileToml,
475                        label: "Section 4".into(),
476                        metadata: Some(json!({ "d": true })),
477                    },
478                ],
479                run_commands_in_text: false,
480            };
481
482            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
483            let events = events
484                .into_iter()
485                .filter_map(|event| event.ok())
486                .collect::<Vec<_>>();
487
488            assert_eq!(
489                events,
490                vec![
491                    SlashCommandEvent::StartSection {
492                        icon: IconName::FileCode,
493                        label: "Section 1".into(),
494                        metadata: Some(json!({ "a": true }))
495                    },
496                    SlashCommandEvent::Content(SlashCommandContent::Text {
497                        text: "Line 1".into(),
498                        run_commands_in_text: false
499                    }),
500                    SlashCommandEvent::EndSection,
501                    SlashCommandEvent::Content(SlashCommandContent::Text {
502                        text: "\n".into(),
503                        run_commands_in_text: false
504                    }),
505                    SlashCommandEvent::StartSection {
506                        icon: IconName::FileDoc,
507                        label: "Section 2".into(),
508                        metadata: Some(json!({ "b": true }))
509                    },
510                    SlashCommandEvent::Content(SlashCommandContent::Text {
511                        text: "Line 2".into(),
512                        run_commands_in_text: false
513                    }),
514                    SlashCommandEvent::EndSection,
515                    SlashCommandEvent::Content(SlashCommandContent::Text {
516                        text: "\n".into(),
517                        run_commands_in_text: false
518                    }),
519                    SlashCommandEvent::StartSection {
520                        icon: IconName::FileGit,
521                        label: "Section 3".into(),
522                        metadata: Some(json!({ "c": true }))
523                    },
524                    SlashCommandEvent::Content(SlashCommandContent::Text {
525                        text: "Line 3".into(),
526                        run_commands_in_text: false
527                    }),
528                    SlashCommandEvent::EndSection,
529                    SlashCommandEvent::Content(SlashCommandContent::Text {
530                        text: "\n".into(),
531                        run_commands_in_text: false
532                    }),
533                    SlashCommandEvent::StartSection {
534                        icon: IconName::FileToml,
535                        label: "Section 4".into(),
536                        metadata: Some(json!({ "d": true }))
537                    },
538                    SlashCommandEvent::Content(SlashCommandContent::Text {
539                        text: "Line 4".into(),
540                        run_commands_in_text: false
541                    }),
542                    SlashCommandEvent::EndSection,
543                    SlashCommandEvent::Content(SlashCommandContent::Text {
544                        text: "\n".into(),
545                        run_commands_in_text: false
546                    }),
547                ]
548            );
549
550            let new_output =
551                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
552                    .await
553                    .unwrap();
554
555            assert_eq!(new_output, output);
556        }
557    }
558}