Detailed changes
@@ -51,6 +51,7 @@ use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::Arc;
use util::ResultExt;
+use util::path_list::PathList;
use util::rel_path::RelPath;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@@ -848,13 +849,26 @@ impl NativeAgent {
let Some(session) = self.sessions.get_mut(&id) else {
return;
};
+
+ let folder_paths = PathList::new(
+ &self
+ .project
+ .read(cx)
+ .visible_worktrees(cx)
+ .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
+ .collect::<Vec<_>>(),
+ );
+
let thread_store = self.thread_store.clone();
session.pending_save = cx.spawn(async move |_, cx| {
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
return;
};
let db_thread = db_thread.await;
- database.save_thread(id, db_thread).await.log_err();
+ database
+ .save_thread(id, db_thread, folder_paths)
+ .await
+ .log_err();
thread_store.update(cx, |store, cx| store.reload(cx));
});
}
@@ -18,6 +18,7 @@ use sqlez::{
};
use std::sync::Arc;
use ui::{App, SharedString};
+use util::path_list::PathList;
use zed_env_vars::ZED_STATELESS;
pub type DbMessage = crate::Message;
@@ -31,6 +32,9 @@ pub struct DbThreadMetadata {
#[serde(alias = "summary")]
pub title: SharedString,
pub updated_at: DateTime<Utc>,
+ /// The workspace folder paths this thread was created against, sorted
+ /// lexicographically. Used for grouping threads by project in the sidebar.
+ pub folder_paths: PathList,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -382,6 +386,14 @@ impl ThreadsDatabase {
s().ok();
}
+ if let Ok(mut s) = connection.exec(indoc! {"
+ ALTER TABLE threads ADD COLUMN folder_paths TEXT;
+ ALTER TABLE threads ADD COLUMN folder_paths_order TEXT;
+ "})
+ {
+ s().ok();
+ }
+
let db = Self {
executor,
connection: Arc::new(Mutex::new(connection)),
@@ -394,6 +406,7 @@ impl ThreadsDatabase {
connection: &Arc<Mutex<Connection>>,
id: acp::SessionId,
thread: DbThread,
+ folder_paths: &PathList,
) -> Result<()> {
const COMPRESSION_LEVEL: i32 = 3;
@@ -410,6 +423,16 @@ impl ThreadsDatabase {
.subagent_context
.as_ref()
.map(|ctx| ctx.parent_thread_id.0.clone());
+ let serialized_folder_paths = folder_paths.serialize();
+ let (folder_paths_str, folder_paths_order_str): (Option<String>, Option<String>) =
+ if folder_paths.is_empty() {
+ (None, None)
+ } else {
+ (
+ Some(serialized_folder_paths.paths),
+ Some(serialized_folder_paths.order),
+ )
+ };
let json_data = serde_json::to_string(&SerializedThread {
thread,
version: DbThread::VERSION,
@@ -421,11 +444,20 @@ impl ThreadsDatabase {
let data_type = DataType::Zstd;
let data = compressed;
- let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, String, String, DataType, Vec<u8>)>(indoc! {"
- INSERT OR REPLACE INTO threads (id, parent_id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?)
+ let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>)>(indoc! {"
+ INSERT OR REPLACE INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"})?;
- insert((id.0, parent_id, title, updated_at, data_type, data))?;
+ insert((
+ id.0,
+ parent_id,
+ folder_paths_str,
+ folder_paths_order_str,
+ title,
+ updated_at,
+ data_type,
+ data,
+ ))?;
Ok(())
}
@@ -437,19 +469,28 @@ impl ThreadsDatabase {
let connection = connection.lock();
let mut select = connection
- .select_bound::<(), (Arc<str>, Option<Arc<str>>, String, String)>(indoc! {"
- SELECT id, parent_id, summary, updated_at FROM threads ORDER BY updated_at DESC
+ .select_bound::<(), (Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String)>(indoc! {"
+ SELECT id, parent_id, folder_paths, folder_paths_order, summary, updated_at FROM threads ORDER BY updated_at DESC
"})?;
let rows = select(())?;
let mut threads = Vec::new();
- for (id, parent_id, summary, updated_at) in rows {
+ for (id, parent_id, folder_paths, folder_paths_order, summary, updated_at) in rows {
+ let folder_paths = folder_paths
+ .map(|paths| {
+ PathList::deserialize(&util::path_list::SerializedPathList {
+ paths,
+ order: folder_paths_order.unwrap_or_default(),
+ })
+ })
+ .unwrap_or_default();
threads.push(DbThreadMetadata {
id: acp::SessionId::new(id),
parent_session_id: parent_id.map(acp::SessionId::new),
title: summary.into(),
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
+ folder_paths,
});
}
@@ -483,11 +524,16 @@ impl ThreadsDatabase {
})
}
- pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
+ pub fn save_thread(
+ &self,
+ id: acp::SessionId,
+ thread: DbThread,
+ folder_paths: PathList,
+ ) -> Task<Result<()>> {
let connection = self.connection.clone();
self.executor
- .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
+ .spawn(async move { Self::save_thread_sync(&connection, id, thread, &folder_paths) })
}
pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
@@ -606,11 +652,11 @@ mod tests {
);
database
- .save_thread(older_id.clone(), older_thread)
+ .save_thread(older_id.clone(), older_thread, PathList::default())
.await
.unwrap();
database
- .save_thread(newer_id.clone(), newer_thread)
+ .save_thread(newer_id.clone(), newer_thread, PathList::default())
.await
.unwrap();
@@ -635,11 +681,11 @@ mod tests {
);
database
- .save_thread(thread_id.clone(), original_thread)
+ .save_thread(thread_id.clone(), original_thread, PathList::default())
.await
.unwrap();
database
- .save_thread(thread_id.clone(), updated_thread)
+ .save_thread(thread_id.clone(), updated_thread, PathList::default())
.await
.unwrap();
@@ -686,7 +732,7 @@ mod tests {
});
database
- .save_thread(child_id.clone(), child_thread)
+ .save_thread(child_id.clone(), child_thread, PathList::default())
.await
.unwrap();
@@ -714,7 +760,7 @@ mod tests {
);
database
- .save_thread(thread_id.clone(), thread)
+ .save_thread(thread_id.clone(), thread, PathList::default())
.await
.unwrap();
@@ -729,4 +775,49 @@ mod tests {
"Regular threads should have no subagent_context"
);
}
+
+ #[gpui::test]
+ async fn test_folder_paths_roundtrip(cx: &mut TestAppContext) {
+ let database = ThreadsDatabase::new(cx.executor()).unwrap();
+
+ let thread_id = session_id("folder-thread");
+ let thread = make_thread(
+ "Folder Thread",
+ Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
+ );
+
+ let folder_paths = PathList::new(&[
+ std::path::PathBuf::from("/home/user/project-a"),
+ std::path::PathBuf::from("/home/user/project-b"),
+ ]);
+
+ database
+ .save_thread(thread_id.clone(), thread, folder_paths.clone())
+ .await
+ .unwrap();
+
+ let threads = database.list_threads().await.unwrap();
+ assert_eq!(threads.len(), 1);
+ assert_eq!(threads[0].folder_paths, folder_paths);
+ }
+
+ #[gpui::test]
+ async fn test_folder_paths_empty_when_not_set(cx: &mut TestAppContext) {
+ let database = ThreadsDatabase::new(cx.executor()).unwrap();
+
+ let thread_id = session_id("no-folder-thread");
+ let thread = make_thread(
+ "No Folder Thread",
+ Utc.with_ymd_and_hms(2024, 6, 15, 12, 0, 0).unwrap(),
+ );
+
+ database
+ .save_thread(thread_id.clone(), thread, PathList::default())
+ .await
+ .unwrap();
+
+ let threads = database.list_threads().await.unwrap();
+ assert_eq!(threads.len(), 1);
+ assert!(threads[0].folder_paths.is_empty());
+ }
}
@@ -2,6 +2,7 @@ use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use gpui::{App, Context, Entity, Global, Task, prelude::*};
+use util::path_list::PathList;
struct GlobalThreadStore(Entity<ThreadStore>);
@@ -49,12 +50,13 @@ impl ThreadStore {
&mut self,
id: acp::SessionId,
thread: crate::DbThread,
+ folder_paths: PathList,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let database_future = ThreadsDatabase::connect(cx);
cx.spawn(async move |this, cx| {
let database = database_future.await.map_err(|err| anyhow!(err))?;
- database.save_thread(id, thread).await?;
+ database.save_thread(id, thread, folder_paths).await?;
this.update(cx, |this, cx| this.reload(cx))
})
}
@@ -106,6 +108,13 @@ impl ThreadStore {
pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
self.threads.iter().cloned()
}
+
+ /// Returns threads whose folder_paths match the given paths exactly.
+ pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
+ self.threads
+ .iter()
+ .filter(move |thread| &thread.folder_paths == paths)
+ }
}
#[cfg(test)]
@@ -157,12 +166,12 @@ mod tests {
);
let save_older = thread_store.update(cx, |store, cx| {
- store.save_thread(older_id.clone(), older_thread, cx)
+ store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
});
save_older.await.unwrap();
let save_newer = thread_store.update(cx, |store, cx| {
- store.save_thread(newer_id.clone(), newer_thread, cx)
+ store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
});
save_newer.await.unwrap();
@@ -185,8 +194,9 @@ mod tests {
Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
);
- let save_task =
- thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
+ let save_task = thread_store.update(cx, |store, cx| {
+ store.save_thread(thread_id, thread, PathList::default(), cx)
+ });
save_task.await.unwrap();
cx.run_until_parked();
@@ -217,11 +227,11 @@ mod tests {
);
let save_first = thread_store.update(cx, |store, cx| {
- store.save_thread(first_id.clone(), first_thread, cx)
+ store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
});
save_first.await.unwrap();
let save_second = thread_store.update(cx, |store, cx| {
- store.save_thread(second_id.clone(), second_thread, cx)
+ store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
});
save_second.await.unwrap();
cx.run_until_parked();
@@ -254,11 +264,11 @@ mod tests {
);
let save_first = thread_store.update(cx, |store, cx| {
- store.save_thread(first_id.clone(), first_thread, cx)
+ store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
});
save_first.await.unwrap();
let save_second = thread_store.update(cx, |store, cx| {
- store.save_thread(second_id.clone(), second_thread, cx)
+ store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
});
save_second.await.unwrap();
cx.run_until_parked();
@@ -268,7 +278,7 @@ mod tests {
Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
);
let update_task = thread_store.update(cx, |store, cx| {
- store.save_thread(first_id.clone(), updated_first, cx)
+ store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
});
update_task.await.unwrap();
cx.run_until_parked();
@@ -278,4 +288,50 @@ mod tests {
assert_eq!(entries[0].id, first_id);
assert_eq!(entries[1].id, second_id);
}
+
+ #[gpui::test]
+ async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ cx.run_until_parked();
+
+ let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
+ let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
+
+ let thread_a = make_thread(
+ "Thread in A",
+ Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
+ );
+ let thread_b = make_thread(
+ "Thread in B",
+ Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
+ );
+ let thread_a_id = session_id("thread-a");
+ let thread_b_id = session_id("thread-b");
+
+ let save_a = thread_store.update(cx, |store, cx| {
+ store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
+ });
+ save_a.await.unwrap();
+
+ let save_b = thread_store.update(cx, |store, cx| {
+ store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
+ });
+ save_b.await.unwrap();
+
+ cx.run_until_parked();
+
+ thread_store.read_with(cx, |store, _cx| {
+ let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
+ assert_eq!(a_threads.len(), 1);
+ assert_eq!(a_threads[0].id, thread_a_id);
+
+ let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
+ assert_eq!(b_threads.len(), 1);
+ assert_eq!(b_threads[0].id, thread_b_id);
+
+ let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
+ let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
+ assert!(no_threads.is_empty());
+ });
+ }
}
@@ -1461,7 +1461,7 @@ impl AgentPanel {
cx.spawn_in(window, async move |this, cx| {
thread_store
.update(&mut cx.clone(), |store, cx| {
- store.save_thread(session_id.clone(), db_thread, cx)
+ store.save_thread(session_id.clone(), db_thread, Default::default(), cx)
})
.await?;
@@ -1536,7 +1536,7 @@ impl ThreadView {
thread_store
.update(&mut cx.clone(), |store, cx| {
- store.save_thread(session_id.clone(), db_thread, cx)
+ store.save_thread(session_id.clone(), db_thread, Default::default(), cx)
})
.await?;
@@ -3,8 +3,9 @@ use std::{
sync::Arc,
};
+use crate::paths::SanitizedPath;
use itertools::Itertools;
-use util::paths::SanitizedPath;
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
/// A list of absolute paths, in a specific order.
///
@@ -118,6 +119,19 @@ impl PathList {
}
}
+impl Serialize for PathList {
+ fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+ self.paths.serialize(serializer)
+ }
+}
+
+impl<'de> Deserialize<'de> for PathList {
+ fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
+ let paths: Vec<PathBuf> = Vec::deserialize(deserializer)?;
+ Ok(PathList::new(&paths))
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -2,6 +2,7 @@ pub mod archive;
pub mod command;
pub mod fs;
pub mod markdown;
+pub mod path_list;
pub mod paths;
pub mod process;
pub mod redact;
@@ -7,7 +7,9 @@ mod multi_workspace;
pub mod notifications;
pub mod pane;
pub mod pane_group;
-mod path_list;
+pub mod path_list {
+ pub use util::path_list::{PathList, SerializedPathList};
+}
mod persistence;
pub mod searchable;
mod security_modal;
@@ -28,7 +30,7 @@ pub use multi_workspace::{
NextWorkspaceInWindow, PreviousWorkspaceInWindow, Sidebar, SidebarEvent, SidebarHandle,
ToggleWorkspaceSidebar,
};
-pub use path_list::PathList;
+pub use path_list::{PathList, SerializedPathList};
pub use toast_layer::{ToastAction, ToastLayer, ToastView};
use anyhow::{Context as _, Result, anyhow};
@@ -951,7 +951,12 @@ fn handle_open_request(request: OpenRequest, app_state: Arc<AppState>, cx: &mut
thread_store
.update(&mut cx.clone(), |store, cx| {
- store.save_thread(save_session_id.clone(), db_thread, cx)
+ store.save_thread(
+ save_session_id.clone(),
+ db_thread,
+ Default::default(),
+ cx,
+ )
})
.await?;