thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use std::collections::HashMap;
  6use util::path_list::PathList;
  7
  8struct GlobalThreadStore(Entity<ThreadStore>);
  9
 10impl Global for GlobalThreadStore {}
 11
 12pub struct ThreadStore {
 13    threads: Vec<DbThreadMetadata>,
 14    threads_by_paths: HashMap<PathList, Vec<usize>>,
 15}
 16
 17impl ThreadStore {
 18    pub fn init_global(cx: &mut App) {
 19        let thread_store = cx.new(|cx| Self::new(cx));
 20        cx.set_global(GlobalThreadStore(thread_store));
 21    }
 22
 23    pub fn global(cx: &App) -> Entity<Self> {
 24        cx.global::<GlobalThreadStore>().0.clone()
 25    }
 26
 27    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 28        cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
 29    }
 30
 31    pub fn new(cx: &mut Context<Self>) -> Self {
 32        let this = Self {
 33            threads: Vec::new(),
 34            threads_by_paths: HashMap::default(),
 35        };
 36        this.reload(cx);
 37        this
 38    }
 39
 40    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 41        self.threads.iter().find(|thread| &thread.id == session_id)
 42    }
 43
 44    pub fn load_thread(
 45        &mut self,
 46        id: acp::SessionId,
 47        cx: &mut Context<Self>,
 48    ) -> Task<Result<Option<DbThread>>> {
 49        let database_future = ThreadsDatabase::connect(cx);
 50        cx.background_spawn(async move {
 51            let database = database_future.await.map_err(|err| anyhow!(err))?;
 52            database.load_thread(id).await
 53        })
 54    }
 55
 56    pub fn save_thread(
 57        &mut self,
 58        id: acp::SessionId,
 59        thread: crate::DbThread,
 60        folder_paths: PathList,
 61        cx: &mut Context<Self>,
 62    ) -> Task<Result<()>> {
 63        let database_future = ThreadsDatabase::connect(cx);
 64        cx.spawn(async move |this, cx| {
 65            let database = database_future.await.map_err(|err| anyhow!(err))?;
 66            database.save_thread(id, thread, folder_paths).await?;
 67            this.update(cx, |this, cx| this.reload(cx))
 68        })
 69    }
 70
 71    pub fn delete_thread(
 72        &mut self,
 73        id: acp::SessionId,
 74        cx: &mut Context<Self>,
 75    ) -> Task<Result<()>> {
 76        let database_future = ThreadsDatabase::connect(cx);
 77        cx.spawn(async move |this, cx| {
 78            let database = database_future.await.map_err(|err| anyhow!(err))?;
 79            database.delete_thread(id.clone()).await?;
 80            this.update(cx, |this, cx| this.reload(cx))
 81        })
 82    }
 83
 84    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 85        let database_future = ThreadsDatabase::connect(cx);
 86        cx.spawn(async move |this, cx| {
 87            let database = database_future.await.map_err(|err| anyhow!(err))?;
 88            database.delete_threads().await?;
 89            this.update(cx, |this, cx| this.reload(cx))
 90        })
 91    }
 92
 93    pub fn reload(&self, cx: &mut Context<Self>) {
 94        let database_connection = ThreadsDatabase::connect(cx);
 95        cx.spawn(async move |this, cx| {
 96            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 97            let all_threads = database.list_threads().await?;
 98            this.update(cx, |this, cx| {
 99                this.threads.clear();
100                this.threads_by_paths.clear();
101                for thread in all_threads {
102                    if thread.parent_session_id.is_some() {
103                        continue;
104                    }
105                    let index = this.threads.len();
106                    this.threads_by_paths
107                        .entry(thread.folder_paths.clone())
108                        .or_default()
109                        .push(index);
110                    this.threads.push(thread);
111                }
112                cx.notify();
113            })
114        })
115        .detach_and_log_err(cx);
116    }
117
118    pub fn is_empty(&self) -> bool {
119        self.threads.is_empty()
120    }
121
122    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
123        self.threads.iter().cloned()
124    }
125
126    /// Returns threads whose folder_paths match the given paths exactly.
127    /// Uses a cached index for O(1) lookup per path list.
128    pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
129        self.threads_by_paths
130            .get(paths)
131            .into_iter()
132            .flat_map(|indices| indices.iter().map(|&index| &self.threads[index]))
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use chrono::{DateTime, TimeZone, Utc};
140    use collections::HashMap;
141    use gpui::TestAppContext;
142    use std::sync::Arc;
143
144    fn session_id(value: &str) -> acp::SessionId {
145        acp::SessionId::new(Arc::<str>::from(value))
146    }
147
148    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
149        DbThread {
150            title: title.to_string().into(),
151            messages: Vec::new(),
152            updated_at,
153            detailed_summary: None,
154            initial_project_snapshot: None,
155            cumulative_token_usage: Default::default(),
156            request_token_usage: HashMap::default(),
157            model: None,
158            profile: None,
159            imported: false,
160            subagent_context: None,
161            speed: None,
162            thinking_enabled: false,
163            thinking_effort: None,
164            draft_prompt: None,
165            provisional_title: None,
166            ui_scroll_position: None,
167        }
168    }
169
170    #[gpui::test]
171    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
172        let thread_store = cx.new(|cx| ThreadStore::new(cx));
173        cx.run_until_parked();
174
175        let older_id = session_id("thread-a");
176        let newer_id = session_id("thread-b");
177
178        let older_thread = make_thread(
179            "Thread A",
180            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
181        );
182        let newer_thread = make_thread(
183            "Thread B",
184            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
185        );
186
187        let save_older = thread_store.update(cx, |store, cx| {
188            store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
189        });
190        save_older.await.unwrap();
191
192        let save_newer = thread_store.update(cx, |store, cx| {
193            store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
194        });
195        save_newer.await.unwrap();
196
197        cx.run_until_parked();
198
199        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
200        assert_eq!(entries.len(), 2);
201        assert_eq!(entries[0].id, newer_id);
202        assert_eq!(entries[1].id, older_id);
203    }
204
205    #[gpui::test]
206    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
207        let thread_store = cx.new(|cx| ThreadStore::new(cx));
208        cx.run_until_parked();
209
210        let thread_id = session_id("thread-a");
211        let thread = make_thread(
212            "Thread A",
213            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
214        );
215
216        let save_task = thread_store.update(cx, |store, cx| {
217            store.save_thread(thread_id, thread, PathList::default(), cx)
218        });
219        save_task.await.unwrap();
220
221        cx.run_until_parked();
222        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
223
224        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
225        delete_task.await.unwrap();
226        cx.run_until_parked();
227
228        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
229    }
230
231    #[gpui::test]
232    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
233        let thread_store = cx.new(|cx| ThreadStore::new(cx));
234        cx.run_until_parked();
235
236        let first_id = session_id("thread-a");
237        let second_id = session_id("thread-b");
238
239        let first_thread = make_thread(
240            "Thread A",
241            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
242        );
243        let second_thread = make_thread(
244            "Thread B",
245            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
246        );
247
248        let save_first = thread_store.update(cx, |store, cx| {
249            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
250        });
251        save_first.await.unwrap();
252        let save_second = thread_store.update(cx, |store, cx| {
253            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
254        });
255        save_second.await.unwrap();
256        cx.run_until_parked();
257
258        let delete_task =
259            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
260        delete_task.await.unwrap();
261        cx.run_until_parked();
262
263        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
264        assert_eq!(entries.len(), 1);
265        assert_eq!(entries[0].id, second_id);
266    }
267
268    #[gpui::test]
269    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
270        let thread_store = cx.new(|cx| ThreadStore::new(cx));
271        cx.run_until_parked();
272
273        let first_id = session_id("thread-a");
274        let second_id = session_id("thread-b");
275
276        let first_thread = make_thread(
277            "Thread A",
278            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
279        );
280        let second_thread = make_thread(
281            "Thread B",
282            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
283        );
284
285        let save_first = thread_store.update(cx, |store, cx| {
286            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
287        });
288        save_first.await.unwrap();
289        let save_second = thread_store.update(cx, |store, cx| {
290            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
291        });
292        save_second.await.unwrap();
293        cx.run_until_parked();
294
295        let updated_first = make_thread(
296            "Thread A",
297            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
298        );
299        let update_task = thread_store.update(cx, |store, cx| {
300            store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
301        });
302        update_task.await.unwrap();
303        cx.run_until_parked();
304
305        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
306        assert_eq!(entries.len(), 2);
307        assert_eq!(entries[0].id, first_id);
308        assert_eq!(entries[1].id, second_id);
309    }
310
311    #[gpui::test]
312    async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
313        let thread_store = cx.new(|cx| ThreadStore::new(cx));
314        cx.run_until_parked();
315
316        let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
317        let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
318
319        let thread_a = make_thread(
320            "Thread in A",
321            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
322        );
323        let thread_b = make_thread(
324            "Thread in B",
325            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
326        );
327        let thread_a_id = session_id("thread-a");
328        let thread_b_id = session_id("thread-b");
329
330        let save_a = thread_store.update(cx, |store, cx| {
331            store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
332        });
333        save_a.await.unwrap();
334
335        let save_b = thread_store.update(cx, |store, cx| {
336            store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
337        });
338        save_b.await.unwrap();
339
340        cx.run_until_parked();
341
342        thread_store.read_with(cx, |store, _cx| {
343            let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
344            assert_eq!(a_threads.len(), 1);
345            assert_eq!(a_threads[0].id, thread_a_id);
346
347            let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
348            assert_eq!(b_threads.len(), 1);
349            assert_eq!(b_threads[0].id, thread_b_id);
350
351            let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
352            let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
353            assert!(no_threads.is_empty());
354        });
355    }
356}