assistant_slash_command.rs

  1mod slash_command_registry;
  2
  3use anyhow::Result;
  4use futures::stream::{self, BoxStream};
  5use futures::StreamExt;
  6use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
  7use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
  8pub use language_model::Role;
  9use serde::{Deserialize, Serialize};
 10pub use slash_command_registry::*;
 11use std::{
 12    ops::Range,
 13    sync::{atomic::AtomicBool, Arc},
 14};
 15use workspace::{ui::IconName, Workspace};
 16
 17pub fn init(cx: &mut AppContext) {
 18    SlashCommandRegistry::default_global(cx);
 19}
 20
 21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 22pub enum AfterCompletion {
 23    /// Run the command
 24    Run,
 25    /// Continue composing the current argument, doesn't add a space
 26    Compose,
 27    /// Continue the command composition, adds a space
 28    Continue,
 29}
 30
 31impl From<bool> for AfterCompletion {
 32    fn from(value: bool) -> Self {
 33        if value {
 34            AfterCompletion::Run
 35        } else {
 36            AfterCompletion::Continue
 37        }
 38    }
 39}
 40
 41impl AfterCompletion {
 42    pub fn run(&self) -> bool {
 43        match self {
 44            AfterCompletion::Run => true,
 45            AfterCompletion::Compose | AfterCompletion::Continue => false,
 46        }
 47    }
 48}
 49
 50#[derive(Debug)]
 51pub struct ArgumentCompletion {
 52    /// The label to display for this completion.
 53    pub label: CodeLabel,
 54    /// The new text that should be inserted into the command when this completion is accepted.
 55    pub new_text: String,
 56    /// Whether the command should be run when accepting this completion.
 57    pub after_completion: AfterCompletion,
 58    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 59    pub replace_previous_arguments: bool,
 60}
 61
 62pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 63
 64pub trait SlashCommand: 'static + Send + Sync {
 65    fn name(&self) -> String;
 66    fn icon(&self) -> IconName {
 67        IconName::Slash
 68    }
 69    fn label(&self, _cx: &AppContext) -> CodeLabel {
 70        CodeLabel::plain(self.name(), None)
 71    }
 72    fn description(&self) -> String;
 73    fn menu_text(&self) -> String;
 74    fn complete_argument(
 75        self: Arc<Self>,
 76        arguments: &[String],
 77        cancel: Arc<AtomicBool>,
 78        workspace: Option<WeakView<Workspace>>,
 79        cx: &mut WindowContext,
 80    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 81    fn requires_argument(&self) -> bool;
 82    fn accepts_arguments(&self) -> bool {
 83        self.requires_argument()
 84    }
 85    fn run(
 86        self: Arc<Self>,
 87        arguments: &[String],
 88        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 89        context_buffer: BufferSnapshot,
 90        workspace: WeakView<Workspace>,
 91        // TODO: We're just using the `LspAdapterDelegate` here because that is
 92        // what the extension API is already expecting.
 93        //
 94        // It may be that `LspAdapterDelegate` needs a more general name, or
 95        // perhaps another kind of delegate is needed here.
 96        delegate: Option<Arc<dyn LspAdapterDelegate>>,
 97        cx: &mut WindowContext,
 98    ) -> Task<SlashCommandResult>;
 99}
