assistant_slash_command.rs

  1mod slash_command_registry;
  2
  3use anyhow::Result;
  4use futures::stream::{self, BoxStream};
  5use futures::StreamExt;
  6use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
  7use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
  8pub use language_model::Role;
  9use serde::{Deserialize, Serialize};
 10pub use slash_command_registry::*;
 11use std::{
 12    ops::Range,
 13    sync::{atomic::AtomicBool, Arc},
 14};
 15use workspace::{ui::IconName, Workspace};
 16
 17pub fn init(cx: &mut AppContext) {
 18    SlashCommandRegistry::default_global(cx);
 19}
 20
 21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 22pub enum AfterCompletion {
 23    /// Run the command
 24    Run,
 25    /// Continue composing the current argument, doesn't add a space
 26    Compose,
 27    /// Continue the command composition, adds a space
 28    Continue,
 29}
 30
 31impl From<bool> for AfterCompletion {
 32    fn from(value: bool) -> Self {
 33        if value {
 34            AfterCompletion::Run
 35        } else {
 36            AfterCompletion::Continue
 37        }
 38    }
 39}
 40
 41impl AfterCompletion {
 42    pub fn run(&self) -> bool {
 43        match self {
 44            AfterCompletion::Run => true,
 45            AfterCompletion::Compose | AfterCompletion::Continue => false,
 46        }
 47    }
 48}
 49
 50#[derive(Debug)]
 51pub struct ArgumentCompletion {
 52    /// The label to display for this completion.
 53    pub label: CodeLabel,
 54    /// The new text that should be inserted into the command when this completion is accepted.
 55    pub new_text: String,
 56    /// Whether the command should be run when accepting this completion.
 57    pub after_completion: AfterCompletion,
 58    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 59    pub replace_previous_arguments: bool,
 60}
 61
 62pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 63
 64pub trait SlashCommand: 'static + Send + Sync {
 65    fn name(&self) -> String;
 66    fn icon(&self) -> IconName {
 67        IconName::Slash
 68    }
 69    fn label(&self, _cx: &AppContext) -> CodeLabel {
 70        CodeLabel::plain(self.name(), None)
 71    }
 72    fn description(&self) -> String;
 73    fn menu_text(&self) -> String;
 74    fn complete_argument(
 75        self: Arc<Self>,
 76        arguments: &[String],
 77        cancel: Arc<AtomicBool>,
 78        workspace: Option<WeakView<Workspace>>,
 79        cx: &mut WindowContext,
 80    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 81    fn requires_argument(&self) -> bool;
 82    fn accepts_arguments(&self) -> bool {
 83        self.requires_argument()
 84    }
 85    fn run(
 86        self: Arc<Self>,
 87        arguments: &[String],
 88        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 89        context_buffer: BufferSnapshot,
 90        workspace: WeakView<Workspace>,
 91        // TODO: We're just using the `LspAdapterDelegate` here because that is
 92        // what the extension API is already expecting.
 93        //
 94        // It may be that `LspAdapterDelegate` needs a more general name, or
 95        // perhaps another kind of delegate is needed here.
 96        delegate: Option<Arc<dyn LspAdapterDelegate>>,
 97        cx: &mut WindowContext,
 98    ) -> Task<SlashCommandResult>;
 99}
