Polish streaming slash commands (#20345)

Antonio Scandurra created

This improves the experience in a few ways:

- It avoids merging slash command output sections that are adjacent.
- When hitting cmd-z, all the output from a command is undone at once.
- When deleting a pending command, it stops the command and prevents new
output from flowing in.

Release Notes:

- N/A

Change summary

crates/assistant/src/assistant_panel.rs       |  46 +---
crates/assistant/src/context.rs               | 198 +++++++++++++-------
crates/assistant/src/context/context_tests.rs |   2 
3 files changed, 140 insertions(+), 106 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -77,8 +77,8 @@ use text::SelectionGoal;
 use ui::{
     prelude::*,
     utils::{format_distance_from_now, DateTimeType},
-    Avatar, ButtonLike, ContextMenu, Disclosure, ElevationIndex, IconButtonShape, KeyBinding,
-    ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, TintColor, Tooltip,
+    Avatar, ButtonLike, ContextMenu, Disclosure, ElevationIndex, KeyBinding, ListItem,
+    ListItemSpacing, PopoverMenu, PopoverMenuHandle, TintColor, Tooltip,
 };
 use util::{maybe, ResultExt};
 use workspace::{
@@ -2111,7 +2111,6 @@ impl ContextEditor {
         command_id: SlashCommandId,
         cx: &mut ViewContext<Self>,
     ) {
-        let context_editor = cx.view().downgrade();
         self.editor.update(cx, |editor, cx| {
             if let Some(invoked_slash_command) =
                 self.context.read(cx).invoked_slash_command(&command_id)
@@ -2152,7 +2151,7 @@ impl ContextEditor {
                         .anchor_in_excerpt(excerpt_id, invoked_slash_command.range.end)
                         .unwrap();
                     let fold_placeholder =
-                        invoked_slash_command_fold_placeholder(command_id, context, context_editor);
+                        invoked_slash_command_fold_placeholder(command_id, context);
                     let crease_ids = editor.insert_creases(
                         [Crease::new(
                             crease_start..crease_end,
@@ -2352,6 +2351,7 @@ impl ContextEditor {
                                 section.icon,
                                 section.label.clone(),
                             ),
+                            merge_adjacent: false,
                             ..Default::default()
                         },
                         render_slash_command_output_toggle,
@@ -4963,6 +4963,7 @@ fn quote_selection_fold_placeholder(title: String, editor: WeakView<Editor>) ->
                     .into_any_element()
             }
         }),
+        merge_adjacent: false,
         ..Default::default()
     }
 }
@@ -5096,7 +5097,6 @@ enum PendingSlashCommand {}
 fn invoked_slash_command_fold_placeholder(
     command_id: SlashCommandId,
     context: WeakModel<Context>,
-    context_editor: WeakView<ContextEditor>,
 ) -> FoldPlaceholder {
     FoldPlaceholder {
         constrain_width: false,
@@ -5126,37 +5126,11 @@ fn invoked_slash_command_fold_placeholder(
                             |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
                         ))
                     }
-                    InvokedSlashCommandStatus::Error(message) => parent
-                        .child(
-                            Label::new(format!("error: {message}"))
-                                .single_line()
-                                .color(Color::Error),
-                        )
-                        .child(
-                            IconButton::new("dismiss-error", IconName::Close)
-                                .shape(IconButtonShape::Square)
-                                .icon_size(IconSize::XSmall)
-                                .icon_color(Color::Muted)
-                                .on_click({
-                                    let context_editor = context_editor.clone();
-                                    move |_event, cx| {
-                                        context_editor
-                                            .update(cx, |context_editor, cx| {
-                                                context_editor.editor.update(cx, |editor, cx| {
-                                                    editor.remove_creases(
-                                                        HashSet::from_iter(
-                                                            context_editor
-                                                                .invoked_slash_command_creases
-                                                                .remove(&command_id),
-                                                        ),
-                                                        cx,
-                                                    );
-                                                })
-                                            })
-                                            .log_err();
-                                    }
-                                }),
-                        ),
+                    InvokedSlashCommandStatus::Error(message) => parent.child(
+                        Label::new(format!("error: {message}"))
+                            .single_line()
+                            .color(Color::Error),
+                    ),
                     InvokedSlashCommandStatus::Finished => parent,
                 })
                 .into_any_element()

crates/assistant/src/context.rs 🔗

@@ -545,7 +545,6 @@ pub struct Context {
     parsed_slash_commands: Vec<ParsedSlashCommand>,
     invoked_slash_commands: HashMap<SlashCommandId, InvokedSlashCommand>,
     edits_since_last_parse: language::Subscription,
-    finished_slash_commands: HashSet<SlashCommandId>,
     slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
     pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
     message_anchors: Vec<MessageAnchor>,
@@ -647,7 +646,6 @@ impl Context {
             messages_metadata: Default::default(),
             parsed_slash_commands: Vec::new(),
             invoked_slash_commands: HashMap::default(),
-            finished_slash_commands: HashSet::default(),
             pending_tool_uses_by_id: HashMap::default(),
             slash_command_output_sections: Vec::new(),
             edits_since_last_parse: edits_since_last_slash_command_parse,
@@ -905,6 +903,8 @@ impl Context {
                             name: name.into(),
                             range: output_range,
                             status: InvokedSlashCommandStatus::Running(Task::ready(())),
+                            transaction: None,
+                            timestamp: id.0,
                         },
                     );
                     cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id });