100
101pub type RenderFoldPlaceholder = Arc<
102    dyn Send
103        + Sync
104        + Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
105>;
106
107#[derive(Debug, PartialEq)]
108pub enum SlashCommandContent {
109    Text {
110        text: String,
111        run_commands_in_text: bool,
112    },
113}
114
115impl<'a> From<&'a str> for SlashCommandContent {
116    fn from(text: &'a str) -> Self {
117        Self::Text {
118            text: text.into(),
119            run_commands_in_text: false,
120        }
121    }
122}
123
124#[derive(Debug, PartialEq)]
125pub enum SlashCommandEvent {
126    StartMessage {
127        role: Role,
128        merge_same_roles: bool,
129    },
130    StartSection {
131        icon: IconName,
132        label: SharedString,
133        metadata: Option<serde_json::Value>,
134    },
135    Content(SlashCommandContent),
136    EndSection {
137        metadata: Option<serde_json::Value>,
138    },
139}
140
141#[derive(Debug, Default, PartialEq, Clone)]
142pub struct SlashCommandOutput {
143    pub text: String,
144    pub sections: Vec<SlashCommandOutputSection<usize>>,
145    pub run_commands_in_text: bool,
146}
147
148impl SlashCommandOutput {
149    pub fn ensure_valid_section_ranges(&mut self) {
150        for section in &mut self.sections {
151            section.range.start = section.range.start.min(self.text.len());
152            section.range.end = section.range.end.min(self.text.len());
153            while !self.text.is_char_boundary(section.range.start) {
154                section.range.start -= 1;
155            }
156            while !self.text.is_char_boundary(section.range.end) {
157                section.range.end += 1;
158            }
159        }
160    }
161
162    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
163    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
164        self.ensure_valid_section_ranges();
165
166        let mut events = Vec::new();
167        let mut last_section_end = 0;
168
169        for section in self.sections {
170            if last_section_end < section.range.start {
171                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
172                    text: self
173                        .text
174                        .get(last_section_end..section.range.start)
175                        .unwrap_or_default()
176                        .to_string(),
177                    run_commands_in_text: self.run_commands_in_text,
178                })));
179            }
180
181            events.push(Ok(SlashCommandEvent::StartSection {
182                icon: section.icon,
183                label: section.label,
184                metadata: section.metadata.clone(),
185            }));
186            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
187                text: self
188                    .text
189                    .get(section.range.start..section.range.end)
190                    .unwrap_or_default()
191                    .to_string(),
192                run_commands_in_text: self.run_commands_in_text,
193            })));
194            events.push(Ok(SlashCommandEvent::EndSection {
195                metadata: section.metadata,
196            }));
197
198            last_section_end = section.range.end;
199        }
200
201        if last_section_end < self.text.len() {
202            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
203                text: self.text[last_section_end..].to_string(),
204                run_commands_in_text: self.run_commands_in_text,
205            })));
206        }
207
208        stream::iter(events).boxed()
209    }
210
211    pub async fn from_event_stream(
212        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
213    ) -> Result<SlashCommandOutput> {
214        let mut output = SlashCommandOutput::default();
215        let mut section_stack = Vec::new();
216
217        while let Some(event) = events.next().await {
218            match event? {
219                SlashCommandEvent::StartSection {
220                    icon,
221                    label,
222                    metadata,
223                } => {
224                    let start = output.text.len();
225                    section_stack.push(SlashCommandOutputSection {
226                        range: start..start,
227                        icon,
228                        label,
229                        metadata,
230                    });
231                }
232                SlashCommandEvent::Content(SlashCommandContent::Text {
233                    text,
234                    run_commands_in_text,
235                }) => {
236                    output.text.push_str(&text);
237                    output.run_commands_in_text = run_commands_in_text;
238
239                    if let Some(section) = section_stack.last_mut() {
240                        section.range.end = output.text.len();
241                    }
242                }
243                SlashCommandEvent::EndSection { metadata } => {
244                    if let Some(mut section) = section_stack.pop() {
245                        section.metadata = metadata;
246                        output.sections.push(section);
247                    }
248                }
249                SlashCommandEvent::StartMessage { .. } => {}
250            }
251        }
252
253        while let Some(section) = section_stack.pop() {
254            output.sections.push(section);
255        }
256
257        Ok(output)
258    }
259}
260
261#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
262pub struct SlashCommandOutputSection<T> {
263    pub range: Range<T>,
264    pub icon: IconName,
265    pub label: SharedString,
266    pub metadata: Option<serde_json::Value>,
267}
268
269impl SlashCommandOutputSection<language::Anchor> {
270    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
271        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use pretty_assertions::assert_eq;
278    use serde_json::json;
279
280    use super::*;
281
282    #[gpui::test]
283    async fn test_slash_command_output_to_events_round_trip() {
284        // Test basic output consisting of a single section.
285        {
286            let text = "Hello, world!".to_string();
287            let range = 0..text.len();
288            let output = SlashCommandOutput {
289                text,
290                sections: vec![SlashCommandOutputSection {
291                    range,
292                    icon: IconName::Code,
293                    label: "Section 1".into(),
294                    metadata: None,
295                }],
296                run_commands_in_text: false,
297            };
298
299            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
300            let events = events
301                .into_iter()
302                .filter_map(|event| event.ok())
303                .collect::<Vec<_>>();
304
305            assert_eq!(
306                events,
307                vec![
308                    SlashCommandEvent::StartSection {
309                        icon: IconName::Code,
310                        label: "Section 1".into(),
311                        metadata: None
312                    },
313                    SlashCommandEvent::Content(SlashCommandContent::Text {
314                        text: "Hello, world!".into(),
315                        run_commands_in_text: false
316                    }),
317                    SlashCommandEvent::EndSection { metadata: None }
318                ]
319            );
320
321            let new_output =
322                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
323                    .await
324                    .unwrap();
325
326            assert_eq!(new_output, output);
327        }
328
329        // Test output where the sections do not comprise all of the text.
330        {
331            let text = "Apple\nCucumber\nBanana\n".to_string();
332            let output = SlashCommandOutput {
333                text,
334                sections: vec![
335                    SlashCommandOutputSection {
336                        range: 0..6,
337                        icon: IconName::Check,
338                        label: "Fruit".into(),
339                        metadata: None,
340                    },
341                    SlashCommandOutputSection {
342                        range: 15..22,
343                        icon: IconName::Check,
344                        label: "Fruit".into(),
345                        metadata: None,
346                    },
347                ],
348                run_commands_in_text: false,
349            };
350
351            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
352            let events = events
353                .into_iter()
354                .filter_map(|event| event.ok())
355                .collect::<Vec<_>>();
356
357            assert_eq!(
358                events,
359                vec![
360                    SlashCommandEvent::StartSection {
361                        icon: IconName::Check,
362                        label: "Fruit".into(),
363                        metadata: None
364                    },
365                    SlashCommandEvent::Content(SlashCommandContent::Text {
366                        text: "Apple\n".into(),
367                        run_commands_in_text: false
368                    }),
369                    SlashCommandEvent::EndSection { metadata: None },
370                    SlashCommandEvent::Content(SlashCommandContent::Text {
371                        text: "Cucumber\n".into(),
372                        run_commands_in_text: false
373                    }),
374                    SlashCommandEvent::StartSection {
375                        icon: IconName::Check,
376                        label: "Fruit".into(),
377                        metadata: None
378                    },
379                    SlashCommandEvent::Content(SlashCommandContent::Text {
380                        text: "Banana\n".into(),
381                        run_commands_in_text: false
382                    }),
383                    SlashCommandEvent::EndSection { metadata: None }
384                ]
385            );
386
387            let new_output =
388                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
389                    .await
390                    .unwrap();
391
392            assert_eq!(new_output, output);
393        }
394
395        // Test output consisting of multiple sections.
396        {
397            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
398            let output = SlashCommandOutput {
399                text,
400                sections: vec![
401                    SlashCommandOutputSection {
402                        range: 0..6,
403                        icon: IconName::FileCode,
404                        label: "Section 1".into(),
405                        metadata: Some(json!({ "a": true })),
406                    },
407                    SlashCommandOutputSection {
408                        range: 7..13,
409                        icon: IconName::FileDoc,
410                        label: "Section 2".into(),
411                        metadata: Some(json!({ "b": true })),
412                    },
413                    SlashCommandOutputSection {
414                        range: 14..20,
415                        icon: IconName::FileGit,
416                        label: "Section 3".into(),
417                        metadata: Some(json!({ "c": true })),
418                    },
419                    SlashCommandOutputSection {
420                        range: 21..27,
421                        icon: IconName::FileToml,
422                        label: "Section 4".into(),
423                        metadata: Some(json!({ "d": true })),
424                    },
425                ],
426                run_commands_in_text: false,
427            };
428
429            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
430            let events = events
431                .into_iter()
432                .filter_map(|event| event.ok())
433                .collect::<Vec<_>>();
434
435            assert_eq!(
436                events,
437                vec![
438                    SlashCommandEvent::StartSection {
439                        icon: IconName::FileCode,
440                        label: "Section 1".into(),
441                        metadata: Some(json!({ "a": true }))
442                    },
443                    SlashCommandEvent::Content(SlashCommandContent::Text {
444                        text: "Line 1".into(),
445                        run_commands_in_text: false
446                    }),
447                    SlashCommandEvent::EndSection {
448                        metadata: Some(json!({ "a": true }))
449                    },
450                    SlashCommandEvent::Content(SlashCommandContent::Text {
451                        text: "\n".into(),
452                        run_commands_in_text: false
453                    }),
454                    SlashCommandEvent::StartSection {
455                        icon: IconName::FileDoc,
456                        label: "Section 2".into(),
457                        metadata: Some(json!({ "b": true }))
458                    },
459                    SlashCommandEvent::Content(SlashCommandContent::Text {
460                        text: "Line 2".into(),
461                        run_commands_in_text: false
462                    }),
463                    SlashCommandEvent::EndSection {
464                        metadata: Some(json!({ "b": true }))
465                    },
466                    SlashCommandEvent::Content(SlashCommandContent::Text {
467                        text: "\n".into(),
468                        run_commands_in_text: false
469                    }),
470                    SlashCommandEvent::StartSection {
471                        icon: IconName::FileGit,
472                        label: "Section 3".into(),
473                        metadata: Some(json!({ "c": true }))
474                    },
475                    SlashCommandEvent::Content(SlashCommandContent::Text {
476                        text: "Line 3".into(),
477                        run_commands_in_text: false
478                    }),
479                    SlashCommandEvent::EndSection {
480                        metadata: Some(json!({ "c": true }))
481                    },
482                    SlashCommandEvent::Content(SlashCommandContent::Text {
483                        text: "\n".into(),
484                        run_commands_in_text: false
485                    }),
486                    SlashCommandEvent::StartSection {
487                        icon: IconName::FileToml,
488                        label: "Section 4".into(),
489                        metadata: Some(json!({ "d": true }))
490                    },
491                    SlashCommandEvent::Content(SlashCommandContent::Text {
492                        text: "Line 4".into(),
493                        run_commands_in_text: false
494                    }),
495                    SlashCommandEvent::EndSection {
496                        metadata: Some(json!({ "d": true }))
497                    },
498                    SlashCommandEvent::Content(SlashCommandContent::Text {
499                        text: "\n".into(),
500                        run_commands_in_text: false
501                    }),
502                ]
503            );
504
505            let new_output =
506                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
507                    .await
508                    .unwrap();
509
510            assert_eq!(new_output, output);
511        }
512    }
513}