Maintain scroll bottom when streaming assistant responses

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs          | 205 +++++++++++++++++++++++-------
crates/editor/src/editor_tests.rs   |   6 
crates/editor/src/items.rs          |  10 
crates/editor/src/scroll.rs         |  22 +-
crates/editor/src/scroll/actions.rs |   6 
crates/vim/src/normal.rs            |   2 
6 files changed, 176 insertions(+), 75 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -5,13 +5,21 @@ use crate::{
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
 use collections::{HashMap, HashSet};
-use editor::{Anchor, Editor, ExcerptId, ExcerptRange, MultiBuffer};
+use editor::{
+    display_map::ToDisplayPoint,
+    scroll::{
+        autoscroll::{Autoscroll, AutoscrollStrategy},
+        ScrollAnchor,
+    },
+    Anchor, DisplayPoint, Editor, ExcerptId, ExcerptRange, MultiBuffer,
+};
 use fs::Fs;
 use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 use gpui::{
     actions,
     elements::*,
     executor::Background,
+    geometry::vector::vec2f,
     platform::{CursorStyle, MouseButton},
     Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
     Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
@@ -414,6 +422,7 @@ impl Panel for AssistantPanel {
 enum AssistantEvent {
     MessagesEdited { ids: Vec<ExcerptId> },
     SummaryChanged,
+    StreamedCompletion,
 }
 
 struct Assistant {
@@ -531,7 +540,7 @@ impl Assistant {
         cx.notify();
     }
 
-    fn assist(&mut self, cx: &mut ModelContext<Self>) {
+    fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
         let messages = self
             .messages
             .iter()
@@ -548,24 +557,30 @@ impl Assistant {
             stream: true,
         };
 
-        let api_key = self.api_key.borrow().clone();
-        if let Some(api_key) = api_key {
-            let stream = stream_completion(api_key, cx.background().clone(), request);
-            let (excerpt_id, content) =
-                self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
-            self.insert_message_after(ExcerptId::max(), Role::User, cx);
-            let task = cx.spawn_weak(|this, mut cx| async move {
+        let api_key = self.api_key.borrow().clone()?;
+        let stream = stream_completion(api_key, cx.background().clone(), request);
+        let assistant_message = self.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
+        let user_message = self.insert_message_after(ExcerptId::max(), Role::User, cx);
+        let task = cx.spawn_weak({
+            let assistant_message = assistant_message.clone();
+            |this, mut cx| async move {
+                let assistant_message = assistant_message;
                 let stream_completion = async {
                     let mut messages = stream.await?;
 
                     while let Some(message) = messages.next().await {
                         let mut message = message?;
                         if let Some(choice) = message.choices.pop() {
-                            content.update(&mut cx, |content, cx| {
+                            assistant_message.content.update(&mut cx, |content, cx| {
                                 let text: Arc<str> = choice.delta.content?.into();
                                 content.edit([(content.len()..content.len(), text)], None, cx);
                                 Some(())
                             });
+                            this.upgrade(&cx)
+                                .ok_or_else(|| anyhow!("assistant was dropped"))?
+                                .update(&mut cx, |_, cx| {
+                                    cx.emit(AssistantEvent::StreamedCompletion);
+                                });
                         }
                     }
 
@@ -580,23 +595,28 @@ impl Assistant {
                     anyhow::Ok(())
                 };
 
-                if let Err(error) = stream_completion.await {
-                    if let Some(this) = this.upgrade(&cx) {
-                        this.update(&mut cx, |this, cx| {
-                            if let Some(metadata) = this.messages_metadata.get_mut(&excerpt_id) {
+                let result = stream_completion.await;
+                if let Some(this) = this.upgrade(&cx) {
+                    this.update(&mut cx, |this, cx| {
+                        if let Err(error) = result {
+                            if let Some(metadata) = this
+                                .messages_metadata
+                                .get_mut(&assistant_message.excerpt_id)
+                            {
                                 metadata.error = Some(error.to_string().trim().into());
                                 cx.notify();
                             }
-                        });
-                    }
+                        }
+                    });
                 }
-            });
+            }
+        });
 
-            self.pending_completions.push(PendingCompletion {
-                id: post_inc(&mut self.completion_count),
-                _task: task,
-            });
-        }
+        self.pending_completions.push(PendingCompletion {
+            id: post_inc(&mut self.completion_count),
+            _task: task,
+        });
+        Some((assistant_message, user_message))
     }
 
     fn cancel_last_assist(&mut self) -> bool {
@@ -646,7 +666,7 @@ impl Assistant {
         excerpt_id: ExcerptId,
         role: Role,
         cx: &mut ModelContext<Self>,
-    ) -> (ExcerptId, ModelHandle<Buffer>) {
+    ) -> Message {
         let content = cx.add_model(|cx| {
             let mut buffer = Buffer::new(0, "", cx);
             let markdown = self.languages.language_for_name("Markdown");
@@ -684,13 +704,11 @@ impl Assistant {
             .iter()
             .position(|message| message.excerpt_id == excerpt_id)
             .map_or(self.messages.len(), |ix| ix + 1);
-        self.messages.insert(
-            ix,
-            Message {
-                excerpt_id: new_excerpt_id,
-                content: content.clone(),
-            },
-        );
+        let message = Message {
+            excerpt_id: new_excerpt_id,
+            content: content.clone(),
+        };
+        self.messages.insert(ix, message.clone());
         self.messages_metadata.insert(
             new_excerpt_id,
             MessageMetadata {
@@ -699,7 +717,7 @@ impl Assistant {
                 error: None,
             },
         );
-        (new_excerpt_id, content)
+        message
     }
 
     fn summarize(&mut self, cx: &mut ModelContext<Self>) {
@@ -766,6 +784,7 @@ enum AssistantEditorEvent {
 struct AssistantEditor {
     assistant: ModelHandle<Assistant>,
     editor: ViewHandle<Editor>,
+    scroll_bottom: ScrollAnchor,
     _subscriptions: Vec<Subscription>,
 }
 
@@ -875,37 +894,64 @@ impl AssistantEditor {
         let _subscriptions = vec![
             cx.observe(&assistant, |_, _, cx| cx.notify()),
             cx.subscribe(&assistant, Self::handle_assistant_event),
+            cx.subscribe(&editor, Self::handle_editor_event),
         ];
 
         Self {
             assistant,
             editor,
+            scroll_bottom: ScrollAnchor {
+                offset: Default::default(),
+                anchor: Anchor::max(),
+            },
             _subscriptions,
         }
     }
 
     fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
-        self.assistant.update(cx, |assistant, cx| {
+        let user_message = self.assistant.update(cx, |assistant, cx| {
             let editor = self.editor.read(cx);
             let newest_selection = editor.selections.newest_anchor();
             let excerpt_id = if newest_selection.head() == Anchor::min() {
-                assistant.messages.first().map(|message| message.excerpt_id)
+                assistant
+                    .messages
+                    .first()
+                    .map(|message| message.excerpt_id)?
             } else if newest_selection.head() == Anchor::max() {
-                assistant.messages.last().map(|message| message.excerpt_id)
+                assistant
+                    .messages
+                    .last()
+                    .map(|message| message.excerpt_id)?
             } else {
-                Some(newest_selection.head().excerpt_id())
+                newest_selection.head().excerpt_id()
             };
 
-            if let Some(excerpt_id) = excerpt_id {
-                if let Some(metadata) = assistant.messages_metadata.get(&excerpt_id) {
-                    if metadata.role == Role::User {
-                        assistant.assist(cx);
-                    } else {
-                        assistant.insert_message_after(excerpt_id, Role::User, cx);
-                    }
-                }
-            }
+            let metadata = assistant.messages_metadata.get(&excerpt_id)?;
+            let user_message = if metadata.role == Role::User {
+                let (_, user_message) = assistant.assist(cx)?;
+                user_message
+            } else {
+                let user_message = assistant.insert_message_after(excerpt_id, Role::User, cx);
+                user_message
+            };
+            Some(user_message)
         });
+
+        if let Some(user_message) = user_message {
+            self.editor.update(cx, |editor, cx| {
+                let cursor = editor
+                    .buffer()
+                    .read(cx)
+                    .snapshot(cx)
+                    .anchor_in_excerpt(user_message.excerpt_id, language::Anchor::MIN);
+                editor.change_selections(
+                    Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
+                    cx,
+                    |selections| selections.select_anchor_ranges([cursor..cursor]),
+                );
+            });
+            self.update_scroll_bottom(cx);
+        }
     }
 
     fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
@@ -919,7 +965,7 @@ impl AssistantEditor {
 
     fn handle_assistant_event(
         &mut self,
-        assistant: ModelHandle<Assistant>,
+        _: ModelHandle<Assistant>,
         event: &AssistantEvent,
         cx: &mut ViewContext<Self>,
     ) {
@@ -931,16 +977,70 @@ impl AssistantEditor {
                     .map(|selection| selection.head())
                     .collect::<HashSet<usize>>();
                 let ids = ids.iter().copied().collect::<HashSet<_>>();
-                assistant.update(cx, |assistant, cx| {
+                self.assistant.update(cx, |assistant, cx| {
                     assistant.remove_empty_messages(ids, selection_heads, cx)
                 });
             }
             AssistantEvent::SummaryChanged => {
                 cx.emit(AssistantEditorEvent::TabContentChanged);
             }
+            AssistantEvent::StreamedCompletion => {
+                self.editor.update(cx, |editor, cx| {
+                    let snapshot = editor.snapshot(cx);
+                    let scroll_bottom_row = self
+                        .scroll_bottom
+                        .anchor
+                        .to_display_point(&snapshot.display_snapshot)
+                        .row();
+
+                    let scroll_bottom = scroll_bottom_row as f32 + self.scroll_bottom.offset.y();
+                    let visible_line_count = editor.visible_line_count().unwrap_or(0.);
+                    let scroll_top = scroll_bottom - visible_line_count;
+                    editor
+                        .set_scroll_position(vec2f(self.scroll_bottom.offset.x(), scroll_top), cx);
+                });
+            }
+        }
+    }
+
+    fn handle_editor_event(
+        &mut self,
+        _: ViewHandle<Editor>,
+        event: &editor::Event,
+        cx: &mut ViewContext<Self>,
+    ) {
+        match event {
+            editor::Event::ScrollPositionChanged { .. } => self.update_scroll_bottom(cx),
+            _ => {}
         }
     }
 
+    fn update_scroll_bottom(&mut self, cx: &mut ViewContext<Self>) {
+        self.editor.update(cx, |editor, cx| {
+            let snapshot = editor.snapshot(cx);
+            let scroll_position = editor
+                .scroll_manager
+                .anchor()
+                .scroll_position(&snapshot.display_snapshot);
+            let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
+            let scroll_bottom_point = cmp::min(
+                DisplayPoint::new(scroll_bottom.floor() as u32, 0),
+                snapshot.display_snapshot.max_point(),
+            );
+            let scroll_bottom_anchor = snapshot
+                .buffer_snapshot
+                .anchor_after(scroll_bottom_point.to_point(&snapshot.display_snapshot));
+            let scroll_bottom_offset = vec2f(
+                scroll_position.x(),
+                scroll_bottom - scroll_bottom_point.row() as f32,
+            );
+            self.scroll_bottom = ScrollAnchor {
+                anchor: scroll_bottom_anchor,
+                offset: scroll_bottom_offset,
+            };
+        });
+    }
+
     fn quote_selection(
         workspace: &mut Workspace,
         _: &QuoteSelection,
@@ -1155,7 +1255,7 @@ impl Item for AssistantEditor {
     }
 }
 
-#[derive(Debug)]
+#[derive(Clone, Debug)]
 struct Message {
     excerpt_id: ExcerptId,
     content: ModelHandle<Buffer>,
@@ -1265,15 +1365,16 @@ mod tests {
 
         cx.add_model(|cx| {
             let mut assistant = Assistant::new(Default::default(), registry, cx);
-            let (excerpt_1, _) =
-                assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
-            let (excerpt_2, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
-            let (excerpt_3, _) = assistant.insert_message_after(excerpt_1, Role::User, cx);
+            let message_1 = assistant.insert_message_after(ExcerptId::max(), Role::Assistant, cx);
+            let message_2 = assistant.insert_message_after(message_1.excerpt_id, Role::User, cx);
+            let message_3 = assistant.insert_message_after(message_1.excerpt_id, Role::User, cx);
             assistant.remove_empty_messages(
-                HashSet::from_iter([excerpt_2, excerpt_3]),
+                HashSet::from_iter([message_2.excerpt_id, message_3.excerpt_id]),
                 Default::default(),
                 cx,
             );
+            assert_eq!(assistant.messages.len(), 1);
+            assert_eq!(assistant.messages[0].excerpt_id, message_1.excerpt_id);
             assistant
         });
     }

crates/editor/src/editor_tests.rs 🔗

@@ -579,7 +579,7 @@ async fn test_navigation_history(cx: &mut TestAppContext) {
         assert_eq!(editor.scroll_manager.anchor(), original_scroll_position);
 
         // Ensure we don't panic when navigation data contains invalid anchors *and* points.
-        let mut invalid_anchor = editor.scroll_manager.anchor().top_anchor;
+        let mut invalid_anchor = editor.scroll_manager.anchor().anchor;
         invalid_anchor.text_anchor.buffer_id = Some(999);
         let invalid_point = Point::new(9999, 0);
         editor.navigate(
@@ -587,7 +587,7 @@ async fn test_navigation_history(cx: &mut TestAppContext) {
                 cursor_anchor: invalid_anchor,
                 cursor_position: invalid_point,
                 scroll_anchor: ScrollAnchor {
-                    top_anchor: invalid_anchor,
+                    anchor: invalid_anchor,
                     offset: Default::default(),
                 },
                 scroll_top_row: invalid_point.row,
@@ -5815,7 +5815,7 @@ async fn test_following(cx: &mut gpui::TestAppContext) {
         let top_anchor = follower.buffer().read(cx).read(cx).anchor_after(0);
         follower.set_scroll_anchor(
             ScrollAnchor {
-                top_anchor,
+                anchor: top_anchor,
                 offset: vec2f(0.0, 0.5),
             },
             cx,

crates/editor/src/items.rs 🔗

@@ -196,7 +196,7 @@ impl FollowableItem for Editor {
             singleton: buffer.is_singleton(),
             title: (!buffer.is_singleton()).then(|| buffer.title(cx).into()),
             excerpts,
-            scroll_top_anchor: Some(serialize_anchor(&scroll_anchor.top_anchor)),
+            scroll_top_anchor: Some(serialize_anchor(&scroll_anchor.anchor)),
             scroll_x: scroll_anchor.offset.x(),
             scroll_y: scroll_anchor.offset.y(),
             selections: self
@@ -253,7 +253,7 @@ impl FollowableItem for Editor {
                 }
                 Event::ScrollPositionChanged { .. } => {
                     let scroll_anchor = self.scroll_manager.anchor();
-                    update.scroll_top_anchor = Some(serialize_anchor(&scroll_anchor.top_anchor));
+                    update.scroll_top_anchor = Some(serialize_anchor(&scroll_anchor.anchor));
                     update.scroll_x = scroll_anchor.offset.x();
                     update.scroll_y = scroll_anchor.offset.y();
                     true
@@ -412,7 +412,7 @@ async fn update_editor_from_message(
         } else if let Some(scroll_top_anchor) = scroll_top_anchor {
             editor.set_scroll_anchor_remote(
                 ScrollAnchor {
-                    top_anchor: scroll_top_anchor,
+                    anchor: scroll_top_anchor,
                     offset: vec2f(message.scroll_x, message.scroll_y),
                 },
                 cx,
@@ -510,8 +510,8 @@ impl Item for Editor {
             };
 
             let mut scroll_anchor = data.scroll_anchor;
-            if !buffer.can_resolve(&scroll_anchor.top_anchor) {
-                scroll_anchor.top_anchor = buffer.anchor_before(
+            if !buffer.can_resolve(&scroll_anchor.anchor) {
+                scroll_anchor.anchor = buffer.anchor_before(
                     buffer.clip_point(Point::new(data.scroll_top_row, 0), Bias::Left),
                 );
             }

crates/editor/src/scroll.rs 🔗

@@ -36,21 +36,21 @@ pub struct ScrollbarAutoHide(pub bool);
 #[derive(Clone, Copy, Debug, PartialEq)]
 pub struct ScrollAnchor {
     pub offset: Vector2F,
-    pub top_anchor: Anchor,
+    pub anchor: Anchor,
 }
 
 impl ScrollAnchor {
     fn new() -> Self {
         Self {
             offset: Vector2F::zero(),
-            top_anchor: Anchor::min(),
+            anchor: Anchor::min(),
         }
     }
 
     pub fn scroll_position(&self, snapshot: &DisplaySnapshot) -> Vector2F {
         let mut scroll_position = self.offset;
-        if self.top_anchor != Anchor::min() {
-            let scroll_top = self.top_anchor.to_display_point(snapshot).row() as f32;
+        if self.anchor != Anchor::min() {
+            let scroll_top = self.anchor.to_display_point(snapshot).row() as f32;
             scroll_position.set_y(scroll_top + scroll_position.y());
         } else {
             scroll_position.set_y(0.);
@@ -59,7 +59,7 @@ impl ScrollAnchor {
     }
 
     pub fn top_row(&self, buffer: &MultiBufferSnapshot) -> u32 {
-        self.top_anchor.to_point(buffer).row
+        self.anchor.to_point(buffer).row
     }
 }
 
@@ -179,7 +179,7 @@ impl ScrollManager {
         let (new_anchor, top_row) = if scroll_position.y() <= 0. {
             (
                 ScrollAnchor {
-                    top_anchor: Anchor::min(),
+                    anchor: Anchor::min(),
                     offset: scroll_position.max(vec2f(0., 0.)),
                 },
                 0,
@@ -193,7 +193,7 @@ impl ScrollManager {
 
             (
                 ScrollAnchor {
-                    top_anchor,
+                    anchor: top_anchor,
                     offset: vec2f(
                         scroll_position.x(),
                         scroll_position.y() - top_anchor.to_display_point(&map).row() as f32,
@@ -322,7 +322,7 @@ impl Editor {
         hide_hover(self, cx);
         let workspace_id = self.workspace.as_ref().map(|workspace| workspace.1);
         let top_row = scroll_anchor
-            .top_anchor
+            .anchor
             .to_point(&self.buffer().read(cx).snapshot(cx))
             .row;
         self.scroll_manager
@@ -337,7 +337,7 @@ impl Editor {
         hide_hover(self, cx);
         let workspace_id = self.workspace.as_ref().map(|workspace| workspace.1);
         let top_row = scroll_anchor
-            .top_anchor
+            .anchor
             .to_point(&self.buffer().read(cx).snapshot(cx))
             .row;
         self.scroll_manager
@@ -377,7 +377,7 @@ impl Editor {
         let screen_top = self
             .scroll_manager
             .anchor
-            .top_anchor
+            .anchor
             .to_display_point(&snapshot);
 
         if screen_top > newest_head {
@@ -408,7 +408,7 @@ impl Editor {
                 .anchor_at(Point::new(top_row as u32, 0), Bias::Left);
             let scroll_anchor = ScrollAnchor {
                 offset: Vector2F::new(x, y),
-                top_anchor,
+                anchor: top_anchor,
             };
             self.set_scroll_anchor(scroll_anchor, cx);
         }

crates/editor/src/scroll/actions.rs 🔗

@@ -86,7 +86,7 @@ impl Editor {
 
         editor.set_scroll_anchor(
             ScrollAnchor {
-                top_anchor: new_anchor,
+                anchor: new_anchor,
                 offset: Default::default(),
             },
             cx,
@@ -113,7 +113,7 @@ impl Editor {
 
         editor.set_scroll_anchor(
             ScrollAnchor {
-                top_anchor: new_anchor,
+                anchor: new_anchor,
                 offset: Default::default(),
             },
             cx,
@@ -143,7 +143,7 @@ impl Editor {
 
         editor.set_scroll_anchor(
             ScrollAnchor {
-                top_anchor: new_anchor,
+                anchor: new_anchor,
                 offset: Default::default(),
             },
             cx,

crates/vim/src/normal.rs 🔗

@@ -400,7 +400,7 @@ fn scroll(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContext<Edito
         };
 
         let scroll_margin_rows = editor.vertical_scroll_margin() as u32;
-        let top_anchor = editor.scroll_manager.anchor().top_anchor;
+        let top_anchor = editor.scroll_manager.anchor().anchor;
 
         editor.change_selections(None, cx, |s| {
             s.replace_cursors_with(|snapshot| {