thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use util::path_list::PathList;
  6
  7struct GlobalThreadStore(Entity<ThreadStore>);
  8
  9impl Global for GlobalThreadStore {}
 10
 11pub struct ThreadStore {
 12    threads: Vec<DbThreadMetadata>,
 13}
 14
 15impl ThreadStore {
 16    pub fn init_global(cx: &mut App) {
 17        let thread_store = cx.new(|cx| Self::new(cx));
 18        cx.set_global(GlobalThreadStore(thread_store));
 19    }
 20
 21    pub fn global(cx: &App) -> Entity<Self> {
 22        cx.global::<GlobalThreadStore>().0.clone()
 23    }
 24
 25    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 26        cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
 27    }
 28
 29    pub fn new(cx: &mut Context<Self>) -> Self {
 30        let this = Self {
 31            threads: Vec::new(),
 32        };
 33        this.reload(cx);
 34        this
 35    }
 36
 37    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 38        self.threads.iter().find(|thread| &thread.id == session_id)
 39    }
 40
 41    pub fn load_thread(
 42        &mut self,
 43        id: acp::SessionId,
 44        cx: &mut Context<Self>,
 45    ) -> Task<Result<Option<DbThread>>> {
 46        let database_future = ThreadsDatabase::connect(cx);
 47        cx.background_spawn(async move {
 48            let database = database_future.await.map_err(|err| anyhow!(err))?;
 49            database.load_thread(id).await
 50        })
 51    }
 52
 53    pub fn save_thread(
 54        &mut self,
 55        id: acp::SessionId,
 56        thread: crate::DbThread,
 57        folder_paths: PathList,
 58        cx: &mut Context<Self>,
 59    ) -> Task<Result<()>> {
 60        let database_future = ThreadsDatabase::connect(cx);
 61        cx.spawn(async move |this, cx| {
 62            let database = database_future.await.map_err(|err| anyhow!(err))?;
 63            database.save_thread(id, thread, folder_paths).await?;
 64            this.update(cx, |this, cx| this.reload(cx))
 65        })
 66    }
 67
 68    pub fn delete_thread(
 69        &mut self,
 70        id: acp::SessionId,
 71        cx: &mut Context<Self>,
 72    ) -> Task<Result<()>> {
 73        let database_future = ThreadsDatabase::connect(cx);
 74        cx.spawn(async move |this, cx| {
 75            let database = database_future.await.map_err(|err| anyhow!(err))?;
 76            database.delete_thread(id.clone()).await?;
 77            this.update(cx, |this, cx| this.reload(cx))
 78        })
 79    }
 80
 81    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 82        let database_future = ThreadsDatabase::connect(cx);
 83        cx.spawn(async move |this, cx| {
 84            let database = database_future.await.map_err(|err| anyhow!(err))?;
 85            database.delete_threads().await?;
 86            this.update(cx, |this, cx| this.reload(cx))
 87        })
 88    }
 89
 90    pub fn reload(&self, cx: &mut Context<Self>) {
 91        let database_connection = ThreadsDatabase::connect(cx);
 92        cx.spawn(async move |this, cx| {
 93            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 94            let threads = database
 95                .list_threads()
 96                .await?
 97                .into_iter()
 98                .filter(|thread| thread.parent_session_id.is_none())
 99                .collect::<Vec<_>>();
100            this.update(cx, |this, cx| {
101                this.threads = threads;
102                cx.notify();
103            })
104        })
105        .detach_and_log_err(cx);
106    }
107
108    pub fn is_empty(&self) -> bool {
109        self.threads.is_empty()
110    }
111
112    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
113        self.threads.iter().cloned()
114    }
115
116    /// Returns threads whose folder_paths match the given paths exactly.
117    pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
118        self.threads
119            .iter()
120            .filter(move |thread| &thread.folder_paths == paths)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use chrono::{DateTime, TimeZone, Utc};
128    use collections::HashMap;
129    use gpui::TestAppContext;
130    use std::sync::Arc;
131
132    fn session_id(value: &str) -> acp::SessionId {
133        acp::SessionId::new(Arc::<str>::from(value))
134    }
135
136    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
137        DbThread {
138            title: title.to_string().into(),
139            messages: Vec::new(),
140            updated_at,
141            detailed_summary: None,
142            initial_project_snapshot: None,
143            cumulative_token_usage: Default::default(),
144            request_token_usage: HashMap::default(),
145            model: None,
146            profile: None,
147            imported: false,
148            subagent_context: None,
149            speed: None,
150            thinking_enabled: false,
151            thinking_effort: None,
152            draft_prompt: None,
153            ui_scroll_position: None,
154        }
155    }
156
157    #[gpui::test]
158    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
159        let thread_store = cx.new(|cx| ThreadStore::new(cx));
160        cx.run_until_parked();
161
162        let older_id = session_id("thread-a");
163        let newer_id = session_id("thread-b");
164
165        let older_thread = make_thread(
166            "Thread A",
167            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
168        );
169        let newer_thread = make_thread(
170            "Thread B",
171            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
172        );
173
174        let save_older = thread_store.update(cx, |store, cx| {
175            store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
176        });
177        save_older.await.unwrap();
178
179        let save_newer = thread_store.update(cx, |store, cx| {
180            store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
181        });
182        save_newer.await.unwrap();
183
184        cx.run_until_parked();
185
186        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
187        assert_eq!(entries.len(), 2);
188        assert_eq!(entries[0].id, newer_id);
189        assert_eq!(entries[1].id, older_id);
190    }
191
192    #[gpui::test]
193    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
194        let thread_store = cx.new(|cx| ThreadStore::new(cx));
195        cx.run_until_parked();
196
197        let thread_id = session_id("thread-a");
198        let thread = make_thread(
199            "Thread A",
200            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
201        );
202
203        let save_task = thread_store.update(cx, |store, cx| {
204            store.save_thread(thread_id, thread, PathList::default(), cx)
205        });
206        save_task.await.unwrap();
207
208        cx.run_until_parked();
209        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
210
211        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
212        delete_task.await.unwrap();
213        cx.run_until_parked();
214
215        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
216    }
217
218    #[gpui::test]
219    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
220        let thread_store = cx.new(|cx| ThreadStore::new(cx));
221        cx.run_until_parked();
222
223        let first_id = session_id("thread-a");
224        let second_id = session_id("thread-b");
225
226        let first_thread = make_thread(
227            "Thread A",
228            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
229        );
230        let second_thread = make_thread(
231            "Thread B",
232            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
233        );
234
235        let save_first = thread_store.update(cx, |store, cx| {
236            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
237        });
238        save_first.await.unwrap();
239        let save_second = thread_store.update(cx, |store, cx| {
240            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
241        });
242        save_second.await.unwrap();
243        cx.run_until_parked();
244
245        let delete_task =
246            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
247        delete_task.await.unwrap();
248        cx.run_until_parked();
249
250        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
251        assert_eq!(entries.len(), 1);
252        assert_eq!(entries[0].id, second_id);
253    }
254
255    #[gpui::test]
256    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
257        let thread_store = cx.new(|cx| ThreadStore::new(cx));
258        cx.run_until_parked();
259
260        let first_id = session_id("thread-a");
261        let second_id = session_id("thread-b");
262
263        let first_thread = make_thread(
264            "Thread A",
265            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
266        );
267        let second_thread = make_thread(
268            "Thread B",
269            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
270        );
271
272        let save_first = thread_store.update(cx, |store, cx| {
273            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
274        });
275        save_first.await.unwrap();
276        let save_second = thread_store.update(cx, |store, cx| {
277            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
278        });
279        save_second.await.unwrap();
280        cx.run_until_parked();
281
282        let updated_first = make_thread(
283            "Thread A",
284            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
285        );
286        let update_task = thread_store.update(cx, |store, cx| {
287            store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
288        });
289        update_task.await.unwrap();
290        cx.run_until_parked();
291
292        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
293        assert_eq!(entries.len(), 2);
294        assert_eq!(entries[0].id, first_id);
295        assert_eq!(entries[1].id, second_id);
296    }
297
298    #[gpui::test]
299    async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
300        let thread_store = cx.new(|cx| ThreadStore::new(cx));
301        cx.run_until_parked();
302
303        let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
304        let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
305
306        let thread_a = make_thread(
307            "Thread in A",
308            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
309        );
310        let thread_b = make_thread(
311            "Thread in B",
312            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
313        );
314        let thread_a_id = session_id("thread-a");
315        let thread_b_id = session_id("thread-b");
316
317        let save_a = thread_store.update(cx, |store, cx| {
318            store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
319        });
320        save_a.await.unwrap();
321
322        let save_b = thread_store.update(cx, |store, cx| {
323            store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
324        });
325        save_b.await.unwrap();
326
327        cx.run_until_parked();
328
329        thread_store.read_with(cx, |store, _cx| {
330            let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
331            assert_eq!(a_threads.len(), 1);
332            assert_eq!(a_threads[0].id, thread_a_id);
333
334            let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
335            assert_eq!(b_threads.len(), 1);
336            assert_eq!(b_threads[0].id, thread_b_id);
337
338            let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
339            let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
340            assert!(no_threads.is_empty());
341        });
342    }
343}