assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3mod slash_command_working_set;
  4
  5pub use crate::extension_slash_command::*;
  6pub use crate::slash_command_registry::*;
  7pub use crate::slash_command_working_set::*;
  8use anyhow::Result;
  9use futures::stream::{self, BoxStream};
 10use futures::StreamExt;
 11use gpui::{AppContext, SharedString, Task, WeakView, WindowContext};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 13pub use language_model::Role;
 14use serde::{Deserialize, Serialize};
 15use std::{
 16    ops::Range,
 17    sync::{atomic::AtomicBool, Arc},
 18};
 19use workspace::{ui::IconName, Workspace};
 20
 21pub fn init(cx: &mut AppContext) {
 22    SlashCommandRegistry::default_global(cx);
 23    extension_slash_command::init(cx);
 24}
 25
 26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 27pub enum AfterCompletion {
 28    /// Run the command
 29    Run,
 30    /// Continue composing the current argument, doesn't add a space
 31    Compose,
 32    /// Continue the command composition, adds a space
 33    Continue,
 34}
 35
 36impl From<bool> for AfterCompletion {
 37    fn from(value: bool) -> Self {
 38        if value {
 39            AfterCompletion::Run
 40        } else {
 41            AfterCompletion::Continue
 42        }
 43    }
 44}
 45
 46impl AfterCompletion {
 47    pub fn run(&self) -> bool {
 48        match self {
 49            AfterCompletion::Run => true,
 50            AfterCompletion::Compose | AfterCompletion::Continue => false,
 51        }
 52    }
 53}
 54
 55#[derive(Debug)]
 56pub struct ArgumentCompletion {
 57    /// The label to display for this completion.
 58    pub label: CodeLabel,
 59    /// The new text that should be inserted into the command when this completion is accepted.
 60    pub new_text: String,
 61    /// Whether the command should be run when accepting this completion.
 62    pub after_completion: AfterCompletion,
 63    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 64    pub replace_previous_arguments: bool,
 65}
 66
 67pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 68
 69pub trait SlashCommand: 'static + Send + Sync {
 70    fn name(&self) -> String;
 71    fn icon(&self) -> IconName {
 72        IconName::Slash
 73    }
 74    fn label(&self, _cx: &AppContext) -> CodeLabel {
 75        CodeLabel::plain(self.name(), None)
 76    }
 77    fn description(&self) -> String;
 78    fn menu_text(&self) -> String;
 79    fn complete_argument(
 80        self: Arc<Self>,
 81        arguments: &[String],
 82        cancel: Arc<AtomicBool>,
 83        workspace: Option<WeakView<Workspace>>,
 84        cx: &mut WindowContext,
 85    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 86    fn requires_argument(&self) -> bool;
 87    fn accepts_arguments(&self) -> bool {
 88        self.requires_argument()
 89    }
 90    fn run(
 91        self: Arc<Self>,
 92        arguments: &[String],
 93        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 94        context_buffer: BufferSnapshot,
 95        workspace: WeakView<Workspace>,
 96        // TODO: We're just using the `LspAdapterDelegate` here because that is
 97        // what the extension API is already expecting.
 98        //
 99        // It may be that `LspAdapterDelegate` needs a more general name, or
100        // perhaps another kind of delegate is needed here.
101        delegate: Option<Arc<dyn LspAdapterDelegate>>,
102        cx: &mut WindowContext,
103    ) -> Task<SlashCommandResult>;
104}
105
106#[derive(Debug, PartialEq)]
107pub enum SlashCommandContent {
108    Text {
109        text: String,
110        run_commands_in_text: bool,
111    },
112}
113
114impl<'a> From<&'a str> for SlashCommandContent {
115    fn from(text: &'a str) -> Self {
116        Self::Text {
117            text: text.into(),
118            run_commands_in_text: false,
119        }
120    }
121}
122
123#[derive(Debug, PartialEq)]
124pub enum SlashCommandEvent {
125    StartMessage {
126        role: Role,
127        merge_same_roles: bool,
128    },
129    StartSection {
130        icon: IconName,
131        label: SharedString,
132        metadata: Option<serde_json::Value>,
133    },
134    Content(SlashCommandContent),
135    EndSection,
136}
137
138#[derive(Debug, Default, PartialEq, Clone)]
139pub struct SlashCommandOutput {
140    pub text: String,
141    pub sections: Vec<SlashCommandOutputSection<usize>>,
142    pub run_commands_in_text: bool,
143}
144
145impl SlashCommandOutput {
146    pub fn ensure_valid_section_ranges(&mut self) {
147        for section in &mut self.sections {
148            section.range.start = section.range.start.min(self.text.len());
149            section.range.end = section.range.end.min(self.text.len());
150            while !self.text.is_char_boundary(section.range.start) {
151                section.range.start -= 1;
152            }
153            while !self.text.is_char_boundary(section.range.end) {
154                section.range.end += 1;
155            }
156        }
157    }
158
159    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
160    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
161        self.ensure_valid_section_ranges();
162
163        let mut events = Vec::new();
164
165        let mut section_endpoints = Vec::new();
166        for section in self.sections {
167            section_endpoints.push((
168                section.range.start,
169                SlashCommandEvent::StartSection {
170                    icon: section.icon,
171                    label: section.label,
172                    metadata: section.metadata,
173                },
174            ));
175            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
176        }
177        section_endpoints.sort_by_key(|(offset, _)| *offset);
178
179        let mut content_offset = 0;
180        for (endpoint_offset, endpoint) in section_endpoints {
181            if content_offset < endpoint_offset {
182                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
183                    text: self.text[content_offset..endpoint_offset].to_string(),
184                    run_commands_in_text: self.run_commands_in_text,
185                })));
186                content_offset = endpoint_offset;
187            }
188
189            events.push(Ok(endpoint));
190        }
191
192        if content_offset < self.text.len() {
193            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
194                text: self.text[content_offset..].to_string(),
195                run_commands_in_text: self.run_commands_in_text,
196            })));
197        }
198
199        stream::iter(events).boxed()
200    }
201
202    pub async fn from_event_stream(
203        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
204    ) -> Result<SlashCommandOutput> {
205        let mut output = SlashCommandOutput::default();
206        let mut section_stack = Vec::new();
207
208        while let Some(event) = events.next().await {
209            match event? {
210                SlashCommandEvent::StartSection {
211                    icon,
212                    label,
213                    metadata,
214                } => {
215                    let start = output.text.len();
216                    section_stack.push(SlashCommandOutputSection {
217                        range: start..start,
218                        icon,
219                        label,
220                        metadata,
221                    });
222                }
223                SlashCommandEvent::Content(SlashCommandContent::Text {
224                    text,
225                    run_commands_in_text,
226                }) => {
227                    output.text.push_str(&text);
228                    output.run_commands_in_text = run_commands_in_text;
229
230                    if let Some(section) = section_stack.last_mut() {
231                        section.range.end = output.text.len();
232                    }
233                }
234                SlashCommandEvent::EndSection => {
235                    if let Some(section) = section_stack.pop() {
236                        output.sections.push(section);
237                    }
238                }
239                SlashCommandEvent::StartMessage { .. } => {}
240            }
241        }
242
243        while let Some(section) = section_stack.pop() {
244            output.sections.push(section);
245        }
246
247        Ok(output)
248    }
249}
250
251#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
252pub struct SlashCommandOutputSection<T> {
253    pub range: Range<T>,
254    pub icon: IconName,
255    pub label: SharedString,
256    pub metadata: Option<serde_json::Value>,
257}
258
259impl SlashCommandOutputSection<language::Anchor> {
260    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
261        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
262    }
263}
264
265pub struct SlashCommandLine {
266    /// The range within the line containing the command name.
267    pub name: Range<usize>,
268    /// Ranges within the line containing the command arguments.
269    pub arguments: Vec<Range<usize>>,
270}
271
272impl SlashCommandLine {
273    pub fn parse(line: &str) -> Option<Self> {
274        let mut call: Option<Self> = None;
275        let mut ix = 0;
276        for c in line.chars() {
277            let next_ix = ix + c.len_utf8();
278            if let Some(call) = &mut call {
279                // The command arguments start at the first non-whitespace character
280                // after the command name, and continue until the end of the line.
281                if let Some(argument) = call.arguments.last_mut() {
282                    if c.is_whitespace() {
283                        if (*argument).is_empty() {
284                            argument.start = next_ix;
285                            argument.end = next_ix;
286                        } else {
287                            argument.end = ix;
288                            call.arguments.push(next_ix..next_ix);
289                        }
290                    } else {
291                        argument.end = next_ix;
292                    }
293                }
294                // The command name ends at the first whitespace character.
295                else if !call.name.is_empty() {
296                    if c.is_whitespace() {
297                        call.arguments = vec![next_ix..next_ix];
298                    } else {
299                        call.name.end = next_ix;
300                    }
301                }
302                // The command name must begin with a letter.
303                else if c.is_alphabetic() {
304                    call.name.end = next_ix;
305                } else {
306                    return None;
307                }
308            }
309            // Commands start with a slash.
310            else if c == '/' {
311                call = Some(SlashCommandLine {
312                    name: next_ix..next_ix,
313                    arguments: Vec::new(),
314                });
315            }
316            // The line can't contain anything before the slash except for whitespace.
317            else if !c.is_whitespace() {
318                return None;
319            }
320            ix = next_ix;
321        }
322        call
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use pretty_assertions::assert_eq;
329    use serde_json::json;
330
331    use super::*;
332
333    #[gpui::test]
334    async fn test_slash_command_output_to_events_round_trip() {
335        // Test basic output consisting of a single section.
336        {
337            let text = "Hello, world!".to_string();
338            let range = 0..text.len();
339            let output = SlashCommandOutput {
340                text,
341                sections: vec![SlashCommandOutputSection {
342                    range,
343                    icon: IconName::Code,
344                    label: "Section 1".into(),
345                    metadata: None,
346                }],
347                run_commands_in_text: false,
348            };
349
350            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
351            let events = events
352                .into_iter()
353                .filter_map(|event| event.ok())
354                .collect::<Vec<_>>();
355
356            assert_eq!(
357                events,
358                vec![
359                    SlashCommandEvent::StartSection {
360                        icon: IconName::Code,
361                        label: "Section 1".into(),
362                        metadata: None
363                    },
364                    SlashCommandEvent::Content(SlashCommandContent::Text {
365                        text: "Hello, world!".into(),
366                        run_commands_in_text: false
367                    }),
368                    SlashCommandEvent::EndSection
369                ]
370            );
371
372            let new_output =
373                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
374                    .await
375                    .unwrap();
376
377            assert_eq!(new_output, output);
378        }
379
380        // Test output where the sections do not comprise all of the text.
381        {
382            let text = "Apple\nCucumber\nBanana\n".to_string();
383            let output = SlashCommandOutput {
384                text,
385                sections: vec![
386                    SlashCommandOutputSection {
387                        range: 0..6,
388                        icon: IconName::Check,
389                        label: "Fruit".into(),
390                        metadata: None,
391                    },
392                    SlashCommandOutputSection {
393                        range: 15..22,
394                        icon: IconName::Check,
395                        label: "Fruit".into(),
396                        metadata: None,
397                    },
398                ],
399                run_commands_in_text: false,
400            };
401
402            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
403            let events = events
404                .into_iter()
405                .filter_map(|event| event.ok())
406                .collect::<Vec<_>>();
407
408            assert_eq!(
409                events,
410                vec![
411                    SlashCommandEvent::StartSection {
412                        icon: IconName::Check,
413                        label: "Fruit".into(),
414                        metadata: None
415                    },
416                    SlashCommandEvent::Content(SlashCommandContent::Text {
417                        text: "Apple\n".into(),
418                        run_commands_in_text: false
419                    }),
420                    SlashCommandEvent::EndSection,
421                    SlashCommandEvent::Content(SlashCommandContent::Text {
422                        text: "Cucumber\n".into(),
423                        run_commands_in_text: false
424                    }),
425                    SlashCommandEvent::StartSection {
426                        icon: IconName::Check,
427                        label: "Fruit".into(),
428                        metadata: None
429                    },
430                    SlashCommandEvent::Content(SlashCommandContent::Text {
431                        text: "Banana\n".into(),
432                        run_commands_in_text: false
433                    }),
434                    SlashCommandEvent::EndSection
435                ]
436            );
437
438            let new_output =
439                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
440                    .await
441                    .unwrap();
442
443            assert_eq!(new_output, output);
444        }
445
446        // Test output consisting of multiple sections.
447        {
448            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
449            let output = SlashCommandOutput {
450                text,
451                sections: vec![
452                    SlashCommandOutputSection {
453                        range: 0..6,
454                        icon: IconName::FileCode,
455                        label: "Section 1".into(),
456                        metadata: Some(json!({ "a": true })),
457                    },
458                    SlashCommandOutputSection {
459                        range: 7..13,
460                        icon: IconName::FileDoc,
461                        label: "Section 2".into(),
462                        metadata: Some(json!({ "b": true })),
463                    },
464                    SlashCommandOutputSection {
465                        range: 14..20,
466                        icon: IconName::FileGit,
467                        label: "Section 3".into(),
468                        metadata: Some(json!({ "c": true })),
469                    },
470                    SlashCommandOutputSection {
471                        range: 21..27,
472                        icon: IconName::FileToml,
473                        label: "Section 4".into(),
474                        metadata: Some(json!({ "d": true })),
475                    },
476                ],
477                run_commands_in_text: false,
478            };
479
480            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
481            let events = events
482                .into_iter()
483                .filter_map(|event| event.ok())
484                .collect::<Vec<_>>();
485
486            assert_eq!(
487                events,
488                vec![
489                    SlashCommandEvent::StartSection {
490                        icon: IconName::FileCode,
491                        label: "Section 1".into(),
492                        metadata: Some(json!({ "a": true }))
493                    },
494                    SlashCommandEvent::Content(SlashCommandContent::Text {
495                        text: "Line 1".into(),
496                        run_commands_in_text: false
497                    }),
498                    SlashCommandEvent::EndSection,
499                    SlashCommandEvent::Content(SlashCommandContent::Text {
500                        text: "\n".into(),
501                        run_commands_in_text: false
502                    }),
503                    SlashCommandEvent::StartSection {
504                        icon: IconName::FileDoc,
505                        label: "Section 2".into(),
506                        metadata: Some(json!({ "b": true }))
507                    },
508                    SlashCommandEvent::Content(SlashCommandContent::Text {
509                        text: "Line 2".into(),
510                        run_commands_in_text: false
511                    }),
512                    SlashCommandEvent::EndSection,
513                    SlashCommandEvent::Content(SlashCommandContent::Text {
514                        text: "\n".into(),
515                        run_commands_in_text: false
516                    }),
517                    SlashCommandEvent::StartSection {
518                        icon: IconName::FileGit,
519                        label: "Section 3".into(),
520                        metadata: Some(json!({ "c": true }))
521                    },
522                    SlashCommandEvent::Content(SlashCommandContent::Text {
523                        text: "Line 3".into(),
524                        run_commands_in_text: false
525                    }),
526                    SlashCommandEvent::EndSection,
527                    SlashCommandEvent::Content(SlashCommandContent::Text {
528                        text: "\n".into(),
529                        run_commands_in_text: false
530                    }),
531                    SlashCommandEvent::StartSection {
532                        icon: IconName::FileToml,
533                        label: "Section 4".into(),
534                        metadata: Some(json!({ "d": true }))
535                    },
536                    SlashCommandEvent::Content(SlashCommandContent::Text {
537                        text: "Line 4".into(),
538                        run_commands_in_text: false
539                    }),
540                    SlashCommandEvent::EndSection,
541                    SlashCommandEvent::Content(SlashCommandContent::Text {
542                        text: "\n".into(),
543                        run_commands_in_text: false
544                    }),
545                ]
546            );
547
548            let new_output =
549                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
550                    .await
551                    .unwrap();
552
553            assert_eq!(new_output, output);
554        }
555    }
556}