thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use std::collections::HashMap;
  6use util::path_list::PathList;
  7
  8struct GlobalThreadStore(Entity<ThreadStore>);
  9
 10impl Global for GlobalThreadStore {}
 11
 12pub struct ThreadStore {
 13    threads: Vec<DbThreadMetadata>,
 14    threads_by_paths: HashMap<PathList, Vec<usize>>,
 15}
 16
 17impl ThreadStore {
 18    pub fn init_global(cx: &mut App) {
 19        let thread_store = cx.new(|cx| Self::new(cx));
 20        cx.set_global(GlobalThreadStore(thread_store));
 21    }
 22
 23    pub fn global(cx: &App) -> Entity<Self> {
 24        cx.global::<GlobalThreadStore>().0.clone()
 25    }
 26
 27    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 28        cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
 29    }
 30
 31    pub fn new(cx: &mut Context<Self>) -> Self {
 32        let this = Self {
 33            threads: Vec::new(),
 34            threads_by_paths: HashMap::default(),
 35        };
 36        this.reload(cx);
 37        this
 38    }
 39
 40    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 41        self.threads.iter().find(|thread| &thread.id == session_id)
 42    }
 43
 44    pub fn load_thread(
 45        &mut self,
 46        id: acp::SessionId,
 47        cx: &mut Context<Self>,
 48    ) -> Task<Result<Option<DbThread>>> {
 49        let database_future = ThreadsDatabase::connect(cx);
 50        cx.background_spawn(async move {
 51            let database = database_future.await.map_err(|err| anyhow!(err))?;
 52            database.load_thread(id).await
 53        })
 54    }
 55
 56    pub fn save_thread(
 57        &mut self,
 58        id: acp::SessionId,
 59        thread: crate::DbThread,
 60        folder_paths: PathList,
 61        cx: &mut Context<Self>,
 62    ) -> Task<Result<()>> {
 63        let database_future = ThreadsDatabase::connect(cx);
 64        cx.spawn(async move |this, cx| {
 65            let database = database_future.await.map_err(|err| anyhow!(err))?;
 66            database.save_thread(id, thread, folder_paths).await?;
 67            this.update(cx, |this, cx| this.reload(cx))
 68        })
 69    }
 70
 71    pub fn delete_thread(
 72        &mut self,
 73        id: acp::SessionId,
 74        cx: &mut Context<Self>,
 75    ) -> Task<Result<()>> {
 76        let database_future = ThreadsDatabase::connect(cx);
 77        cx.spawn(async move |this, cx| {
 78            let database = database_future.await.map_err(|err| anyhow!(err))?;
 79            database.delete_thread(id.clone()).await?;
 80            this.update(cx, |this, cx| this.reload(cx))
 81        })
 82    }
 83
 84    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 85        let database_future = ThreadsDatabase::connect(cx);
 86        cx.spawn(async move |this, cx| {
 87            let database = database_future.await.map_err(|err| anyhow!(err))?;
 88            database.delete_threads().await?;
 89            this.update(cx, |this, cx| this.reload(cx))
 90        })
 91    }
 92
 93    pub fn reload(&self, cx: &mut Context<Self>) {
 94        let database_connection = ThreadsDatabase::connect(cx);
 95        cx.spawn(async move |this, cx| {
 96            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 97            let all_threads = database.list_threads().await?;
 98            this.update(cx, |this, cx| {
 99                this.threads.clear();
100                this.threads_by_paths.clear();
101                for thread in all_threads {
102                    if thread.parent_session_id.is_some() {
103                        continue;
104                    }
105                    let index = this.threads.len();
106                    this.threads_by_paths
107                        .entry(thread.folder_paths.clone())
108                        .or_default()
109                        .push(index);
110                    this.threads.push(thread);
111                }
112                cx.notify();
113            })
114        })
115        .detach_and_log_err(cx);
116    }
117
118    pub fn is_empty(&self) -> bool {
119        self.threads.is_empty()
120    }
121
122    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
123        self.threads.iter().cloned()
124    }
125
126    /// Returns threads whose folder_paths match the given paths exactly.
127    /// Uses a cached index for O(1) lookup per path list.
128    pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
129        self.threads_by_paths
130            .get(paths)
131            .into_iter()
132            .flat_map(|indices| indices.iter().map(|&index| &self.threads[index]))
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use chrono::{DateTime, TimeZone, Utc};
140    use collections::HashMap;
141    use gpui::TestAppContext;
142    use std::sync::Arc;
143
144    fn session_id(value: &str) -> acp::SessionId {
145        acp::SessionId::new(Arc::<str>::from(value))
146    }
147
148    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
149        DbThread {
150            title: title.to_string().into(),
151            messages: Vec::new(),
152            updated_at,
153            detailed_summary: None,
154            initial_project_snapshot: None,
155            cumulative_token_usage: Default::default(),
156            request_token_usage: HashMap::default(),
157            model: None,
158            profile: None,
159            imported: false,
160            subagent_context: None,
161            speed: None,
162            thinking_enabled: false,
163            thinking_effort: None,
164            draft_prompt: None,
165            ui_scroll_position: None,
166        }
167    }
168
169    #[gpui::test]
170    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
171        let thread_store = cx.new(|cx| ThreadStore::new(cx));
172        cx.run_until_parked();
173
174        let older_id = session_id("thread-a");
175        let newer_id = session_id("thread-b");
176
177        let older_thread = make_thread(
178            "Thread A",
179            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
180        );
181        let newer_thread = make_thread(
182            "Thread B",
183            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
184        );
185
186        let save_older = thread_store.update(cx, |store, cx| {
187            store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
188        });
189        save_older.await.unwrap();
190
191        let save_newer = thread_store.update(cx, |store, cx| {
192            store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
193        });
194        save_newer.await.unwrap();
195
196        cx.run_until_parked();
197
198        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
199        assert_eq!(entries.len(), 2);
200        assert_eq!(entries[0].id, newer_id);
201        assert_eq!(entries[1].id, older_id);
202    }
203
204    #[gpui::test]
205    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
206        let thread_store = cx.new(|cx| ThreadStore::new(cx));
207        cx.run_until_parked();
208
209        let thread_id = session_id("thread-a");
210        let thread = make_thread(
211            "Thread A",
212            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
213        );
214
215        let save_task = thread_store.update(cx, |store, cx| {
216            store.save_thread(thread_id, thread, PathList::default(), cx)
217        });
218        save_task.await.unwrap();
219
220        cx.run_until_parked();
221        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
222
223        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
224        delete_task.await.unwrap();
225        cx.run_until_parked();
226
227        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
228    }
229
230    #[gpui::test]
231    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
232        let thread_store = cx.new(|cx| ThreadStore::new(cx));
233        cx.run_until_parked();
234
235        let first_id = session_id("thread-a");
236        let second_id = session_id("thread-b");
237
238        let first_thread = make_thread(
239            "Thread A",
240            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
241        );
242        let second_thread = make_thread(
243            "Thread B",
244            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
245        );
246
247        let save_first = thread_store.update(cx, |store, cx| {
248            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
249        });
250        save_first.await.unwrap();
251        let save_second = thread_store.update(cx, |store, cx| {
252            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
253        });
254        save_second.await.unwrap();
255        cx.run_until_parked();
256
257        let delete_task =
258            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
259        delete_task.await.unwrap();
260        cx.run_until_parked();
261
262        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
263        assert_eq!(entries.len(), 1);
264        assert_eq!(entries[0].id, second_id);
265    }
266
267    #[gpui::test]
268    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
269        let thread_store = cx.new(|cx| ThreadStore::new(cx));
270        cx.run_until_parked();
271
272        let first_id = session_id("thread-a");
273        let second_id = session_id("thread-b");
274
275        let first_thread = make_thread(
276            "Thread A",
277            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
278        );
279        let second_thread = make_thread(
280            "Thread B",
281            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
282        );
283
284        let save_first = thread_store.update(cx, |store, cx| {
285            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
286        });
287        save_first.await.unwrap();
288        let save_second = thread_store.update(cx, |store, cx| {
289            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
290        });
291        save_second.await.unwrap();
292        cx.run_until_parked();
293
294        let updated_first = make_thread(
295            "Thread A",
296            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
297        );
298        let update_task = thread_store.update(cx, |store, cx| {
299            store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
300        });
301        update_task.await.unwrap();
302        cx.run_until_parked();
303
304        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
305        assert_eq!(entries.len(), 2);
306        assert_eq!(entries[0].id, first_id);
307        assert_eq!(entries[1].id, second_id);
308    }
309
310    #[gpui::test]
311    async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
312        let thread_store = cx.new(|cx| ThreadStore::new(cx));
313        cx.run_until_parked();
314
315        let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
316        let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
317
318        let thread_a = make_thread(
319            "Thread in A",
320            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
321        );
322        let thread_b = make_thread(
323            "Thread in B",
324            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
325        );
326        let thread_a_id = session_id("thread-a");
327        let thread_b_id = session_id("thread-b");
328
329        let save_a = thread_store.update(cx, |store, cx| {
330            store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
331        });
332        save_a.await.unwrap();
333
334        let save_b = thread_store.update(cx, |store, cx| {
335            store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
336        });
337        save_b.await.unwrap();
338
339        cx.run_until_parked();
340
341        thread_store.read_with(cx, |store, _cx| {
342            let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
343            assert_eq!(a_threads.len(), 1);
344            assert_eq!(a_threads[0].id, thread_a_id);
345
346            let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
347            assert_eq!(b_threads.len(), 1);
348            assert_eq!(b_threads[0].id, thread_b_id);
349
350            let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
351            let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
352            assert!(no_threads.is_empty());
353        });
354    }
355}