assistant_slash_command.rs

  1mod slash_command_registry;
  2
  3use anyhow::Result;
  4use futures::stream::{self, BoxStream};
  5use futures::StreamExt;
  6use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
  7use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
  8use serde::{Deserialize, Serialize};
  9pub use slash_command_registry::*;
 10use std::{
 11    ops::Range,
 12    sync::{atomic::AtomicBool, Arc},
 13};
 14use workspace::{ui::IconName, Workspace};
 15
 16pub fn init(cx: &mut AppContext) {
 17    SlashCommandRegistry::default_global(cx);
 18}
 19
 20#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 21pub enum AfterCompletion {
 22    /// Run the command
 23    Run,
 24    /// Continue composing the current argument, doesn't add a space
 25    Compose,
 26    /// Continue the command composition, adds a space
 27    Continue,
 28}
 29
 30impl From<bool> for AfterCompletion {
 31    fn from(value: bool) -> Self {
 32        if value {
 33            AfterCompletion::Run
 34        } else {
 35            AfterCompletion::Continue
 36        }
 37    }
 38}
 39
 40impl AfterCompletion {
 41    pub fn run(&self) -> bool {
 42        match self {
 43            AfterCompletion::Run => true,
 44            AfterCompletion::Compose | AfterCompletion::Continue => false,
 45        }
 46    }
 47}
 48
 49#[derive(Debug)]
 50pub struct ArgumentCompletion {
 51    /// The label to display for this completion.
 52    pub label: CodeLabel,
 53    /// The new text that should be inserted into the command when this completion is accepted.
 54    pub new_text: String,
 55    /// Whether the command should be run when accepting this completion.
 56    pub after_completion: AfterCompletion,
 57    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 58    pub replace_previous_arguments: bool,
 59}
 60
 61pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 62
 63pub trait SlashCommand: 'static + Send + Sync {
 64    fn name(&self) -> String;
 65    fn label(&self, _cx: &AppContext) -> CodeLabel {
 66        CodeLabel::plain(self.name(), None)
 67    }
 68    fn description(&self) -> String;
 69    fn menu_text(&self) -> String;
 70    fn complete_argument(
 71        self: Arc<Self>,
 72        arguments: &[String],
 73        cancel: Arc<AtomicBool>,
 74        workspace: Option<WeakView<Workspace>>,
 75        cx: &mut WindowContext,
 76    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 77    fn requires_argument(&self) -> bool;
 78    fn accepts_arguments(&self) -> bool {
 79        self.requires_argument()
 80    }
 81    fn run(
 82        self: Arc<Self>,
 83        arguments: &[String],
 84        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 85        context_buffer: BufferSnapshot,
 86        workspace: WeakView<Workspace>,
 87        // TODO: We're just using the `LspAdapterDelegate` here because that is
 88        // what the extension API is already expecting.
 89        //
 90        // It may be that `LspAdapterDelegate` needs a more general name, or
 91        // perhaps another kind of delegate is needed here.
 92        delegate: Option<Arc<dyn LspAdapterDelegate>>,
 93        cx: &mut WindowContext,
 94    ) -> Task<SlashCommandResult>;
 95}
 96
 97pub type RenderFoldPlaceholder = Arc<
 98    dyn Send
 99        + Sync
