thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use util::path_list::PathList;
  6
  7struct GlobalThreadStore(Entity<ThreadStore>);
  8
  9impl Global for GlobalThreadStore {}
 10
 11pub struct ThreadStore {
 12    threads: Vec<DbThreadMetadata>,
 13}
 14
 15impl ThreadStore {
 16    pub fn init_global(cx: &mut App) {
 17        let thread_store = cx.new(|cx| Self::new(cx));
 18        cx.set_global(GlobalThreadStore(thread_store));
 19    }
 20
 21    pub fn global(cx: &App) -> Entity<Self> {
 22        cx.global::<GlobalThreadStore>().0.clone()
 23    }
 24
 25    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 26        cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
 27    }
 28
 29    pub fn new(cx: &mut Context<Self>) -> Self {
 30        let this = Self {
 31            threads: Vec::new(),
 32        };
 33        this.reload(cx);
 34        this
 35    }
 36
 37    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 38        self.threads.iter().find(|thread| &thread.id == session_id)
 39    }
 40
 41    pub fn load_thread(
 42        &mut self,
 43        id: acp::SessionId,
 44        cx: &mut Context<Self>,
 45    ) -> Task<Result<Option<DbThread>>> {
 46        let database_future = ThreadsDatabase::connect(cx);
 47        cx.background_spawn(async move {
 48            let database = database_future.await.map_err(|err| anyhow!(err))?;
 49            database.load_thread(id).await
 50        })
 51    }
 52
 53    pub fn save_thread(
 54        &mut self,
 55        id: acp::SessionId,
 56        thread: crate::DbThread,
 57        folder_paths: PathList,
 58        cx: &mut Context<Self>,
 59    ) -> Task<Result<()>> {
 60        let database_future = ThreadsDatabase::connect(cx);
 61        cx.spawn(async move |this, cx| {
 62            let database = database_future.await.map_err(|err| anyhow!(err))?;
 63            database.save_thread(id, thread, folder_paths).await?;
 64            this.update(cx, |this, cx| this.reload(cx))
 65        })
 66    }
 67
 68    pub fn delete_thread(
 69        &mut self,
 70        id: acp::SessionId,
 71        cx: &mut Context<Self>,
 72    ) -> Task<Result<()>> {
 73        let database_future = ThreadsDatabase::connect(cx);
 74        cx.spawn(async move |this, cx| {
 75            let database = database_future.await.map_err(|err| anyhow!(err))?;
 76            database.delete_thread(id.clone()).await?;
 77            this.update(cx, |this, cx| this.reload(cx))
 78        })
 79    }
 80
 81    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 82        let database_future = ThreadsDatabase::connect(cx);
 83        cx.spawn(async move |this, cx| {
 84            let database = database_future.await.map_err(|err| anyhow!(err))?;
 85            database.delete_threads().await?;
 86            this.update(cx, |this, cx| this.reload(cx))
 87        })
 88    }
 89
 90    pub fn reload(&self, cx: &mut Context<Self>) {
 91        let database_connection = ThreadsDatabase::connect(cx);
 92        cx.spawn(async move |this, cx| {
 93            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 94            let all_threads = database.list_threads().await?;
 95            this.update(cx, |this, cx| {
 96                this.threads.clear();
 97                for thread in all_threads {
 98                    if thread.parent_session_id.is_some() {
 99                        continue;
100                    }
101                    this.threads.push(thread);
102                }
103                cx.notify();
104            })
105        })
106        .detach_and_log_err(cx);
107    }
108
109    pub fn is_empty(&self) -> bool {
110        self.threads.is_empty()
111    }
112
113    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
114        self.threads.iter().cloned()
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use chrono::{DateTime, TimeZone, Utc};
122    use collections::HashMap;
123    use gpui::TestAppContext;
124    use std::sync::Arc;
125
126    fn session_id(value: &str) -> acp::SessionId {
127        acp::SessionId::new(Arc::<str>::from(value))
128    }
129
130    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
131        DbThread {
132            title: title.to_string().into(),
133            messages: Vec::new(),
134            updated_at,
135            detailed_summary: None,
136            initial_project_snapshot: None,
137            cumulative_token_usage: Default::default(),
138            request_token_usage: HashMap::default(),
139            model: None,
140            profile: None,
141            imported: false,
142            subagent_context: None,
143            speed: None,
144            thinking_enabled: false,
145            thinking_effort: None,
146            draft_prompt: None,
147            ui_scroll_position: None,
148        }
149    }
150
151    #[gpui::test]
152    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
153        let thread_store = cx.new(|cx| ThreadStore::new(cx));
154        cx.run_until_parked();
155
156        let older_id = session_id("thread-a");
157        let newer_id = session_id("thread-b");
158
159        let older_thread = make_thread(
160            "Thread A",
161            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
162        );
163        let newer_thread = make_thread(
164            "Thread B",
165            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
166        );
167
168        let save_older = thread_store.update(cx, |store, cx| {
169            store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
170        });
171        save_older.await.unwrap();
172
173        let save_newer = thread_store.update(cx, |store, cx| {
174            store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
175        });
176        save_newer.await.unwrap();
177
178        cx.run_until_parked();
179
180        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
181        assert_eq!(entries.len(), 2);
182        assert_eq!(entries[0].id, newer_id);
183        assert_eq!(entries[1].id, older_id);
184    }
185
186    #[gpui::test]
187    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
188        let thread_store = cx.new(|cx| ThreadStore::new(cx));
189        cx.run_until_parked();
190
191        let thread_id = session_id("thread-a");
192        let thread = make_thread(
193            "Thread A",
194            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
195        );
196
197        let save_task = thread_store.update(cx, |store, cx| {
198            store.save_thread(thread_id, thread, PathList::default(), cx)
199        });
200        save_task.await.unwrap();
201
202        cx.run_until_parked();
203        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
204
205        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
206        delete_task.await.unwrap();
207        cx.run_until_parked();
208
209        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
210    }
211
212    #[gpui::test]
213    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
214        let thread_store = cx.new(|cx| ThreadStore::new(cx));
215        cx.run_until_parked();
216
217        let first_id = session_id("thread-a");
218        let second_id = session_id("thread-b");
219
220        let first_thread = make_thread(
221            "Thread A",
222            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
223        );
224        let second_thread = make_thread(
225            "Thread B",
226            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
227        );
228
229        let save_first = thread_store.update(cx, |store, cx| {
230            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
231        });
232        save_first.await.unwrap();
233        let save_second = thread_store.update(cx, |store, cx| {
234            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
235        });
236        save_second.await.unwrap();
237        cx.run_until_parked();
238
239        let delete_task =
240            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
241        delete_task.await.unwrap();
242        cx.run_until_parked();
243
244        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
245        assert_eq!(entries.len(), 1);
246        assert_eq!(entries[0].id, second_id);
247    }
248
249    #[gpui::test]
250    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
251        let thread_store = cx.new(|cx| ThreadStore::new(cx));
252        cx.run_until_parked();
253
254        let first_id = session_id("thread-a");
255        let second_id = session_id("thread-b");
256
257        let first_thread = make_thread(
258            "Thread A",
259            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
260        );
261        let second_thread = make_thread(
262            "Thread B",
263            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
264        );
265
266        let save_first = thread_store.update(cx, |store, cx| {
267            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
268        });
269        save_first.await.unwrap();
270        let save_second = thread_store.update(cx, |store, cx| {
271            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
272        });
273        save_second.await.unwrap();
274        cx.run_until_parked();
275
276        let updated_first = make_thread(
277            "Thread A",
278            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
279        );
280        let update_task = thread_store.update(cx, |store, cx| {
281            store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
282        });
283        update_task.await.unwrap();
284        cx.run_until_parked();
285
286        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
287        assert_eq!(entries.len(), 2);
288        assert_eq!(entries[0].id, first_id);
289        assert_eq!(entries[1].id, second_id);
290    }
291}