thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use util::path_list::PathList;
  6
  7struct GlobalThreadStore(Entity<ThreadStore>);
  8
  9impl Global for GlobalThreadStore {}
 10
 11pub struct ThreadStore {
 12    threads: Vec<DbThreadMetadata>,
 13}
 14
 15impl ThreadStore {
 16    pub fn init_global(cx: &mut App) {
 17        let thread_store = cx.new(|cx| Self::new(cx));
 18        cx.set_global(GlobalThreadStore(thread_store));
 19    }
 20
 21    pub fn global(cx: &App) -> Entity<Self> {
 22        cx.global::<GlobalThreadStore>().0.clone()
 23    }
 24
 25    pub fn new(cx: &mut Context<Self>) -> Self {
 26        let this = Self {
 27            threads: Vec::new(),
 28        };
 29        this.reload(cx);
 30        this
 31    }
 32
 33    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 34        self.threads.iter().find(|thread| &thread.id == session_id)
 35    }
 36
 37    pub fn load_thread(
 38        &mut self,
 39        id: acp::SessionId,
 40        cx: &mut Context<Self>,
 41    ) -> Task<Result<Option<DbThread>>> {
 42        let database_future = ThreadsDatabase::connect(cx);
 43        cx.background_spawn(async move {
 44            let database = database_future.await.map_err(|err| anyhow!(err))?;
 45            database.load_thread(id).await
 46        })
 47    }
 48
 49    pub fn save_thread(
 50        &mut self,
 51        id: acp::SessionId,
 52        thread: crate::DbThread,
 53        folder_paths: PathList,
 54        cx: &mut Context<Self>,
 55    ) -> Task<Result<()>> {
 56        let database_future = ThreadsDatabase::connect(cx);
 57        cx.spawn(async move |this, cx| {
 58            let database = database_future.await.map_err(|err| anyhow!(err))?;
 59            database.save_thread(id, thread, folder_paths).await?;
 60            this.update(cx, |this, cx| this.reload(cx))
 61        })
 62    }
 63
 64    pub fn delete_thread(
 65        &mut self,
 66        id: acp::SessionId,
 67        cx: &mut Context<Self>,
 68    ) -> Task<Result<()>> {
 69        let database_future = ThreadsDatabase::connect(cx);
 70        cx.spawn(async move |this, cx| {
 71            let database = database_future.await.map_err(|err| anyhow!(err))?;
 72            database.delete_thread(id.clone()).await?;
 73            this.update(cx, |this, cx| this.reload(cx))
 74        })
 75    }
 76
 77    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 78        let database_future = ThreadsDatabase::connect(cx);
 79        cx.spawn(async move |this, cx| {
 80            let database = database_future.await.map_err(|err| anyhow!(err))?;
 81            database.delete_threads().await?;
 82            this.update(cx, |this, cx| this.reload(cx))
 83        })
 84    }
 85
 86    pub fn reload(&self, cx: &mut Context<Self>) {
 87        let database_connection = ThreadsDatabase::connect(cx);
 88        cx.spawn(async move |this, cx| {
 89            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 90            let threads = database
 91                .list_threads()
 92                .await?
 93                .into_iter()
 94                .filter(|thread| thread.parent_session_id.is_none())
 95                .collect::<Vec<_>>();
 96            this.update(cx, |this, cx| {
 97                this.threads = threads;
 98                cx.notify();
 99            })
100        })
101        .detach_and_log_err(cx);
102    }
103
104    pub fn is_empty(&self) -> bool {
105        self.threads.is_empty()
106    }
107
108    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
109        self.threads.iter().cloned()
110    }
111
112    /// Returns threads whose folder_paths match the given paths exactly.
113    pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
114        self.threads
115            .iter()
116            .filter(move |thread| &thread.folder_paths == paths)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use chrono::{DateTime, TimeZone, Utc};
124    use collections::HashMap;
125    use gpui::TestAppContext;
126    use std::sync::Arc;
127
128    fn session_id(value: &str) -> acp::SessionId {
129        acp::SessionId::new(Arc::<str>::from(value))
130    }
131
132    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
133        DbThread {
134            title: title.to_string().into(),
135            messages: Vec::new(),
136            updated_at,
137            detailed_summary: None,
138            initial_project_snapshot: None,
139            cumulative_token_usage: Default::default(),
140            request_token_usage: HashMap::default(),
141            model: None,
142            profile: None,
143            imported: false,
144            subagent_context: None,
145            speed: None,
146            thinking_enabled: false,
147            thinking_effort: None,
148            draft_prompt: None,
149            ui_scroll_position: None,
150        }
151    }
152
153    #[gpui::test]
154    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
155        let thread_store = cx.new(|cx| ThreadStore::new(cx));
156        cx.run_until_parked();
157
158        let older_id = session_id("thread-a");
159        let newer_id = session_id("thread-b");
160
161        let older_thread = make_thread(
162            "Thread A",
163            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
164        );
165        let newer_thread = make_thread(
166            "Thread B",
167            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
168        );
169
170        let save_older = thread_store.update(cx, |store, cx| {
171            store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
172        });
173        save_older.await.unwrap();
174
175        let save_newer = thread_store.update(cx, |store, cx| {
176            store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
177        });
178        save_newer.await.unwrap();
179
180        cx.run_until_parked();
181
182        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
183        assert_eq!(entries.len(), 2);
184        assert_eq!(entries[0].id, newer_id);
185        assert_eq!(entries[1].id, older_id);
186    }
187
188    #[gpui::test]
189    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
190        let thread_store = cx.new(|cx| ThreadStore::new(cx));
191        cx.run_until_parked();
192
193        let thread_id = session_id("thread-a");
194        let thread = make_thread(
195            "Thread A",
196            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
197        );
198
199        let save_task = thread_store.update(cx, |store, cx| {
200            store.save_thread(thread_id, thread, PathList::default(), cx)
201        });
202        save_task.await.unwrap();
203
204        cx.run_until_parked();
205        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
206
207        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
208        delete_task.await.unwrap();
209        cx.run_until_parked();
210
211        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
212    }
213
214    #[gpui::test]
215    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
216        let thread_store = cx.new(|cx| ThreadStore::new(cx));
217        cx.run_until_parked();
218
219        let first_id = session_id("thread-a");
220        let second_id = session_id("thread-b");
221
222        let first_thread = make_thread(
223            "Thread A",
224            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
225        );
226        let second_thread = make_thread(
227            "Thread B",
228            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
229        );
230
231        let save_first = thread_store.update(cx, |store, cx| {
232            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
233        });
234        save_first.await.unwrap();
235        let save_second = thread_store.update(cx, |store, cx| {
236            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
237        });
238        save_second.await.unwrap();
239        cx.run_until_parked();
240
241        let delete_task =
242            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
243        delete_task.await.unwrap();
244        cx.run_until_parked();
245
246        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
247        assert_eq!(entries.len(), 1);
248        assert_eq!(entries[0].id, second_id);
249    }
250
251    #[gpui::test]
252    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
253        let thread_store = cx.new(|cx| ThreadStore::new(cx));
254        cx.run_until_parked();
255
256        let first_id = session_id("thread-a");
257        let second_id = session_id("thread-b");
258
259        let first_thread = make_thread(
260            "Thread A",
261            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
262        );
263        let second_thread = make_thread(
264            "Thread B",
265            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
266        );
267
268        let save_first = thread_store.update(cx, |store, cx| {
269            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
270        });
271        save_first.await.unwrap();
272        let save_second = thread_store.update(cx, |store, cx| {
273            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
274        });
275        save_second.await.unwrap();
276        cx.run_until_parked();
277
278        let updated_first = make_thread(
279            "Thread A",
280            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
281        );
282        let update_task = thread_store.update(cx, |store, cx| {
283            store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
284        });
285        update_task.await.unwrap();
286        cx.run_until_parked();
287
288        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
289        assert_eq!(entries.len(), 2);
290        assert_eq!(entries[0].id, first_id);
291        assert_eq!(entries[1].id, second_id);
292    }
293
294    #[gpui::test]
295    async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
296        let thread_store = cx.new(|cx| ThreadStore::new(cx));
297        cx.run_until_parked();
298
299        let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
300        let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
301
302        let thread_a = make_thread(
303            "Thread in A",
304            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
305        );
306        let thread_b = make_thread(
307            "Thread in B",
308            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
309        );
310        let thread_a_id = session_id("thread-a");
311        let thread_b_id = session_id("thread-b");
312
313        let save_a = thread_store.update(cx, |store, cx| {
314            store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
315        });
316        save_a.await.unwrap();
317
318        let save_b = thread_store.update(cx, |store, cx| {
319            store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
320        });
321        save_b.await.unwrap();
322
323        cx.run_until_parked();
324
325        thread_store.read_with(cx, |store, _cx| {
326            let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
327            assert_eq!(a_threads.len(), 1);
328            assert_eq!(a_threads[0].id, thread_a_id);
329
330            let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
331            assert_eq!(b_threads.len(), 1);
332            assert_eq!(b_threads[0].id, thread_b_id);
333
334            let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
335            let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
336            assert!(no_threads.is_empty());
337        });
338    }
339}