1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Context, Entity, Global, Task, prelude::*};
5use util::path_list::PathList;
6
7struct GlobalThreadStore(Entity<ThreadStore>);
8
9impl Global for GlobalThreadStore {}
10
11pub struct ThreadStore {
12 threads: Vec<DbThreadMetadata>,
13}
14
15impl ThreadStore {
16 pub fn init_global(cx: &mut App) {
17 let thread_store = cx.new(|cx| Self::new(cx));
18 cx.set_global(GlobalThreadStore(thread_store));
19 }
20
21 pub fn global(cx: &App) -> Entity<Self> {
22 cx.global::<GlobalThreadStore>().0.clone()
23 }
24
25 pub fn new(cx: &mut Context<Self>) -> Self {
26 let this = Self {
27 threads: Vec::new(),
28 };
29 this.reload(cx);
30 this
31 }
32
33 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
34 self.threads.iter().find(|thread| &thread.id == session_id)
35 }
36
37 pub fn load_thread(
38 &mut self,
39 id: acp::SessionId,
40 cx: &mut Context<Self>,
41 ) -> Task<Result<Option<DbThread>>> {
42 let database_future = ThreadsDatabase::connect(cx);
43 cx.background_spawn(async move {
44 let database = database_future.await.map_err(|err| anyhow!(err))?;
45 database.load_thread(id).await
46 })
47 }
48
49 pub fn save_thread(
50 &mut self,
51 id: acp::SessionId,
52 thread: crate::DbThread,
53 folder_paths: PathList,
54 cx: &mut Context<Self>,
55 ) -> Task<Result<()>> {
56 let database_future = ThreadsDatabase::connect(cx);
57 cx.spawn(async move |this, cx| {
58 let database = database_future.await.map_err(|err| anyhow!(err))?;
59 database.save_thread(id, thread, folder_paths).await?;
60 this.update(cx, |this, cx| this.reload(cx))
61 })
62 }
63
64 pub fn delete_thread(
65 &mut self,
66 id: acp::SessionId,
67 cx: &mut Context<Self>,
68 ) -> Task<Result<()>> {
69 let database_future = ThreadsDatabase::connect(cx);
70 cx.spawn(async move |this, cx| {
71 let database = database_future.await.map_err(|err| anyhow!(err))?;
72 database.delete_thread(id.clone()).await?;
73 this.update(cx, |this, cx| this.reload(cx))
74 })
75 }
76
77 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
78 let database_future = ThreadsDatabase::connect(cx);
79 cx.spawn(async move |this, cx| {
80 let database = database_future.await.map_err(|err| anyhow!(err))?;
81 database.delete_threads().await?;
82 this.update(cx, |this, cx| this.reload(cx))
83 })
84 }
85
86 pub fn reload(&self, cx: &mut Context<Self>) {
87 let database_connection = ThreadsDatabase::connect(cx);
88 cx.spawn(async move |this, cx| {
89 let database = database_connection.await.map_err(|err| anyhow!(err))?;
90 let threads = database
91 .list_threads()
92 .await?
93 .into_iter()
94 .filter(|thread| thread.parent_session_id.is_none())
95 .collect::<Vec<_>>();
96 this.update(cx, |this, cx| {
97 this.threads = threads;
98 cx.notify();
99 })
100 })
101 .detach_and_log_err(cx);
102 }
103
104 pub fn is_empty(&self) -> bool {
105 self.threads.is_empty()
106 }
107
108 pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
109 self.threads.iter().cloned()
110 }
111
112 /// Returns threads whose folder_paths match the given paths exactly.
113 pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
114 self.threads
115 .iter()
116 .filter(move |thread| &thread.folder_paths == paths)
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use chrono::{DateTime, TimeZone, Utc};
124 use collections::HashMap;
125 use gpui::TestAppContext;
126 use std::sync::Arc;
127
128 fn session_id(value: &str) -> acp::SessionId {
129 acp::SessionId::new(Arc::<str>::from(value))
130 }
131
132 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
133 DbThread {
134 title: title.to_string().into(),
135 messages: Vec::new(),
136 updated_at,
137 detailed_summary: None,
138 initial_project_snapshot: None,
139 cumulative_token_usage: Default::default(),
140 request_token_usage: HashMap::default(),
141 model: None,
142 profile: None,
143 imported: false,
144 subagent_context: None,
145 speed: None,
146 thinking_enabled: false,
147 thinking_effort: None,
148 draft_prompt: None,
149 ui_scroll_position: None,
150 }
151 }
152
153 #[gpui::test]
154 async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
155 let thread_store = cx.new(|cx| ThreadStore::new(cx));
156 cx.run_until_parked();
157
158 let older_id = session_id("thread-a");
159 let newer_id = session_id("thread-b");
160
161 let older_thread = make_thread(
162 "Thread A",
163 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
164 );
165 let newer_thread = make_thread(
166 "Thread B",
167 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
168 );
169
170 let save_older = thread_store.update(cx, |store, cx| {
171 store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
172 });
173 save_older.await.unwrap();
174
175 let save_newer = thread_store.update(cx, |store, cx| {
176 store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
177 });
178 save_newer.await.unwrap();
179
180 cx.run_until_parked();
181
182 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
183 assert_eq!(entries.len(), 2);
184 assert_eq!(entries[0].id, newer_id);
185 assert_eq!(entries[1].id, older_id);
186 }
187
188 #[gpui::test]
189 async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
190 let thread_store = cx.new(|cx| ThreadStore::new(cx));
191 cx.run_until_parked();
192
193 let thread_id = session_id("thread-a");
194 let thread = make_thread(
195 "Thread A",
196 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
197 );
198
199 let save_task = thread_store.update(cx, |store, cx| {
200 store.save_thread(thread_id, thread, PathList::default(), cx)
201 });
202 save_task.await.unwrap();
203
204 cx.run_until_parked();
205 assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
206
207 let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
208 delete_task.await.unwrap();
209 cx.run_until_parked();
210
211 assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
212 }
213
214 #[gpui::test]
215 async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
216 let thread_store = cx.new(|cx| ThreadStore::new(cx));
217 cx.run_until_parked();
218
219 let first_id = session_id("thread-a");
220 let second_id = session_id("thread-b");
221
222 let first_thread = make_thread(
223 "Thread A",
224 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
225 );
226 let second_thread = make_thread(
227 "Thread B",
228 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
229 );
230
231 let save_first = thread_store.update(cx, |store, cx| {
232 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
233 });
234 save_first.await.unwrap();
235 let save_second = thread_store.update(cx, |store, cx| {
236 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
237 });
238 save_second.await.unwrap();
239 cx.run_until_parked();
240
241 let delete_task =
242 thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
243 delete_task.await.unwrap();
244 cx.run_until_parked();
245
246 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
247 assert_eq!(entries.len(), 1);
248 assert_eq!(entries[0].id, second_id);
249 }
250
251 #[gpui::test]
252 async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
253 let thread_store = cx.new(|cx| ThreadStore::new(cx));
254 cx.run_until_parked();
255
256 let first_id = session_id("thread-a");
257 let second_id = session_id("thread-b");
258
259 let first_thread = make_thread(
260 "Thread A",
261 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
262 );
263 let second_thread = make_thread(
264 "Thread B",
265 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
266 );
267
268 let save_first = thread_store.update(cx, |store, cx| {
269 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
270 });
271 save_first.await.unwrap();
272 let save_second = thread_store.update(cx, |store, cx| {
273 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
274 });
275 save_second.await.unwrap();
276 cx.run_until_parked();
277
278 let updated_first = make_thread(
279 "Thread A",
280 Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
281 );
282 let update_task = thread_store.update(cx, |store, cx| {
283 store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
284 });
285 update_task.await.unwrap();
286 cx.run_until_parked();
287
288 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
289 assert_eq!(entries.len(), 2);
290 assert_eq!(entries[0].id, first_id);
291 assert_eq!(entries[1].id, second_id);
292 }
293
294 #[gpui::test]
295 async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
296 let thread_store = cx.new(|cx| ThreadStore::new(cx));
297 cx.run_until_parked();
298
299 let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
300 let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
301
302 let thread_a = make_thread(
303 "Thread in A",
304 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
305 );
306 let thread_b = make_thread(
307 "Thread in B",
308 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
309 );
310 let thread_a_id = session_id("thread-a");
311 let thread_b_id = session_id("thread-b");
312
313 let save_a = thread_store.update(cx, |store, cx| {
314 store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
315 });
316 save_a.await.unwrap();
317
318 let save_b = thread_store.update(cx, |store, cx| {
319 store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
320 });
321 save_b.await.unwrap();
322
323 cx.run_until_parked();
324
325 thread_store.read_with(cx, |store, _cx| {
326 let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
327 assert_eq!(a_threads.len(), 1);
328 assert_eq!(a_threads[0].id, thread_a_id);
329
330 let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
331 assert_eq!(b_threads.len(), 1);
332 assert_eq!(b_threads[0].id, thread_b_id);
333
334 let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
335 let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
336 assert!(no_threads.is_empty());
337 });
338 }
339}