thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use util::path_list::PathList;
  6
  7struct GlobalThreadStore(Entity<ThreadStore>);
  8
  9impl Global for GlobalThreadStore {}
 10
 11pub struct ThreadStore {
 12    threads: Vec<DbThreadMetadata>,
 13}
 14
 15impl ThreadStore {
 16    pub fn init_global(cx: &mut App) {
 17        let thread_store = cx.new(|cx| Self::new(cx));
 18        cx.set_global(GlobalThreadStore(thread_store));
 19    }
 20
 21    pub fn global(cx: &App) -> Entity<Self> {
 22        cx.global::<GlobalThreadStore>().0.clone()
 23    }
 24
 25    pub fn new(cx: &mut Context<Self>) -> Self {
 26        let this = Self {
 27            threads: Vec::new(),
 28        };
 29        this.reload(cx);
 30        this
 31    }
 32
 33    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 34        self.threads.iter().find(|thread| &thread.id == session_id)
 35    }
 36
 37    pub fn load_thread(
 38        &mut self,
 39        id: acp::SessionId,
 40        cx: &mut Context<Self>,
 41    ) -> Task<Result<Option<DbThread>>> {
 42        let database_future = ThreadsDatabase::connect(cx);
 43        cx.background_spawn(async move {
 44            let database = database_future.await.map_err(|err| anyhow!(err))?;
 45            database.load_thread(id).await
 46        })
 47    }
 48
 49    pub fn save_thread(
 50        &mut self,
 51        id: acp::SessionId,
 52        thread: crate::DbThread,
 53        folder_paths: PathList,
 54        cx: &mut Context<Self>,
 55    ) -> Task<Result<()>> {
 56        let database_future = ThreadsDatabase::connect(cx);
 57        cx.spawn(async move |this, cx| {
 58            let database = database_future.await.map_err(|err| anyhow!(err))?;
 59            database.save_thread(id, thread, folder_paths).await?;
 60            this.update(cx, |this, cx| this.reload(cx))
 61        })
 62    }
 63
 64    pub fn delete_thread(
 65        &mut self,
 66        id: acp::SessionId,
 67        cx: &mut Context<Self>,
 68    ) -> Task<Result<()>> {
 69        let database_future = ThreadsDatabase::connect(cx);
 70        cx.spawn(async move |this, cx| {
 71            let database = database_future.await.map_err(|err| anyhow!(err))?;
 72            database.delete_thread(id.clone()).await?;
 73            this.update(cx, |this, cx| this.reload(cx))
 74        })
 75    }
 76
 77    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 78        let database_future = ThreadsDatabase::connect(cx);
 79        cx.spawn(async move |this, cx| {
 80            let database = database_future.await.map_err(|err| anyhow!(err))?;
 81            database.delete_threads().await?;
 82            this.update(cx, |this, cx| this.reload(cx))
 83        })
 84    }
 85
 86    pub fn reload(&self, cx: &mut Context<Self>) {
 87        let database_connection = ThreadsDatabase::connect(cx);
 88        cx.spawn(async move |this, cx| {
 89            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 90            let threads = database
 91                .list_threads()
 92                .await?
 93                .into_iter()
 94                .filter(|thread| thread.parent_session_id.is_none())
 95                .collect::<Vec<_>>();
 96            this.update(cx, |this, cx| {
 97                this.threads = threads;
 98                cx.notify();
 99            })
100        })
101        .detach_and_log_err(cx);
102    }
103
104    pub fn is_empty(&self) -> bool {
105        self.threads.is_empty()
106    }
107
108    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
109        self.threads.iter().cloned()
110    }
111
112    /// Returns threads whose folder_paths match the given paths exactly.
113    pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
114        self.threads
115            .iter()
116            .filter(move |thread| &thread.folder_paths == paths)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use chrono::{DateTime, TimeZone, Utc};
124    use collections::HashMap;
125    use gpui::TestAppContext;
126    use std::sync::Arc;
127
128    fn session_id(value: &str) -> acp::SessionId {
129        acp::SessionId::new(Arc::<str>::from(value))
130    }
131
132    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
133        DbThread {
134            title: title.to_string().into(),
135            messages: Vec::new(),
136            updated_at,
137            detailed_summary: None,
138            initial_project_snapshot: None,
139            cumulative_token_usage: Default::default(),
140            request_token_usage: HashMap::default(),
141            model: None,
142            profile: None,
143            imported: false,
144            subagent_context: None,
145            speed: None,
146            thinking_enabled: false,
147            thinking_effort: None,
148        }
149    }
150
151    #[gpui::test]
152    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
153        let thread_store = cx.new(|cx| ThreadStore::new(cx));
154        cx.run_until_parked();
155
156        let older_id = session_id("thread-a");
157        let newer_id = session_id("thread-b");
158
159        let older_thread = make_thread(
160            "Thread A",
161            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
162        );
163        let newer_thread = make_thread(
164            "Thread B",
165            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
166        );
167
168        let save_older = thread_store.update(cx, |store, cx| {
169            store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
170        });
171        save_older.await.unwrap();
172
173        let save_newer = thread_store.update(cx, |store, cx| {
174            store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
175        });
176        save_newer.await.unwrap();
177
178        cx.run_until_parked();
179
180        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
181        assert_eq!(entries.len(), 2);
182        assert_eq!(entries[0].id, newer_id);
183        assert_eq!(entries[1].id, older_id);
184    }
185
186    #[gpui::test]
187    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
188        let thread_store = cx.new(|cx| ThreadStore::new(cx));
189        cx.run_until_parked();
190
191        let thread_id = session_id("thread-a");
192        let thread = make_thread(
193            "Thread A",
194            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
195        );
196
197        let save_task = thread_store.update(cx, |store, cx| {
198            store.save_thread(thread_id, thread, PathList::default(), cx)
199        });
200        save_task.await.unwrap();
201
202        cx.run_until_parked();
203        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
204
205        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
206        delete_task.await.unwrap();
207        cx.run_until_parked();
208
209        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
210    }
211
212    #[gpui::test]
213    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
214        let thread_store = cx.new(|cx| ThreadStore::new(cx));
215        cx.run_until_parked();
216
217        let first_id = session_id("thread-a");
218        let second_id = session_id("thread-b");
219
220        let first_thread = make_thread(
221            "Thread A",
222            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
223        );
224        let second_thread = make_thread(
225            "Thread B",
226            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
227        );
228
229        let save_first = thread_store.update(cx, |store, cx| {
230            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
231        });
232        save_first.await.unwrap();
233        let save_second = thread_store.update(cx, |store, cx| {
234            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
235        });
236        save_second.await.unwrap();
237        cx.run_until_parked();
238
239        let delete_task =
240            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
241        delete_task.await.unwrap();
242        cx.run_until_parked();
243
244        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
245        assert_eq!(entries.len(), 1);
246        assert_eq!(entries[0].id, second_id);
247    }
248
249    #[gpui::test]
250    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
251        let thread_store = cx.new(|cx| ThreadStore::new(cx));
252        cx.run_until_parked();
253
254        let first_id = session_id("thread-a");
255        let second_id = session_id("thread-b");
256
257        let first_thread = make_thread(
258            "Thread A",
259            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
260        );
261        let second_thread = make_thread(
262            "Thread B",
263            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
264        );
265
266        let save_first = thread_store.update(cx, |store, cx| {
267            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
268        });
269        save_first.await.unwrap();
270        let save_second = thread_store.update(cx, |store, cx| {
271            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
272        });
273        save_second.await.unwrap();
274        cx.run_until_parked();
275
276        let updated_first = make_thread(
277            "Thread A",
278            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
279        );
280        let update_task = thread_store.update(cx, |store, cx| {
281            store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
282        });
283        update_task.await.unwrap();
284        cx.run_until_parked();
285
286        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
287        assert_eq!(entries.len(), 2);
288        assert_eq!(entries[0].id, first_id);
289        assert_eq!(entries[1].id, second_id);
290    }
291
292    #[gpui::test]
293    async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
294        let thread_store = cx.new(|cx| ThreadStore::new(cx));
295        cx.run_until_parked();
296
297        let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
298        let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
299
300        let thread_a = make_thread(
301            "Thread in A",
302            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
303        );
304        let thread_b = make_thread(
305            "Thread in B",
306            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
307        );
308        let thread_a_id = session_id("thread-a");
309        let thread_b_id = session_id("thread-b");
310
311        let save_a = thread_store.update(cx, |store, cx| {
312            store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
313        });
314        save_a.await.unwrap();
315
316        let save_b = thread_store.update(cx, |store, cx| {
317            store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
318        });
319        save_b.await.unwrap();
320
321        cx.run_until_parked();
322
323        thread_store.read_with(cx, |store, _cx| {
324            let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
325            assert_eq!(a_threads.len(), 1);
326            assert_eq!(a_threads[0].id, thread_a_id);
327
328            let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
329            assert_eq!(b_threads.len(), 1);
330            assert_eq!(b_threads[0].id, thread_b_id);
331
332            let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
333            let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
334            assert!(no_threads.is_empty());
335        });
336    }
337}