1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Context, Entity, Global, Task, prelude::*};
5use util::path_list::PathList;
6
7struct GlobalThreadStore(Entity<ThreadStore>);
8
9impl Global for GlobalThreadStore {}
10
11pub struct ThreadStore {
12 threads: Vec<DbThreadMetadata>,
13}
14
15impl ThreadStore {
16 pub fn init_global(cx: &mut App) {
17 let thread_store = cx.new(|cx| Self::new(cx));
18 cx.set_global(GlobalThreadStore(thread_store));
19 }
20
21 pub fn global(cx: &App) -> Entity<Self> {
22 cx.global::<GlobalThreadStore>().0.clone()
23 }
24
25 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
26 cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
27 }
28
29 pub fn new(cx: &mut Context<Self>) -> Self {
30 let this = Self {
31 threads: Vec::new(),
32 };
33 this.reload(cx);
34 this
35 }
36
37 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
38 self.threads.iter().find(|thread| &thread.id == session_id)
39 }
40
41 pub fn load_thread(
42 &mut self,
43 id: acp::SessionId,
44 cx: &mut Context<Self>,
45 ) -> Task<Result<Option<DbThread>>> {
46 let database_future = ThreadsDatabase::connect(cx);
47 cx.background_spawn(async move {
48 let database = database_future.await.map_err(|err| anyhow!(err))?;
49 database.load_thread(id).await
50 })
51 }
52
53 pub fn save_thread(
54 &mut self,
55 id: acp::SessionId,
56 thread: crate::DbThread,
57 folder_paths: PathList,
58 cx: &mut Context<Self>,
59 ) -> Task<Result<()>> {
60 let database_future = ThreadsDatabase::connect(cx);
61 cx.spawn(async move |this, cx| {
62 let database = database_future.await.map_err(|err| anyhow!(err))?;
63 database.save_thread(id, thread, folder_paths).await?;
64 this.update(cx, |this, cx| this.reload(cx))
65 })
66 }
67
68 pub fn delete_thread(
69 &mut self,
70 id: acp::SessionId,
71 cx: &mut Context<Self>,
72 ) -> Task<Result<()>> {
73 let database_future = ThreadsDatabase::connect(cx);
74 cx.spawn(async move |this, cx| {
75 let database = database_future.await.map_err(|err| anyhow!(err))?;
76 database.delete_thread(id.clone()).await?;
77 this.update(cx, |this, cx| this.reload(cx))
78 })
79 }
80
81 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
82 let database_future = ThreadsDatabase::connect(cx);
83 cx.spawn(async move |this, cx| {
84 let database = database_future.await.map_err(|err| anyhow!(err))?;
85 database.delete_threads().await?;
86 this.update(cx, |this, cx| this.reload(cx))
87 })
88 }
89
90 pub fn reload(&self, cx: &mut Context<Self>) {
91 let database_connection = ThreadsDatabase::connect(cx);
92 cx.spawn(async move |this, cx| {
93 let database = database_connection.await.map_err(|err| anyhow!(err))?;
94 let threads = database
95 .list_threads()
96 .await?
97 .into_iter()
98 .filter(|thread| thread.parent_session_id.is_none())
99 .collect::<Vec<_>>();
100 this.update(cx, |this, cx| {
101 this.threads = threads;
102 cx.notify();
103 })
104 })
105 .detach_and_log_err(cx);
106 }
107
108 pub fn is_empty(&self) -> bool {
109 self.threads.is_empty()
110 }
111
112 pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
113 self.threads.iter().cloned()
114 }
115
116 /// Returns threads whose folder_paths match the given paths exactly.
117 pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
118 self.threads
119 .iter()
120 .filter(move |thread| &thread.folder_paths == paths)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use chrono::{DateTime, TimeZone, Utc};
128 use collections::HashMap;
129 use gpui::TestAppContext;
130 use std::sync::Arc;
131
132 fn session_id(value: &str) -> acp::SessionId {
133 acp::SessionId::new(Arc::<str>::from(value))
134 }
135
136 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
137 DbThread {
138 title: title.to_string().into(),
139 messages: Vec::new(),
140 updated_at,
141 detailed_summary: None,
142 initial_project_snapshot: None,
143 cumulative_token_usage: Default::default(),
144 request_token_usage: HashMap::default(),
145 model: None,
146 profile: None,
147 imported: false,
148 subagent_context: None,
149 speed: None,
150 thinking_enabled: false,
151 thinking_effort: None,
152 draft_prompt: None,
153 ui_scroll_position: None,
154 }
155 }
156
157 #[gpui::test]
158 async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
159 let thread_store = cx.new(|cx| ThreadStore::new(cx));
160 cx.run_until_parked();
161
162 let older_id = session_id("thread-a");
163 let newer_id = session_id("thread-b");
164
165 let older_thread = make_thread(
166 "Thread A",
167 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
168 );
169 let newer_thread = make_thread(
170 "Thread B",
171 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
172 );
173
174 let save_older = thread_store.update(cx, |store, cx| {
175 store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
176 });
177 save_older.await.unwrap();
178
179 let save_newer = thread_store.update(cx, |store, cx| {
180 store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
181 });
182 save_newer.await.unwrap();
183
184 cx.run_until_parked();
185
186 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
187 assert_eq!(entries.len(), 2);
188 assert_eq!(entries[0].id, newer_id);
189 assert_eq!(entries[1].id, older_id);
190 }
191
192 #[gpui::test]
193 async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
194 let thread_store = cx.new(|cx| ThreadStore::new(cx));
195 cx.run_until_parked();
196
197 let thread_id = session_id("thread-a");
198 let thread = make_thread(
199 "Thread A",
200 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
201 );
202
203 let save_task = thread_store.update(cx, |store, cx| {
204 store.save_thread(thread_id, thread, PathList::default(), cx)
205 });
206 save_task.await.unwrap();
207
208 cx.run_until_parked();
209 assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
210
211 let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
212 delete_task.await.unwrap();
213 cx.run_until_parked();
214
215 assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
216 }
217
218 #[gpui::test]
219 async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
220 let thread_store = cx.new(|cx| ThreadStore::new(cx));
221 cx.run_until_parked();
222
223 let first_id = session_id("thread-a");
224 let second_id = session_id("thread-b");
225
226 let first_thread = make_thread(
227 "Thread A",
228 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
229 );
230 let second_thread = make_thread(
231 "Thread B",
232 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
233 );
234
235 let save_first = thread_store.update(cx, |store, cx| {
236 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
237 });
238 save_first.await.unwrap();
239 let save_second = thread_store.update(cx, |store, cx| {
240 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
241 });
242 save_second.await.unwrap();
243 cx.run_until_parked();
244
245 let delete_task =
246 thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
247 delete_task.await.unwrap();
248 cx.run_until_parked();
249
250 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
251 assert_eq!(entries.len(), 1);
252 assert_eq!(entries[0].id, second_id);
253 }
254
255 #[gpui::test]
256 async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
257 let thread_store = cx.new(|cx| ThreadStore::new(cx));
258 cx.run_until_parked();
259
260 let first_id = session_id("thread-a");
261 let second_id = session_id("thread-b");
262
263 let first_thread = make_thread(
264 "Thread A",
265 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
266 );
267 let second_thread = make_thread(
268 "Thread B",
269 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
270 );
271
272 let save_first = thread_store.update(cx, |store, cx| {
273 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
274 });
275 save_first.await.unwrap();
276 let save_second = thread_store.update(cx, |store, cx| {
277 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
278 });
279 save_second.await.unwrap();
280 cx.run_until_parked();
281
282 let updated_first = make_thread(
283 "Thread A",
284 Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
285 );
286 let update_task = thread_store.update(cx, |store, cx| {
287 store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
288 });
289 update_task.await.unwrap();
290 cx.run_until_parked();
291
292 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
293 assert_eq!(entries.len(), 2);
294 assert_eq!(entries[0].id, first_id);
295 assert_eq!(entries[1].id, second_id);
296 }
297
298 #[gpui::test]
299 async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
300 let thread_store = cx.new(|cx| ThreadStore::new(cx));
301 cx.run_until_parked();
302
303 let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
304 let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
305
306 let thread_a = make_thread(
307 "Thread in A",
308 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
309 );
310 let thread_b = make_thread(
311 "Thread in B",
312 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
313 );
314 let thread_a_id = session_id("thread-a");
315 let thread_b_id = session_id("thread-b");
316
317 let save_a = thread_store.update(cx, |store, cx| {
318 store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
319 });
320 save_a.await.unwrap();
321
322 let save_b = thread_store.update(cx, |store, cx| {
323 store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
324 });
325 save_b.await.unwrap();
326
327 cx.run_until_parked();
328
329 thread_store.read_with(cx, |store, _cx| {
330 let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
331 assert_eq!(a_threads.len(), 1);
332 assert_eq!(a_threads[0].id, thread_a_id);
333
334 let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
335 assert_eq!(b_threads.len(), 1);
336 assert_eq!(b_threads[0].id, thread_b_id);
337
338 let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
339 let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
340 assert!(no_threads.is_empty());
341 });
342 }
343}