assistant_slash_command.rs

  1mod extension_slash_command;
  2mod slash_command_registry;
  3mod slash_command_working_set;
  4
  5pub use crate::extension_slash_command::*;
  6pub use crate::slash_command_registry::*;
  7pub use crate::slash_command_working_set::*;
  8use anyhow::Result;
  9use futures::stream::{self, BoxStream};
 10use futures::StreamExt;
 11use gpui::{AppContext, SharedString, Task, WeakView, WindowContext};
 12use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate, OffsetRangeExt};
 13pub use language_model::Role;
 14use serde::{Deserialize, Serialize};
 15use std::{
 16    ops::Range,
 17    sync::{atomic::AtomicBool, Arc},
 18};
 19use workspace::{ui::IconName, Workspace};
 20
 21pub fn init(cx: &mut AppContext) {
 22    SlashCommandRegistry::default_global(cx);
 23    extension_slash_command::init(cx);
 24}
 25
 26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 27pub enum AfterCompletion {
 28    /// Run the command
 29    Run,
 30    /// Continue composing the current argument, doesn't add a space
 31    Compose,
 32    /// Continue the command composition, adds a space
 33    Continue,
 34}
 35
 36impl From<bool> for AfterCompletion {
 37    fn from(value: bool) -> Self {
 38        if value {
 39            AfterCompletion::Run
 40        } else {
 41            AfterCompletion::Continue
 42        }
 43    }
 44}
 45
 46impl AfterCompletion {
 47    pub fn run(&self) -> bool {
 48        match self {
 49            AfterCompletion::Run => true,
 50            AfterCompletion::Compose | AfterCompletion::Continue => false,
 51        }
 52    }
 53}
 54
 55#[derive(Debug)]
 56pub struct ArgumentCompletion {
 57    /// The label to display for this completion.
 58    pub label: CodeLabel,
 59    /// The new text that should be inserted into the command when this completion is accepted.
 60    pub new_text: String,
 61    /// Whether the command should be run when accepting this completion.
 62    pub after_completion: AfterCompletion,
 63    /// Whether to replace the all arguments, or whether to treat this as an independent argument.
 64    pub replace_previous_arguments: bool,
 65}
 66
 67pub type SlashCommandResult = Result<BoxStream<'static, Result<SlashCommandEvent>>>;
 68
 69pub trait SlashCommand: 'static + Send + Sync {
 70    fn name(&self) -> String;
 71    fn icon(&self) -> IconName {
 72        IconName::Slash
 73    }
 74    fn label(&self, _cx: &AppContext) -> CodeLabel {
 75        CodeLabel::plain(self.name(), None)
 76    }
 77    fn description(&self) -> String;
 78    fn menu_text(&self) -> String;
 79    fn complete_argument(
 80        self: Arc<Self>,
 81        arguments: &[String],
 82        cancel: Arc<AtomicBool>,
 83        workspace: Option<WeakView<Workspace>>,
 84        cx: &mut WindowContext,
 85    ) -> Task<Result<Vec<ArgumentCompletion>>>;
 86    fn requires_argument(&self) -> bool;
 87    fn accepts_arguments(&self) -> bool {
 88        self.requires_argument()
 89    }
 90    fn run(
 91        self: Arc<Self>,
 92        arguments: &[String],
 93        context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
 94        context_buffer: BufferSnapshot,
 95        workspace: WeakView<Workspace>,
 96        // TODO: We're just using the `LspAdapterDelegate` here because that is
 97        // what the extension API is already expecting.
 98        //
 99        // It may be that `LspAdapterDelegate` needs a more general name, or
100        // perhaps another kind of delegate is needed here.
101        delegate: Option<Arc<dyn LspAdapterDelegate>>,
102        cx: &mut WindowContext,
103    ) -> Task<SlashCommandResult>;
104}
105
106#[derive(Debug, PartialEq)]
107pub enum SlashCommandContent {
108    Text {
109        text: String,
110        run_commands_in_text: bool,
111    },
112}
113
114impl<'a> From<&'a str> for SlashCommandContent {
115    fn from(text: &'a str) -> Self {
116        Self::Text {
117            text: text.into(),
118            run_commands_in_text: false,
119        }
120    }
121}
122
123#[derive(Debug, PartialEq)]
124pub enum SlashCommandEvent {
125    StartMessage {
126        role: Role,
127        merge_same_roles: bool,
128    },
129    StartSection {
130        icon: IconName,
131        label: SharedString,
132        metadata: Option<serde_json::Value>,
133    },
134    Content(SlashCommandContent),
135    EndSection,
136}
137
138#[derive(Debug, Default, PartialEq, Clone)]
139pub struct SlashCommandOutput {
140    pub text: String,
141    pub sections: Vec<SlashCommandOutputSection<usize>>,
142    pub run_commands_in_text: bool,
143}
144
145impl SlashCommandOutput {
146    pub fn ensure_valid_section_ranges(&mut self) {
147        for section in &mut self.sections {
148            section.range.start = section.range.start.min(self.text.len());
149            section.range.end = section.range.end.min(self.text.len());
150            while !self.text.is_char_boundary(section.range.start) {
151                section.range.start -= 1;
152            }
153            while !self.text.is_char_boundary(section.range.end) {
154                section.range.end += 1;
155            }
156        }
157    }
158
159    /// Returns this [`SlashCommandOutput`] as a stream of [`SlashCommandEvent`]s.
160    pub fn to_event_stream(mut self) -> BoxStream<'static, Result<SlashCommandEvent>> {
161        self.ensure_valid_section_ranges();
162
163        let mut events = Vec::new();
164
165        let mut section_endpoints = Vec::new();
166        for section in self.sections {
167            section_endpoints.push((
168                section.range.start,
169                SlashCommandEvent::StartSection {
170                    icon: section.icon,
171                    label: section.label,
172                    metadata: section.metadata,
173                },
174            ));
175            section_endpoints.push((section.range.end, SlashCommandEvent::EndSection));
176        }
177        section_endpoints.sort_by_key(|(offset, _)| *offset);
178
179        let mut content_offset = 0;
180        for (endpoint_offset, endpoint) in section_endpoints {
181            if content_offset < endpoint_offset {
182                events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
183                    text: self.text[content_offset..endpoint_offset].to_string(),
184                    run_commands_in_text: self.run_commands_in_text,
185                })));
186                content_offset = endpoint_offset;
187            }
188
189            events.push(Ok(endpoint));
190        }
191
192        if content_offset < self.text.len() {
193            events.push(Ok(SlashCommandEvent::Content(SlashCommandContent::Text {
194                text: self.text[content_offset..].to_string(),
195                run_commands_in_text: self.run_commands_in_text,
196            })));
197        }
198
199        stream::iter(events).boxed()
200    }
201
202    pub async fn from_event_stream(
203        mut events: BoxStream<'static, Result<SlashCommandEvent>>,
204    ) -> Result<SlashCommandOutput> {
205        let mut output = SlashCommandOutput::default();
206        let mut section_stack = Vec::new();
207
208        while let Some(event) = events.next().await {
209            match event? {
210                SlashCommandEvent::StartSection {
211                    icon,
212                    label,
213                    metadata,
214                } => {
215                    let start = output.text.len();
216                    section_stack.push(SlashCommandOutputSection {
217                        range: start..start,
218                        icon,
219                        label,
220                        metadata,
221                    });
222                }
223                SlashCommandEvent::Content(SlashCommandContent::Text {
224                    text,
225                    run_commands_in_text,
226                }) => {
227                    output.text.push_str(&text);
228                    output.run_commands_in_text = run_commands_in_text;
229
230                    if let Some(section) = section_stack.last_mut() {
231                        section.range.end = output.text.len();
232                    }
233                }
234                SlashCommandEvent::EndSection => {
235                    if let Some(section) = section_stack.pop() {
236                        output.sections.push(section);
237                    }
238                }
239                SlashCommandEvent::StartMessage { .. } => {}
240            }
241        }
242
243        while let Some(section) = section_stack.pop() {
244            output.sections.push(section);
245        }
246
247        Ok(output)
248    }
249}
250
251#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
252pub struct SlashCommandOutputSection<T> {
253    pub range: Range<T>,
254    pub icon: IconName,
255    pub label: SharedString,
256    pub metadata: Option<serde_json::Value>,
257}
258
259impl SlashCommandOutputSection<language::Anchor> {
260    pub fn is_valid(&self, buffer: &language::TextBuffer) -> bool {
261        self.range.start.is_valid(buffer) && !self.range.to_offset(buffer).is_empty()
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use pretty_assertions::assert_eq;
268    use serde_json::json;
269
270    use super::*;
271
272    #[gpui::test]
273    async fn test_slash_command_output_to_events_round_trip() {
274        // Test basic output consisting of a single section.
275        {
276            let text = "Hello, world!".to_string();
277            let range = 0..text.len();
278            let output = SlashCommandOutput {
279                text,
280                sections: vec![SlashCommandOutputSection {
281                    range,
282                    icon: IconName::Code,
283                    label: "Section 1".into(),
284                    metadata: None,
285                }],
286                run_commands_in_text: false,
287            };
288
289            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
290            let events = events
291                .into_iter()
292                .filter_map(|event| event.ok())
293                .collect::<Vec<_>>();
294
295            assert_eq!(
296                events,
297                vec![
298                    SlashCommandEvent::StartSection {
299                        icon: IconName::Code,
300                        label: "Section 1".into(),
301                        metadata: None
302                    },
303                    SlashCommandEvent::Content(SlashCommandContent::Text {
304                        text: "Hello, world!".into(),
305                        run_commands_in_text: false
306                    }),
307                    SlashCommandEvent::EndSection
308                ]
309            );
310
311            let new_output =
312                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
313                    .await
314                    .unwrap();
315
316            assert_eq!(new_output, output);
317        }
318
319        // Test output where the sections do not comprise all of the text.
320        {
321            let text = "Apple\nCucumber\nBanana\n".to_string();
322            let output = SlashCommandOutput {
323                text,
324                sections: vec![
325                    SlashCommandOutputSection {
326                        range: 0..6,
327                        icon: IconName::Check,
328                        label: "Fruit".into(),
329                        metadata: None,
330                    },
331                    SlashCommandOutputSection {
332                        range: 15..22,
333                        icon: IconName::Check,
334                        label: "Fruit".into(),
335                        metadata: None,
336                    },
337                ],
338                run_commands_in_text: false,
339            };
340
341            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
342            let events = events
343                .into_iter()
344                .filter_map(|event| event.ok())
345                .collect::<Vec<_>>();
346
347            assert_eq!(
348                events,
349                vec![
350                    SlashCommandEvent::StartSection {
351                        icon: IconName::Check,
352                        label: "Fruit".into(),
353                        metadata: None
354                    },
355                    SlashCommandEvent::Content(SlashCommandContent::Text {
356                        text: "Apple\n".into(),
357                        run_commands_in_text: false
358                    }),
359                    SlashCommandEvent::EndSection,
360                    SlashCommandEvent::Content(SlashCommandContent::Text {
361                        text: "Cucumber\n".into(),
362                        run_commands_in_text: false
363                    }),
364                    SlashCommandEvent::StartSection {
365                        icon: IconName::Check,
366                        label: "Fruit".into(),
367                        metadata: None
368                    },
369                    SlashCommandEvent::Content(SlashCommandContent::Text {
370                        text: "Banana\n".into(),
371                        run_commands_in_text: false
372                    }),
373                    SlashCommandEvent::EndSection
374                ]
375            );
376
377            let new_output =
378                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
379                    .await
380                    .unwrap();
381
382            assert_eq!(new_output, output);
383        }
384
385        // Test output consisting of multiple sections.
386        {
387            let text = "Line 1\nLine 2\nLine 3\nLine 4\n".to_string();
388            let output = SlashCommandOutput {
389                text,
390                sections: vec![
391                    SlashCommandOutputSection {
392                        range: 0..6,
393                        icon: IconName::FileCode,
394                        label: "Section 1".into(),
395                        metadata: Some(json!({ "a": true })),
396                    },
397                    SlashCommandOutputSection {
398                        range: 7..13,
399                        icon: IconName::FileDoc,
400                        label: "Section 2".into(),
401                        metadata: Some(json!({ "b": true })),
402                    },
403                    SlashCommandOutputSection {
404                        range: 14..20,
405                        icon: IconName::FileGit,
406                        label: "Section 3".into(),
407                        metadata: Some(json!({ "c": true })),
408                    },
409                    SlashCommandOutputSection {
410                        range: 21..27,
411                        icon: IconName::FileToml,
412                        label: "Section 4".into(),
413                        metadata: Some(json!({ "d": true })),
414                    },
415                ],
416                run_commands_in_text: false,
417            };
418
419            let events = output.clone().to_event_stream().collect::<Vec<_>>().await;
420            let events = events
421                .into_iter()
422                .filter_map(|event| event.ok())
423                .collect::<Vec<_>>();
424
425            assert_eq!(
426                events,
427                vec![
428                    SlashCommandEvent::StartSection {
429                        icon: IconName::FileCode,
430                        label: "Section 1".into(),
431                        metadata: Some(json!({ "a": true }))
432                    },
433                    SlashCommandEvent::Content(SlashCommandContent::Text {
434                        text: "Line 1".into(),
435                        run_commands_in_text: false
436                    }),
437                    SlashCommandEvent::EndSection,
438                    SlashCommandEvent::Content(SlashCommandContent::Text {
439                        text: "\n".into(),
440                        run_commands_in_text: false
441                    }),
442                    SlashCommandEvent::StartSection {
443                        icon: IconName::FileDoc,
444                        label: "Section 2".into(),
445                        metadata: Some(json!({ "b": true }))
446                    },
447                    SlashCommandEvent::Content(SlashCommandContent::Text {
448                        text: "Line 2".into(),
449                        run_commands_in_text: false
450                    }),
451                    SlashCommandEvent::EndSection,
452                    SlashCommandEvent::Content(SlashCommandContent::Text {
453                        text: "\n".into(),
454                        run_commands_in_text: false
455                    }),
456                    SlashCommandEvent::StartSection {
457                        icon: IconName::FileGit,
458                        label: "Section 3".into(),
459                        metadata: Some(json!({ "c": true }))
460                    },
461                    SlashCommandEvent::Content(SlashCommandContent::Text {
462                        text: "Line 3".into(),
463                        run_commands_in_text: false
464                    }),
465                    SlashCommandEvent::EndSection,
466                    SlashCommandEvent::Content(SlashCommandContent::Text {
467                        text: "\n".into(),
468                        run_commands_in_text: false
469                    }),
470                    SlashCommandEvent::StartSection {
471                        icon: IconName::FileToml,
472                        label: "Section 4".into(),
473                        metadata: Some(json!({ "d": true }))
474                    },
475                    SlashCommandEvent::Content(SlashCommandContent::Text {
476                        text: "Line 4".into(),
477                        run_commands_in_text: false
478                    }),
479                    SlashCommandEvent::EndSection,
480                    SlashCommandEvent::Content(SlashCommandContent::Text {
481                        text: "\n".into(),
482                        run_commands_in_text: false
483                    }),
484                ]
485            );
486
487            let new_output =
488                SlashCommandOutput::from_event_stream(output.clone().to_event_stream())
489                    .await
490                    .unwrap();
491
492            assert_eq!(new_output, output);
493        }
494    }
495}