1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Context, Entity, Global, Task, prelude::*};
5use std::collections::HashMap;
6use util::path_list::PathList;
7
8struct GlobalThreadStore(Entity<ThreadStore>);
9
10impl Global for GlobalThreadStore {}
11
12pub struct ThreadStore {
13 threads: Vec<DbThreadMetadata>,
14 threads_by_paths: HashMap<PathList, Vec<usize>>,
15}
16
17impl ThreadStore {
18 pub fn init_global(cx: &mut App) {
19 let thread_store = cx.new(|cx| Self::new(cx));
20 cx.set_global(GlobalThreadStore(thread_store));
21 }
22
23 pub fn global(cx: &App) -> Entity<Self> {
24 cx.global::<GlobalThreadStore>().0.clone()
25 }
26
27 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
28 cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
29 }
30
31 pub fn new(cx: &mut Context<Self>) -> Self {
32 let this = Self {
33 threads: Vec::new(),
34 threads_by_paths: HashMap::default(),
35 };
36 this.reload(cx);
37 this
38 }
39
40 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
41 self.threads.iter().find(|thread| &thread.id == session_id)
42 }
43
44 pub fn load_thread(
45 &mut self,
46 id: acp::SessionId,
47 cx: &mut Context<Self>,
48 ) -> Task<Result<Option<DbThread>>> {
49 let database_future = ThreadsDatabase::connect(cx);
50 cx.background_spawn(async move {
51 let database = database_future.await.map_err(|err| anyhow!(err))?;
52 database.load_thread(id).await
53 })
54 }
55
56 pub fn save_thread(
57 &mut self,
58 id: acp::SessionId,
59 thread: crate::DbThread,
60 folder_paths: PathList,
61 cx: &mut Context<Self>,
62 ) -> Task<Result<()>> {
63 let database_future = ThreadsDatabase::connect(cx);
64 cx.spawn(async move |this, cx| {
65 let database = database_future.await.map_err(|err| anyhow!(err))?;
66 database.save_thread(id, thread, folder_paths).await?;
67 this.update(cx, |this, cx| this.reload(cx))
68 })
69 }
70
71 pub fn delete_thread(
72 &mut self,
73 id: acp::SessionId,
74 cx: &mut Context<Self>,
75 ) -> Task<Result<()>> {
76 let database_future = ThreadsDatabase::connect(cx);
77 cx.spawn(async move |this, cx| {
78 let database = database_future.await.map_err(|err| anyhow!(err))?;
79 database.delete_thread(id.clone()).await?;
80 this.update(cx, |this, cx| this.reload(cx))
81 })
82 }
83
84 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
85 let database_future = ThreadsDatabase::connect(cx);
86 cx.spawn(async move |this, cx| {
87 let database = database_future.await.map_err(|err| anyhow!(err))?;
88 database.delete_threads().await?;
89 this.update(cx, |this, cx| this.reload(cx))
90 })
91 }
92
93 pub fn reload(&self, cx: &mut Context<Self>) {
94 let database_connection = ThreadsDatabase::connect(cx);
95 cx.spawn(async move |this, cx| {
96 let database = database_connection.await.map_err(|err| anyhow!(err))?;
97 let all_threads = database.list_threads().await?;
98 this.update(cx, |this, cx| {
99 this.threads.clear();
100 this.threads_by_paths.clear();
101 for thread in all_threads {
102 if thread.parent_session_id.is_some() {
103 continue;
104 }
105 let index = this.threads.len();
106 this.threads_by_paths
107 .entry(thread.folder_paths.clone())
108 .or_default()
109 .push(index);
110 this.threads.push(thread);
111 }
112 cx.notify();
113 })
114 })
115 .detach_and_log_err(cx);
116 }
117
118 pub fn is_empty(&self) -> bool {
119 self.threads.is_empty()
120 }
121
122 pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
123 self.threads.iter().cloned()
124 }
125
126 /// Returns threads whose folder_paths match the given paths exactly.
127 /// Uses a cached index for O(1) lookup per path list.
128 pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
129 self.threads_by_paths
130 .get(paths)
131 .into_iter()
132 .flat_map(|indices| indices.iter().map(|&index| &self.threads[index]))
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use chrono::{DateTime, TimeZone, Utc};
140 use collections::HashMap;
141 use gpui::TestAppContext;
142 use std::sync::Arc;
143
144 fn session_id(value: &str) -> acp::SessionId {
145 acp::SessionId::new(Arc::<str>::from(value))
146 }
147
148 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
149 DbThread {
150 title: title.to_string().into(),
151 messages: Vec::new(),
152 updated_at,
153 detailed_summary: None,
154 initial_project_snapshot: None,
155 cumulative_token_usage: Default::default(),
156 request_token_usage: HashMap::default(),
157 model: None,
158 profile: None,
159 imported: false,
160 subagent_context: None,
161 speed: None,
162 thinking_enabled: false,
163 thinking_effort: None,
164 draft_prompt: None,
165 ui_scroll_position: None,
166 }
167 }
168
169 #[gpui::test]
170 async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
171 let thread_store = cx.new(|cx| ThreadStore::new(cx));
172 cx.run_until_parked();
173
174 let older_id = session_id("thread-a");
175 let newer_id = session_id("thread-b");
176
177 let older_thread = make_thread(
178 "Thread A",
179 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
180 );
181 let newer_thread = make_thread(
182 "Thread B",
183 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
184 );
185
186 let save_older = thread_store.update(cx, |store, cx| {
187 store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
188 });
189 save_older.await.unwrap();
190
191 let save_newer = thread_store.update(cx, |store, cx| {
192 store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
193 });
194 save_newer.await.unwrap();
195
196 cx.run_until_parked();
197
198 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
199 assert_eq!(entries.len(), 2);
200 assert_eq!(entries[0].id, newer_id);
201 assert_eq!(entries[1].id, older_id);
202 }
203
204 #[gpui::test]
205 async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
206 let thread_store = cx.new(|cx| ThreadStore::new(cx));
207 cx.run_until_parked();
208
209 let thread_id = session_id("thread-a");
210 let thread = make_thread(
211 "Thread A",
212 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
213 );
214
215 let save_task = thread_store.update(cx, |store, cx| {
216 store.save_thread(thread_id, thread, PathList::default(), cx)
217 });
218 save_task.await.unwrap();
219
220 cx.run_until_parked();
221 assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
222
223 let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
224 delete_task.await.unwrap();
225 cx.run_until_parked();
226
227 assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
228 }
229
230 #[gpui::test]
231 async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
232 let thread_store = cx.new(|cx| ThreadStore::new(cx));
233 cx.run_until_parked();
234
235 let first_id = session_id("thread-a");
236 let second_id = session_id("thread-b");
237
238 let first_thread = make_thread(
239 "Thread A",
240 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
241 );
242 let second_thread = make_thread(
243 "Thread B",
244 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
245 );
246
247 let save_first = thread_store.update(cx, |store, cx| {
248 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
249 });
250 save_first.await.unwrap();
251 let save_second = thread_store.update(cx, |store, cx| {
252 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
253 });
254 save_second.await.unwrap();
255 cx.run_until_parked();
256
257 let delete_task =
258 thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
259 delete_task.await.unwrap();
260 cx.run_until_parked();
261
262 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
263 assert_eq!(entries.len(), 1);
264 assert_eq!(entries[0].id, second_id);
265 }
266
267 #[gpui::test]
268 async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
269 let thread_store = cx.new(|cx| ThreadStore::new(cx));
270 cx.run_until_parked();
271
272 let first_id = session_id("thread-a");
273 let second_id = session_id("thread-b");
274
275 let first_thread = make_thread(
276 "Thread A",
277 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
278 );
279 let second_thread = make_thread(
280 "Thread B",
281 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
282 );
283
284 let save_first = thread_store.update(cx, |store, cx| {
285 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
286 });
287 save_first.await.unwrap();
288 let save_second = thread_store.update(cx, |store, cx| {
289 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
290 });
291 save_second.await.unwrap();
292 cx.run_until_parked();
293
294 let updated_first = make_thread(
295 "Thread A",
296 Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
297 );
298 let update_task = thread_store.update(cx, |store, cx| {
299 store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
300 });
301 update_task.await.unwrap();
302 cx.run_until_parked();
303
304 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
305 assert_eq!(entries.len(), 2);
306 assert_eq!(entries[0].id, first_id);
307 assert_eq!(entries[1].id, second_id);
308 }
309
310 #[gpui::test]
311 async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
312 let thread_store = cx.new(|cx| ThreadStore::new(cx));
313 cx.run_until_parked();
314
315 let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
316 let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
317
318 let thread_a = make_thread(
319 "Thread in A",
320 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
321 );
322 let thread_b = make_thread(
323 "Thread in B",
324 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
325 );
326 let thread_a_id = session_id("thread-a");
327 let thread_b_id = session_id("thread-b");
328
329 let save_a = thread_store.update(cx, |store, cx| {
330 store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
331 });
332 save_a.await.unwrap();
333
334 let save_b = thread_store.update(cx, |store, cx| {
335 store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
336 });
337 save_b.await.unwrap();
338
339 cx.run_until_parked();
340
341 thread_store.read_with(cx, |store, _cx| {
342 let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
343 assert_eq!(a_threads.len(), 1);
344 assert_eq!(a_threads[0].id, thread_a_id);
345
346 let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
347 assert_eq!(b_threads.len(), 1);
348 assert_eq!(b_threads[0].id, thread_b_id);
349
350 let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
351 let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
352 assert!(no_threads.is_empty());
353 });
354 }
355}