WIP

Bennet Bo Fenner and Ben Brandt created

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

crates/agent_ui/src/thread_import.rs         |   3 
crates/agent_ui/src/thread_metadata_store.rs | 177 ++++++++++++++-------
crates/sidebar/src/sidebar_tests.rs          |  14 -
3 files changed, 123 insertions(+), 71 deletions(-)

Detailed changes

crates/agent_ui/src/thread_import.rs 🔗

@@ -163,7 +163,7 @@ impl ThreadImportModal {
                 Ok(threads) => {
                     let imported_count = threads.len();
                     ThreadMetadataStore::global(cx)
-                        .update(cx, |store, cx| store.save_all(threads, cx));
+                        .update(cx, |store, cx| store.save_all_to_archive(threads, cx));
                     this.is_importing = false;
                     this.last_error = None;
                     this.show_imported_threads_toast(imported_count, cx);
@@ -468,7 +468,6 @@ fn collect_importable_threads(
                 updated_at: session.updated_at.unwrap_or_else(|| Utc::now()),
                 created_at: session.created_at,
                 folder_paths,
-                archived: true,
             });
         }
     }

crates/agent_ui/src/thread_metadata_store.rs 🔗

@@ -14,7 +14,7 @@ use db::{
 };
 use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt};
 use futures::{FutureExt as _, future::Shared};
-use gpui::{AppContext as _, Entity, Global, Subscription, Task};
+use gpui::{AppContext as _, Entity, Global, Subscription, Task, proptest::num::f64::NEGATIVE};
 use project::AgentId;
 use ui::{App, Context, SharedString};
 use util::ResultExt as _;
@@ -63,7 +63,6 @@ fn migrate_thread_metadata(cx: &mut App) {
                         updated_at: entry.updated_at,
                         created_at: entry.created_at,
                         folder_paths: entry.folder_paths,
-                        archived: true,
                     })
                 })
                 .collect::<Vec<_>>()
