assistant_slash_command.rs

  1mod slash_command_registry;
  2
  3use anyhow::Result;
  4use futures::stream::{self, BoxStream};
  5use futures::StreamExt;
  6use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
  7use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
  8use serde::{Deserialize, Serialize};
  9pub use slash_command_registry::*;
 10use std::{
 11    ops::Range,
 12    sync::{atomic::AtomicBool, Arc},
 13};
 14use workspace::{ui::IconName, Workspace};
 15
 16pub fn init(cx: &mut AppContext) {
 17    SlashCommandRegistry::default_global(cx);
 18}
 19
 20#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 21pub enum AfterCompletion {
 22    /// Run the command
 23    Run,
 24    /// Continue composing the current argument, doesn't add a space
 25    Compose,
 26    /// Continue the command composition, adds a space
 27    Continue,
 28}
 29
 30impl From<bool> for AfterCompletion {
 31    fn from(value: bool) -> Self {
 32        if value {
 33            AfterCompletion::Run
 34        } else {
 35            AfterCompletion::Continue
 36        }
 37    }
 38}
 39
 40impl AfterCompletion {
 41    pub fn run(&self) -> bool {
 42        match self {
 43            AfterCompletion::Run => true,
 44            AfterCompletion::Compose | AfterCompletion::Continue => false,
 45        }
 46    }
 47}
 48
 49#[derive(Debug)]
 50pub struct ArgumentCompletion {
 51    /// The label to display for this completion.
 52    pub label: CodeLabel,
 53    /// The new text that should be inserted into the command when this completion is accepted.
 54    pub new_text: String,
 55    /// Whether the command should be run when accepting this completion.
 56    pub after_completion: AfterCompletion,
 57    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 58    pub replace_previous_arguments: bool,
 59}
 60
 61pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 62
 63pub trait SlashCommand: 'static + Send + Sync {
 64    fn name(&self) -> String;
 65    fn icon(&self) -> IconName {
 66        IconName::Slash
 67    }
 68    fn label(&self, _cx: &AppContext) -> CodeLabel {
 69        CodeLabel::plain(self.name(), None)
 70    }
 71    fn description(&self) -> String;
 72    fn menu_text(&self) -> String;
 73    fn complete_argument(
 74        self: Arc<Self>,
 75        arguments: &[String],
 76        cancel: Arc<AtomicBool>,
 77        workspace: Option<WeakView<Workspace>>,
 78        cx: &mut WindowContext,
 79    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 80    fn requires_argument(&self) -> bool;
 81    fn accepts_arguments(&self) -> bool {
 82        self.requires_argument()
 83    }
 84    fn run(
 85        self: Arc<Self>,
 86        arguments: &[String],
 87        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 88        context_buffer: BufferSnapshot,
 89        workspace: WeakView<Workspace>,
 90        // TODO: We're just using the `LspAdapterDelegate` here because that is
 91        // what the extension API is already expecting.
 92        //
 93        // It may be that `LspAdapterDelegate` needs a more general name, or
 94        // perhaps another kind of delegate is needed here.
 95        delegate: Option<Arc<dyn LspAdapterDelegate>>,
 96        cx: &mut WindowContext,
 97    ) -> Task<SlashCommandResult>;
 98}
 99
