assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3
  4pub use crate::extension_slash_command::*;
  5pub use crate::slash_command_registry::*;
  6use anyhow::Result;
  7use futures::stream::{self, BoxStream};
  8use futures::StreamExt;
  9use gpui::{AnyElement, AppContext, ElementId, SharedString, Task, WeakView, WindowContext};
 10use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 11pub use language_model::Role;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    ops::Range,
 15    sync::{atomic::AtomicBool, Arc},
 16};
 17use workspace::{ui::IconName, Workspace};
 18
 19pub fn init(cx: &mut AppContext) {
 20    SlashCommandRegistry::default_global(cx);
 21    extension_slash_command::init(cx);
 22}
 23
 24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 25pub enum AfterCompletion {
 26    /// Run the command
 27    Run,
 28    /// Continue composing the current argument, doesn't add a space
 29    Compose,
 30    /// Continue the command composition, adds a space
 31    Continue,
 32}
 33
 34impl From<bool> for AfterCompletion {
 35    fn from(value: bool) -> Self {
 36        if value {
 37            AfterCompletion::Run
 38        } else {
 39            AfterCompletion::Continue
 40        }
 41    }
 42}
 43
 44impl AfterCompletion {
 45    pub fn run(&self) -> bool {
 46        match self {
 47            AfterCompletion::Run => true,
 48            AfterCompletion::Compose | AfterCompletion::Continue => false,
 49        }
 50    }
 51}
 52
 53#[derive(Debug)]
 54pub struct ArgumentCompletion {
 55    /// The label to display for this completion.
 56    pub label: CodeLabel,
 57    /// The new text that should be inserted into the command when this completion is accepted.
 58    pub new_text: String,
 59    /// Whether the command should be run when accepting this completion.
 60    pub after_completion: AfterCompletion,
 61    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 62    pub replace_previous_arguments: bool,
 63}
 64
 65pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 66
 67pub trait SlashCommand: 'static + Send + Sync {
 68    fn name(&self) -> String;
 69    fn icon(&self) -> IconName {
 70        IconName::Slash
 71    }
 72    fn label(&self, _cx: &AppContext) -> CodeLabel {
 73        CodeLabel::plain(self.name(), None)
 74    }
 75    fn description(&self) -> String;
 76    fn menu_text(&self) -> String;
 77    fn complete_argument(
 78        self: Arc<Self>,
 79        arguments: &[String],
 80        cancel: Arc<AtomicBool>,
 81        workspace: Option<WeakView<Workspace>>,
 82        cx: &mut WindowContext,
 83    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 84    fn requires_argument(&self) -> bool;
 85    fn accepts_arguments(&self) -> bool {
 86        self.requires_argument()
 87    }
 88    fn run(
 89        self: Arc<Self>,
 90        arguments: &[String],
 91        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 92        context_buffer: BufferSnapshot,
 93        workspace: WeakView<Workspace>,
 94        // TODO: We're just using the `LspAdapterDelegate` here because that is
 95        // what the extension API is already expecting.
 96        //
 97        // It may be that `LspAdapterDelegate` needs a more general name, or
 98        // perhaps another kind of delegate is needed here.
 99        delegate: Option<Arc<dyn LspAdapterDelegate>>,
100        cx: &mut WindowContext,
101    ) -> Task<SlashCommandResult>;
102}
103
104pub type RenderFoldPlaceholder = Arc<
105    dyn Send
106        + Sync
107        + Fn(ElementId, Arc<dyn Fn(&mut WindowContext)>, &mut WindowContext) -> AnyElement,
108>;
109
110#[derive(Debug, PartialEq)]
111pub enum SlashCommandContent {
112    Text {
113        text: String,
114        run_commands_in_text: bool,
115    },
116}
117
118impl<'a> From<&'a str> for SlashCommandContent {
119    fn from(text: &'a str) -> Self {
120        Self::Text {
121            text: text.into(),
122            run_commands_in_text: false,
123        }
124    }
125}
126
127#[derive(Debug, PartialEq)]
128pub enum SlashCommandEvent {
129    StartMessage {
130        role: Role,
131        merge_same_roles: bool,
132    },
133    StartSection {
134        icon: IconName,
135        label: SharedString,
136        metadata: Option<serde_json::Value>,
137    },
138    Content(SlashCommandContent),
139    EndSection,
140}
141
142#[derive(Debug, Default, PartialEq, Clone)]
143pub struct SlashCommandOutput {
144    pub text: String,
145    pub sections: Vec<SlashCommandOutputSection<usize>>,
146    pub run_commands_in_text: bool,
147}
148
149impl SlashCommandOutput {
150    pub fn ensure_valid_section_ranges(&mut self) {
151        for section in &mut self.sections {
152            section.range.start = section.range.start.min(self.text.len());
153            section.range.end = section.range.end.min(self.text.len());
154            while !self.text.is_char_boundary(section.range.start) {
155                section.range.start -= 1;
156            }
157            while !self.text.is_char_boundary(section.range.end) {
158                section.range.end += 1;
159            }
160        }
161    }
162
163    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
164    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
165        self.ensure_valid_section_ranges();
166
167        let mut events = Vec::new();
168
169        let mut section_endpoints = Vec::new();
170        for section in self.sections {
171            section_endpoints.push((
172                section.range.start,
173                SlashCommandEvent::StartSection {
174                    icon: section.icon,
175                    label: section.label,
176                    metadata: section.metadata,
177                },
178            ));
179            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
180        }
181        section_endpoints.sort_by_key(|(offset, _)| *offset);
182
183        let mut content_offset = 0;
184        for (endpoint_offset, endpoint) in section_endpoints {
185            if content_offset < endpoint_offset {
186                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
187                    text: self.text[content_offset..endpoint_offset].to_string(),
188                    run_commands_in_text: self.run_commands_in_text,
189                })));
190                content_offset = endpoint_offset;
191            }
192
193            events.push(Ok(endpoint));
194        }
195
196        if content_offset < self.text.len() {
197            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
198                text: self.text[content_offset..].to_string(),
199                run_commands_in_text: self.run_commands_in_text,
200            })));
201        }
202
203        stream::iter(events).boxed()
204    }
205
206    pub async fn from_event_stream(
207        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
208    ) -> Result<SlashCommandOutput> {
209        let mut output = SlashCommandOutput::default();
210        let mut section_stack = Vec::new();
211
212        while let Some(event) = events.next().await {
213            match event? {
214                SlashCommandEvent::StartSection {
215                    icon,
216                    label,
217                    metadata,
218                } => {
219                    let start = output.text.len();
220                    section_stack.push(SlashCommandOutputSection {
221                        range: start..start,
222                        icon,
223                        label,
224                        metadata,
225                    });
226                }
227                SlashCommandEvent::Content(SlashCommandContent::Text {
228                    text,
229                    run_commands_in_text,
230                }) => {
231                    output.text.push_str(&text);
232                    output.run_commands_in_text = run_commands_in_text;
233
234                    if let Some(section) = section_stack.last_mut() {
235                        section.range.end = output.text.len();
236                    }
237                }
238                SlashCommandEvent::EndSection => {
239                    if let Some(section) = section_stack.pop() {
240                        output.sections.push(section);
241                    }
242                }
243                SlashCommandEvent::StartMessage { .. } => {}
244            }
245        }
246
247        while let Some(section) = section_stack.pop() {
248            output.sections.push(section);
249        }
250
251        Ok(output)
252    }
253}
254
255#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
256pub struct SlashCommandOutputSection<T> {
257    pub range: Range<T>,
258    pub icon: IconName,
259    pub label: SharedString,
260    pub metadata: Option<serde_json::Value>,
261}
262
263impl SlashCommandOutputSection<language::Anchor> {
264    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
265        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use pretty_assertions::assert_eq;
272    use serde_json::json;
273
274    use super::*;
275
276    #[gpui::test]
277    async fn test_slash_command_output_to_events_round_trip() {
278        // Test basic output consisting of a single section.
279        {
280            let text = "Hello, world!".to_string();
281            let range = 0..text.len();
282            let output = SlashCommandOutput {
283                text,
284                sections: vec![SlashCommandOutputSection {
285                    range,
286                    icon: IconName::Code,
287                    label: "Section 1".into(),
288                    metadata: None,
289                }],
290                run_commands_in_text: false,
291            };
292
293            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
294            let events = events
295                .into_iter()
296                .filter_map(|event| event.ok())
297                .collect::<Vec<_>>();
298
299            assert_eq!(
300                events,
301                vec![
302                    SlashCommandEvent::StartSection {
303                        icon: IconName::Code,
304                        label: "Section 1".into(),
305                        metadata: None
306                    },
307                    SlashCommandEvent::Content(SlashCommandContent::Text {
308                        text: "Hello, world!".into(),
309                        run_commands_in_text: false
310                    }),
311                    SlashCommandEvent::EndSection
312                ]
313            );
314
315            let new_output =
316                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
317                    .await
318                    .unwrap();
319
320            assert_eq!(new_output, output);
321        }
322
323        // Test output where the sections do not comprise all of the text.
324        {
325            let text = "Apple\nCucumber\nBanana\n".to_string();
326            let output = SlashCommandOutput {
327                text,
328                sections: vec![
329                    SlashCommandOutputSection {
330                        range: 0..6,
331                        icon: IconName::Check,
332                        label: "Fruit".into(),
333                        metadata: None,
334                    },
335                    SlashCommandOutputSection {
336                        range: 15..22,
337                        icon: IconName::Check,
338                        label: "Fruit".into(),
339                        metadata: None,
340                    },
341                ],
342                run_commands_in_text: false,
343            };
344
345            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
346            let events = events
347                .into_iter()
348                .filter_map(|event| event.ok())
349                .collect::<Vec<_>>();
350
351            assert_eq!(
352                events,
353                vec![
354                    SlashCommandEvent::StartSection {
355                        icon: IconName::Check,
356                        label: "Fruit".into(),
357                        metadata: None
358                    },
359                    SlashCommandEvent::Content(SlashCommandContent::Text {
360                        text: "Apple\n".into(),
361                        run_commands_in_text: false
362                    }),
363                    SlashCommandEvent::EndSection,
364                    SlashCommandEvent::Content(SlashCommandContent::Text {
365                        text: "Cucumber\n".into(),
366                        run_commands_in_text: false
367                    }),
368                    SlashCommandEvent::StartSection {
369                        icon: IconName::Check,
370                        label: "Fruit".into(),
371                        metadata: None
372                    },
373                    SlashCommandEvent::Content(SlashCommandContent::Text {
374                        text: "Banana\n".into(),
375                        run_commands_in_text: false
376                    }),
377                    SlashCommandEvent::EndSection
378                ]
379            );
380
381            let new_output =
382                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
383                    .await
384                    .unwrap();
385
386            assert_eq!(new_output, output);
387        }
388
389        // Test output consisting of multiple sections.
390        {
391            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
392            let output = SlashCommandOutput {
393                text,
394                sections: vec![
395                    SlashCommandOutputSection {
396                        range: 0..6,
397                        icon: IconName::FileCode,
398                        label: "Section 1".into(),
399                        metadata: Some(json!({ "a": true })),
400                    },
401                    SlashCommandOutputSection {
402                        range: 7..13,
403                        icon: IconName::FileDoc,
404                        label: "Section 2".into(),
405                        metadata: Some(json!({ "b": true })),
406                    },
407                    SlashCommandOutputSection {
408                        range: 14..20,
409                        icon: IconName::FileGit,
410                        label: "Section 3".into(),
411                        metadata: Some(json!({ "c": true })),
412                    },
413                    SlashCommandOutputSection {
414                        range: 21..27,
415                        icon: IconName::FileToml,
416                        label: "Section 4".into(),
417                        metadata: Some(json!({ "d": true })),
418                    },
419                ],
420                run_commands_in_text: false,
421            };
422
423            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
424            let events = events
425                .into_iter()
426                .filter_map(|event| event.ok())
427                .collect::<Vec<_>>();
428
429            assert_eq!(
430                events,
431                vec![
432                    SlashCommandEvent::StartSection {
433                        icon: IconName::FileCode,
434                        label: "Section 1".into(),
435                        metadata: Some(json!({ "a": true }))
436                    },
437                    SlashCommandEvent::Content(SlashCommandContent::Text {
438                        text: "Line 1".into(),
439                        run_commands_in_text: false
440                    }),
441                    SlashCommandEvent::EndSection,
442                    SlashCommandEvent::Content(SlashCommandContent::Text {
443                        text: "\n".into(),
444                        run_commands_in_text: false
445                    }),
446                    SlashCommandEvent::StartSection {
447                        icon: IconName::FileDoc,
448                        label: "Section 2".into(),
449                        metadata: Some(json!({ "b": true }))
450                    },
451                    SlashCommandEvent::Content(SlashCommandContent::Text {
452                        text: "Line 2".into(),
453                        run_commands_in_text: false
454                    }),
455                    SlashCommandEvent::EndSection,
456                    SlashCommandEvent::Content(SlashCommandContent::Text {
457                        text: "\n".into(),
458                        run_commands_in_text: false
459                    }),
460                    SlashCommandEvent::StartSection {
461                        icon: IconName::FileGit,
462                        label: "Section 3".into(),
463                        metadata: Some(json!({ "c": true }))
464                    },
465                    SlashCommandEvent::Content(SlashCommandContent::Text {
466                        text: "Line 3".into(),
467                        run_commands_in_text: false
468                    }),
469                    SlashCommandEvent::EndSection,
470                    SlashCommandEvent::Content(SlashCommandContent::Text {
471                        text: "\n".into(),
472                        run_commands_in_text: false
473                    }),
474                    SlashCommandEvent::StartSection {
475                        icon: IconName::FileToml,
476                        label: "Section 4".into(),
477                        metadata: Some(json!({ "d": true }))
478                    },
479                    SlashCommandEvent::Content(SlashCommandContent::Text {
480                        text: "Line 4".into(),
481                        run_commands_in_text: false
482                    }),
483                    SlashCommandEvent::EndSection,
484                    SlashCommandEvent::Content(SlashCommandContent::Text {
485                        text: "\n".into(),
486                        run_commands_in_text: false
487                    }),
488                ]
489            );
490
491            let new_output =
492                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
493                    .await
494                    .unwrap();
495
496            assert_eq!(new_output, output);
497        }
498    }
499}