Persist token count and scroll position across agent restarts (#50620)

Eric Holk created

Release Notes:

- Token counts and scroll position are restored when loading a previous
agent thread

Change summary

crates/acp_thread/src/acp_thread.rs                | 11 ++
crates/agent/src/agent.rs                          | 39 ++++++++++
crates/agent/src/db.rs                             | 60 ++++++++++++++++
crates/agent/src/thread.rs                         | 20 +++++
crates/agent/src/thread_store.rs                   |  1 
crates/agent_ui/src/connection_view.rs             |  4 +
crates/agent_ui/src/connection_view/thread_view.rs | 56 +++++++++++---
7 files changed, 178 insertions(+), 13 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -972,6 +972,8 @@ pub struct AcpThread {
     had_error: bool,
     /// The user's unsent prompt text, persisted so it can be restored when reloading the thread.
     draft_prompt: Option<Vec<acp::ContentBlock>>,
+    /// The initial scroll position for the thread view, set during session registration.
+    ui_scroll_position: Option<gpui::ListOffset>,
 }
 
 impl From<&AcpThread> for ActionLogTelemetry {
@@ -1210,6 +1212,7 @@ impl AcpThread {
             pending_terminal_exit: HashMap::default(),
             had_error: false,
             draft_prompt: None,
+            ui_scroll_position: None,
         }
     }
 
@@ -1229,6 +1232,14 @@ impl AcpThread {
         self.draft_prompt = prompt;
     }
 
+    pub fn ui_scroll_position(&self) -> Option<gpui::ListOffset> {
+        self.ui_scroll_position
+    }
+
+    pub fn set_ui_scroll_position(&mut self, position: Option<gpui::ListOffset>) {
+        self.ui_scroll_position = position;
+    }
+
     pub fn connection(&self) -> &Rc<dyn AgentConnection> {
         &self.connection
     }

crates/agent/src/agent.rs 🔗

@@ -352,6 +352,8 @@ impl NativeAgent {
         let parent_session_id = thread.parent_thread_id();
         let title = thread.title();
         let draft_prompt = thread.draft_prompt().map(Vec::from);
+        let scroll_position = thread.ui_scroll_position();
+        let token_usage = thread.latest_token_usage();
         let project = thread.project.clone();
         let action_log = thread.action_log.clone();
         let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
@@ -367,6 +369,8 @@ impl NativeAgent {
                 cx,
             );
             acp_thread.set_draft_prompt(draft_prompt);
+            acp_thread.set_ui_scroll_position(scroll_position);
+            acp_thread.update_token_usage(token_usage, cx);
             acp_thread
         });
 
@@ -1917,7 +1921,9 @@ mod internal_tests {
     use gpui::TestAppContext;
     use indoc::formatdoc;
     use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
-    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
+    use language_model::{
+        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
+    };
     use serde_json::json;
     use settings::SettingsStore;
     use util::{path, rel_path::rel_path};
@@ -2549,6 +2555,13 @@ mod internal_tests {
         cx.run_until_parked();
 
         model.send_last_completion_stream_text_chunk("Lorem.");
+        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+            language_model::TokenUsage {
+                input_tokens: 150,
+                output_tokens: 75,
+                ..Default::default()
+            },
+        ));
         model.end_last_completion_stream();
         cx.run_until_parked();
         summary_model
@@ -2587,6 +2600,12 @@ mod internal_tests {
         acp_thread.update(cx, |thread, _cx| {
             thread.set_draft_prompt(Some(draft_blocks.clone()));
         });
+        thread.update(cx, |thread, _cx| {
+            thread.set_ui_scroll_position(Some(gpui::ListOffset {
+                item_ix: 5,
+                offset_in_item: gpui::px(12.5),
+            }));
+        });
         thread.update(cx, |_thread, cx| cx.notify());
         cx.run_until_parked();
 
@@ -2632,6 +2651,24 @@ mod internal_tests {
         acp_thread.read_with(cx, |thread, _| {
             assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
         });
+
+        // Ensure token usage survived the round-trip.
+        acp_thread.read_with(cx, |thread, _| {
+            let usage = thread
+                .token_usage()
+                .expect("token usage should be restored after reload");
+            assert_eq!(usage.input_tokens, 150);
+            assert_eq!(usage.output_tokens, 75);
+        });
+
+        // Ensure scroll position survived the round-trip.
+        acp_thread.read_with(cx, |thread, _| {
+            let scroll = thread
+                .ui_scroll_position()
+                .expect("scroll position should be restored after reload");
+            assert_eq!(scroll.item_ix, 5);
+            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
+        });
     }
 
     fn thread_entries(

crates/agent/src/db.rs 🔗

@@ -66,6 +66,14 @@ pub struct DbThread {
     pub thinking_effort: Option<String>,
     #[serde(default)]
     pub draft_prompt: Option<Vec<acp::ContentBlock>>,
+    #[serde(default)]
+    pub ui_scroll_position: Option<SerializedScrollPosition>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
+pub struct SerializedScrollPosition {
+    pub item_ix: usize,
+    pub offset_in_item: f32,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -108,6 +116,7 @@ impl SharedThread {
             thinking_enabled: false,
             thinking_effort: None,
             draft_prompt: None,
+            ui_scroll_position: None,
         }
     }
 
@@ -286,6 +295,7 @@ impl DbThread {
             thinking_enabled: false,
             thinking_effort: None,
             draft_prompt: None,
+            ui_scroll_position: None,
         })
     }
 }