100pub type RenderFoldPlaceholder = Arc<
101    dyn Send
102        + Sync
103        + Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
104>;
105
106#[derive(Debug, PartialEq, Eq)]
107pub enum SlashCommandContent {
108    Text {
109        text: String,
110        run_commands_in_text: bool,
111    },
112}
113
114#[derive(Debug, PartialEq, Eq)]
115pub enum SlashCommandEvent {
116    StartSection {
117        icon: IconName,
118        label: SharedString,
119        metadata: Option<serde_json::Value>,
120    },
121    Content(SlashCommandContent),
122    EndSection {
123        metadata: Option<serde_json::Value>,
124    },
125}
126
127#[derive(Debug, Default, PartialEq, Clone)]
128pub struct SlashCommandOutput {
129    pub text: String,
130    pub sections: Vec<SlashCommandOutputSection<usize>>,
131    pub run_commands_in_text: bool,
132}
133
134impl SlashCommandOutput {
135    pub fn ensure_valid_section_ranges(&mut self) {
136        for section in &mut self.sections {
137            section.range.start = section.range.start.min(self.text.len());
138            section.range.end = section.range.end.min(self.text.len());
139            while !self.text.is_char_boundary(section.range.start) {
140                section.range.start -= 1;
141            }
142            while !self.text.is_char_boundary(section.range.end) {
143                section.range.end += 1;
144            }
145        }
146    }
147
148    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
149    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
150        self.ensure_valid_section_ranges();
151
152        let mut events = Vec::new();
153        let mut last_section_end = 0;
154
155        for section in self.sections {
156            if last_section_end < section.range.start {
157                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
158                    text: self
159                        .text
160                        .get(last_section_end..section.range.start)
161                        .unwrap_or_default()
162                        .to_string(),
163                    run_commands_in_text: self.run_commands_in_text,
164                })));
165            }
166
167            events.push(Ok(SlashCommandEvent::StartSection {
168                icon: section.icon,
169                label: section.label,
170                metadata: section.metadata.clone(),
171            }));
172            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
173                text: self
174                    .text
175                    .get(section.range.start..section.range.end)
176                    .unwrap_or_default()
177                    .to_string(),
178                run_commands_in_text: self.run_commands_in_text,
179            })));
180            events.push(Ok(SlashCommandEvent::EndSection {
181                metadata: section.metadata,
182            }));
183
184            last_section_end = section.range.end;
185        }
186
187        if last_section_end < self.text.len() {
188            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
189                text: self.text[last_section_end..].to_string(),
190                run_commands_in_text: self.run_commands_in_text,
191            })));
192        }
193
194        stream::iter(events).boxed()
195    }
196
197    pub async fn from_event_stream(
198        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
199    ) -> Result<SlashCommandOutput> {
200        let mut output = SlashCommandOutput::default();
201        let mut section_stack = Vec::new();
202
203        while let Some(event) = events.next().await {
204            match event? {
205                SlashCommandEvent::StartSection {
206                    icon,
207                    label,
208                    metadata,
209                } => {
210                    let start = output.text.len();
211                    section_stack.push(SlashCommandOutputSection {
212                        range: start..start,
213                        icon,
214                        label,
215                        metadata,
216                    });
217                }
218                SlashCommandEvent::Content(SlashCommandContent::Text {
219                    text,
220                    run_commands_in_text,
221                }) => {
222                    output.text.push_str(&text);
223                    output.run_commands_in_text = run_commands_in_text;
224
225                    if let Some(section) = section_stack.last_mut() {
226                        section.range.end = output.text.len();
227                    }
228                }
229                SlashCommandEvent::EndSection { metadata } => {
230                    if let Some(mut section) = section_stack.pop() {
231                        section.metadata = metadata;
232                        output.sections.push(section);
233                    }
234                }
235            }
236        }
237
238        while let Some(section) = section_stack.pop() {
239            output.sections.push(section);
240        }
241
242        Ok(output)
243    }
244}
245
246#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
247pub struct SlashCommandOutputSection<T> {
248    pub range: Range<T>,
249    pub icon: IconName,
250    pub label: SharedString,
251    pub metadata: Option<serde_json::Value>,
252}
253
254impl SlashCommandOutputSection<language::Anchor> {
255    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
256        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use pretty_assertions::assert_eq;
263    use serde_json::json;
264
265    use super::*;
266
267    #[gpui::test]
268    async fn test_slash_command_output_to_events_round_trip() {
269        // Test basic output consisting of a single section.
270        {
271            let text = "Hello, world!".to_string();
272            let range = 0..text.len();
273            let output = SlashCommandOutput {
274                text,
275                sections: vec![SlashCommandOutputSection {
276                    range,
277                    icon: IconName::Code,
278                    label: "Section 1".into(),
279                    metadata: None,
280                }],
281                run_commands_in_text: false,
282            };
283
284            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
285            let events = events
286                .into_iter()
287                .filter_map(|event| event.ok())
288                .collect::<Vec<_>>();
289
290            assert_eq!(
291                events,
292                vec![
293                    SlashCommandEvent::StartSection {
294                        icon: IconName::Code,
295                        label: "Section 1".into(),
296                        metadata: None
297                    },
298                    SlashCommandEvent::Content(SlashCommandContent::Text {
299                        text: "Hello, world!".into(),
300                        run_commands_in_text: false
301                    }),
302                    SlashCommandEvent::EndSection { metadata: None }
303                ]
304            );
305
306            let new_output =
307                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
308                    .await
309                    .unwrap();
310
311            assert_eq!(new_output, output);
312        }
313
314        // Test output where the sections do not comprise all of the text.
315        {
316            let text = "Apple\nCucumber\nBanana\n".to_string();
317            let output = SlashCommandOutput {
318                text,
319                sections: vec![
320                    SlashCommandOutputSection {
321                        range: 0..6,
322                        icon: IconName::Check,
323                        label: "Fruit".into(),
324                        metadata: None,
325                    },
326                    SlashCommandOutputSection {
327                        range: 15..22,
328                        icon: IconName::Check,
329                        label: "Fruit".into(),
330                        metadata: None,
331                    },
332                ],
333                run_commands_in_text: false,
334            };
335
336            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
337            let events = events
338                .into_iter()
339                .filter_map(|event| event.ok())
340                .collect::<Vec<_>>();
341
342            assert_eq!(
343                events,
344                vec![
345                    SlashCommandEvent::StartSection {
346                        icon: IconName::Check,
347                        label: "Fruit".into(),
348                        metadata: None
349                    },
350                    SlashCommandEvent::Content(SlashCommandContent::Text {
351                        text: "Apple\n".into(),
352                        run_commands_in_text: false
353                    }),
354                    SlashCommandEvent::EndSection { metadata: None },
355                    SlashCommandEvent::Content(SlashCommandContent::Text {
356                        text: "Cucumber\n".into(),
357                        run_commands_in_text: false
358                    }),
359                    SlashCommandEvent::StartSection {
360                        icon: IconName::Check,
361                        label: "Fruit".into(),
362                        metadata: None
363                    },
364                    SlashCommandEvent::Content(SlashCommandContent::Text {
365                        text: "Banana\n".into(),
366                        run_commands_in_text: false
367                    }),
368                    SlashCommandEvent::EndSection { metadata: None }
369                ]
370            );
371
372            let new_output =
373                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
374                    .await
375                    .unwrap();
376
377            assert_eq!(new_output, output);
378        }
379
380        // Test output consisting of multiple sections.
381        {
382            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
383            let output = SlashCommandOutput {
384                text,
385                sections: vec![
386                    SlashCommandOutputSection {
387                        range: 0..6,
388                        icon: IconName::FileCode,
389                        label: "Section 1".into(),
390                        metadata: Some(json!({ "a": true })),
391                    },
392                    SlashCommandOutputSection {
393                        range: 7..13,
394                        icon: IconName::FileDoc,
395                        label: "Section 2".into(),
396                        metadata: Some(json!({ "b": true })),
397                    },
398                    SlashCommandOutputSection {
399                        range: 14..20,
400                        icon: IconName::FileGit,
401                        label: "Section 3".into(),
402                        metadata: Some(json!({ "c": true })),
403                    },
404                    SlashCommandOutputSection {
405                        range: 21..27,
406                        icon: IconName::FileToml,
407                        label: "Section 4".into(),
408                        metadata: Some(json!({ "d": true })),
409                    },
410                ],
411                run_commands_in_text: false,
412            };
413
414            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
415            let events = events
416                .into_iter()
417                .filter_map(|event| event.ok())
418                .collect::<Vec<_>>();
419
420            assert_eq!(
421                events,
422                vec![
423                    SlashCommandEvent::StartSection {
424                        icon: IconName::FileCode,
425                        label: "Section 1".into(),
426                        metadata: Some(json!({ "a": true }))
427                    },
428                    SlashCommandEvent::Content(SlashCommandContent::Text {
429                        text: "Line 1".into(),
430                        run_commands_in_text: false
431                    }),
432                    SlashCommandEvent::EndSection {
433                        metadata: Some(json!({ "a": true }))
434                    },
435                    SlashCommandEvent::Content(SlashCommandContent::Text {
436                        text: "\n".into(),
437                        run_commands_in_text: false
438                    }),
439                    SlashCommandEvent::StartSection {
440                        icon: IconName::FileDoc,
441                        label: "Section 2".into(),
442                        metadata: Some(json!({ "b": true }))
443                    },
444                    SlashCommandEvent::Content(SlashCommandContent::Text {
445                        text: "Line 2".into(),
446                        run_commands_in_text: false
447                    }),
448                    SlashCommandEvent::EndSection {
449                        metadata: Some(json!({ "b": true }))
450                    },
451                    SlashCommandEvent::Content(SlashCommandContent::Text {
452                        text: "\n".into(),
453                        run_commands_in_text: false
454                    }),
455                    SlashCommandEvent::StartSection {
456                        icon: IconName::FileGit,
457                        label: "Section 3".into(),
458                        metadata: Some(json!({ "c": true }))
459                    },
460                    SlashCommandEvent::Content(SlashCommandContent::Text {
461                        text: "Line 3".into(),
462                        run_commands_in_text: false
463                    }),
464                    SlashCommandEvent::EndSection {
465                        metadata: Some(json!({ "c": true }))
466                    },
467                    SlashCommandEvent::Content(SlashCommandContent::Text {
468                        text: "\n".into(),
469                        run_commands_in_text: false
470                    }),
471                    SlashCommandEvent::StartSection {
472                        icon: IconName::FileToml,
473                        label: "Section 4".into(),
474                        metadata: Some(json!({ "d": true }))
475                    },
476                    SlashCommandEvent::Content(SlashCommandContent::Text {
477                        text: "Line 4".into(),
478                        run_commands_in_text: false
479                    }),
480                    SlashCommandEvent::EndSection {
481                        metadata: Some(json!({ "d": true }))
482                    },
483                    SlashCommandEvent::Content(SlashCommandContent::Text {
484                        text: "\n".into(),
485                        run_commands_in_text: false
486                    }),
487                ]
488            );
489
490            let new_output =
491                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
492                    .await
493                    .unwrap();
494
495            assert_eq!(new_output, output);
496        }
497    }
498}