thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5
  6struct GlobalThreadStore(Entity<ThreadStore>);
  7
  8impl Global for GlobalThreadStore {}
  9
 10pub struct ThreadStore {
 11    threads: Vec<DbThreadMetadata>,
 12}
 13
 14impl ThreadStore {
 15    pub fn init_global(cx: &mut App) {
 16        let thread_store = cx.new(|cx| Self::new(cx));
 17        cx.set_global(GlobalThreadStore(thread_store));
 18    }
 19
 20    pub fn global(cx: &App) -> Entity<Self> {
 21        cx.global::<GlobalThreadStore>().0.clone()
 22    }
 23
 24    pub fn new(cx: &mut Context<Self>) -> Self {
 25        let this = Self {
 26            threads: Vec::new(),
 27        };
 28        this.reload(cx);
 29        this
 30    }
 31
 32    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 33        self.threads.iter().find(|thread| &thread.id == session_id)
 34    }
 35
 36    pub fn load_thread(
 37        &mut self,
 38        id: acp::SessionId,
 39        cx: &mut Context<Self>,
 40    ) -> Task<Result<Option<DbThread>>> {
 41        let database_future = ThreadsDatabase::connect(cx);
 42        cx.background_spawn(async move {
 43            let database = database_future.await.map_err(|err| anyhow!(err))?;
 44            database.load_thread(id).await
 45        })
 46    }
 47
 48    pub fn save_thread(
 49        &mut self,
 50        id: acp::SessionId,
 51        thread: crate::DbThread,
 52        cx: &mut Context<Self>,
 53    ) -> Task<Result<()>> {
 54        let database_future = ThreadsDatabase::connect(cx);
 55        cx.spawn(async move |this, cx| {
 56            let database = database_future.await.map_err(|err| anyhow!(err))?;
 57            database.save_thread(id, thread).await?;
 58            this.update(cx, |this, cx| this.reload(cx))
 59        })
 60    }
 61
 62    pub fn delete_thread(
 63        &mut self,
 64        id: acp::SessionId,
 65        cx: &mut Context<Self>,
 66    ) -> Task<Result<()>> {
 67        let database_future = ThreadsDatabase::connect(cx);
 68        cx.spawn(async move |this, cx| {
 69            let database = database_future.await.map_err(|err| anyhow!(err))?;
 70            database.delete_thread(id.clone()).await?;
 71            this.update(cx, |this, cx| this.reload(cx))
 72        })
 73    }
 74
 75    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 76        let database_future = ThreadsDatabase::connect(cx);
 77        cx.spawn(async move |this, cx| {
 78            let database = database_future.await.map_err(|err| anyhow!(err))?;
 79            database.delete_threads().await?;
 80            this.update(cx, |this, cx| this.reload(cx))
 81        })
 82    }
 83
 84    pub fn reload(&self, cx: &mut Context<Self>) {
 85        let database_connection = ThreadsDatabase::connect(cx);
 86        cx.spawn(async move |this, cx| {
 87            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 88            let threads = database
 89                .list_threads()
 90                .await?
 91                .into_iter()
 92                .filter(|thread| thread.parent_session_id.is_none())
 93                .collect::<Vec<_>>();
 94            this.update(cx, |this, cx| {
 95                this.threads = threads;
 96                cx.notify();
 97            })
 98        })
 99        .detach_and_log_err(cx);
100    }
101
102    pub fn is_empty(&self) -> bool {
103        self.threads.is_empty()
104    }
105
106    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
107        self.threads.iter().cloned()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use chrono::{DateTime, TimeZone, Utc};
115    use collections::HashMap;
116    use gpui::TestAppContext;
117    use std::sync::Arc;
118
119    fn session_id(value: &str) -> acp::SessionId {
120        acp::SessionId::new(Arc::<str>::from(value))
121    }
122
123    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
124        DbThread {
125            title: title.to_string().into(),
126            messages: Vec::new(),
127            updated_at,
128            detailed_summary: None,
129            initial_project_snapshot: None,
130            cumulative_token_usage: Default::default(),
131            request_token_usage: HashMap::default(),
132            model: None,
133            profile: None,
134            imported: false,
135            subagent_context: None,
136        }
137    }
138
139    #[gpui::test]
140    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
141        let thread_store = cx.new(|cx| ThreadStore::new(cx));
142        cx.run_until_parked();
143
144        let older_id = session_id("thread-a");
145        let newer_id = session_id("thread-b");
146
147        let older_thread = make_thread(
148            "Thread A",
149            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
150        );
151        let newer_thread = make_thread(
152            "Thread B",
153            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
154        );
155
156        let save_older = thread_store.update(cx, |store, cx| {
157            store.save_thread(older_id.clone(), older_thread, cx)
158        });
159        save_older.await.unwrap();
160
161        let save_newer = thread_store.update(cx, |store, cx| {
162            store.save_thread(newer_id.clone(), newer_thread, cx)
163        });
164        save_newer.await.unwrap();
165
166        cx.run_until_parked();
167
168        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
169        assert_eq!(entries.len(), 2);
170        assert_eq!(entries[0].id, newer_id);
171        assert_eq!(entries[1].id, older_id);
172    }
173
174    #[gpui::test]
175    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
176        let thread_store = cx.new(|cx| ThreadStore::new(cx));
177        cx.run_until_parked();
178
179        let thread_id = session_id("thread-a");
180        let thread = make_thread(
181            "Thread A",
182            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
183        );
184
185        let save_task =
186            thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
187        save_task.await.unwrap();
188
189        cx.run_until_parked();
190        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
191
192        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
193        delete_task.await.unwrap();
194        cx.run_until_parked();
195
196        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
197    }
198
199    #[gpui::test]
200    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
201        let thread_store = cx.new(|cx| ThreadStore::new(cx));
202        cx.run_until_parked();
203
204        let first_id = session_id("thread-a");
205        let second_id = session_id("thread-b");
206
207        let first_thread = make_thread(
208            "Thread A",
209            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
210        );
211        let second_thread = make_thread(
212            "Thread B",
213            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
214        );
215
216        let save_first = thread_store.update(cx, |store, cx| {
217            store.save_thread(first_id.clone(), first_thread, cx)
218        });
219        save_first.await.unwrap();
220        let save_second = thread_store.update(cx, |store, cx| {
221            store.save_thread(second_id.clone(), second_thread, cx)
222        });
223        save_second.await.unwrap();
224        cx.run_until_parked();
225
226        let delete_task =
227            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
228        delete_task.await.unwrap();
229        cx.run_until_parked();
230
231        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
232        assert_eq!(entries.len(), 1);
233        assert_eq!(entries[0].id, second_id);
234    }
235
236    #[gpui::test]
237    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
238        let thread_store = cx.new(|cx| ThreadStore::new(cx));
239        cx.run_until_parked();
240
241        let first_id = session_id("thread-a");
242        let second_id = session_id("thread-b");
243
244        let first_thread = make_thread(
245            "Thread A",
246            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
247        );
248        let second_thread = make_thread(
249            "Thread B",
250            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
251        );
252
253        let save_first = thread_store.update(cx, |store, cx| {
254            store.save_thread(first_id.clone(), first_thread, cx)
255        });
256        save_first.await.unwrap();
257        let save_second = thread_store.update(cx, |store, cx| {
258            store.save_thread(second_id.clone(), second_thread, cx)
259        });
260        save_second.await.unwrap();
261        cx.run_until_parked();
262
263        let updated_first = make_thread(
264            "Thread A",
265            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
266        );
267        let update_task = thread_store.update(cx, |store, cx| {
268            store.save_thread(first_id.clone(), updated_first, cx)
269        });
270        update_task.await.unwrap();
271        cx.run_until_parked();
272
273        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
274        assert_eq!(entries.len(), 2);
275        assert_eq!(entries[0].id, first_id);
276        assert_eq!(entries[1].id, second_id);
277    }
278}