assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3
  4pub use crate::extension_slash_command::*;
  5pub use crate::slash_command_registry::*;
  6use anyhow::Result;
  7use futures::stream::{self, BoxStream};
  8use futures::StreamExt;
  9use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
 10use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 11pub use language_model::Role;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    ops::Range,
 15    sync::{atomic::AtomicBool, Arc},
 16};
 17use workspace::{ui::IconName, Workspace};
 18
 19pub fn init(cx: &mut AppContext) {
 20    SlashCommandRegistry::default_global(cx);
 21}
 22
 23#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 24pub enum AfterCompletion {
 25    /// Run the command
 26    Run,
 27    /// Continue composing the current argument, doesn't add a space
 28    Compose,
 29    /// Continue the command composition, adds a space
 30    Continue,
 31}
 32
 33impl From<bool> for AfterCompletion {
 34    fn from(value: bool) -> Self {
 35        if value {
 36            AfterCompletion::Run
 37        } else {
 38            AfterCompletion::Continue
 39        }
 40    }
 41}
 42
 43impl AfterCompletion {
 44    pub fn run(&self) -> bool {
 45        match self {
 46            AfterCompletion::Run => true,
 47            AfterCompletion::Compose | AfterCompletion::Continue => false,
 48        }
 49    }
 50}
 51
 52#[derive(Debug)]
 53pub struct ArgumentCompletion {
 54    /// The label to display for this completion.
 55    pub label: CodeLabel,
 56    /// The new text that should be inserted into the command when this completion is accepted.
 57    pub new_text: String,
 58    /// Whether the command should be run when accepting this completion.
 59    pub after_completion: AfterCompletion,
 60    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 61    pub replace_previous_arguments: bool,
 62}
 63
 64pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 65
 66pub trait SlashCommand: 'static + Send + Sync {
 67    fn name(&self) -> String;
 68    fn icon(&self) -> IconName {
 69        IconName::Slash
 70    }
 71    fn label(&self, _cx: &AppContext) -> CodeLabel {
 72        CodeLabel::plain(self.name(), None)
 73    }
 74    fn description(&self) -> String;
 75    fn menu_text(&self) -> String;
 76    fn complete_argument(
 77        self: Arc<Self>,
 78        arguments: &[String],
 79        cancel: Arc<AtomicBool>,
 80        workspace: Option<WeakView<Workspace>>,
 81        cx: &mut WindowContext,
 82    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 83    fn requires_argument(&self) -> bool;
 84    fn accepts_arguments(&self) -> bool {
 85        self.requires_argument()
 86    }
 87    fn run(
 88        self: Arc<Self>,
 89        arguments: &[String],
 90        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 91        context_buffer: BufferSnapshot,
 92        workspace: WeakView<Workspace>,
 93        // TODO: We're just using the `LspAdapterDelegate` here because that is
 94        // what the extension API is already expecting.
 95        //
 96        // It may be that `LspAdapterDelegate` needs a more general name, or
 97        // perhaps another kind of delegate is needed here.
 98        delegate: Option<Arc<dyn LspAdapterDelegate>>,
 99        cx: &mut WindowContext,
100    ) -> Task<SlashCommandResult>;
101}
102
103pub type RenderFoldPlaceholder = Arc<
104    dyn Send
105        + Sync
106        + Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
107>;
108
109#[derive(Debug, PartialEq)]
110pub enum SlashCommandContent {
111    Text {
112        text: String,
113        run_commands_in_text: bool,
114    },
115}
116
117impl<'a> From<&'a str> for SlashCommandContent {
118    fn from(text: &'a str) -> Self {
119        Self::Text {
120            text: text.into(),
121            run_commands_in_text: false,
122        }
123    }
124}
125
126#[derive(Debug, PartialEq)]
127pub enum SlashCommandEvent {
128    StartMessage {
129        role: Role,
130        merge_same_roles: bool,
131    },
132    StartSection {
133        icon: IconName,
134        label: SharedString,
135        metadata: Option<serde_json::Value>,
136    },
137    Content(SlashCommandContent),
138    EndSection,
139}
140
141#[derive(Debug, Default, PartialEq, Clone)]
142pub struct SlashCommandOutput {
143    pub text: String,
144    pub sections: Vec<SlashCommandOutputSection<usize>>,
145    pub run_commands_in_text: bool,
146}
147
148impl SlashCommandOutput {
149    pub fn ensure_valid_section_ranges(&mut self) {
150        for section in &mut self.sections {
151            section.range.start = section.range.start.min(self.text.len());
152            section.range.end = section.range.end.min(self.text.len());
153            while !self.text.is_char_boundary(section.range.start) {
154                section.range.start -= 1;
155            }
156            while !self.text.is_char_boundary(section.range.end) {
157                section.range.end += 1;
158            }
159        }
160    }
161
162    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
163    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
164        self.ensure_valid_section_ranges();
165
166        let mut events = Vec::new();
167
168        let mut section_endpoints = Vec::new();
169        for section in self.sections {
170            section_endpoints.push((
171                section.range.start,
172                SlashCommandEvent::StartSection {
173                    icon: section.icon,
174                    label: section.label,
175                    metadata: section.metadata,
176                },
177            ));
178            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
179        }
180        section_endpoints.sort_by_key(|(offset, _)| *offset);
181
182        let mut content_offset = 0;
183        for (endpoint_offset, endpoint) in section_endpoints {
184            if content_offset < endpoint_offset {
185                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
186                    text: self.text[content_offset..endpoint_offset].to_string(),
187                    run_commands_in_text: self.run_commands_in_text,
188                })));
189                content_offset = endpoint_offset;
190            }
191
192            events.push(Ok(endpoint));
193        }
194
195        if content_offset < self.text.len() {
196            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
197                text: self.text[content_offset..].to_string(),
198                run_commands_in_text: self.run_commands_in_text,
199            })));
200        }
201
202        stream::iter(events).boxed()
203    }
204
205    pub async fn from_event_stream(
206        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
207    ) -> Result<SlashCommandOutput> {
208        let mut output = SlashCommandOutput::default();
209        let mut section_stack = Vec::new();
210
211        while let Some(event) = events.next().await {
212            match event? {
213                SlashCommandEvent::StartSection {
214                    icon,
215                    label,
216                    metadata,
217                } => {
218                    let start = output.text.len();
219                    section_stack.push(SlashCommandOutputSection {
220                        range: start..start,
221                        icon,
222                        label,
223                        metadata,
224                    });
225                }
226                SlashCommandEvent::Content(SlashCommandContent::Text {
227                    text,
228                    run_commands_in_text,
229                }) => {
230                    output.text.push_str(&text);
231                    output.run_commands_in_text = run_commands_in_text;
232
233                    if let Some(section) = section_stack.last_mut() {
234                        section.range.end = output.text.len();
235                    }
236                }
237                SlashCommandEvent::EndSection => {
238                    if let Some(section) = section_stack.pop() {
239                        output.sections.push(section);
240                    }
241                }
242                SlashCommandEvent::StartMessage { .. } => {}
243            }
244        }
245
246        while let Some(section) = section_stack.pop() {
247            output.sections.push(section);
248        }
249
250        Ok(output)
251    }
252}
253
254#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
255pub struct SlashCommandOutputSection<T> {
256    pub range: Range<T>,
257    pub icon: IconName,
258    pub label: SharedString,
259    pub metadata: Option<serde_json::Value>,
260}
261
262impl SlashCommandOutputSection<language::Anchor> {
263    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
264        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use pretty_assertions::assert_eq;
271    use serde_json::json;
272
273    use super::*;
274
275    #[gpui::test]
276    async fn test_slash_command_output_to_events_round_trip() {
277        // Test basic output consisting of a single section.
278        {
279            let text = "Hello, world!".to_string();
280            let range = 0..text.len();
281            let output = SlashCommandOutput {
282                text,
283                sections: vec![SlashCommandOutputSection {
284                    range,
285                    icon: IconName::Code,
286                    label: "Section 1".into(),
287                    metadata: None,
288                }],
289                run_commands_in_text: false,
290            };
291
292            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
293            let events = events
294                .into_iter()
295                .filter_map(|event| event.ok())
296                .collect::<Vec<_>>();
297
298            assert_eq!(
299                events,
300                vec![
301                    SlashCommandEvent::StartSection {
302                        icon: IconName::Code,
303                        label: "Section 1".into(),
304                        metadata: None
305                    },
306                    SlashCommandEvent::Content(SlashCommandContent::Text {
307                        text: "Hello, world!".into(),
308                        run_commands_in_text: false
309                    }),
310                    SlashCommandEvent::EndSection
311                ]
312            );
313
314            let new_output =
315                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
316                    .await
317                    .unwrap();
318
319            assert_eq!(new_output, output);
320        }
321
322        // Test output where the sections do not comprise all of the text.
323        {
324            let text = "Apple\nCucumber\nBanana\n".to_string();
325            let output = SlashCommandOutput {
326                text,
327                sections: vec![
328                    SlashCommandOutputSection {
329                        range: 0..6,
330                        icon: IconName::Check,
331                        label: "Fruit".into(),
332                        metadata: None,
333                    },
334                    SlashCommandOutputSection {
335                        range: 15..22,
336                        icon: IconName::Check,
337                        label: "Fruit".into(),
338                        metadata: None,
339                    },
340                ],
341                run_commands_in_text: false,
342            };
343
344            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
345            let events = events
346                .into_iter()
347                .filter_map(|event| event.ok())
348                .collect::<Vec<_>>();
349
350            assert_eq!(
351                events,
352                vec![
353                    SlashCommandEvent::StartSection {
354                        icon: IconName::Check,
355                        label: "Fruit".into(),
356                        metadata: None
357                    },
358                    SlashCommandEvent::Content(SlashCommandContent::Text {
359                        text: "Apple\n".into(),
360                        run_commands_in_text: false
361                    }),
362                    SlashCommandEvent::EndSection,
363                    SlashCommandEvent::Content(SlashCommandContent::Text {
364                        text: "Cucumber\n".into(),
365                        run_commands_in_text: false
366                    }),
367                    SlashCommandEvent::StartSection {
368                        icon: IconName::Check,
369                        label: "Fruit".into(),
370                        metadata: None
371                    },
372                    SlashCommandEvent::Content(SlashCommandContent::Text {
373                        text: "Banana\n".into(),
374                        run_commands_in_text: false
375                    }),
376                    SlashCommandEvent::EndSection
377                ]
378            );
379
380            let new_output =
381                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
382                    .await
383                    .unwrap();
384
385            assert_eq!(new_output, output);
386        }
387
388        // Test output consisting of multiple sections.
389        {
390            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
391            let output = SlashCommandOutput {
392                text,
393                sections: vec![
394                    SlashCommandOutputSection {
395                        range: 0..6,
396                        icon: IconName::FileCode,
397                        label: "Section 1".into(),
398                        metadata: Some(json!({ "a": true })),
399                    },
400                    SlashCommandOutputSection {
401                        range: 7..13,
402                        icon: IconName::FileDoc,
403                        label: "Section 2".into(),
404                        metadata: Some(json!({ "b": true })),
405                    },
406                    SlashCommandOutputSection {
407                        range: 14..20,
408                        icon: IconName::FileGit,
409                        label: "Section 3".into(),
410                        metadata: Some(json!({ "c": true })),
411                    },
412                    SlashCommandOutputSection {
413                        range: 21..27,
414                        icon: IconName::FileToml,
415                        label: "Section 4".into(),
416                        metadata: Some(json!({ "d": true })),
417                    },
418                ],
419                run_commands_in_text: false,
420            };
421
422            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
423            let events = events
424                .into_iter()
425                .filter_map(|event| event.ok())
426                .collect::<Vec<_>>();
427
428            assert_eq!(
429                events,
430                vec![
431                    SlashCommandEvent::StartSection {
432                        icon: IconName::FileCode,
433                        label: "Section 1".into(),
434                        metadata: Some(json!({ "a": true }))
435                    },
436                    SlashCommandEvent::Content(SlashCommandContent::Text {
437                        text: "Line 1".into(),
438                        run_commands_in_text: false
439                    }),
440                    SlashCommandEvent::EndSection,
441                    SlashCommandEvent::Content(SlashCommandContent::Text {
442                        text: "\n".into(),
443                        run_commands_in_text: false
444                    }),
445                    SlashCommandEvent::StartSection {
446                        icon: IconName::FileDoc,
447                        label: "Section 2".into(),
448                        metadata: Some(json!({ "b": true }))
449                    },
450                    SlashCommandEvent::Content(SlashCommandContent::Text {
451                        text: "Line 2".into(),
452                        run_commands_in_text: false
453                    }),
454                    SlashCommandEvent::EndSection,
455                    SlashCommandEvent::Content(SlashCommandContent::Text {
456                        text: "\n".into(),
457                        run_commands_in_text: false
458                    }),
459                    SlashCommandEvent::StartSection {
460                        icon: IconName::FileGit,
461                        label: "Section 3".into(),
462                        metadata: Some(json!({ "c": true }))
463                    },
464                    SlashCommandEvent::Content(SlashCommandContent::Text {
465                        text: "Line 3".into(),
466                        run_commands_in_text: false
467                    }),
468                    SlashCommandEvent::EndSection,
469                    SlashCommandEvent::Content(SlashCommandContent::Text {
470                        text: "\n".into(),
471                        run_commands_in_text: false
472                    }),
473                    SlashCommandEvent::StartSection {
474                        icon: IconName::FileToml,
475                        label: "Section 4".into(),
476                        metadata: Some(json!({ "d": true }))
477                    },
478                    SlashCommandEvent::Content(SlashCommandContent::Text {
479                        text: "Line 4".into(),
480                        run_commands_in_text: false
481                    }),
482                    SlashCommandEvent::EndSection,
483                    SlashCommandEvent::Content(SlashCommandContent::Text {
484                        text: "\n".into(),
485                        run_commands_in_text: false
486                    }),
487                ]
488            );
489
490            let new_output =
491                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
492                    .await
493                    .unwrap();
494
495            assert_eq!(new_output, output);
496        }
497    }
498}