diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index e110f9c0514e2a030b632872d1df4e3a66973c97..85b943da4bb65b038100b2b842d81bc34662325d 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -51,6 +51,7 @@ use std::path::{Path, PathBuf}; use std::rc::Rc; use std::sync::Arc; use util::ResultExt; +use util::path_list::PathList; use util::rel_path::RelPath; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -848,13 +849,26 @@ impl NativeAgent { let Some(session) = self.sessions.get_mut(&id) else { return; }; + + let folder_paths = PathList::new( + &self + .project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| worktree.read(cx).abs_path().to_path_buf()) + .collect::>(), + ); + let thread_store = self.thread_store.clone(); session.pending_save = cx.spawn(async move |_, cx| { let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { return; }; let db_thread = db_thread.await; - database.save_thread(id, db_thread).await.log_err(); + database + .save_thread(id, db_thread, folder_paths) + .await + .log_err(); thread_store.update(cx, |store, cx| store.reload(cx)); }); } diff --git a/crates/agent/src/db.rs b/crates/agent/src/db.rs index 7dba2f078adac47b951dcec9dd30883fdea618ad..5a14e920e52c18fb6341e09fa9f747b3c5019f1d 100644 --- a/crates/agent/src/db.rs +++ b/crates/agent/src/db.rs @@ -18,6 +18,7 @@ use sqlez::{ }; use std::sync::Arc; use ui::{App, SharedString}; +use util::path_list::PathList; use zed_env_vars::ZED_STATELESS; pub type DbMessage = crate::Message; @@ -31,6 +32,9 @@ pub struct DbThreadMetadata { #[serde(alias = "summary")] pub title: SharedString, pub updated_at: DateTime, + /// The workspace folder paths this thread was created against, sorted + /// lexicographically. Used for grouping threads by project in the sidebar. + pub folder_paths: PathList, } #[derive(Debug, Serialize, Deserialize)] @@ -382,6 +386,14 @@ impl ThreadsDatabase { s().ok(); } + if let Ok(mut s) = connection.exec(indoc! {" + ALTER TABLE threads ADD COLUMN folder_paths TEXT; + ALTER TABLE threads ADD COLUMN folder_paths_order TEXT; + "}) + { + s().ok(); + } + let db = Self { executor, connection: Arc::new(Mutex::new(connection)), @@ -394,6 +406,7 @@ impl ThreadsDatabase { connection: &Arc>, id: acp::SessionId, thread: DbThread, + folder_paths: &PathList, ) -> Result<()> { const COMPRESSION_LEVEL: i32 = 3; @@ -410,6 +423,16 @@ impl ThreadsDatabase { .subagent_context .as_ref() .map(|ctx| ctx.parent_thread_id.0.clone()); + let serialized_folder_paths = folder_paths.serialize(); + let (folder_paths_str, folder_paths_order_str): (Option, Option) = + if folder_paths.is_empty() { + (None, None) + } else { + ( + Some(serialized_folder_paths.paths), + Some(serialized_folder_paths.order), + ) + }; let json_data = serde_json::to_string(&SerializedThread { thread, version: DbThread::VERSION, @@ -421,11 +444,20 @@ impl ThreadsDatabase { let data_type = DataType::Zstd; let data = compressed; - let mut insert = connection.exec_bound::<(Arc, Option>, String, String, DataType, Vec)>(indoc! {" - INSERT OR REPLACE INTO threads (id, parent_id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?) + let mut insert = connection.exec_bound::<(Arc, Option>, Option, Option, String, String, DataType, Vec)>(indoc! {" + INSERT OR REPLACE INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?, ?) "})?; - insert((id.0, parent_id, title, updated_at, data_type, data))?; + insert(( + id.0, + parent_id, + folder_paths_str, + folder_paths_order_str, + title, + updated_at, + data_type, + data, + ))?; Ok(()) } @@ -437,19 +469,28 @@ impl ThreadsDatabase { let connection = connection.lock(); let mut select = connection - .select_bound::<(), (Arc, Option>, String, String)>(indoc! {" - SELECT id, parent_id, summary, updated_at FROM threads ORDER BY updated_at DESC + .select_bound::<(), (Arc, Option>, Option, Option, String, String)>(indoc! {" + SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at FROM threads ORDER BY updated_at DESC "})?; let rows = select(())?; let mut threads = Vec::new(); - for (id, parent_id, summary, updated_at) in rows { + for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at) in rows { + let folder_paths = folder_paths + .map(|paths| { + PathList::deserialize(&util::path_list::SerializedPathList { + paths, + order: folder_paths_order.unwrap_or_default(), + }) + }) + .unwrap_or_default(); threads.push(DbThreadMetadata { id: acp::SessionId::new(id), parent_session_id: parent_id.map(acp::SessionId::new), title: summary.into(), updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), + folder_paths, }); } @@ -483,11 +524,16 @@ impl ThreadsDatabase { }) } - pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task> { + pub fn save_thread( + &self, + id: acp::SessionId, + thread: DbThread, + folder_paths: PathList, + ) -> Task> { let connection = self.connection.clone(); self.executor - .spawn(async move { Self::save_thread_sync(&connection, id, thread) }) + .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) }) } pub fn delete_thread(&self, id: acp::SessionId) -> Task> { @@ -606,11 +652,11 @@ mod tests { ); database - .save_thread(older_id.clone(), older_thread) + .save_thread(older_id.clone(), older_thread, PathList::default()) .await .unwrap(); database - .save_thread(newer_id.clone(), newer_thread) + .save_thread(newer_id.clone(), newer_thread, PathList::default()) .await .unwrap(); @@ -635,11 +681,11 @@ mod tests { ); database - .save_thread(thread_id.clone(), original_thread) + .save_thread(thread_id.clone(), original_thread, PathList::default()) .await .unwrap(); database - .save_thread(thread_id.clone(), updated_thread) + .save_thread(thread_id.clone(), updated_thread, PathList::default()) .await .unwrap(); @@ -686,7 +732,7 @@ mod tests { }); database - .save_thread(child_id.clone(), child_thread) + .save_thread(child_id.clone(), child_thread, PathList::default()) .await .unwrap(); @@ -714,7 +760,7 @@ mod tests { ); database - .save_thread(thread_id.clone(), thread) + .save_thread(thread_id.clone(), thread, PathList::default()) .await .unwrap(); @@ -729,4 +775,49 @@ mod tests { "Regular threads should have no subagent_context" ); } + + #[gpui::test] + async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) { + let database = ThreadsDatabase::new(cx.executor()).unwrap(); + + let thread_id = session_id("folder-thread"); + let thread = make_thread( + "Folder Thread", + Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(), + ); + + let folder_paths = PathList::new(&[ + std::path::PathBuf::from("/home/user/project-a"), + std::path::PathBuf::from("/home/user/project-b"), + ]); + + database + .save_thread(thread_id.clone(), thread, folder_paths.clone()) + .await + .unwrap(); + + let threads = database.list_threads().await.unwrap(); + assert_eq!(threads.len(), 1); + assert_eq!(threads[0].folder_paths, folder_paths); + } + + #[gpui::test] + async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) { + let database = ThreadsDatabase::new(cx.executor()).unwrap(); + + let thread_id = session_id("no-folder-thread"); + let thread = make_thread( + "No Folder Thread", + Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(), + ); + + database + .save_thread(thread_id.clone(), thread, PathList::default()) + .await + .unwrap(); + + let threads = database.list_threads().await.unwrap(); + assert_eq!(threads.len(), 1); + assert!(threads[0].folder_paths.is_empty()); + } } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 3769355bc8d3495f614ccd6787bb3a33d58e8f2f..5cdce12125da8f7d26677388169e899f94b7e7f1 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -2,6 +2,7 @@ use crate::{DbThread, DbThreadMetadata, ThreadsDatabase}; use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use gpui::{App, Context, Entity, Global, Task, prelude::*}; +use util::path_list::PathList; struct GlobalThreadStore(Entity); @@ -49,12 +50,13 @@ impl ThreadStore { &mut self, id: acp::SessionId, thread: crate::DbThread, + folder_paths: PathList, cx: &mut Context, ) -> Task> { let database_future = ThreadsDatabase::connect(cx); cx.spawn(async move |this, cx| { let database = database_future.await.map_err(|err| anyhow!(err))?; - database.save_thread(id, thread).await?; + database.save_thread(id, thread, folder_paths).await?; this.update(cx, |this, cx| this.reload(cx)) }) } @@ -106,6 +108,13 @@ impl ThreadStore { pub fn entries(&self) -> impl Iterator + '_ { self.threads.iter().cloned() } + + /// Returns threads whose folder_paths match the given paths exactly. + pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator { + self.threads + .iter() + .filter(move |thread| &thread.folder_paths == paths) + } } #[cfg(test)] @@ -157,12 +166,12 @@ mod tests { ); let save_older = thread_store.update(cx, |store, cx| { - store.save_thread(older_id.clone(), older_thread, cx) + store.save_thread(older_id.clone(), older_thread, PathList::default(), cx) }); save_older.await.unwrap(); let save_newer = thread_store.update(cx, |store, cx| { - store.save_thread(newer_id.clone(), newer_thread, cx) + store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx) }); save_newer.await.unwrap(); @@ -185,8 +194,9 @@ mod tests { Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), ); - let save_task = - thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx)); + let save_task = thread_store.update(cx, |store, cx| { + store.save_thread(thread_id, thread, PathList::default(), cx) + }); save_task.await.unwrap(); cx.run_until_parked(); @@ -217,11 +227,11 @@ mod tests { ); let save_first = thread_store.update(cx, |store, cx| { - store.save_thread(first_id.clone(), first_thread, cx) + store.save_thread(first_id.clone(), first_thread, PathList::default(), cx) }); save_first.await.unwrap(); let save_second = thread_store.update(cx, |store, cx| { - store.save_thread(second_id.clone(), second_thread, cx) + store.save_thread(second_id.clone(), second_thread, PathList::default(), cx) }); save_second.await.unwrap(); cx.run_until_parked(); @@ -254,11 +264,11 @@ mod tests { ); let save_first = thread_store.update(cx, |store, cx| { - store.save_thread(first_id.clone(), first_thread, cx) + store.save_thread(first_id.clone(), first_thread, PathList::default(), cx) }); save_first.await.unwrap(); let save_second = thread_store.update(cx, |store, cx| { - store.save_thread(second_id.clone(), second_thread, cx) + store.save_thread(second_id.clone(), second_thread, PathList::default(), cx) }); save_second.await.unwrap(); cx.run_until_parked(); @@ -268,7 +278,7 @@ mod tests { Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(), ); let update_task = thread_store.update(cx, |store, cx| { - store.save_thread(first_id.clone(), updated_first, cx) + store.save_thread(first_id.clone(), updated_first, PathList::default(), cx) }); update_task.await.unwrap(); cx.run_until_parked(); @@ -278,4 +288,50 @@ mod tests { assert_eq!(entries[0].id, first_id); assert_eq!(entries[1].id, second_id); } + + #[gpui::test] + async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) { + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + cx.run_until_parked(); + + let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]); + let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]); + + let thread_a = make_thread( + "Thread in A", + Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), + ); + let thread_b = make_thread( + "Thread in B", + Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(), + ); + let thread_a_id = session_id("thread-a"); + let thread_b_id = session_id("thread-b"); + + let save_a = thread_store.update(cx, |store, cx| { + store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx) + }); + save_a.await.unwrap(); + + let save_b = thread_store.update(cx, |store, cx| { + store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx) + }); + save_b.await.unwrap(); + + cx.run_until_parked(); + + thread_store.read_with(cx, |store, _cx| { + let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect(); + assert_eq!(a_threads.len(), 1); + assert_eq!(a_threads[0].id, thread_a_id); + + let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect(); + assert_eq!(b_threads.len(), 1); + assert_eq!(b_threads[0].id, thread_b_id); + + let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]); + let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect(); + assert!(no_threads.is_empty()); + }); + } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 5269e3f1b8d03d16577e4aaeea0c258140853cb5..7097e5be156eb33382a1a0f47c1b4256c84ce9b1 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1461,7 +1461,7 @@ impl AgentPanel { cx.spawn_in(window, async move |this, cx| { thread_store .update(&mut cx.clone(), |store, cx| { - store.save_thread(session_id.clone(), db_thread, cx) + store.save_thread(session_id.clone(), db_thread, Default::default(), cx) }) .await?; diff --git a/crates/agent_ui/src/connection_view/thread_view.rs b/crates/agent_ui/src/connection_view/thread_view.rs index 499b11e5c08bd9b2c811e4cf5119bf7f71663c4b..9578a0752b45ea48477f4fab7935f670f84c25d5 100644 --- a/crates/agent_ui/src/connection_view/thread_view.rs +++ b/crates/agent_ui/src/connection_view/thread_view.rs @@ -1536,7 +1536,7 @@ impl ThreadView { thread_store .update(&mut cx.clone(), |store, cx| { - store.save_thread(session_id.clone(), db_thread, cx) + store.save_thread(session_id.clone(), db_thread, Default::default(), cx) }) .await?; diff --git a/crates/workspace/src/path_list.rs b/crates/util/src/path_list.rs similarity index 92% rename from crates/workspace/src/path_list.rs rename to crates/util/src/path_list.rs index 035f9e44fcce46527faa0c1053b7a6bb09aae0c8..1f923769780de2ae7f1dc18d3334020960ff3bb6 100644 --- a/crates/workspace/src/path_list.rs +++ b/crates/util/src/path_list.rs @@ -3,8 +3,9 @@ use std::{ sync::Arc, }; +use crate::paths::SanitizedPath; use itertools::Itertools; -use util::paths::SanitizedPath; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// A list of absolute paths, in a specific order. /// @@ -118,6 +119,19 @@ impl PathList { } } +impl Serialize for PathList { + fn serialize(&self, serializer: S) -> Result { + self.paths.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for PathList { + fn deserialize>(deserializer: D) -> Result { + let paths: Vec = Vec::deserialize(deserializer)?; + Ok(PathList::new(&paths)) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index 86d26aee884da5f708fec14b5a3c09dccfa7f5f3..4f129ef6d529aff0991b86882e5e60b6ad837d5c 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -2,6 +2,7 @@ pub mod archive; pub mod command; pub mod fs; pub mod markdown; +pub mod path_list; pub mod paths; pub mod process; pub mod redact; diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index bcc6f2ccc26c967537e5c9069ae3c8da7e0a1402..cde04d987a015982006d283c17ee82ed9b7a7cb2 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -7,7 +7,9 @@ mod multi_workspace; pub mod notifications; pub mod pane; pub mod pane_group; -mod path_list; +pub mod path_list { + pub use util::path_list::{PathList, SerializedPathList}; +} mod persistence; pub mod searchable; mod security_modal; @@ -28,7 +30,7 @@ pub use multi_workspace::{ NextWorkspaceInWindow, PreviousWorkspaceInWindow, Sidebar, SidebarEvent, SidebarHandle, ToggleWorkspaceSidebar, }; -pub use path_list::PathList; +pub use path_list::{PathList, SerializedPathList}; pub use toast_layer::{ToastAction, ToastLayer, ToastView}; use anyhow::{Context as _, Result, anyhow}; diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 95ff6f03b1b7902e254c5e405c5d8b50e1f48773..f429c32df79b6a1a62a82832e69d412800544e8a 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -951,7 +951,12 @@ fn handle_open_request(request: OpenRequest, app_state: Arc, cx: &mut thread_store .update(&mut cx.clone(), |store, cx| { - store.save_thread(save_session_id.clone(), db_thread, cx) + store.save_thread( + save_session_id.clone(), + db_thread, + Default::default(), + cx, + ) }) .await?;