Wire up history completely

Conrad Irwin and Antonio Scandurra created

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/acp_thread/src/connection.rs    |  3 
crates/agent2/src/agent.rs             | 69 +++++++++++++++------------
crates/agent2/src/db.rs                |  8 +-
crates/agent2/src/history_store.rs     | 58 +++++++---------------
crates/agent2/src/tests/mod.rs         |  4 
crates/agent2/src/thread.rs            | 32 +++++++-----
crates/agent_ui/src/acp/thread_view.rs | 17 ++++--
7 files changed, 93 insertions(+), 98 deletions(-)

Detailed changes

crates/acp_thread/src/connection.rs 🔗

@@ -2,7 +2,6 @@ use crate::{AcpThread, AcpThreadMetadata};
 use agent_client_protocol::{self as acp};
 use anyhow::Result;
 use collections::IndexMap;
-use futures::channel::mpsc::UnboundedReceiver;
 use gpui::{Entity, SharedString, Task};
 use project::Project;
 use serde::{Deserialize, Serialize};
@@ -27,6 +26,8 @@ pub trait AgentConnection {
         cx: &mut App,
     ) -> Task<Result<Entity<AcpThread>>>;
 
+    // todo!(expose a history trait, and include list_threads and load_thread)
+    // todo!(write a test)
     fn list_threads(
         &self,
         _cx: &mut App,

crates/agent2/src/agent.rs 🔗

@@ -5,16 +5,15 @@ use crate::{
     OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
     UserMessageContent, WebSearchTool, templates::Templates,
 };
-use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id};
+use crate::{ThreadsDatabase, generate_session_id};
 use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
 use agent_client_protocol as acp;
 use agent_settings::AgentSettings;
 use anyhow::{Context as _, Result, anyhow};
 use collections::{HashSet, IndexMap};
 use fs::Fs;
-use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
-use futures::future::Shared;
-use futures::{SinkExt, StreamExt, future};
+use futures::channel::mpsc;
+use futures::{StreamExt, future};
 use gpui::{
     App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
 };
@@ -30,6 +29,7 @@ use std::collections::HashMap;
 use std::path::Path;
 use std::rc::Rc;
 use std::sync::Arc;