@@ -921,10 +921,14 @@ impl Context {
                     }
                 }
                 ContextOperation::SlashCommandFinished {
-                    id, error_message, ..
+                    id,
+                    error_message,
+                    timestamp,
+                    ..
                 } => {
-                    if self.finished_slash_commands.insert(id) {
-                        if let Some(slash_command) = self.invoked_slash_commands.get_mut(&id) {
+                    if let Some(slash_command) = self.invoked_slash_commands.get_mut(&id) {
+                        if timestamp > slash_command.timestamp {
+                            slash_command.timestamp = timestamp;
                             match error_message {
                                 Some(message) => {
                                     slash_command.status =
@@ -934,9 +938,8 @@ impl Context {
                                     slash_command.status = InvokedSlashCommandStatus::Finished;
                                 }
                             }
+                            cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id });
                         }
-
-                        cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id: id });
                     }
                 }
                 ContextOperation::BufferOperation(_) => unreachable!(),
@@ -1370,8 +1373,8 @@ impl Context {
             })
             .peekable();
 
-        let mut removed_slash_command_ranges = Vec::new();
-        let mut updated_slash_commands = Vec::new();
+        let mut removed_parsed_slash_command_ranges = Vec::new();
+        let mut updated_parsed_slash_commands = Vec::new();
         let mut removed_patches = Vec::new();
         let mut updated_patches = Vec::new();
         while let Some(mut row_range) = row_ranges.next() {
@@ -1393,10 +1396,11 @@ impl Context {
             self.reparse_slash_commands_in_range(
                 start..end,
                 &buffer,
-                &mut updated_slash_commands,
-                &mut removed_slash_command_ranges,
+                &mut updated_parsed_slash_commands,
+                &mut removed_parsed_slash_command_ranges,
                 cx,
             );
+            self.invalidate_pending_slash_commands(&buffer, cx);
             self.reparse_patches_in_range(
                 start..end,
                 &buffer,
@@ -1406,10 +1410,12 @@ impl Context {
             );
         }
 
-        if !updated_slash_commands.is_empty() || !removed_slash_command_ranges.is_empty() {
+        if !updated_parsed_slash_commands.is_empty()
+            || !removed_parsed_slash_command_ranges.is_empty()
+        {
             cx.emit(ContextEvent::ParsedSlashCommandsUpdated {
-                removed: removed_slash_command_ranges,
-                updated: updated_slash_commands,
+                removed: removed_parsed_slash_command_ranges,
+                updated: updated_parsed_slash_commands,
             });
         }
 
@@ -1478,6 +1484,37 @@ impl Context {
         removed.extend(removed_commands.map(|command| command.source_range));
     }
 
+    fn invalidate_pending_slash_commands(
+        &mut self,
+        buffer: &BufferSnapshot,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let mut invalidated_command_ids = Vec::new();
+        for (&command_id, command) in self.invoked_slash_commands.iter_mut() {
+            if !matches!(command.status, InvokedSlashCommandStatus::Finished)
+                && (!command.range.start.is_valid(buffer) || !command.range.end.is_valid(buffer))
+            {
+                command.status = InvokedSlashCommandStatus::Finished;
+                cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id });
+                invalidated_command_ids.push(command_id);
+            }
+        }
+
+        for command_id in invalidated_command_ids {
+            let version = self.version.clone();
+            let timestamp = self.next_timestamp();
+            self.push_op(
+                ContextOperation::SlashCommandFinished {
+                    id: command_id,
+                    timestamp,
+                    error_message: None,
+                    version: version.clone(),
+                },
+                cx,
+            );
+        }
+    }
+
     fn reparse_patches_in_range(
         &mut self,
         range: Range<text::Anchor>,
@@ -1814,13 +1851,16 @@ impl Context {
 
         const PENDING_OUTPUT_END_MARKER: &str = "…";
 
-        let (command_range, command_source_range, insert_position) =
+        let (command_range, command_source_range, insert_position, first_transaction) =
             self.buffer.update(cx, |buffer, cx| {
                 let command_source_range = command_source_range.to_offset(buffer);
                 let mut insertion = format!("\n{PENDING_OUTPUT_END_MARKER}");
                 if ensure_trailing_newline {
                     insertion.push('\n');
                 }
+
+                buffer.finalize_last_transaction();
+                buffer.start_transaction();
                 buffer.edit(
                     [(
                         command_source_range.end..command_source_range.end,
@@ -1829,14 +1869,22 @@ impl Context {
                     None,
                     cx,
                 );
+                let first_transaction = buffer.end_transaction(cx).unwrap();
+                buffer.finalize_last_transaction();
+
                 let insert_position = buffer.anchor_after(command_source_range.end + 1);
-                let command_range = buffer.anchor_before(command_source_range.start)
-                    ..buffer.anchor_after(
+                let command_range = buffer.anchor_after(command_source_range.start)
+                    ..buffer.anchor_before(
                         command_source_range.end + 1 + PENDING_OUTPUT_END_MARKER.len(),
                     );
                 let command_source_range = buffer.anchor_before(command_source_range.start)
                     ..buffer.anchor_before(command_source_range.end + 1);
-                (command_range, command_source_range, insert_position)
+                (
+                    command_range,
+                    command_source_range,
+                    insert_position,
+                    first_transaction,
+                )
             });
         self.reparse(cx);
 
@@ -1858,13 +1906,18 @@ impl Context {
 
                 while let Some(event) = stream.next().await {
                     let event = event?;
-                    match event {
-                        SlashCommandEvent::StartMessage {
-                            role,
-                            merge_same_roles,
-                        } => {
-                            if !merge_same_roles && Some(role) != last_role {
-                                this.update(&mut cx, |this, cx| {
+                    this.update(&mut cx, |this, cx| {
+                        this.buffer.update(cx, |buffer, _cx| {
+                            buffer.finalize_last_transaction();
+                            buffer.start_transaction()
+                        });
+
+                        match event {
+                            SlashCommandEvent::StartMessage {
+                                role,
+                                merge_same_roles,
+                            } => {
+                                if !merge_same_roles && Some(role) != last_role {
                                     let offset = this.buffer.read_with(cx, |buffer, _cx| {
                                         insert_position.to_offset(buffer)
                                     });
@@ -1874,17 +1927,15 @@ impl Context {
                                         MessageStatus::Pending,
                                         cx,
                                     );
-                                })?;
-                            }
+                                }
 
-                            last_role = Some(role);
-                        }
-                        SlashCommandEvent::StartSection {
-                            icon,
-                            label,
-                            metadata,
-                        } => {
-                            this.update(&mut cx, |this, cx| {
+                                last_role = Some(role);
+                            }
+                            SlashCommandEvent::StartSection {
+                                icon,
+                                label,
+                                metadata,
+                            } => {
                                 this.buffer.update(cx, |buffer, cx| {
                                     let insert_point = insert_position.to_point(buffer);
                                     if insert_point.column > 0 {
@@ -1898,16 +1949,14 @@ impl Context {
                                         metadata,
                                     });
                                 });
-                            })?;
-                        }
-                        SlashCommandEvent::Content(SlashCommandContent::Text {
-                            text,
-                            run_commands_in_text,
-                        }) => {
-                            this.update(&mut cx, |this, cx| {
+                            }
+                            SlashCommandEvent::Content(SlashCommandContent::Text {
+                                text,
+                                run_commands_in_text,
+                            }) => {
                                 let start = this.buffer.read(cx).anchor_before(insert_position);
 
-                                let result = this.buffer.update(cx, |buffer, cx| {
+                                this.buffer.update(cx, |buffer, cx| {
                                     buffer.edit(
                                         [(insert_position..insert_position, text)],
                                         None,
@@ -1919,41 +1968,44 @@ impl Context {
                                 if run_commands_in_text {
                                     run_commands_in_ranges.push(start..end);
                                 }
-
-                                result
-                            })?;
-                        }
-                        SlashCommandEvent::EndSection { metadata } => {
-                            if let Some(pending_section) = pending_section_stack.pop() {
-                                this.update(&mut cx, |this, cx| {
+                            }
+                            SlashCommandEvent::EndSection { metadata } => {
+                                if let Some(pending_section) = pending_section_stack.pop() {
                                     let offset_range = (pending_section.start..insert_position)
                                         .to_offset(this.buffer.read(cx));
-                                    if offset_range.is_empty() {
-                                        return;
+                                    if !offset_range.is_empty() {
+                                        let range = this.buffer.update(cx, |buffer, _cx| {
+                                            buffer.anchor_after(offset_range.start)
+                                                ..buffer.anchor_before(offset_range.end)
+                                        });
+                                        this.insert_slash_command_output_section(
+                                            SlashCommandOutputSection {
+                                                range: range.clone(),
+                                                icon: pending_section.icon,
+                                                label: pending_section.label,
+                                                metadata: metadata.or(pending_section.metadata),
+                                            },
+                                            cx,
+                                        );
+                                        last_section_range = Some(range);
                                     }
-
-                                    let range = this.buffer.update(cx, |buffer, _cx| {
-                                        buffer.anchor_after(offset_range.start)
-                                            ..buffer.anchor_before(offset_range.end)
-                                    });
-                                    this.insert_slash_command_output_section(
-                                        SlashCommandOutputSection {
-                                            range: range.clone(),
-                                            icon: pending_section.icon,
-                                            label: pending_section.label,
-                                            metadata: metadata.or(pending_section.metadata),
-                                        },
-                                        cx,
-                                    );
-                                    last_section_range = Some(range);
-                                })?;
+                                }
                             }
                         }
-                    }
+
+                        this.buffer.update(cx, |buffer, cx| {
+                            if let Some(event_transaction) = buffer.end_transaction(cx) {
+                                buffer.merge_transactions(event_transaction, first_transaction);
+                            }
+                        });
+                    })?;
                 }
 
                 this.update(&mut cx, |this, cx| {
                     this.buffer.update(cx, |buffer, cx| {
+                        buffer.finalize_last_transaction();
+                        buffer.start_transaction();
+
                         let mut deletions = vec![(command_source_range.to_offset(buffer), "")];
                         let insert_position = insert_position.to_offset(buffer);
                         let command_range_end = command_range.end.to_offset(buffer);
@@ -1981,6 +2033,10 @@ impl Context {
                         }
 
                         buffer.edit(deletions, None, cx);
+
+                        if let Some(deletion_transaction) = buffer.end_transaction(cx) {
+                            buffer.merge_transactions(deletion_transaction, first_transaction);
+                        }
                     });
                 })?;
 
@@ -2031,6 +2087,8 @@ impl Context {
                 name: name.to_string().into(),
                 range: command_range.clone(),
                 status: InvokedSlashCommandStatus::Running(insert_output_task),
+                transaction: Some(first_transaction),
+                timestamp: command_id.0,
             },
         );
         cx.emit(ContextEvent::InvokedSlashCommandChanged { command_id });
@@ -3101,6 +3159,8 @@ pub struct InvokedSlashCommand {
     pub name: SharedString,
     pub range: Range<language::Anchor>,
     pub status: InvokedSlashCommandStatus,
+    pub transaction: Option<language::TransactionId>,
+    timestamp: clock::Lamport,
 }
 
 #[derive(Debug)]

crates/assistant/src/context/context_tests.rs 🔗

@@ -1357,7 +1357,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
         let first_context = contexts[0].read(cx);
         for context in &contexts[1..] {
             let context = context.read(cx);
-            assert!(context.pending_ops.is_empty());
+            assert!(context.pending_ops.is_empty(), "pending ops: {:?}", context.pending_ops);
             assert_eq!(
                 context.buffer.read(cx).text(),
                 first_context.buffer.read(cx).text(),