@@ -78,7 +77,7 @@ fn migrate_thread_metadata(cx: &mut App) {
         // Manually save each entry to the database and call reload, otherwise
         // we'll end up triggering lots of reloads after each save
         for entry in to_migrate {
-            db.save(entry).await?;
+            db.save(entry, true).await?;
         }
 
         log::info!("Finished migrating thread store entries");
@@ -102,12 +101,35 @@ pub struct ThreadMetadata {
     pub updated_at: DateTime<Utc>,
     pub created_at: Option<DateTime<Utc>>,
     pub folder_paths: PathList,
+}
+
+impl From<DbThreadMetadata> for ThreadMetadata {
+    fn from(row: DbThreadMetadata) -> Self {
+        ThreadMetadata {
+            session_id: row.session_id,
+            agent_id: row.agent_id,
+            title: row.title,
+            updated_at: row.updated_at,
+            created_at: row.created_at,
+            folder_paths: row.folder_paths,
+        }
+    }
+}
+
+#[derive(Debug, Clone, PartialEq)]
+struct DbThreadMetadata {
+    pub session_id: acp::SessionId,
+    pub agent_id: AgentId,
+    pub title: SharedString,
+    pub updated_at: DateTime<Utc>,
+    pub created_at: Option<DateTime<Utc>>,
+    pub folder_paths: PathList,
     pub archived: bool,
 }
 
+
 impl ThreadMetadata {
     pub fn from_thread(
-        is_archived: bool,
         thread: &Entity<acp_thread::AcpThread>,
         cx: &App,
     ) -> Self {
@@ -136,7 +158,6 @@ impl ThreadMetadata {
             created_at: Some(updated_at), // handled by db `ON CONFLICT`
             updated_at,
             folder_paths,
-            archived: is_archived,
         }
     }
 }
@@ -147,6 +168,7 @@ impl ThreadMetadata {
 pub struct ThreadMetadataStore {
     db: ThreadMetadataDb,
     threads: HashMap<acp::SessionId, ThreadMetadata>,
+    archived_threads: HashMap<acp::SessionId, ThreadMetadata>,
     threads_by_paths: HashMap<PathList, Vec<ThreadMetadata>>,
     reload_task: Option<Shared<Task<()>>>,
     session_subscriptions: HashMap<acp::SessionId, Subscription>,
@@ -156,14 +178,16 @@ pub struct ThreadMetadataStore {
 
 #[derive(Debug, PartialEq)]
 enum DbOperation {
-    Insert(ThreadMetadata),
+    UpdateArchived(acp::SessionId, bool),
+    Upsert(ThreadMetadata, Option<bool>),
     Delete(acp::SessionId),
 }
 
 impl DbOperation {
     fn id(&self) -> &acp::SessionId {
         match self {
-            DbOperation::Insert(thread) => &thread.session_id,
+            DbOperation::UpdateArchived(session_id, _) => session_id,
+            DbOperation::Upsert(thread, _) => &thread.session_id,
             DbOperation::Delete(session_id) => session_id,
         }
     }
@@ -206,17 +230,17 @@ impl ThreadMetadataStore {
 
     /// Returns all thread IDs.
     pub fn entry_ids(&self) -> impl Iterator<Item = acp::SessionId> + '_ {
-        self.threads.keys().cloned()
+        self.threads.keys().chain(self.archived_threads.keys()).cloned()
     }
 
     /// Returns all threads.
     pub fn entries(&self) -> impl Iterator<Item = ThreadMetadata> + '_ {
-        self.threads.values().cloned()
+        self.threads.values().cloned().chain(self.archived_threads.values().cloned())
     }
 
     /// Returns all archived threads.
     pub fn archived_entries(&self) -> impl Iterator<Item = ThreadMetadata> + '_ {
-        self.entries().filter(|t| t.archived)
+        self.archived_threads.values().cloned()
     }
 
     /// Returns all threads for the given path list, excluding archived threads.
@@ -228,7 +252,6 @@ impl ThreadMetadataStore {
             .get(path_list)
             .into_iter()
             .flatten()
-            .filter(|s| !s.archived)
             .cloned()
     }
 
@@ -250,11 +273,18 @@ impl ThreadMetadataStore {
                     this.threads_by_paths.clear();
 
                     for row in rows {
-                        this.threads_by_paths
-                            .entry(row.folder_paths.clone())
-                            .or_default()
-                            .push(row.clone());
-                        this.threads.insert(row.session_id.clone(), row);
+                        let is_archived = row.archived;
+                        let metadata = ThreadMetadata::from(row);
+
+                        if is_archived {
+                            this.archived_threads.entry(metadata.session_id.clone()).or_insert(metadata);
+                        } else {
+                            this.threads_by_paths
+                                .entry(metadata.folder_paths.clone())
+                                .or_default()
+                                .push(metadata.clone());
+                            this.threads.insert(metadata.session_id.clone(), metadata);
+                        }
                     }
 
                     cx.notify();
@@ -266,14 +296,14 @@ impl ThreadMetadataStore {
         reload_task
     }
 
-    pub fn save_all(&mut self, metadata: Vec<ThreadMetadata>, cx: &mut Context<Self>) {
+    pub fn save_all_to_archive(&mut self, metadata: Vec<ThreadMetadata>, cx: &mut Context<Self>) {
         if !cx.has_flag::<AgentV2FeatureFlag>() {
             return;
         }
 
         for metadata in metadata {
             self.pending_thread_ops_tx
-                .try_send(DbOperation::Insert(metadata))
+                .try_send(DbOperation::Upsert(metadata, Some(true)))
                 .log_err();
         }
     }
@@ -284,7 +314,7 @@ impl ThreadMetadataStore {
         }
 
         self.pending_thread_ops_tx
-            .try_send(DbOperation::Insert(metadata))
+            .try_send(DbOperation::Upsert(metadata, None))
             .log_err();
     }
 
@@ -306,16 +336,9 @@ impl ThreadMetadataStore {
             return;
         }
 
-        if let Some(thread) = self.threads.get(session_id) {
-            self.save(
-                ThreadMetadata {
-                    archived,
-                    ..thread.clone()
-                },
-                cx,
-            );
-            cx.notify();
-        }
+        self.pending_thread_ops_tx
+            .try_send(DbOperation::UpdateArchived(session_id.clone(), archived))
+            .log_err();
     }
 
     pub fn delete(&mut self, session_id: acp::SessionId, cx: &mut Context<Self>) {
@@ -379,12 +402,15 @@ impl ThreadMetadataStore {
                     let updates = Self::dedup_db_operations(updates);
                     for operation in updates {
                         match operation {
-                            DbOperation::Insert(metadata) => {
-                                db.save(metadata).await.log_err();
+                            DbOperation::Upsert(metadata, archived) => {
+                                db.save(metadata, archived).await.log_err();
                             }
                             DbOperation::Delete(session_id) => {
                                 db.delete(session_id).await.log_err();
                             }
+                            DbOperation::UpdateArchived(session_id, archived) => {
+                                db.update_archived(&session_id, archived).await.log_err();
+                            },
                         }
                     }
 
@@ -397,6 +423,7 @@ impl ThreadMetadataStore {
             db,
             threads: HashMap::default(),
             threads_by_paths: HashMap::default(),
+            archived_threads: HashMap::default(),
             reload_task: None,
             session_subscriptions: HashMap::default(),
             pending_thread_ops_tx: tx,
@@ -409,7 +436,28 @@ impl ThreadMetadataStore {
     fn dedup_db_operations(operations: Vec<DbOperation>) -> Vec<DbOperation> {
         let mut ops = HashMap::default();
         for operation in operations.into_iter().rev() {
-            if ops.contains_key(operation.id()) {
+            if let Some(existing_operation) = ops.get_mut(operation.id()) {
+                match (existing_operation, operation) {
+                    (DbOperation::Delete(_), _) => {
+                        continue;
+                    }
+                    (DbOperation::UpdateArchived(_, _), DbOperation::UpdateArchived(_, _)) => {
+                        continue;
+                    },
+                    (DbOperation::Upsert(_, left_archive), DbOperation::UpdateArchived(_, right_archive)) if left_archive.is_none() => {
+                        *left_archive = Some(right_archive);
+                    }
+                    (DbOperation::UpdateArchived(_, left_archive), DbOperation::Upsert(thread_metadata, right_archive)) if right_archive.is_none() => {
+                        let archive = *left_archive;
+                        *existing_operation = DbOperation::Upsert(thread_metadata, Some(archive));
+                    }
+                    _ => todo!()
+                    // (DbOperation::UpdateArchived(session_id, _), DbOperation::Upsert(thread_metadata, _)) => todo!(),
+                    // (DbOperation::UpdateArchived(session_id, _), DbOperation::Delete(session_id)) => todo!(),
+                    // (DbOperation::Upsert(thread_metadata, _), DbOperation::UpdateArchived(session_id, _)) => todo!(),
+                    // (DbOperation::Upsert(thread_metadata, _), DbOperation::Upsert(thread_metadata, _)) => todo!(),
+                    // (DbOperation::Upsert(thread_metadata, _), DbOperation::Delete(session_id)) => todo!(),
+                };
                 continue;
             }
             ops.insert(operation.id().clone(), operation);
@@ -440,12 +488,7 @@ impl ThreadMetadataStore {
             | acp_thread::AcpThreadEvent::Error
             | acp_thread::AcpThreadEvent::LoadError(_)
             | acp_thread::AcpThreadEvent::Refusal => {
-                let is_archived = self
-                    .threads
-                    .get(thread.read(cx).session_id())
-                    .map(|t| t.archived)
-                    .unwrap_or(false);
-                let metadata = ThreadMetadata::from_thread(is_archived, &thread, cx);
+                let metadata = ThreadMetadata::from_thread(&thread, cx);
                 self.save(metadata, cx);
             }
             _ => {}
@@ -487,8 +530,8 @@ impl ThreadMetadataDb {
     }
 
     /// List all sidebar thread metadata, ordered by updated_at descending.
-    pub fn list(&self) -> anyhow::Result<Vec<ThreadMetadata>> {
-        self.select::<ThreadMetadata>(
+    pub fn list(&self) -> anyhow::Result<Vec<DbThreadMetadata>> {
+        self.select::<DbThreadMetadata>(
             "SELECT session_id, agent_id, title, updated_at, created_at, folder_paths, folder_paths_order, archived \
              FROM sidebar_threads \
              ORDER BY updated_at DESC"
@@ -496,7 +539,7 @@ impl ThreadMetadataDb {
     }
 
     /// Upsert metadata for a thread.
-    pub async fn save(&self, row: ThreadMetadata) -> anyhow::Result<()> {
+    pub async fn save(&self, row: ThreadMetadata, archived: Option<bool>) -> anyhow::Result<()> {
         let id = row.session_id.0.clone();
         let agent_id = if row.agent_id.as_ref() == ZED_AGENT_ID.as_ref() {
             None
@@ -512,18 +555,19 @@ impl ThreadMetadataDb {
         } else {
             (Some(serialized.paths), Some(serialized.order))
         };
-        let archived = row.archived;
 
         self.write(move |conn| {
-            let sql = "INSERT INTO sidebar_threads(session_id, agent_id, title, updated_at, created_at, folder_paths, folder_paths_order, archived) \
+            let mut sql = "INSERT INTO sidebar_threads(session_id, agent_id, title, updated_at, created_at, folder_paths, folder_paths_order, archived) \
                        VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) \
                        ON CONFLICT(session_id) DO UPDATE SET \
                            agent_id = excluded.agent_id, \
                            title = excluded.title, \
                            updated_at = excluded.updated_at, \
                            folder_paths = excluded.folder_paths, \
-                           folder_paths_order = excluded.folder_paths_order, \
-                           archived = excluded.archived";
+                           folder_paths_order = excluded.folder_paths_order".to_string();
+            if archived.is_some() {
+                sql.push_str(", archived = excluded.archived");
+            }
             let mut stmt = Statement::prepare(conn, sql)?;
             let mut i = stmt.bind(&id, 1)?;
             i = stmt.bind(&agent_id, i)?;
@@ -532,7 +576,19 @@ impl ThreadMetadataDb {
             i = stmt.bind(&created_at, i)?;
             i = stmt.bind(&folder_paths, i)?;
             i = stmt.bind(&folder_paths_order, i)?;
-            stmt.bind(&archived, i)?;
+            stmt.bind(&archived.unwrap_or(false), i)?;
+            stmt.exec()
+        })
+        .await
+    }
+
+    pub async fn update_archived(&self, session_id: &acp::SessionId, archived: bool) -> anyhow::Result<()> {
+        let id = session_id.0.clone();
+        self.write(move |conn| {
+            let mut stmt =
+                Statement::prepare(conn, "UPDATE sidebar_threads SET archived = ? WHERE session_id = ?")?;
+            stmt.bind(&archived, 1)?;
+            stmt.bind(&id, 2)?;
             stmt.exec()
         })
         .await
@@ -551,7 +607,7 @@ impl ThreadMetadataDb {
     }
 }
 
-impl Column for ThreadMetadata {
+impl Column for DbThreadMetadata {
     fn column(statement: &mut Statement, start_index: i32) -> anyhow::Result<(Self, i32)> {
         let (id, next): (Arc<str>, i32) = Column::column(statement, start_index)?;
         let (agent_id, next): (Option<String>, i32) = Column::column(statement, next)?;
@@ -584,7 +640,7 @@ impl Column for ThreadMetadata {
             .unwrap_or_default();
 
         Ok((
-            ThreadMetadata {
+            DbThreadMetadata {
                 session_id: acp::SessionId::new(id),
                 agent_id,
                 title: title.into(),
@@ -669,6 +725,7 @@ mod tests {
             "First Thread",
             now,
             first_paths.clone(),
+            false,
         ))
         .await
         .unwrap();
@@ -677,6 +734,7 @@ mod tests {
             "Second Thread",
             older,
             second_paths.clone(),
+            false,
         ))
         .await
         .unwrap();
@@ -1210,7 +1268,7 @@ mod tests {
         let now = Utc::now();
 
         let operations = vec![
-            DbOperation::Insert(make_metadata(
+            DbOperation::Upsert(make_metadata(
                 "session-1",
                 "First Thread",
                 now,
@@ -1237,12 +1295,12 @@ mod tests {
         let new_metadata = make_metadata("session-1", "New Title", later, PathList::default());
 
         let deduped = ThreadMetadataStore::dedup_db_operations(vec![
-            DbOperation::Insert(old_metadata),
-            DbOperation::Insert(new_metadata.clone()),
+            DbOperation::Upsert(old_metadata),
+            DbOperation::Upsert(new_metadata.clone()),
         ]);
 
         assert_eq!(deduped.len(), 1);
-        assert_eq!(deduped[0], DbOperation::Insert(new_metadata));
+        assert_eq!(deduped[0], DbOperation::Upsert(new_metadata));
     }
 
     #[test]
@@ -1252,13 +1310,13 @@ mod tests {
         let metadata1 = make_metadata("session-1", "First Thread", now, PathList::default());
         let metadata2 = make_metadata("session-2", "Second Thread", now, PathList::default());
         let deduped = ThreadMetadataStore::dedup_db_operations(vec![
-            DbOperation::Insert(metadata1.clone()),
-            DbOperation::Insert(metadata2.clone()),
+            DbOperation::Upsert(metadata1.clone()),
+            DbOperation::Upsert(metadata2.clone()),
         ]);
 
         assert_eq!(deduped.len(), 2);
-        assert!(deduped.contains(&DbOperation::Insert(metadata1)));
-        assert!(deduped.contains(&DbOperation::Insert(metadata2)));
+        assert!(deduped.contains(&DbOperation::Upsert(metadata1)));
+        assert!(deduped.contains(&DbOperation::Upsert(metadata2)));
     }
 
     #[gpui::test]
@@ -1514,11 +1572,10 @@ mod tests {
             let store = ThreadMetadataStore::global(cx);
             let store = store.read(cx);
 
-            let thread = store
+            assert!(store
                 .entries()
-                .find(|e| e.session_id.0.as_ref() == "session-1")
-                .expect("thread should exist after reload");
-            assert!(thread.archived);
+                .find(|e| e.session_id.0.as_ref() == "session-1").is_some(),
+                "thread should exist after reload");
 
             let path_entries = store
                 .entries_for_path(&paths)

crates/sidebar/src/sidebar_tests.rs 🔗

@@ -121,7 +121,6 @@ async fn save_thread_metadata(
         updated_at,
         created_at: None,
         folder_paths: path_list,
-        archived: false,
     };
     cx.update(|cx| {
         ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save(metadata, cx))
@@ -4194,7 +4193,6 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) {
                         chrono::TimeZone::with_ymd_and_hms(&Utc, 2024, 1, 1, 0, 0, 0).unwrap(),
                     ),
                     folder_paths: path_list.clone(),
-                    archived: false,
                 },
                 cx,
             )
@@ -4222,7 +4220,6 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) {
                         chrono::TimeZone::with_ymd_and_hms(&Utc, 2024, 1, 2, 0, 0, 0).unwrap(),
                     ),
                     folder_paths: path_list.clone(),
-                    archived: false,
                 },
                 cx,
             )
@@ -4250,7 +4247,6 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) {
                         chrono::TimeZone::with_ymd_and_hms(&Utc, 2024, 1, 3, 0, 0, 0).unwrap(),
                     ),
                     folder_paths: path_list.clone(),
-                    archived: false,
                 },
                 cx,
             )
@@ -4362,7 +4358,6 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) {
                         chrono::TimeZone::with_ymd_and_hms(&Utc, 2024, 6, 1, 0, 0, 0).unwrap(),
                     ),
                     folder_paths: path_list.clone(),
-                    archived: false,
                 },
                 cx,
             )
@@ -4414,7 +4409,6 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) {
                         chrono::TimeZone::with_ymd_and_hms(&Utc, 2023, 6, 1, 0, 0, 0).unwrap(),
                     ),
                     folder_paths: path_list.clone(),
-                    archived: false,
                 },
                 cx,
             )
@@ -4526,9 +4520,12 @@ async fn test_archived_threads_excluded_from_sidebar_entries(cx: &mut TestAppCon
             updated_at: chrono::TimeZone::with_ymd_and_hms(&Utc, 2024, 1, 1, 0, 0, 0).unwrap(),
             created_at: None,
             folder_paths: path_list.clone(),
-            archived: true,
         };
-        ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save(metadata, cx));
+        ThreadMetadataStore::global(cx).update(cx, |store, cx| {
+            let session_id = metadata.session_id.clone();
+            store.save(metadata, cx);
+            store.archive(&session_id, cx);
+        });
     });
     cx.run_until_parked();
 
@@ -4708,7 +4705,6 @@ mod property_test {
             updated_at,
             created_at: None,
             folder_paths: path_list,
-            archived: false,
         };
         cx.update(|_, cx| {
             ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save(metadata, cx));