@@ -637,6 +647,7 @@ mod tests {
             thinking_enabled: false,
             thinking_effort: None,
             draft_prompt: None,
+            ui_scroll_position: None,
         }
     }
 
@@ -841,4 +852,53 @@ mod tests {
         assert_eq!(threads.len(), 1);
         assert!(threads[0].folder_paths.is_empty());
     }
+
+    #[test]
+    fn test_scroll_position_defaults_to_none() {
+        let json = r#"{
+            "title": "Old Thread",
+            "messages": [],
+            "updated_at": "2024-01-01T00:00:00Z"
+        }"#;
+
+        let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
+
+        assert!(
+            db_thread.ui_scroll_position.is_none(),
+            "Legacy threads without scroll_position field should default to None"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
+        let database = ThreadsDatabase::new(cx.executor()).unwrap();
+
+        let thread_id = session_id("thread-with-scroll");
+
+        let mut thread = make_thread(
+            "Thread With Scroll",
+            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
+        );
+        thread.ui_scroll_position = Some(SerializedScrollPosition {
+            item_ix: 42,
+            offset_in_item: 13.5,
+        });
+
+        database
+            .save_thread(thread_id.clone(), thread, PathList::default())
+            .await
+            .unwrap();
+
+        let loaded = database
+            .load_thread(thread_id)
+            .await
+            .unwrap()
+            .expect("thread should exist");
+
+        let scroll = loaded
+            .ui_scroll_position
+            .expect("scroll_position should be restored");
+        assert_eq!(scroll.item_ix, 42);
+        assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
+    }
 }

crates/agent/src/thread.rs 🔗

@@ -901,6 +901,7 @@ pub struct Thread {
     subagent_context: Option<SubagentContext>,
     /// The user's unsent prompt text, persisted so it can be restored when reloading the thread.
     draft_prompt: Option<Vec<acp::ContentBlock>>,
+    ui_scroll_position: Option<gpui::ListOffset>,
     /// Weak references to running subagent threads for cancellation propagation
     running_subagents: Vec<WeakEntity<Thread>>,
 }
@@ -1017,6 +1018,7 @@ impl Thread {
             imported: false,
             subagent_context: None,
             draft_prompt: None,
+            ui_scroll_position: None,
             running_subagents: Vec::new(),
         }
     }
@@ -1233,6 +1235,10 @@ impl Thread {
             imported: db_thread.imported,
             subagent_context: db_thread.subagent_context,
             draft_prompt: db_thread.draft_prompt,
+            ui_scroll_position: db_thread.ui_scroll_position.map(|sp| gpui::ListOffset {
+                item_ix: sp.item_ix,
+                offset_in_item: gpui::px(sp.offset_in_item),
+            }),
             running_subagents: Vec::new(),
         }
     }
@@ -1258,6 +1264,12 @@ impl Thread {
             thinking_enabled: self.thinking_enabled,
             thinking_effort: self.thinking_effort.clone(),
             draft_prompt: self.draft_prompt.clone(),
+            ui_scroll_position: self.ui_scroll_position.map(|lo| {
+                crate::db::SerializedScrollPosition {
+                    item_ix: lo.item_ix,
+                    offset_in_item: lo.offset_in_item.as_f32(),
+                }
+            }),
         };
 
         cx.background_spawn(async move {
@@ -1307,6 +1319,14 @@ impl Thread {
         self.draft_prompt = prompt;
     }
 
+    pub fn ui_scroll_position(&self) -> Option<gpui::ListOffset> {
+        self.ui_scroll_position
+    }
+
+    pub fn set_ui_scroll_position(&mut self, position: Option<gpui::ListOffset>) {
+        self.ui_scroll_position = position;
+    }
+
     pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
         self.model.as_ref()
     }

crates/agent_ui/src/connection_view.rs 🔗

@@ -845,6 +845,10 @@ impl ConnectionView {
             );
         });
 
