thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5
  6struct GlobalThreadStore(Entity<ThreadStore>);
  7
  8impl Global for GlobalThreadStore {}
  9
 10pub struct ThreadStore {
 11    threads: Vec<DbThreadMetadata>,
 12}
 13
 14impl ThreadStore {
 15    pub fn init_global(cx: &mut App) {
 16        let thread_store = cx.new(|cx| Self::new(cx));
 17        cx.set_global(GlobalThreadStore(thread_store));
 18    }
 19
 20    pub fn global(cx: &App) -> Entity<Self> {
 21        cx.global::<GlobalThreadStore>().0.clone()
 22    }
 23
 24    pub fn new(cx: &mut Context<Self>) -> Self {
 25        let this = Self {
 26            threads: Vec::new(),
 27        };
 28        this.reload(cx);
 29        this
 30    }
 31
 32    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 33        self.threads.iter().find(|thread| &thread.id == session_id)
 34    }
 35
 36    pub fn load_thread(
 37        &mut self,
 38        id: acp::SessionId,
 39        cx: &mut Context<Self>,
 40    ) -> Task<Result<Option<DbThread>>> {
 41        let database_future = ThreadsDatabase::connect(cx);
 42        cx.background_spawn(async move {
 43            let database = database_future.await.map_err(|err| anyhow!(err))?;
 44            database.load_thread(id).await
 45        })
 46    }
 47
 48    pub fn save_thread(
 49        &mut self,
 50        id: acp::SessionId,
 51        thread: crate::DbThread,
 52        cx: &mut Context<Self>,
 53    ) -> Task<Result<()>> {
 54        let database_future = ThreadsDatabase::connect(cx);
 55        cx.spawn(async move |this, cx| {
 56            let database = database_future.await.map_err(|err| anyhow!(err))?;
 57            database.save_thread(id, thread).await?;
 58            this.update(cx, |this, cx| this.reload(cx))
 59        })
 60    }
 61
 62    pub fn delete_thread(
 63        &mut self,
 64        id: acp::SessionId,
 65        cx: &mut Context<Self>,
 66    ) -> Task<Result<()>> {
 67        let database_future = ThreadsDatabase::connect(cx);
 68        cx.spawn(async move |this, cx| {
 69            let database = database_future.await.map_err(|err| anyhow!(err))?;
 70            database.delete_thread(id.clone()).await?;
 71            this.update(cx, |this, cx| this.reload(cx))
 72        })
 73    }
 74
 75    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 76        let database_future = ThreadsDatabase::connect(cx);
 77        cx.spawn(async move |this, cx| {
 78            let database = database_future.await.map_err(|err| anyhow!(err))?;
 79            database.delete_threads().await?;
 80            this.update(cx, |this, cx| this.reload(cx))
 81        })
 82    }
 83
 84    pub fn reload(&self, cx: &mut Context<Self>) {
 85        let database_connection = ThreadsDatabase::connect(cx);
 86        cx.spawn(async move |this, cx| {
 87            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 88            let threads = database
 89                .list_threads()
 90                .await?
 91                .into_iter()
 92                .filter(|thread| thread.parent_session_id.is_none())
 93                .collect::<Vec<_>>();
 94            this.update(cx, |this, cx| {
 95                this.threads = threads;
 96                cx.notify();
 97            })
 98        })
 99        .detach_and_log_err(cx);
100    }
101
102    pub fn is_empty(&self) -> bool {
103        self.threads.is_empty()
104    }
105
106    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
107        self.threads.iter().cloned()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use chrono::{DateTime, TimeZone, Utc};
115    use collections::HashMap;
116    use gpui::TestAppContext;
117    use std::sync::Arc;
118
119    fn session_id(value: &str) -> acp::SessionId {
120        acp::SessionId::new(Arc::<str>::from(value))
121    }
122
123    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
124        DbThread {
125            title: title.to_string().into(),
126            messages: Vec::new(),
127            updated_at,
128            detailed_summary: None,
129            initial_project_snapshot: None,
130            cumulative_token_usage: Default::default(),
131            request_token_usage: HashMap::default(),
132            model: None,
133            profile: None,
134            imported: false,
135            subagent_context: None,
136            speed: None,
137            thinking_enabled: false,
138            thinking_effort: None,
139        }
140    }
141
142    #[gpui::test]
143    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
144        let thread_store = cx.new(|cx| ThreadStore::new(cx));
145        cx.run_until_parked();
146
147        let older_id = session_id("thread-a");
148        let newer_id = session_id("thread-b");
149
150        let older_thread = make_thread(
151            "Thread A",
152            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
153        );
154        let newer_thread = make_thread(
155            "Thread B",
156            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
157        );
158
159        let save_older = thread_store.update(cx, |store, cx| {
160            store.save_thread(older_id.clone(), older_thread, cx)
161        });
162        save_older.await.unwrap();
163
164        let save_newer = thread_store.update(cx, |store, cx| {
165            store.save_thread(newer_id.clone(), newer_thread, cx)
166        });
167        save_newer.await.unwrap();
168
169        cx.run_until_parked();
170
171        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
172        assert_eq!(entries.len(), 2);
173        assert_eq!(entries[0].id, newer_id);
174        assert_eq!(entries[1].id, older_id);
175    }
176
177    #[gpui::test]
178    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
179        let thread_store = cx.new(|cx| ThreadStore::new(cx));
180        cx.run_until_parked();
181
182        let thread_id = session_id("thread-a");
183        let thread = make_thread(
184            "Thread A",
185            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
186        );
187
188        let save_task =
189            thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
190        save_task.await.unwrap();
191
192        cx.run_until_parked();
193        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
194
195        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
196        delete_task.await.unwrap();
197        cx.run_until_parked();
198
199        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
200    }
201
202    #[gpui::test]
203    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
204        let thread_store = cx.new(|cx| ThreadStore::new(cx));
205        cx.run_until_parked();
206
207        let first_id = session_id("thread-a");
208        let second_id = session_id("thread-b");
209
210        let first_thread = make_thread(
211            "Thread A",
212            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
213        );
214        let second_thread = make_thread(
215            "Thread B",
216            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
217        );
218
219        let save_first = thread_store.update(cx, |store, cx| {
220            store.save_thread(first_id.clone(), first_thread, cx)
221        });
222        save_first.await.unwrap();
223        let save_second = thread_store.update(cx, |store, cx| {
224            store.save_thread(second_id.clone(), second_thread, cx)
225        });
226        save_second.await.unwrap();
227        cx.run_until_parked();
228
229        let delete_task =
230            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
231        delete_task.await.unwrap();
232        cx.run_until_parked();
233
234        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
235        assert_eq!(entries.len(), 1);
236        assert_eq!(entries[0].id, second_id);
237    }
238
239    #[gpui::test]
240    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
241        let thread_store = cx.new(|cx| ThreadStore::new(cx));
242        cx.run_until_parked();
243
244        let first_id = session_id("thread-a");
245        let second_id = session_id("thread-b");
246
247        let first_thread = make_thread(
248            "Thread A",
249            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
250        );
251        let second_thread = make_thread(
252            "Thread B",
253            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
254        );
255
256        let save_first = thread_store.update(cx, |store, cx| {
257            store.save_thread(first_id.clone(), first_thread, cx)
258        });
259        save_first.await.unwrap();
260        let save_second = thread_store.update(cx, |store, cx| {
261            store.save_thread(second_id.clone(), second_thread, cx)
262        });
263        save_second.await.unwrap();
264        cx.run_until_parked();
265
266        let updated_first = make_thread(
267            "Thread A",
268            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
269        );
270        let update_task = thread_store.update(cx, |store, cx| {
271            store.save_thread(first_id.clone(), updated_first, cx)
272        });
273        update_task.await.unwrap();
274        cx.run_until_parked();
275
276        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
277        assert_eq!(entries.len(), 2);
278        assert_eq!(entries[0].id, first_id);
279        assert_eq!(entries[1].id, second_id);
280    }
281}