1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Context, Entity, Global, Task, prelude::*};
5use std::collections::HashMap;
6use util::path_list::PathList;
7
8struct GlobalThreadStore(Entity<ThreadStore>);
9
10impl Global for GlobalThreadStore {}
11
12pub struct ThreadStore {
13 threads: Vec<DbThreadMetadata>,
14 threads_by_paths: HashMap<PathList, Vec<usize>>,
15}
16
17impl ThreadStore {
18 pub fn init_global(cx: &mut App) {
19 let thread_store = cx.new(|cx| Self::new(cx));
20 cx.set_global(GlobalThreadStore(thread_store));
21 }
22
23 pub fn global(cx: &App) -> Entity<Self> {
24 cx.global::<GlobalThreadStore>().0.clone()
25 }
26
27 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
28 cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
29 }
30
31 pub fn new(cx: &mut Context<Self>) -> Self {
32 let this = Self {
33 threads: Vec::new(),
34 threads_by_paths: HashMap::default(),
35 };
36 this.reload(cx);
37 this
38 }
39
40 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
41 self.threads.iter().find(|thread| &thread.id == session_id)
42 }
43
44 pub fn load_thread(
45 &mut self,
46 id: acp::SessionId,
47 cx: &mut Context<Self>,
48 ) -> Task<Result<Option<DbThread>>> {
49 let database_future = ThreadsDatabase::connect(cx);
50 cx.background_spawn(async move {
51 let database = database_future.await.map_err(|err| anyhow!(err))?;
52 database.load_thread(id).await
53 })
54 }
55
56 pub fn save_thread(
57 &mut self,
58 id: acp::SessionId,
59 thread: crate::DbThread,
60 folder_paths: PathList,
61 cx: &mut Context<Self>,
62 ) -> Task<Result<()>> {
63 let database_future = ThreadsDatabase::connect(cx);
64 cx.spawn(async move |this, cx| {
65 let database = database_future.await.map_err(|err| anyhow!(err))?;
66 database.save_thread(id, thread, folder_paths).await?;
67 this.update(cx, |this, cx| this.reload(cx))
68 })
69 }
70
71 pub fn delete_thread(
72 &mut self,
73 id: acp::SessionId,
74 cx: &mut Context<Self>,
75 ) -> Task<Result<()>> {
76 let database_future = ThreadsDatabase::connect(cx);
77 cx.spawn(async move |this, cx| {
78 let database = database_future.await.map_err(|err| anyhow!(err))?;
79 database.delete_thread(id.clone()).await?;
80 this.update(cx, |this, cx| this.reload(cx))
81 })
82 }
83
84 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
85 let database_future = ThreadsDatabase::connect(cx);
86 cx.spawn(async move |this, cx| {
87 let database = database_future.await.map_err(|err| anyhow!(err))?;
88 database.delete_threads().await?;
89 this.update(cx, |this, cx| this.reload(cx))
90 })
91 }
92
93 pub fn reload(&self, cx: &mut Context<Self>) {
94 let database_connection = ThreadsDatabase::connect(cx);
95 cx.spawn(async move |this, cx| {
96 let database = database_connection.await.map_err(|err| anyhow!(err))?;
97 let all_threads = database.list_threads().await?;
98 this.update(cx, |this, cx| {
99 this.threads.clear();
100 this.threads_by_paths.clear();
101 for thread in all_threads {
102 if thread.parent_session_id.is_some() {
103 continue;
104 }
105 let index = this.threads.len();
106 this.threads_by_paths
107 .entry(thread.folder_paths.clone())
108 .or_default()
109 .push(index);
110 this.threads.push(thread);
111 }
112 cx.notify();
113 })
114 })
115 .detach_and_log_err(cx);
116 }
117
118 pub fn is_empty(&self) -> bool {
119 self.threads.is_empty()
120 }
121
122 pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
123 self.threads.iter().cloned()
124 }
125
126 /// Returns threads whose folder_paths match the given paths exactly.
127 /// Uses a cached index for O(1) lookup per path list.
128 pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
129 self.threads_by_paths
130 .get(paths)
131 .into_iter()
132 .flat_map(|indices| indices.iter().map(|&index| &self.threads[index]))
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use chrono::{DateTime, TimeZone, Utc};
140 use collections::HashMap;
141 use gpui::TestAppContext;
142 use std::sync::Arc;
143
144 fn session_id(value: &str) -> acp::SessionId {
145 acp::SessionId::new(Arc::<str>::from(value))
146 }
147
148 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
149 DbThread {
150 title: title.to_string().into(),
151 messages: Vec::new(),
152 updated_at,
153 detailed_summary: None,
154 initial_project_snapshot: None,
155 cumulative_token_usage: Default::default(),
156 request_token_usage: HashMap::default(),
157 model: None,
158 profile: None,
159 imported: false,
160 subagent_context: None,
161 speed: None,
162 thinking_enabled: false,
163 thinking_effort: None,
164 draft_prompt: None,
165 provisional_title: None,
166 ui_scroll_position: None,
167 }
168 }
169
170 #[gpui::test]
171 async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
172 let thread_store = cx.new(|cx| ThreadStore::new(cx));
173 cx.run_until_parked();
174
175 let older_id = session_id("thread-a");
176 let newer_id = session_id("thread-b");
177
178 let older_thread = make_thread(
179 "Thread A",
180 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
181 );
182 let newer_thread = make_thread(
183 "Thread B",
184 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
185 );
186
187 let save_older = thread_store.update(cx, |store, cx| {
188 store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
189 });
190 save_older.await.unwrap();
191
192 let save_newer = thread_store.update(cx, |store, cx| {
193 store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
194 });
195 save_newer.await.unwrap();
196
197 cx.run_until_parked();
198
199 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
200 assert_eq!(entries.len(), 2);
201 assert_eq!(entries[0].id, newer_id);
202 assert_eq!(entries[1].id, older_id);
203 }
204
205 #[gpui::test]
206 async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
207 let thread_store = cx.new(|cx| ThreadStore::new(cx));
208 cx.run_until_parked();
209
210 let thread_id = session_id("thread-a");
211 let thread = make_thread(
212 "Thread A",
213 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
214 );
215
216 let save_task = thread_store.update(cx, |store, cx| {
217 store.save_thread(thread_id, thread, PathList::default(), cx)
218 });
219 save_task.await.unwrap();
220
221 cx.run_until_parked();
222 assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
223
224 let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
225 delete_task.await.unwrap();
226 cx.run_until_parked();
227
228 assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
229 }
230
231 #[gpui::test]
232 async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
233 let thread_store = cx.new(|cx| ThreadStore::new(cx));
234 cx.run_until_parked();
235
236 let first_id = session_id("thread-a");
237 let second_id = session_id("thread-b");
238
239 let first_thread = make_thread(
240 "Thread A",
241 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
242 );
243 let second_thread = make_thread(
244 "Thread B",
245 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
246 );
247
248 let save_first = thread_store.update(cx, |store, cx| {
249 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
250 });
251 save_first.await.unwrap();
252 let save_second = thread_store.update(cx, |store, cx| {
253 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
254 });
255 save_second.await.unwrap();
256 cx.run_until_parked();
257
258 let delete_task =
259 thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
260 delete_task.await.unwrap();
261 cx.run_until_parked();
262
263 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
264 assert_eq!(entries.len(), 1);
265 assert_eq!(entries[0].id, second_id);
266 }
267
268 #[gpui::test]
269 async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
270 let thread_store = cx.new(|cx| ThreadStore::new(cx));
271 cx.run_until_parked();
272
273 let first_id = session_id("thread-a");
274 let second_id = session_id("thread-b");
275
276 let first_thread = make_thread(
277 "Thread A",
278 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
279 );
280 let second_thread = make_thread(
281 "Thread B",
282 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
283 );
284
285 let save_first = thread_store.update(cx, |store, cx| {
286 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
287 });
288 save_first.await.unwrap();
289 let save_second = thread_store.update(cx, |store, cx| {
290 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
291 });
292 save_second.await.unwrap();
293 cx.run_until_parked();
294
295 let updated_first = make_thread(
296 "Thread A",
297 Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
298 );
299 let update_task = thread_store.update(cx, |store, cx| {
300 store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
301 });
302 update_task.await.unwrap();
303 cx.run_until_parked();
304
305 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
306 assert_eq!(entries.len(), 2);
307 assert_eq!(entries[0].id, first_id);
308 assert_eq!(entries[1].id, second_id);
309 }
310
311 #[gpui::test]
312 async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
313 let thread_store = cx.new(|cx| ThreadStore::new(cx));
314 cx.run_until_parked();
315
316 let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
317 let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
318
319 let thread_a = make_thread(
320 "Thread in A",
321 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
322 );
323 let thread_b = make_thread(
324 "Thread in B",
325 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
326 );
327 let thread_a_id = session_id("thread-a");
328 let thread_b_id = session_id("thread-b");
329
330 let save_a = thread_store.update(cx, |store, cx| {
331 store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
332 });
333 save_a.await.unwrap();
334
335 let save_b = thread_store.update(cx, |store, cx| {
336 store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
337 });
338 save_b.await.unwrap();
339
340 cx.run_until_parked();
341
342 thread_store.read_with(cx, |store, _cx| {
343 let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
344 assert_eq!(a_threads.len(), 1);
345 assert_eq!(a_threads[0].id, thread_a_id);
346
347 let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
348 assert_eq!(b_threads.len(), 1);
349 assert_eq!(b_threads[0].id, thread_b_id);
350
351 let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
352 let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
353 assert!(no_threads.is_empty());
354 });
355 }
356}