+use std::time::Duration;
 use util::ResultExt;
 
 const RULES_FILE_NAMES: [&'static str; 9] = [
@@ -174,7 +174,7 @@ pub struct NativeAgent {
     prompt_store: Option<Entity<PromptStore>>,
     thread_database: Arc<ThreadsDatabase>,
     history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
-    load_history: Task<Result<()>>,
+    load_history: Task<()>,
     fs: Arc<dyn Fs>,
     _subscriptions: Vec<Subscription>,
 }
@@ -212,7 +212,7 @@ impl NativeAgent {
 
             let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
                 watch::channel(());
-            let this = Self {
+            let mut this = Self {
                 sessions: HashMap::new(),
                 project_context: Rc::new(RefCell::new(project_context)),
                 project_context_needs_refresh: project_context_needs_refresh_tx,
@@ -229,7 +229,7 @@ impl NativeAgent {
                 prompt_store,
                 fs,
                 history: watch::channel(None).0,
-                load_history: Task::ready(Ok(())),
+                load_history: Task::ready(()),
                 _subscriptions: subscriptions,
             };
             this.reload_history(cx);
@@ -249,7 +249,7 @@ impl NativeAgent {
             Session {
                 thread: thread.clone(),
                 acp_thread: acp_thread.downgrade(),
-                save_task: Task::ready(()),
+                save_task: Task::ready(Ok(())),
                 _subscriptions: vec![
                     cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
                         this.sessions.remove(acp_thread.session_id());
@@ -280,24 +280,30 @@ impl NativeAgent {
     }
 
     fn reload_history(&mut self, cx: &mut Context<Self>) {
+        dbg!("");
         let thread_database = self.thread_database.clone();
         self.load_history = cx.spawn(async move |this, cx| {
             let results = cx
                 .background_spawn(async move {
                     let results = thread_database.list_threads().await?;
-                    Ok(results
-                        .into_iter()
-                        .map(|thread| AcpThreadMetadata {
-                            agent: NATIVE_AGENT_SERVER_NAME.clone(),
-                            id: thread.id.into(),
-                            title: thread.title,
-                            updated_at: thread.updated_at,
-                        })
-                        .collect())
+                    dbg!(&results);
+                    anyhow::Ok(
+                        results
+                            .into_iter()
+                            .map(|thread| AcpThreadMetadata {
+                                agent: NATIVE_AGENT_SERVER_NAME.clone(),
+                                id: thread.id.into(),
+                                title: thread.title,
+                                updated_at: thread.updated_at,
+                            })
+                            .collect(),
+                    )
                 })
-                .await?;
-            this.update(cx, |this, cx| this.history.send(Some(results)))?;
-            anyhow::Ok(())
+                .await;
+            if let Some(results) = results.log_err() {
+                this.update(cx, |this, _| this.history.send(Some(results)))
+                    .ok();
+            }
         });
     }
 
@@ -509,10 +515,10 @@ impl NativeAgent {
     ) {
         self.models.refresh_list(cx);
         for session in self.sessions.values_mut() {
-            session.thread.update(cx, |thread, _| {
+            session.thread.update(cx, |thread, cx| {
                 let model_id = LanguageModels::model_id(&thread.model());
                 if let Some(model) = self.models.model_from_id(&model_id) {
-                    thread.set_model(model.clone());
+                    thread.set_model(model.clone(), cx);
                 }
             });
         }
@@ -715,8 +721,8 @@ impl AgentModelSelector for NativeAgentConnection {
             return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
         };
 
-        thread.update(cx, |thread, _cx| {
-            thread.set_model(model.clone());
+        thread.update(cx, |thread, cx| {
+            thread.set_model(model.clone(), cx);
         });
 
         update_settings_file::<AgentSettings>(
@@ -867,12 +873,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         session_id: acp::SessionId,
         cx: &mut App,
     ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
-        let thread_id = ThreadId::from(session_id.clone());
         let database = self.0.update(cx, |this, _| this.thread_database.clone());
         cx.spawn(async move |cx| {
-            let database = database.await.map_err(|e| anyhow!(e))?;
             let db_thread = database
-                .load_thread(thread_id.clone())
+                .load_thread(session_id.clone())
                 .await?
                 .context("no such thread found")?;
 
@@ -915,7 +919,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
 
                 let thread = cx.new(|cx| {
                     let mut thread = Thread::from_db(
-                        thread_id,
+                        session_id,
                         db_thread,
                         project.clone(),
                         agent.project_context.clone(),
@@ -934,7 +938,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
 
             // Store the session
             agent.update(cx, |agent, cx| {
-                agent.insert_session(session_id, thread, acp_thread, cx)
+                agent.insert_session(thread.clone(), acp_thread.clone(), cx)
             })?;
 
             let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
@@ -995,7 +999,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         log::info!("Cancelling on session: {}", session_id);
         self.0.update(cx, |agent, cx| {
             if let Some(agent) = agent.sessions.get(session_id) {
-                agent.thread.update(cx, |thread, _cx| thread.cancel());
+                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
             }
         });
     }
@@ -1022,7 +1026,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
 
 impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
     fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
-        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
+        Task::ready(
+            self.0
+                .update(cx, |thread, cx| thread.truncate(message_id, cx)),
+        )
     }
 }
 

crates/agent2/src/db.rs 🔗

@@ -1,4 +1,4 @@
-use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent};
+use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
 use agent::thread_store;
 use agent_client_protocol as acp;
 use agent_settings::{AgentProfileId, CompletionMode};
@@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel;
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct DbThreadMetadata {
-    pub id: ThreadId,
+    pub id: acp::SessionId,
     #[serde(alias = "summary")]
     pub title: SharedString,
     pub updated_at: DateTime<Utc>,
@@ -323,7 +323,7 @@ impl ThreadsDatabase {
 
             for (id, summary, updated_at) in rows {
                 threads.push(DbThreadMetadata {
-                    id: ThreadId(id),
+                    id: acp::SessionId(id),
                     title: summary.into(),
                     updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
                 });
@@ -333,7 +333,7 @@ impl ThreadsDatabase {
         })
     }
 
-    pub fn load_thread(&self, id: ThreadId) -> Task<Result<Option<DbThread>>> {
+    pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
         let connection = self.connection.clone();
 
         self.executor.spawn(async move {

crates/agent2/src/history_store.rs 🔗

@@ -1,17 +1,13 @@
 use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
-use agent::{ThreadId, thread_store::ThreadStore};
 use agent_client_protocol as acp;
 use anyhow::{Context as _, Result};
 use assistant_context::SavedContextMetadata;
 use chrono::{DateTime, Utc};
 use collections::HashMap;
-use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
-use itertools::Itertools;
-use paths::contexts_dir;
+use gpui::{SharedString, Task, prelude::*};
 use serde::{Deserialize, Serialize};
 use smol::stream::StreamExt;
-use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
-use util::ResultExt as _;
+use std::{path::Path, sync::Arc, time::Duration};
 
 const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
 const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
@@ -64,16 +60,16 @@ enum SerializedRecentOpen {
 }
 
 pub struct AgentHistory {
-    entries: HashMap<acp::SessionId, AcpThreadMetadata>,
-    _task: Task<Result<()>>,
+    entries: watch::Receiver<Option<Vec<AcpThreadMetadata>>>,
+    _task: Task<()>,
 }
 
 pub struct HistoryStore {
-    agents: HashMap<AgentServerName, AgentHistory>,
+    agents: HashMap<AgentServerName, AgentHistory>, // todo!() text threads
 }
 
 impl HistoryStore {
-    pub fn new(cx: &mut Context<Self>) -> Self {
+    pub fn new(_cx: &mut Context<Self>) -> Self {
         Self {
             agents: HashMap::default(),
         }
@@ -88,33 +84,18 @@ impl HistoryStore {
         let Some(mut history) = connection.list_threads(cx) else {
             return;
         };
-        let task = cx.spawn(async move |this, cx| {
-            while let Some(updated_history) = history.next().await {
-                dbg!(&updated_history);
-                this.update(cx, |this, cx| {
-                    for entry in updated_history {
-                        let agent = this
-                            .agents
-                            .get_mut(&entry.agent)
-                            .context("agent not found")?;
-                        agent.entries.insert(entry.id.clone(), entry);
-                    }
-                    cx.notify();
-                    anyhow::Ok(())
-                })??
-            }
-            Ok(())
-        });
-        self.agents.insert(
-            agent_name,
-            AgentHistory {
-                entries: Default::default(),
-                _task: task,
-            },
-        );
+        let history = AgentHistory {
+            entries: history.clone(),
+            _task: cx.spawn(async move |this, cx| {
+                while history.changed().await.is_ok() {
+                    this.update(cx, |_, cx| cx.notify()).ok();
+                }
+            }),
+        };
+        self.agents.insert(agent_name.clone(), history);
     }
 
-    pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
+    pub fn entries(&mut self, _cx: &mut Context<Self>) -> Vec<HistoryEntry> {
         let mut history_entries = Vec::new();
 
         #[cfg(debug_assertions)]
@@ -124,9 +105,8 @@ impl HistoryStore {
 
         history_entries.extend(
             self.agents
-                .values()
-                .flat_map(|agent| agent.entries.values())
-                .cloned()
+                .values_mut()
+                .flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?")
                 .map(HistoryEntry::Thread),
         );
         // todo!() include the text threads in here.
@@ -135,7 +115,7 @@ impl HistoryStore {
         history_entries
     }
 
-    pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
+    pub fn recent_entries(&mut self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
         self.entries(cx).into_iter().take(limit).collect()
     }
 }

crates/agent2/src/tests/mod.rs 🔗

@@ -938,7 +938,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
 
     // Cancel the current send and ensure that the event stream is closed, even
     // if one of the tools is still running.
-    thread.update(cx, |thread, _cx| thread.cancel());
+    thread.update(cx, |thread, cx| thread.cancel(cx));
     let events = events.collect::<Vec<_>>().await;
     let last_event = events.last();
     assert!(
@@ -1113,7 +1113,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
     });
 
     thread
-        .update(cx, |thread, _cx| thread.truncate(message_id))
+        .update(cx, |thread, cx| thread.truncate(message_id, cx))
         .unwrap();
     cx.run_until_parked();
     thread.read_with(cx, |thread, _| {

crates/agent2/src/thread.rs 🔗

@@ -802,16 +802,18 @@ impl Thread {
         &self.model
     }
 
-    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
+    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
         self.model = model;
+        cx.notify()
     }
 
     pub fn completion_mode(&self) -> CompletionMode {
         self.completion_mode
     }
 
-    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
+    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
         self.completion_mode = mode;
+        cx.notify()
     }
 
     #[cfg(any(test, feature = "test-support"))]
@@ -839,21 +841,22 @@ impl Thread {
         self.profile_id = profile_id;
     }
 
-    pub fn cancel(&mut self) {
+    pub fn cancel(&mut self, cx: &mut Context<Self>) {
         if let Some(running_turn) = self.running_turn.take() {
             running_turn.cancel();
         }
-        self.flush_pending_message();
+        self.flush_pending_message(cx);
     }
 
-    pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
-        self.cancel();
+    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
+        self.cancel(cx);
         let Some(position) = self.messages.iter().position(
             |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
         ) else {
             return Err(anyhow!("Message not found"));
         };
         self.messages.truncate(position);
+        cx.notify();
         Ok(())
     }
 
@@ -900,7 +903,7 @@ impl Thread {
     }
 
     fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
-        self.cancel();
+        self.cancel(cx);
 
         let model = self.model.clone();
         let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
@@ -938,8 +941,8 @@ impl Thread {
                                 LanguageModelCompletionEvent::Stop(reason) => {
                                     event_stream.send_stop(reason);
                                     if reason == StopReason::Refusal {
-                                        this.update(cx, |this, _cx| {
-                                            this.flush_pending_message();
+                                        this.update(cx, |this, cx| {
+                                            this.flush_pending_message(cx);
                                             this.messages.truncate(message_ix);
                                         })?;
                                         return Ok(());
@@ -991,7 +994,7 @@ impl Thread {
                             log::info!("No tool uses found, completing turn");
                             return Ok(());
                         } else {
-                            this.update(cx, |this, _| this.flush_pending_message())?;
+                            this.update(cx, |this, cx| this.flush_pending_message(cx))?;
                             completion_intent = CompletionIntent::ToolResults;
                         }
                     }
@@ -1005,8 +1008,8 @@ impl Thread {
                     log::info!("Turn execution completed successfully");
                 }
 
-                this.update(cx, |this, _| {
-                    this.flush_pending_message();
+                this.update(cx, |this, cx| {
+                    this.flush_pending_message(cx);
                     this.running_turn.take();
                 })
                 .ok();
@@ -1046,7 +1049,7 @@ impl Thread {
 
         match event {
             StartMessage { .. } => {
-                self.flush_pending_message();
+                self.flush_pending_message(cx);
                 self.pending_message = Some(AgentMessage::default());
             }
             Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
@@ -1255,7 +1258,7 @@ impl Thread {
         self.pending_message.get_or_insert_default()
     }
 
-    fn flush_pending_message(&mut self) {
+    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
         let Some(mut message) = self.pending_message.take() else {
             return;
         };
@@ -1280,6 +1283,7 @@ impl Thread {
         }
 
         self.messages.push(Message::Agent(message));
+        cx.notify()
     }
 
     pub(crate) fn build_completion_request(

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -2487,12 +2487,15 @@ impl AcpThreadView {
             return;
         };
 
-        thread.update(cx, |thread, _cx| {
+        thread.update(cx, |thread, cx| {
             let current_mode = thread.completion_mode();
-            thread.set_completion_mode(match current_mode {
-                CompletionMode::Burn => CompletionMode::Normal,
-                CompletionMode::Normal => CompletionMode::Burn,
-            });
+            thread.set_completion_mode(
+                match current_mode {
+                    CompletionMode::Burn => CompletionMode::Normal,
+                    CompletionMode::Normal => CompletionMode::Burn,
+                },
+                cx,
+            );
         });
     }
 
@@ -3274,8 +3277,8 @@ impl AcpThreadView {
                             .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
                             .on_click({
                                 cx.listener(move |this, _, _window, cx| {
-                                    thread.update(cx, |thread, _cx| {
-                                        thread.set_completion_mode(CompletionMode::Burn);
+                                    thread.update(cx, |thread, cx| {
+                                        thread.set_completion_mode(CompletionMode::Burn, cx);
                                     });
                                     this.resume_chat(cx);
                                 })