thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5
  6struct GlobalThreadStore(Entity<ThreadStore>);
  7
  8impl Global for GlobalThreadStore {}
  9
 10pub struct ThreadStore {
 11    threads: Vec<DbThreadMetadata>,
 12}
 13
 14impl ThreadStore {
 15    pub fn init_global(cx: &mut App) {
 16        let thread_store = cx.new(|cx| Self::new(cx));
 17        cx.set_global(GlobalThreadStore(thread_store));
 18    }
 19
 20    pub fn global(cx: &App) -> Entity<Self> {
 21        cx.global::<GlobalThreadStore>().0.clone()
 22    }
 23
 24    pub fn new(cx: &mut Context<Self>) -> Self {
 25        let this = Self {
 26            threads: Vec::new(),
 27        };
 28        this.reload(cx);
 29        this
 30    }
 31
 32    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 33        self.threads.iter().find(|thread| &thread.id == session_id)
 34    }
 35
 36    pub fn load_thread(
 37        &mut self,
 38        id: acp::SessionId,
 39        cx: &mut Context<Self>,
 40    ) -> Task<Result<Option<DbThread>>> {
 41        let database_future = ThreadsDatabase::connect(cx);
 42        cx.background_spawn(async move {
 43            let database = database_future.await.map_err(|err| anyhow!(err))?;
 44            database.load_thread(id).await
 45        })
 46    }
 47
 48    pub fn save_thread(
 49        &mut self,
 50        id: acp::SessionId,
 51        thread: crate::DbThread,
 52        cx: &mut Context<Self>,
 53    ) -> Task<Result<()>> {
 54        let database_future = ThreadsDatabase::connect(cx);
 55        cx.spawn(async move |this, cx| {
 56            let database = database_future.await.map_err(|err| anyhow!(err))?;
 57            database.save_thread(id, thread).await?;
 58            this.update(cx, |this, cx| this.reload(cx))
 59        })
 60    }
 61
 62    pub fn delete_thread(
 63        &mut self,
 64        id: acp::SessionId,
 65        cx: &mut Context<Self>,
 66    ) -> Task<Result<()>> {
 67        let database_future = ThreadsDatabase::connect(cx);
 68        cx.spawn(async move |this, cx| {
 69            let database = database_future.await.map_err(|err| anyhow!(err))?;
 70            database.delete_thread(id.clone()).await?;
 71            this.update(cx, |this, cx| this.reload(cx))
 72        })
 73    }
 74
 75    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 76        let database_future = ThreadsDatabase::connect(cx);
 77        cx.spawn(async move |this, cx| {
 78            let database = database_future.await.map_err(|err| anyhow!(err))?;
 79            database.delete_threads().await?;
 80            this.update(cx, |this, cx| this.reload(cx))
 81        })
 82    }
 83
 84    pub fn reload(&self, cx: &mut Context<Self>) {
 85        let database_connection = ThreadsDatabase::connect(cx);
 86        cx.spawn(async move |this, cx| {
 87            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 88            let threads = database
 89                .list_threads()
 90                .await?
 91                .into_iter()
 92                .filter(|thread| thread.parent_session_id.is_none())
 93                .collect::<Vec<_>>();
 94            this.update(cx, |this, cx| {
 95                this.threads = threads;
 96                cx.notify();
 97            })
 98        })
 99        .detach_and_log_err(cx);
100    }
101
102    pub fn is_empty(&self) -> bool {
103        self.threads.is_empty()
104    }
105
106    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
107        self.threads.iter().cloned()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use chrono::{DateTime, TimeZone, Utc};
115    use collections::HashMap;
116    use gpui::TestAppContext;
117    use std::sync::Arc;
118
119    fn session_id(value: &str) -> acp::SessionId {
120        acp::SessionId::new(Arc::<str>::from(value))
121    }
122
123    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
124        DbThread {
125            title: title.to_string().into(),
126            messages: Vec::new(),
127            updated_at,
128            detailed_summary: None,
129            initial_project_snapshot: None,
130            cumulative_token_usage: Default::default(),
131            request_token_usage: HashMap::default(),
132            model: None,
133            profile: None,
134            imported: false,
135            subagent_context: None,
136            git_worktree_info: None,
137        }
138    }
139
140    #[gpui::test]
141    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
142        let thread_store = cx.new(|cx| ThreadStore::new(cx));
143        cx.run_until_parked();
144
145        let older_id = session_id("thread-a");
146        let newer_id = session_id("thread-b");
147
148        let older_thread = make_thread(
149            "Thread A",
150            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
151        );
152        let newer_thread = make_thread(
153            "Thread B",
154            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
155        );
156
157        let save_older = thread_store.update(cx, |store, cx| {
158            store.save_thread(older_id.clone(), older_thread, cx)
159        });
160        save_older.await.unwrap();
161
162        let save_newer = thread_store.update(cx, |store, cx| {
163            store.save_thread(newer_id.clone(), newer_thread, cx)
164        });
165        save_newer.await.unwrap();
166
167        cx.run_until_parked();
168
169        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
170        assert_eq!(entries.len(), 2);
171        assert_eq!(entries[0].id, newer_id);
172        assert_eq!(entries[1].id, older_id);
173    }
174
175    #[gpui::test]
176    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
177        let thread_store = cx.new(|cx| ThreadStore::new(cx));
178        cx.run_until_parked();
179
180        let thread_id = session_id("thread-a");
181        let thread = make_thread(
182            "Thread A",
183            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
184        );
185
186        let save_task =
187            thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
188        save_task.await.unwrap();
189
190        cx.run_until_parked();
191        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
192
193        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
194        delete_task.await.unwrap();
195        cx.run_until_parked();
196
197        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
198    }
199
200    #[gpui::test]
201    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
202        let thread_store = cx.new(|cx| ThreadStore::new(cx));
203        cx.run_until_parked();
204
205        let first_id = session_id("thread-a");
206        let second_id = session_id("thread-b");
207
208        let first_thread = make_thread(
209            "Thread A",
210            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
211        );
212        let second_thread = make_thread(
213            "Thread B",
214            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
215        );
216
217        let save_first = thread_store.update(cx, |store, cx| {
218            store.save_thread(first_id.clone(), first_thread, cx)
219        });
220        save_first.await.unwrap();
221        let save_second = thread_store.update(cx, |store, cx| {
222            store.save_thread(second_id.clone(), second_thread, cx)
223        });
224        save_second.await.unwrap();
225        cx.run_until_parked();
226
227        let delete_task =
228            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
229        delete_task.await.unwrap();
230        cx.run_until_parked();
231
232        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
233        assert_eq!(entries.len(), 1);
234        assert_eq!(entries[0].id, second_id);
235    }
236
237    #[gpui::test]
238    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
239        let thread_store = cx.new(|cx| ThreadStore::new(cx));
240        cx.run_until_parked();
241
242        let first_id = session_id("thread-a");
243        let second_id = session_id("thread-b");
244
245        let first_thread = make_thread(
246            "Thread A",
247            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
248        );
249        let second_thread = make_thread(
250            "Thread B",
251            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
252        );
253
254        let save_first = thread_store.update(cx, |store, cx| {
255            store.save_thread(first_id.clone(), first_thread, cx)
256        });
257        save_first.await.unwrap();
258        let save_second = thread_store.update(cx, |store, cx| {
259            store.save_thread(second_id.clone(), second_thread, cx)
260        });
261        save_second.await.unwrap();
262        cx.run_until_parked();
263
264        let updated_first = make_thread(
265            "Thread A",
266            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
267        );
268        let update_task = thread_store.update(cx, |store, cx| {
269            store.save_thread(first_id.clone(), updated_first, cx)
270        });
271        update_task.await.unwrap();
272        cx.run_until_parked();
273
274        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
275        assert_eq!(entries.len(), 2);
276        assert_eq!(entries[0].id, first_id);
277        assert_eq!(entries[1].id, second_id);
278    }
279}