assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3mod slash_command_working_set;
  4
  5pub use crate::extension_slash_command::*;
  6pub use crate::slash_command_registry::*;
  7pub use crate::slash_command_working_set::*;
  8use anyhow::Result;
  9use futures::stream::{self, BoxStream};
 10use futures::StreamExt;
 11use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 13pub use language_model::Role;
 14use serde::{Deserialize, Serialize};
 15use std::{
 16    ops::Range,
 17    sync::{atomic::AtomicBool, Arc},
 18};
 19use workspace::{ui::IconName, Workspace};
 20
 21pub fn init(cx: &mut AppContext) {
 22    SlashCommandRegistry::default_global(cx);
 23    extension_slash_command::init(cx);
 24}
 25
 26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 27pub enum AfterCompletion {
 28    /// Run the command
 29    Run,
 30    /// Continue composing the current argument, doesn't add a space
 31    Compose,
 32    /// Continue the command composition, adds a space
 33    Continue,
 34}
 35
 36impl From<bool> for AfterCompletion {
 37    fn from(value: bool) -> Self {
 38        if value {
 39            AfterCompletion::Run
 40        } else {
 41            AfterCompletion::Continue
 42        }
 43    }
 44}
 45
 46impl AfterCompletion {
 47    pub fn run(&self) -> bool {
 48        match self {
 49            AfterCompletion::Run => true,
 50            AfterCompletion::Compose | AfterCompletion::Continue => false,
 51        }
 52    }
 53}
 54
 55#[derive(Debug)]
 56pub struct ArgumentCompletion {
 57    /// The label to display for this completion.
 58    pub label: CodeLabel,
 59    /// The new text that should be inserted into the command when this completion is accepted.
 60    pub new_text: String,
 61    /// Whether the command should be run when accepting this completion.
 62    pub after_completion: AfterCompletion,
 63    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 64    pub replace_previous_arguments: bool,
 65}
 66
 67pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 68
 69pub trait SlashCommand: 'static + Send + Sync {
 70    fn name(&self) -> String;
 71    fn icon(&self) -> IconName {
 72        IconName::Slash
 73    }
 74    fn label(&self, _cx: &AppContext) -> CodeLabel {
 75        CodeLabel::plain(self.name(), None)
 76    }
 77    fn description(&self) -> String;
 78    fn menu_text(&self) -> String;
 79    fn complete_argument(
 80        self: Arc<Self>,
 81        arguments: &[String],
 82        cancel: Arc<AtomicBool>,
 83        workspace: Option<WeakView<Workspace>>,
 84        cx: &mut WindowContext,
 85    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 86    fn requires_argument(&self) -> bool;
 87    fn accepts_arguments(&self) -> bool {
 88        self.requires_argument()
 89    }
 90    fn run(
 91        self: Arc<Self>,
 92        arguments: &[String],
 93        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 94        context_buffer: BufferSnapshot,
 95        workspace: WeakView<Workspace>,
 96        // TODO: We're just using the `LspAdapterDelegate` here because that is
 97        // what the extension API is already expecting.
 98        //
 99        // It may be that `LspAdapterDelegate` needs a more general name, or
100        // perhaps another kind of delegate is needed here.
101        delegate: Option<Arc<dyn LspAdapterDelegate>>,
102        cx: &mut WindowContext,
103    ) -> Task<SlashCommandResult>;
104}
105
106pub type RenderFoldPlaceholder = Arc<
107    dyn Send
108        + Sync
109        + Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
110>;
111
112#[derive(Debug, PartialEq)]
113pub enum SlashCommandContent {
114    Text {
115        text: String,
116        run_commands_in_text: bool,
117    },
118}
119
120impl<'a> From<&'a str> for SlashCommandContent {
121    fn from(text: &'a str) -> Self {
122        Self::Text {
123            text: text.into(),
124            run_commands_in_text: false,
125        }
126    }
127}
128
129#[derive(Debug, PartialEq)]
130pub enum SlashCommandEvent {
131    StartMessage {
132        role: Role,
133        merge_same_roles: bool,
134    },
135    StartSection {
136        icon: IconName,
137        label: SharedString,
138        metadata: Option<serde_json::Value>,
139    },
140    Content(SlashCommandContent),
141    EndSection,
142}
143
144#[derive(Debug, Default, PartialEq, Clone)]
145pub struct SlashCommandOutput {
146    pub text: String,
147    pub sections: Vec<SlashCommandOutputSection<usize>>,
148    pub run_commands_in_text: bool,
149}
150
151impl SlashCommandOutput {
152    pub fn ensure_valid_section_ranges(&mut self) {
153        for section in &mut self.sections {
154            section.range.start = section.range.start.min(self.text.len());
155            section.range.end = section.range.end.min(self.text.len());
156            while !self.text.is_char_boundary(section.range.start) {
157                section.range.start -= 1;
158            }
159            while !self.text.is_char_boundary(section.range.end) {
160                section.range.end += 1;
161            }
162        }
163    }
164
165    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
166    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
167        self.ensure_valid_section_ranges();
168
169        let mut events = Vec::new();
170
171        let mut section_endpoints = Vec::new();
172        for section in self.sections {
173            section_endpoints.push((
174                section.range.start,
175                SlashCommandEvent::StartSection {
176                    icon: section.icon,
177                    label: section.label,
178                    metadata: section.metadata,
179                },
180            ));
181            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
182        }
183        section_endpoints.sort_by_key(|(offset, _)| *offset);
184
185        let mut content_offset = 0;
186        for (endpoint_offset, endpoint) in section_endpoints {
187            if content_offset < endpoint_offset {
188                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
189                    text: self.text[content_offset..endpoint_offset].to_string(),
190                    run_commands_in_text: self.run_commands_in_text,
191                })));
192                content_offset = endpoint_offset;
193            }
194
195            events.push(Ok(endpoint));
196        }
197
198        if content_offset < self.text.len() {
199            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
200                text: self.text[content_offset..].to_string(),
201                run_commands_in_text: self.run_commands_in_text,
202            })));
203        }
204
205        stream::iter(events).boxed()
206    }
207
208    pub async fn from_event_stream(
209        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
210    ) -> Result<SlashCommandOutput> {
211        let mut output = SlashCommandOutput::default();
212        let mut section_stack = Vec::new();
213
214        while let Some(event) = events.next().await {
215            match event? {
216                SlashCommandEvent::StartSection {
217                    icon,
218                    label,
219                    metadata,
220                } => {
221                    let start = output.text.len();
222                    section_stack.push(SlashCommandOutputSection {
223                        range: start..start,
224                        icon,
225                        label,
226                        metadata,
227                    });
228                }
229                SlashCommandEvent::Content(SlashCommandContent::Text {
230                    text,
231                    run_commands_in_text,
232                }) => {
233                    output.text.push_str(&text);
234                    output.run_commands_in_text = run_commands_in_text;
235
236                    if let Some(section) = section_stack.last_mut() {
237                        section.range.end = output.text.len();
238                    }
239                }
240                SlashCommandEvent::EndSection => {
241                    if let Some(section) = section_stack.pop() {
242                        output.sections.push(section);
243                    }
244                }
245                SlashCommandEvent::StartMessage { .. } => {}
246            }
247        }
248
249        while let Some(section) = section_stack.pop() {
250            output.sections.push(section);
251        }
252
253        Ok(output)
254    }
255}
256
257#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
258pub struct SlashCommandOutputSection<T> {
259    pub range: Range<T>,
260    pub icon: IconName,
261    pub label: SharedString,
262    pub metadata: Option<serde_json::Value>,
263}
264
265impl SlashCommandOutputSection<language::Anchor> {
266    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
267        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use pretty_assertions::assert_eq;
274    use serde_json::json;
275
276    use super::*;
277
278    #[gpui::test]
279    async fn test_slash_command_output_to_events_round_trip() {
280        // Test basic output consisting of a single section.
281        {
282            let text = "Hello, world!".to_string();
283            let range = 0..text.len();
284            let output = SlashCommandOutput {
285                text,
286                sections: vec![SlashCommandOutputSection {
287                    range,
288                    icon: IconName::Code,
289                    label: "Section 1".into(),
290                    metadata: None,
291                }],
292                run_commands_in_text: false,
293            };
294
295            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
296            let events = events
297                .into_iter()
298                .filter_map(|event| event.ok())
299                .collect::<Vec<_>>();
300
301            assert_eq!(
302                events,
303                vec![
304                    SlashCommandEvent::StartSection {
305                        icon: IconName::Code,
306                        label: "Section 1".into(),
307                        metadata: None
308                    },
309                    SlashCommandEvent::Content(SlashCommandContent::Text {
310                        text: "Hello, world!".into(),
311                        run_commands_in_text: false
312                    }),
313                    SlashCommandEvent::EndSection
314                ]
315            );
316
317            let new_output =
318                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
319                    .await
320                    .unwrap();
321
322            assert_eq!(new_output, output);
323        }
324
325        // Test output where the sections do not comprise all of the text.
326        {
327            let text = "Apple\nCucumber\nBanana\n".to_string();
328            let output = SlashCommandOutput {
329                text,
330                sections: vec![
331                    SlashCommandOutputSection {
332                        range: 0..6,
333                        icon: IconName::Check,
334                        label: "Fruit".into(),
335                        metadata: None,
336                    },
337                    SlashCommandOutputSection {
338                        range: 15..22,
339                        icon: IconName::Check,
340                        label: "Fruit".into(),
341                        metadata: None,
342                    },
343                ],
344                run_commands_in_text: false,
345            };
346
347            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
348            let events = events
349                .into_iter()
350                .filter_map(|event| event.ok())
351                .collect::<Vec<_>>();
352
353            assert_eq!(
354                events,
355                vec![
356                    SlashCommandEvent::StartSection {
357                        icon: IconName::Check,
358                        label: "Fruit".into(),
359                        metadata: None
360                    },
361                    SlashCommandEvent::Content(SlashCommandContent::Text {
362                        text: "Apple\n".into(),
363                        run_commands_in_text: false
364                    }),
365                    SlashCommandEvent::EndSection,
366                    SlashCommandEvent::Content(SlashCommandContent::Text {
367                        text: "Cucumber\n".into(),
368                        run_commands_in_text: false
369                    }),
370                    SlashCommandEvent::StartSection {
371                        icon: IconName::Check,
372                        label: "Fruit".into(),
373                        metadata: None
374                    },
375                    SlashCommandEvent::Content(SlashCommandContent::Text {
376                        text: "Banana\n".into(),
377                        run_commands_in_text: false
378                    }),
379                    SlashCommandEvent::EndSection
380                ]
381            );
382
383            let new_output =
384                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
385                    .await
386                    .unwrap();
387
388            assert_eq!(new_output, output);
389        }
390
391        // Test output consisting of multiple sections.
392        {
393            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
394            let output = SlashCommandOutput {
395                text,
396                sections: vec![
397                    SlashCommandOutputSection {
398                        range: 0..6,
399                        icon: IconName::FileCode,
400                        label: "Section 1".into(),
401                        metadata: Some(json!({ "a": true })),
402                    },
403                    SlashCommandOutputSection {
404                        range: 7..13,
405                        icon: IconName::FileDoc,
406                        label: "Section 2".into(),
407                        metadata: Some(json!({ "b": true })),
408                    },
409                    SlashCommandOutputSection {
410                        range: 14..20,
411                        icon: IconName::FileGit,
412                        label: "Section 3".into(),
413                        metadata: Some(json!({ "c": true })),
414                    },
415                    SlashCommandOutputSection {
416                        range: 21..27,
417                        icon: IconName::FileToml,
418                        label: "Section 4".into(),
419                        metadata: Some(json!({ "d": true })),
420                    },
421                ],
422                run_commands_in_text: false,
423            };
424
425            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
426            let events = events
427                .into_iter()
428                .filter_map(|event| event.ok())
429                .collect::<Vec<_>>();
430
431            assert_eq!(
432                events,
433                vec![
434                    SlashCommandEvent::StartSection {
435                        icon: IconName::FileCode,
436                        label: "Section 1".into(),
437                        metadata: Some(json!({ "a": true }))
438                    },
439                    SlashCommandEvent::Content(SlashCommandContent::Text {
440                        text: "Line 1".into(),
441                        run_commands_in_text: false
442                    }),
443                    SlashCommandEvent::EndSection,
444                    SlashCommandEvent::Content(SlashCommandContent::Text {
445                        text: "\n".into(),
446                        run_commands_in_text: false
447                    }),
448                    SlashCommandEvent::StartSection {
449                        icon: IconName::FileDoc,
450                        label: "Section 2".into(),
451                        metadata: Some(json!({ "b": true }))
452                    },
453                    SlashCommandEvent::Content(SlashCommandContent::Text {
454                        text: "Line 2".into(),
455                        run_commands_in_text: false
456                    }),
457                    SlashCommandEvent::EndSection,
458                    SlashCommandEvent::Content(SlashCommandContent::Text {
459                        text: "\n".into(),
460                        run_commands_in_text: false
461                    }),
462                    SlashCommandEvent::StartSection {
463                        icon: IconName::FileGit,
464                        label: "Section 3".into(),
465                        metadata: Some(json!({ "c": true }))
466                    },
467                    SlashCommandEvent::Content(SlashCommandContent::Text {
468                        text: "Line 3".into(),
469                        run_commands_in_text: false
470                    }),
471                    SlashCommandEvent::EndSection,
472                    SlashCommandEvent::Content(SlashCommandContent::Text {
473                        text: "\n".into(),
474                        run_commands_in_text: false
475                    }),
476                    SlashCommandEvent::StartSection {
477                        icon: IconName::FileToml,
478                        label: "Section 4".into(),
479                        metadata: Some(json!({ "d": true }))
480                    },
481                    SlashCommandEvent::Content(SlashCommandContent::Text {
482                        text: "Line 4".into(),
483                        run_commands_in_text: false
484                    }),
485                    SlashCommandEvent::EndSection,
486                    SlashCommandEvent::Content(SlashCommandContent::Text {
487                        text: "\n".into(),
488                        run_commands_in_text: false
489                    }),
490                ]
491            );
492
493            let new_output =
494                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
495                    .await
496                    .unwrap();
497
498            assert_eq!(new_output, output);
499        }
500    }
501}