100
101pub type RenderFoldPlaceholder = Arc<
102    dyn Send
103        + Sync
104        + Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
105>;
106
107#[derive(Debug, PartialEq)]
108pub enum SlashCommandContent {
109    Text {
110        text: String,
111        run_commands_in_text: bool,
112    },
113}
114
115impl<'a> From<&'a str> for SlashCommandContent {
116    fn from(text: &'a str) -> Self {
117        Self::Text {
118            text: text.into(),
119            run_commands_in_text: false,
120        }
121    }
122}
123
124#[derive(Debug, PartialEq)]
125pub enum SlashCommandEvent {
126    StartMessage {
127        role: Role,
128        merge_same_roles: bool,
129    },
130    StartSection {
131        icon: IconName,
132        label: SharedString,
133        metadata: Option<serde_json::Value>,
134    },
135    Content(SlashCommandContent),
136    EndSection,
137}
138
139#[derive(Debug, Default, PartialEq, Clone)]
140pub struct SlashCommandOutput {
141    pub text: String,
142    pub sections: Vec<SlashCommandOutputSection<usize>>,
143    pub run_commands_in_text: bool,
144}
145
146impl SlashCommandOutput {
147    pub fn ensure_valid_section_ranges(&mut self) {
148        for section in &mut self.sections {
149            section.range.start = section.range.start.min(self.text.len());
150            section.range.end = section.range.end.min(self.text.len());
151            while !self.text.is_char_boundary(section.range.start) {
152                section.range.start -= 1;
153            }
154            while !self.text.is_char_boundary(section.range.end) {
155                section.range.end += 1;
156            }
157        }
158    }
159
160    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
161    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
162        self.ensure_valid_section_ranges();
163
164        let mut events = Vec::new();
165
166        let mut section_endpoints = Vec::new();
167        for section in self.sections {
168            section_endpoints.push((
169                section.range.start,
170                SlashCommandEvent::StartSection {
171                    icon: section.icon,
172                    label: section.label,
173                    metadata: section.metadata,
174                },
175            ));
176            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
177        }
178        section_endpoints.sort_by_key(|(offset, _)| *offset);
179
180        let mut content_offset = 0;
181        for (endpoint_offset, endpoint) in section_endpoints {
182            if content_offset < endpoint_offset {
183                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
184                    text: self.text[content_offset..endpoint_offset].to_string(),
185                    run_commands_in_text: self.run_commands_in_text,
186                })));
187                content_offset = endpoint_offset;
188            }
189
190            events.push(Ok(endpoint));
191        }
192
193        if content_offset < self.text.len() {
194            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
195                text: self.text[content_offset..].to_string(),
196                run_commands_in_text: self.run_commands_in_text,
197            })));
198        }
199
200        stream::iter(events).boxed()
201    }
202
203    pub async fn from_event_stream(
204        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
205    ) -> Result<SlashCommandOutput> {
206        let mut output = SlashCommandOutput::default();
207        let mut section_stack = Vec::new();
208
209        while let Some(event) = events.next().await {
210            match event? {
211                SlashCommandEvent::StartSection {
212                    icon,
213                    label,
214                    metadata,
215                } => {
216                    let start = output.text.len();
217                    section_stack.push(SlashCommandOutputSection {
218                        range: start..start,
219                        icon,
220                        label,
221                        metadata,
222                    });
223                }
224                SlashCommandEvent::Content(SlashCommandContent::Text {
225                    text,
226                    run_commands_in_text,
227                }) => {
228                    output.text.push_str(&text);
229                    output.run_commands_in_text = run_commands_in_text;
230
231                    if let Some(section) = section_stack.last_mut() {
232                        section.range.end = output.text.len();
233                    }
234                }
235                SlashCommandEvent::EndSection => {
236                    if let Some(section) = section_stack.pop() {
237                        output.sections.push(section);
238                    }
239                }
240                SlashCommandEvent::StartMessage { .. } => {}
241            }
242        }
243
244        while let Some(section) = section_stack.pop() {
245            output.sections.push(section);
246        }
247
248        Ok(output)
249    }
250}
251
252#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
253pub struct SlashCommandOutputSection<T> {
254    pub range: Range<T>,
255    pub icon: IconName,
256    pub label: SharedString,
257    pub metadata: Option<serde_json::Value>,
258}
259
260impl SlashCommandOutputSection<language::Anchor> {
261    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
262        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use pretty_assertions::assert_eq;
269    use serde_json::json;
270
271    use super::*;
272
273    #[gpui::test]
274    async fn test_slash_command_output_to_events_round_trip() {
275        // Test basic output consisting of a single section.
276        {
277            let text = "Hello, world!".to_string();
278            let range = 0..text.len();
279            let output = SlashCommandOutput {
280                text,
281                sections: vec![SlashCommandOutputSection {
282                    range,
283                    icon: IconName::Code,
284                    label: "Section 1".into(),
285                    metadata: None,
286                }],
287                run_commands_in_text: false,
288            };
289
290            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
291            let events = events
292                .into_iter()
293                .filter_map(|event| event.ok())
294                .collect::<Vec<_>>();
295
296            assert_eq!(
297                events,
298                vec![
299                    SlashCommandEvent::StartSection {
300                        icon: IconName::Code,
301                        label: "Section 1".into(),
302                        metadata: None
303                    },
304                    SlashCommandEvent::Content(SlashCommandContent::Text {
305                        text: "Hello, world!".into(),
306                        run_commands_in_text: false
307                    }),
308                    SlashCommandEvent::EndSection
309                ]
310            );
311
312            let new_output =
313                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
314                    .await
315                    .unwrap();
316
317            assert_eq!(new_output, output);
318        }
319
320        // Test output where the sections do not comprise all of the text.
321        {
322            let text = "Apple\nCucumber\nBanana\n".to_string();
323            let output = SlashCommandOutput {
324                text,
325                sections: vec![
326                    SlashCommandOutputSection {
327                        range: 0..6,
328                        icon: IconName::Check,
329                        label: "Fruit".into(),
330                        metadata: None,
331                    },
332                    SlashCommandOutputSection {
333                        range: 15..22,
334                        icon: IconName::Check,
335                        label: "Fruit".into(),
336                        metadata: None,
337                    },
338                ],
339                run_commands_in_text: false,
340            };
341
342            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
343            let events = events
344                .into_iter()
345                .filter_map(|event| event.ok())
346                .collect::<Vec<_>>();
347
348            assert_eq!(
349                events,
350                vec![
351                    SlashCommandEvent::StartSection {
352                        icon: IconName::Check,
353                        label: "Fruit".into(),
354                        metadata: None
355                    },
356                    SlashCommandEvent::Content(SlashCommandContent::Text {
357                        text: "Apple\n".into(),
358                        run_commands_in_text: false
359                    }),
360                    SlashCommandEvent::EndSection,
361                    SlashCommandEvent::Content(SlashCommandContent::Text {
362                        text: "Cucumber\n".into(),
363                        run_commands_in_text: false
364                    }),
365                    SlashCommandEvent::StartSection {
366                        icon: IconName::Check,
367                        label: "Fruit".into(),
368                        metadata: None
369                    },
370                    SlashCommandEvent::Content(SlashCommandContent::Text {
371                        text: "Banana\n".into(),
372                        run_commands_in_text: false
373                    }),
374                    SlashCommandEvent::EndSection
375                ]
376            );
377
378            let new_output =
379                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
380                    .await
381                    .unwrap();
382
383            assert_eq!(new_output, output);
384        }
385
386        // Test output consisting of multiple sections.
387        {
388            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
389            let output = SlashCommandOutput {
390                text,
391                sections: vec![
392                    SlashCommandOutputSection {
393                        range: 0..6,
394                        icon: IconName::FileCode,
395                        label: "Section 1".into(),
396                        metadata: Some(json!({ "a": true })),
397                    },
398                    SlashCommandOutputSection {
399                        range: 7..13,
400                        icon: IconName::FileDoc,
401                        label: "Section 2".into(),
402                        metadata: Some(json!({ "b": true })),
403                    },
404                    SlashCommandOutputSection {
405                        range: 14..20,
406                        icon: IconName::FileGit,
407                        label: "Section 3".into(),
408                        metadata: Some(json!({ "c": true })),
409                    },
410                    SlashCommandOutputSection {
411                        range: 21..27,
412                        icon: IconName::FileToml,
413                        label: "Section 4".into(),
414                        metadata: Some(json!({ "d": true })),
415                    },
416                ],
417                run_commands_in_text: false,
418            };
419
420            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
421            let events = events
422                .into_iter()
423                .filter_map(|event| event.ok())
424                .collect::<Vec<_>>();
425
426            assert_eq!(
427                events,
428                vec![
429                    SlashCommandEvent::StartSection {
430                        icon: IconName::FileCode,
431                        label: "Section 1".into(),
432                        metadata: Some(json!({ "a": true }))
433                    },
434                    SlashCommandEvent::Content(SlashCommandContent::Text {
435                        text: "Line 1".into(),
436                        run_commands_in_text: false
437                    }),
438                    SlashCommandEvent::EndSection,
439                    SlashCommandEvent::Content(SlashCommandContent::Text {
440                        text: "\n".into(),
441                        run_commands_in_text: false
442                    }),
443                    SlashCommandEvent::StartSection {
444                        icon: IconName::FileDoc,
445                        label: "Section 2".into(),
446                        metadata: Some(json!({ "b": true }))
447                    },
448                    SlashCommandEvent::Content(SlashCommandContent::Text {
449                        text: "Line 2".into(),
450                        run_commands_in_text: false
451                    }),
452                    SlashCommandEvent::EndSection,
453                    SlashCommandEvent::Content(SlashCommandContent::Text {
454                        text: "\n".into(),
455                        run_commands_in_text: false
456                    }),
457                    SlashCommandEvent::StartSection {
458                        icon: IconName::FileGit,
459                        label: "Section 3".into(),
460                        metadata: Some(json!({ "c": true }))
461                    },
462                    SlashCommandEvent::Content(SlashCommandContent::Text {
463                        text: "Line 3".into(),
464                        run_commands_in_text: false
465                    }),
466                    SlashCommandEvent::EndSection,
467                    SlashCommandEvent::Content(SlashCommandContent::Text {
468                        text: "\n".into(),
469                        run_commands_in_text: false
470                    }),
471                    SlashCommandEvent::StartSection {
472                        icon: IconName::FileToml,
473                        label: "Section 4".into(),
474                        metadata: Some(json!({ "d": true }))
475                    },
476                    SlashCommandEvent::Content(SlashCommandContent::Text {
477                        text: "Line 4".into(),
478                        run_commands_in_text: false
479                    }),
480                    SlashCommandEvent::EndSection,
481                    SlashCommandEvent::Content(SlashCommandContent::Text {
482                        text: "\n".into(),
483                        run_commands_in_text: false
484                    }),
485                ]
486            );
487
488            let new_output =
489                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
490                    .await
491                    .unwrap();
492
493            assert_eq!(new_output, output);
494        }
495    }
496}