assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3mod slash_command_working_set;
  4
  5pub use crate::extension_slash_command::*;
  6pub use crate::slash_command_registry::*;
  7pub use crate::slash_command_working_set::*;
  8use anyhow::Result;
  9use futures::StreamExt;
 10use futures::stream::{self, BoxStream};
 11use gpui::{App, SharedString, Task, WeakEntity, Window};
 12use language::CodeLabelBuilder;
 13use language::HighlightId;
 14use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 15pub use language_model::Role;
 16use serde::{Deserialize, Deserializer, Serialize};
 17use std::{
 18    ops::Range,
 19    sync::{Arc, atomic::AtomicBool},
 20};
 21use ui::ActiveTheme;
 22use workspace::{Workspace, ui::IconName};
 23
 24/// Deserializes IconName, falling back to Code for unknown variants.
 25/// This handles old saved data that may contain removed or renamed icon variants.
 26fn deserialize_icon_with_fallback<'de, D>(deserializer: D) -> Result<IconName, D::Error>
 27where
 28    D: Deserializer<'de>,
 29{
 30    Ok(String::deserialize(deserializer)
 31        .ok()
 32        .and_then(|string| serde_json::from_value(serde_json::Value::String(string)).ok())
 33        .unwrap_or(IconName::Code))
 34}
 35
 36pub fn init(cx: &mut App) {
 37    SlashCommandRegistry::default_global(cx);
 38    extension_slash_command::init(cx);
 39}
 40
 41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 42pub enum AfterCompletion {
 43    /// Run the command
 44    Run,
 45    /// Continue composing the current argument, doesn't add a space
 46    Compose,
 47    /// Continue the command composition, adds a space
 48    Continue,
 49}
 50
 51impl From<bool> for AfterCompletion {
 52    fn from(value: bool) -> Self {
 53        if value {
 54            AfterCompletion::Run
 55        } else {
 56            AfterCompletion::Continue
 57        }
 58    }
 59}
 60
 61impl AfterCompletion {
 62    pub fn run(&self) -> bool {
 63        match self {
 64            AfterCompletion::Run => true,
 65            AfterCompletion::Compose | AfterCompletion::Continue => false,
 66        }
 67    }
 68}
 69
 70#[derive(Debug)]
 71pub struct ArgumentCompletion {
 72    /// The label to display for this completion.
 73    pub label: CodeLabel,
 74    /// The new text that should be inserted into the command when this completion is accepted.
 75    pub new_text: String,
 76    /// Whether the command should be run when accepting this completion.
 77    pub after_completion: AfterCompletion,
 78    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 79    pub replace_previous_arguments: bool,
 80}
 81
 82pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 83
 84pub trait SlashCommand: 'static + Send + Sync {
 85    fn name(&self) -> String;
 86    fn icon(&self) -> IconName {
 87        IconName::Slash
 88    }
 89    fn label(&self, _cx: &App) -> CodeLabel {
 90        CodeLabel::plain(self.name(), None)
 91    }
 92    fn description(&self) -> String;
 93    fn menu_text(&self) -> String;
 94    fn complete_argument(
 95        self: Arc<Self>,
 96        arguments: &[String],
 97        cancel: Arc<AtomicBool>,
 98        workspace: Option<WeakEntity<Workspace>>,
 99        window: &mut Window,
100        cx: &mut App,
101    ) -> Task<Result<Vec<ArgumentCompletion>>>;
102    fn requires_argument(&self) -> bool;
103    fn accepts_arguments(&self) -> bool {
104        self.requires_argument()
105    }
106    fn run(
107        self: Arc<Self>,
108        arguments: &[String],
109        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
110        context_buffer: BufferSnapshot,
111        workspace: WeakEntity<Workspace>,
112        // TODO: We're just using the `LspAdapterDelegate` here because that is
113        // what the extension API is already expecting.
114        //
115        // It may be that `LspAdapterDelegate` needs a more general name, or
116        // perhaps another kind of delegate is needed here.
117        delegate: Option<Arc<dyn LspAdapterDelegate>>,
118        window: &mut Window,
119        cx: &mut App,
120    ) -> Task<SlashCommandResult>;
121}
122
123#[derive(Debug, PartialEq)]
124pub enum SlashCommandContent {
125    Text {
126        text: String,
127        run_commands_in_text: bool,
128    },
129}
130
131impl<'a> From<&'a str> for SlashCommandContent {
132    fn from(text: &'a str) -> Self {
133        Self::Text {
134            text: text.into(),
135            run_commands_in_text: false,
136        }
137    }
138}
139
140#[derive(Debug, PartialEq)]
141pub enum SlashCommandEvent {
142    StartMessage {
143        role: Role,
144        merge_same_roles: bool,
145    },
146    StartSection {
147        icon: IconName,
148        label: SharedString,
149        metadata: Option<serde_json::Value>,
150    },
151    Content(SlashCommandContent),
152    EndSection,
153}
154
155#[derive(Debug, Default, PartialEq, Clone)]
156pub struct SlashCommandOutput {
157    pub text: String,
158    pub sections: Vec<SlashCommandOutputSection<usize>>,
159    pub run_commands_in_text: bool,
160}
161
162impl SlashCommandOutput {
163    pub fn ensure_valid_section_ranges(&mut self) {
164        for section in &mut self.sections {
165            section.range.start = section.range.start.min(self.text.len());
166            section.range.end = section.range.end.min(self.text.len());
167            while !self.text.is_char_boundary(section.range.start) {
168                section.range.start -= 1;
169            }
170            while !self.text.is_char_boundary(section.range.end) {
171                section.range.end += 1;
172            }
173        }
174    }
175
176    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
177    pub fn into_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
178        self.ensure_valid_section_ranges();
179
180        let mut events = Vec::new();
181
182        let mut section_endpoints = Vec::new();
183        for section in self.sections {
184            section_endpoints.push((
185                section.range.start,
186                SlashCommandEvent::StartSection {
187                    icon: section.icon,
188                    label: section.label,
189                    metadata: section.metadata,
190                },
191            ));
192            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
193        }
194        section_endpoints.sort_by_key(|(offset, _)| *offset);
195
196        let mut content_offset = 0;
197        for (endpoint_offset, endpoint) in section_endpoints {
198            if content_offset < endpoint_offset {
199                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
200                    text: self.text[content_offset..endpoint_offset].to_string(),
201                    run_commands_in_text: self.run_commands_in_text,
202                })));
203                content_offset = endpoint_offset;
204            }
205
206            events.push(Ok(endpoint));
207        }
208
209        if content_offset < self.text.len() {
210            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
211                text: self.text[content_offset..].to_string(),
212                run_commands_in_text: self.run_commands_in_text,
213            })));
214        }
215
216        stream::iter(events).boxed()
217    }
218
219    pub async fn from_event_stream(
220        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
221    ) -> Result<SlashCommandOutput> {
222        let mut output = SlashCommandOutput::default();
223        let mut section_stack = Vec::new();
224
225        while let Some(event) = events.next().await {
226            match event? {
227                SlashCommandEvent::StartSection {
228                    icon,
229                    label,
230                    metadata,
231                } => {
232                    let start = output.text.len();
233                    section_stack.push(SlashCommandOutputSection {
234                        range: start..start,
235                        icon,
236                        label,
237                        metadata,
238                    });
239                }
240                SlashCommandEvent::Content(SlashCommandContent::Text {
241                    text,
242                    run_commands_in_text,
243                }) => {
244                    output.text.push_str(&text);
245                    output.run_commands_in_text = run_commands_in_text;
246
247                    if let Some(section) = section_stack.last_mut() {
248                        section.range.end = output.text.len();
249                    }
250                }
251                SlashCommandEvent::EndSection => {
252                    if let Some(section) = section_stack.pop() {
253                        output.sections.push(section);
254                    }
255                }
256                SlashCommandEvent::StartMessage { .. } => {}
257            }
258        }
259
260        while let Some(section) = section_stack.pop() {
261            output.sections.push(section);
262        }
263
264        Ok(output)
265    }
266}
267
268#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
269pub struct SlashCommandOutputSection<T> {
270    pub range: Range<T>,
271    #[serde(deserialize_with = "deserialize_icon_with_fallback")]
272    pub icon: IconName,
273    pub label: SharedString,
274    pub metadata: Option<serde_json::Value>,
275}
276
277impl SlashCommandOutputSection<language::Anchor> {
278    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
279        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
280    }
281}
282
283pub struct SlashCommandLine {
284    /// The range within the line containing the command name.
285    pub name: Range<usize>,
286    /// Ranges within the line containing the command arguments.
287    pub arguments: Vec<Range<usize>>,
288}
289
290impl SlashCommandLine {
291    pub fn parse(line: &str) -> Option<Self> {
292        let mut call: Option<Self> = None;
293        let mut ix = 0;
294        for c in line.chars() {
295            let next_ix = ix + c.len_utf8();
296            if let Some(call) = &mut call {
297                // The command arguments start at the first non-whitespace character
298                // after the command name, and continue until the end of the line.
299                if let Some(argument) = call.arguments.last_mut() {
300                    if c.is_whitespace() {
301                        if (*argument).is_empty() {
302                            argument.start = next_ix;
303                            argument.end = next_ix;
304                        } else {
305                            argument.end = ix;
306                            call.arguments.push(next_ix..next_ix);
307                        }
308                    } else {
309                        argument.end = next_ix;
310                    }
311                }
312                // The command name ends at the first whitespace character.
313                else if !call.name.is_empty() {
314                    if c.is_whitespace() {
315                        call.arguments = vec![next_ix..next_ix];
316                    } else {
317                        call.name.end = next_ix;
318                    }
319                }
320                // The command name must begin with a letter.
321                else if c.is_alphabetic() {
322                    call.name.end = next_ix;
323                } else {
324                    return None;
325                }
326            }
327            // Commands start with a slash.
328            else if c == '/' {
329                call = Some(SlashCommandLine {
330                    name: next_ix..next_ix,
331                    arguments: Vec::new(),
332                });
333            }
334            // The line can't contain anything before the slash except for whitespace.
335            else if !c.is_whitespace() {
336                return None;
337            }
338            ix = next_ix;
339        }
340        call
341    }
342}
343
344pub fn create_label_for_command(command_name: &str, arguments: &[&str], cx: &App) -> CodeLabel {
345    let mut label = CodeLabelBuilder::default();
346    label.push_str(command_name, None);
347    label.respan_filter_range(None);
348    label.push_str(" ", None);
349    label.push_str(
350        &arguments.join(" "),
351        cx.theme().syntax().highlight_id("comment").map(HighlightId),
352    );
353    label.build()
354}
355
356#[cfg(test)]
357mod tests {
358    use pretty_assertions::assert_eq;
359    use serde_json::json;
360
361    use super::*;
362
363    #[gpui::test]
364    async fn test_slash_command_output_to_events_round_trip() {
365        // Test basic output consisting of a single section.
366        {
367            let text = "Hello, world!".to_string();
368            let range = 0..text.len();
369            let output = SlashCommandOutput {
370                text,
371                sections: vec![SlashCommandOutputSection {
372                    range,
373                    icon: IconName::Code,
374                    label: "Section 1".into(),
375                    metadata: None,
376                }],
377                run_commands_in_text: false,
378            };
379
380            let events = output.clone().into_event_stream().collect::<Vec<_>>().await;
381            let events = events
382                .into_iter()
383                .filter_map(|event| event.ok())
384                .collect::<Vec<_>>();
385
386            assert_eq!(
387                events,
388                vec![
389                    SlashCommandEvent::StartSection {
390                        icon: IconName::Code,
391                        label: "Section 1".into(),
392                        metadata: None
393                    },
394                    SlashCommandEvent::Content(SlashCommandContent::Text {
395                        text: "Hello, world!".into(),
396                        run_commands_in_text: false
397                    }),
398                    SlashCommandEvent::EndSection
399                ]
400            );
401
402            let new_output =
403                SlashCommandOutput::from_event_stream(output.clone().into_event_stream())
404                    .await
405                    .unwrap();
406
407            assert_eq!(new_output, output);
408        }
409
410        // Test output where the sections do not comprise all of the text.
411        {
412            let text = "Apple\nCucumber\nBanana\n".to_string();
413            let output = SlashCommandOutput {
414                text,
415                sections: vec![
416                    SlashCommandOutputSection {
417                        range: 0..6,
418                        icon: IconName::Check,
419                        label: "Fruit".into(),
420                        metadata: None,
421                    },
422                    SlashCommandOutputSection {
423                        range: 15..22,
424                        icon: IconName::Check,
425                        label: "Fruit".into(),
426                        metadata: None,
427                    },
428                ],
429                run_commands_in_text: false,
430            };
431
432            let events = output.clone().into_event_stream().collect::<Vec<_>>().await;
433            let events = events
434                .into_iter()
435                .filter_map(|event| event.ok())
436                .collect::<Vec<_>>();
437
438            assert_eq!(
439                events,
440                vec![
441                    SlashCommandEvent::StartSection {
442                        icon: IconName::Check,
443                        label: "Fruit".into(),
444                        metadata: None
445                    },
446                    SlashCommandEvent::Content(SlashCommandContent::Text {
447                        text: "Apple\n".into(),
448                        run_commands_in_text: false
449                    }),
450                    SlashCommandEvent::EndSection,
451                    SlashCommandEvent::Content(SlashCommandContent::Text {
452                        text: "Cucumber\n".into(),
453                        run_commands_in_text: false
454                    }),
455                    SlashCommandEvent::StartSection {
456                        icon: IconName::Check,
457                        label: "Fruit".into(),
458                        metadata: None
459                    },
460                    SlashCommandEvent::Content(SlashCommandContent::Text {
461                        text: "Banana\n".into(),
462                        run_commands_in_text: false
463                    }),
464                    SlashCommandEvent::EndSection
465                ]
466            );
467
468            let new_output =
469                SlashCommandOutput::from_event_stream(output.clone().into_event_stream())
470                    .await
471                    .unwrap();
472
473            assert_eq!(new_output, output);
474        }
475
476        // Test output consisting of multiple sections.
477        {
478            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
479            let output = SlashCommandOutput {
480                text,
481                sections: vec![
482                    SlashCommandOutputSection {
483                        range: 0..6,
484                        icon: IconName::FileCode,
485                        label: "Section 1".into(),
486                        metadata: Some(json!({ "a": true })),
487                    },
488                    SlashCommandOutputSection {
489                        range: 7..13,
490                        icon: IconName::FileDoc,
491                        label: "Section 2".into(),
492                        metadata: Some(json!({ "b": true })),
493                    },
494                    SlashCommandOutputSection {
495                        range: 14..20,
496                        icon: IconName::FileGit,
497                        label: "Section 3".into(),
498                        metadata: Some(json!({ "c": true })),
499                    },
500                    SlashCommandOutputSection {
501                        range: 21..27,
502                        icon: IconName::FileToml,
503                        label: "Section 4".into(),
504                        metadata: Some(json!({ "d": true })),
505                    },
506                ],
507                run_commands_in_text: false,
508            };
509
510            let events = output.clone().into_event_stream().collect::<Vec<_>>().await;
511            let events = events
512                .into_iter()
513                .filter_map(|event| event.ok())
514                .collect::<Vec<_>>();
515
516            assert_eq!(
517                events,
518                vec![
519                    SlashCommandEvent::StartSection {
520                        icon: IconName::FileCode,
521                        label: "Section 1".into(),
522                        metadata: Some(json!({ "a": true }))
523                    },
524                    SlashCommandEvent::Content(SlashCommandContent::Text {
525                        text: "Line 1".into(),
526                        run_commands_in_text: false
527                    }),
528                    SlashCommandEvent::EndSection,
529                    SlashCommandEvent::Content(SlashCommandContent::Text {
530                        text: "\n".into(),
531                        run_commands_in_text: false
532                    }),
533                    SlashCommandEvent::StartSection {
534                        icon: IconName::FileDoc,
535                        label: "Section 2".into(),
536                        metadata: Some(json!({ "b": true }))
537                    },
538                    SlashCommandEvent::Content(SlashCommandContent::Text {
539                        text: "Line 2".into(),
540                        run_commands_in_text: false
541                    }),
542                    SlashCommandEvent::EndSection,
543                    SlashCommandEvent::Content(SlashCommandContent::Text {
544                        text: "\n".into(),
545                        run_commands_in_text: false
546                    }),
547                    SlashCommandEvent::StartSection {
548                        icon: IconName::FileGit,
549                        label: "Section 3".into(),
550                        metadata: Some(json!({ "c": true }))
551                    },
552                    SlashCommandEvent::Content(SlashCommandContent::Text {
553                        text: "Line 3".into(),
554                        run_commands_in_text: false
555                    }),
556                    SlashCommandEvent::EndSection,
557                    SlashCommandEvent::Content(SlashCommandContent::Text {
558                        text: "\n".into(),
559                        run_commands_in_text: false
560                    }),
561                    SlashCommandEvent::StartSection {
562                        icon: IconName::FileToml,
563                        label: "Section 4".into(),
564                        metadata: Some(json!({ "d": true }))
565                    },
566                    SlashCommandEvent::Content(SlashCommandContent::Text {
567                        text: "Line 4".into(),
568                        run_commands_in_text: false
569                    }),
570                    SlashCommandEvent::EndSection,
571                    SlashCommandEvent::Content(SlashCommandContent::Text {
572                        text: "\n".into(),
573                        run_commands_in_text: false
574                    }),
575                ]
576            );
577
578            let new_output =
579                SlashCommandOutput::from_event_stream(output.clone().into_event_stream())
580                    .await
581                    .unwrap();
582
583            assert_eq!(new_output, output);
584        }
585    }
586
587    #[test]
588    fn test_deserialize_with_valid_icon_pascal_case() {
589        // Test that PascalCase icons (serde default) deserialize correctly
590        let json = json!({
591            "range": {
592                "start": 0,
593                "end": 5
594            },
595            "icon": "AcpRegistry",
596            "label": "Test",
597            "metadata": null
598        });
599        let section: SlashCommandOutputSection<usize> = serde_json::from_value(json).unwrap();
600        assert_eq!(section.icon, IconName::AcpRegistry);
601    }
602    #[test]
603    fn test_deserialize_with_unknown_icon() {
604        // Test that unknown icon variants fall back to Code
605        let json = json!({
606            "range": {
607                "start": 0,
608                "end": 5
609            },
610            "icon": "removed_icon",
611            "label": "Old Icon",
612            "metadata": null
613        });
614        let section: SlashCommandOutputSection<usize> = serde_json::from_value(json).unwrap();
615        assert_eq!(section.icon, IconName::Code);
616    }
617}