+        if let Some(scroll_position) = thread.read(cx).ui_scroll_position() {
+            list_state.scroll_to(scroll_position);
+        }
+
         AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx);
 
         let connection = thread.read(cx).connection().clone();

crates/agent_ui/src/connection_view/thread_view.rs 🔗

@@ -248,7 +248,8 @@ pub struct ThreadView {
     pub resumed_without_history: bool,
     pub resume_thread_metadata: Option<AgentSessionInfo>,
     pub _cancel_task: Option<Task<()>>,
-    _draft_save_task: Option<Task<()>>,
+    _save_task: Option<Task<()>>,
+    _draft_resolve_task: Option<Task<()>>,
     pub skip_queue_processing_count: usize,
     pub user_interrupted_generation: bool,
     pub can_fast_track_queue: bool,
@@ -396,7 +397,7 @@ impl ThreadView {
             } else {
                 Some(editor.update(cx, |editor, cx| editor.draft_contents(cx)))
             };
-            this._draft_save_task = Some(cx.spawn(async move |this, cx| {
+            this._draft_resolve_task = Some(cx.spawn(async move |this, cx| {
                 let draft = if let Some(task) = draft_contents_task {
                     let blocks = task.await.ok().filter(|b| !b.is_empty());
                     blocks
@@ -407,15 +408,7 @@ impl ThreadView {
                     this.thread.update(cx, |thread, _cx| {
                         thread.set_draft_prompt(draft);
                     });
-                })
-                .ok();
-                cx.background_executor()
-                    .timer(SERIALIZATION_THROTTLE_TIME)
-                    .await;
-                this.update(cx, |this, cx| {
-                    if let Some(thread) = this.as_native_thread(cx) {
-                        thread.update(cx, |_thread, cx| cx.notify());
-                    }
+                    this.schedule_save(cx);
                 })
                 .ok();
             }));
@@ -471,7 +464,8 @@ impl ThreadView {
             is_loading_contents: false,
             new_server_version_available: None,
             _cancel_task: None,
-            _draft_save_task: None,
+            _save_task: None,
+            _draft_resolve_task: None,
             skip_queue_processing_count: 0,
             user_interrupted_generation: false,
             can_fast_track_queue: false,
@@ -487,12 +481,50 @@ impl ThreadView {
             _history_subscription: history_subscription,
             show_codex_windows_warning,
         };
+        let list_state_for_scroll = this.list_state.clone();
+        let thread_view = cx.entity().downgrade();
+        this.list_state
+            .set_scroll_handler(move |_event, _window, cx| {
+                let list_state = list_state_for_scroll.clone();
+                let thread_view = thread_view.clone();
+                // N.B. We must defer because the scroll handler is called while the
+                // ListState's RefCell is mutably borrowed. Reading logical_scroll_top()
+                // directly would panic from a double borrow.
+                cx.defer(move |cx| {
+                    let scroll_top = list_state.logical_scroll_top();
+                    let _ = thread_view.update(cx, |this, cx| {
+                        if let Some(thread) = this.as_native_thread(cx) {
+                            thread.update(cx, |thread, _cx| {
+                                thread.set_ui_scroll_position(Some(scroll_top));
+                            });
+                        }
+                        this.schedule_save(cx);
+                    });
+                });
+            });
+
         if should_auto_submit {
             this.send(window, cx);
         }
         this
     }
 
+    /// Schedule a throttled save of the thread state (draft prompt, scroll position, etc.).
+    /// Multiple calls within `SERIALIZATION_THROTTLE_TIME` are coalesced into a single save.
+    fn schedule_save(&mut self, cx: &mut Context<Self>) {
+        self._save_task = Some(cx.spawn(async move |this, cx| {
+            cx.background_executor()
+                .timer(SERIALIZATION_THROTTLE_TIME)
+                .await;
+            this.update(cx, |this, cx| {
+                if let Some(thread) = this.as_native_thread(cx) {
+                    thread.update(cx, |_thread, cx| cx.notify());
+                }
+            })
+            .ok();
+        }));
+    }
+
     pub fn handle_message_editor_event(
         &mut self,
         _editor: &Entity<MessageEditor>,