100        + Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
101>;
102
103#[derive(Debug, PartialEq, Eq)]
104pub enum SlashCommandContent {
105    Text {
106        text: String,
107        run_commands_in_text: bool,
108    },
109}
110
111#[derive(Debug, PartialEq, Eq)]
112pub enum SlashCommandEvent {
113    StartSection {
114        icon: IconName,
115        label: SharedString,
116        metadata: Option<serde_json::Value>,
117    },
118    Content(SlashCommandContent),
119    EndSection {
120        metadata: Option<serde_json::Value>,
121    },
122}
123
124#[derive(Debug, Default, PartialEq, Clone)]
125pub struct SlashCommandOutput {
126    pub text: String,
127    pub sections: Vec<SlashCommandOutputSection<usize>>,
128    pub run_commands_in_text: bool,
129}
130
131impl SlashCommandOutput {
132    pub fn ensure_valid_section_ranges(&mut self) {
133        for section in &mut self.sections {
134            section.range.start = section.range.start.min(self.text.len());
135            section.range.end = section.range.end.min(self.text.len());
136            while !self.text.is_char_boundary(section.range.start) {
137                section.range.start -= 1;
138            }
139            while !self.text.is_char_boundary(section.range.end) {
140                section.range.end += 1;
141            }
142        }
143    }
144
145    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
146    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
147        self.ensure_valid_section_ranges();
148
149        let mut events = Vec::new();
150        let mut last_section_end = 0;
151
152        for section in self.sections {
153            if last_section_end < section.range.start {
154                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
155                    text: self
156                        .text
157                        .get(last_section_end..section.range.start)
158                        .unwrap_or_default()
159                        .to_string(),
160                    run_commands_in_text: self.run_commands_in_text,
161                })));
162            }
163
164            events.push(Ok(SlashCommandEvent::StartSection {
165                icon: section.icon,
166                label: section.label,
167                metadata: section.metadata.clone(),
168            }));
169            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
170                text: self
171                    .text
172                    .get(section.range.start..section.range.end)
173                    .unwrap_or_default()
174                    .to_string(),
175                run_commands_in_text: self.run_commands_in_text,
176            })));
177            events.push(Ok(SlashCommandEvent::EndSection {
178                metadata: section.metadata,
179            }));
180
181            last_section_end = section.range.end;
182        }
183
184        if last_section_end < self.text.len() {
185            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
186                text: self.text[last_section_end..].to_string(),
187                run_commands_in_text: self.run_commands_in_text,
188            })));
189        }
190
191        stream::iter(events).boxed()
192    }
193
194    pub async fn from_event_stream(
195        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
196    ) -> Result<SlashCommandOutput> {
197        let mut output = SlashCommandOutput::default();
198        let mut section_stack = Vec::new();
199
200        while let Some(event) = events.next().await {
201            match event? {
202                SlashCommandEvent::StartSection {
203                    icon,
204                    label,
205                    metadata,
206                } => {
207                    let start = output.text.len();
208                    section_stack.push(SlashCommandOutputSection {
209                        range: start..start,
210                        icon,
211                        label,
212                        metadata,
213                    });
214                }
215                SlashCommandEvent::Content(SlashCommandContent::Text {
216                    text,
217                    run_commands_in_text,
218                }) => {
219                    output.text.push_str(&text);
220                    output.run_commands_in_text = run_commands_in_text;
221
222                    if let Some(section) = section_stack.last_mut() {
223                        section.range.end = output.text.len();
224                    }
225                }
226                SlashCommandEvent::EndSection { metadata } => {
227                    if let Some(mut section) = section_stack.pop() {
228                        section.metadata = metadata;
229                        output.sections.push(section);
230                    }
231                }
232            }
233        }
234
235        while let Some(section) = section_stack.pop() {
236            output.sections.push(section);
237        }
238
239        Ok(output)
240    }
241}
242
243#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
244pub struct SlashCommandOutputSection<T> {
245    pub range: Range<T>,
246    pub icon: IconName,
247    pub label: SharedString,
248    pub metadata: Option<serde_json::Value>,
249}
250
251impl SlashCommandOutputSection<language::Anchor> {
252    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
253        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use pretty_assertions::assert_eq;
260    use serde_json::json;
261
262    use super::*;
263
264    #[gpui::test]
265    async fn test_slash_command_output_to_events_round_trip() {
266        // Test basic output consisting of a single section.
267        {
268            let text = "Hello, world!".to_string();
269            let range = 0..text.len();
270            let output = SlashCommandOutput {
271                text,
272                sections: vec![SlashCommandOutputSection {
273                    range,
274                    icon: IconName::Code,
275                    label: "Section 1".into(),
276                    metadata: None,
277                }],
278                run_commands_in_text: false,
279            };
280
281            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
282            let events = events
283                .into_iter()
284                .filter_map(|event| event.ok())
285                .collect::<Vec<_>>();
286
287            assert_eq!(
288                events,
289                vec![
290                    SlashCommandEvent::StartSection {
291                        icon: IconName::Code,
292                        label: "Section 1".into(),
293                        metadata: None
294                    },
295                    SlashCommandEvent::Content(SlashCommandContent::Text {
296                        text: "Hello, world!".into(),
297                        run_commands_in_text: false
298                    }),
299                    SlashCommandEvent::EndSection { metadata: None }
300                ]
301            );
302
303            let new_output =
304                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
305                    .await
306                    .unwrap();
307
308            assert_eq!(new_output, output);
309        }
310
311        // Test output where the sections do not comprise all of the text.
312        {
313            let text = "Apple\nCucumber\nBanana\n".to_string();
314            let output = SlashCommandOutput {
315                text,
316                sections: vec![
317                    SlashCommandOutputSection {
318                        range: 0..6,
319                        icon: IconName::Check,
320                        label: "Fruit".into(),
321                        metadata: None,
322                    },
323                    SlashCommandOutputSection {
324                        range: 15..22,
325                        icon: IconName::Check,
326                        label: "Fruit".into(),
327                        metadata: None,
328                    },
329                ],
330                run_commands_in_text: false,
331            };
332
333            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
334            let events = events
335                .into_iter()
336                .filter_map(|event| event.ok())
337                .collect::<Vec<_>>();
338
339            assert_eq!(
340                events,
341                vec![
342                    SlashCommandEvent::StartSection {
343                        icon: IconName::Check,
344                        label: "Fruit".into(),
345                        metadata: None
346                    },
347                    SlashCommandEvent::Content(SlashCommandContent::Text {
348                        text: "Apple\n".into(),
349                        run_commands_in_text: false
350                    }),
351                    SlashCommandEvent::EndSection { metadata: None },
352                    SlashCommandEvent::Content(SlashCommandContent::Text {
353                        text: "Cucumber\n".into(),
354                        run_commands_in_text: false
355                    }),
356                    SlashCommandEvent::StartSection {
357                        icon: IconName::Check,
358                        label: "Fruit".into(),
359                        metadata: None
360                    },
361                    SlashCommandEvent::Content(SlashCommandContent::Text {
362                        text: "Banana\n".into(),
363                        run_commands_in_text: false
364                    }),
365                    SlashCommandEvent::EndSection { metadata: None }
366                ]
367            );
368
369            let new_output =
370                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
371                    .await
372                    .unwrap();
373
374            assert_eq!(new_output, output);
375        }
376
377        // Test output consisting of multiple sections.
378        {
379            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
380            let output = SlashCommandOutput {
381                text,
382                sections: vec![
383                    SlashCommandOutputSection {
384                        range: 0..6,
385                        icon: IconName::FileCode,
386                        label: "Section 1".into(),
387                        metadata: Some(json!({ "a": true })),
388                    },
389                    SlashCommandOutputSection {
390                        range: 7..13,
391                        icon: IconName::FileDoc,
392                        label: "Section 2".into(),
393                        metadata: Some(json!({ "b": true })),
394                    },
395                    SlashCommandOutputSection {
396                        range: 14..20,
397                        icon: IconName::FileGit,
398                        label: "Section 3".into(),
399                        metadata: Some(json!({ "c": true })),
400                    },
401                    SlashCommandOutputSection {
402                        range: 21..27,
403                        icon: IconName::FileToml,
404                        label: "Section 4".into(),
405                        metadata: Some(json!({ "d": true })),
406                    },
407                ],
408                run_commands_in_text: false,
409            };
410
411            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
412            let events = events
413                .into_iter()
414                .filter_map(|event| event.ok())
415                .collect::<Vec<_>>();
416
417            assert_eq!(
418                events,
419                vec![
420                    SlashCommandEvent::StartSection {
421                        icon: IconName::FileCode,
422                        label: "Section 1".into(),
423                        metadata: Some(json!({ "a": true }))
424                    },
425                    SlashCommandEvent::Content(SlashCommandContent::Text {
426                        text: "Line 1".into(),
427                        run_commands_in_text: false
428                    }),
429                    SlashCommandEvent::EndSection {
430                        metadata: Some(json!({ "a": true }))
431                    },
432                    SlashCommandEvent::Content(SlashCommandContent::Text {
433                        text: "\n".into(),
434                        run_commands_in_text: false
435                    }),
436                    SlashCommandEvent::StartSection {
437                        icon: IconName::FileDoc,
438                        label: "Section 2".into(),
439                        metadata: Some(json!({ "b": true }))
440                    },
441                    SlashCommandEvent::Content(SlashCommandContent::Text {
442                        text: "Line 2".into(),
443                        run_commands_in_text: false
444                    }),
445                    SlashCommandEvent::EndSection {
446                        metadata: Some(json!({ "b": true }))
447                    },
448                    SlashCommandEvent::Content(SlashCommandContent::Text {
449                        text: "\n".into(),
450                        run_commands_in_text: false
451                    }),
452                    SlashCommandEvent::StartSection {
453                        icon: IconName::FileGit,
454                        label: "Section 3".into(),
455                        metadata: Some(json!({ "c": true }))
456                    },
457                    SlashCommandEvent::Content(SlashCommandContent::Text {
458                        text: "Line 3".into(),
459                        run_commands_in_text: false
460                    }),
461                    SlashCommandEvent::EndSection {
462                        metadata: Some(json!({ "c": true }))
463                    },
464                    SlashCommandEvent::Content(SlashCommandContent::Text {
465                        text: "\n".into(),
466                        run_commands_in_text: false
467                    }),
468                    SlashCommandEvent::StartSection {
469                        icon: IconName::FileToml,
470                        label: "Section 4".into(),
471                        metadata: Some(json!({ "d": true }))
472                    },
473                    SlashCommandEvent::Content(SlashCommandContent::Text {
474                        text: "Line 4".into(),
475                        run_commands_in_text: false
476                    }),
477                    SlashCommandEvent::EndSection {
478                        metadata: Some(json!({ "d": true }))
479                    },
480                    SlashCommandEvent::Content(SlashCommandContent::Text {
481                        text: "\n".into(),
482                        run_commands_in_text: false
483                    }),
484                ]
485            );
486
487            let new_output =
488                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
489                    .await
490                    .unwrap();
491
492            assert_eq!(new_output, output);
493        }
494    }
495}