thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use util::path_list::PathList;
  6
  7struct GlobalThreadStore(Entity<ThreadStore>);
  8
  9impl Global for GlobalThreadStore {}
 10
 11pub struct ThreadStore {
 12    threads: Vec<DbThreadMetadata>,
 13}
 14
 15impl ThreadStore {
 16    pub fn init_global(cx: &mut App) {
 17        let thread_store = cx.new(|cx| Self::new(cx));
 18        cx.set_global(GlobalThreadStore(thread_store));
 19    }
 20
 21    pub fn global(cx: &App) -> Entity<Self> {
 22        cx.global::<GlobalThreadStore>().0.clone()
 23    }
 24
 25    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 26        cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
 27    }
 28
 29    pub fn new(cx: &mut Context<Self>) -> Self {
 30        let this = Self {
 31            threads: Vec::new(),
 32        };
 33        this.reload(cx);
 34        this
 35    }
 36
 37    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 38        self.threads.iter().find(|thread| &thread.id == session_id)
 39    }
 40
 41    pub fn load_thread(
 42        &mut self,
 43        id: acp::SessionId,
 44        cx: &mut Context<Self>,
 45    ) -> Task<Result<Option<DbThread>>> {
 46        let database_future = ThreadsDatabase::connect(cx);
 47        cx.background_spawn(async move {
 48            let database = database_future.await.map_err(|err| anyhow!(err))?;
 49            database.load_thread(id).await
 50        })
 51    }
 52
 53    pub fn save_thread(
 54        &mut self,
 55        id: acp::SessionId,
 56        thread: crate::DbThread,
 57        folder_paths: PathList,
 58        cx: &mut Context<Self>,
 59    ) -> Task<Result<()>> {
 60        let database_future = ThreadsDatabase::connect(cx);
 61        cx.spawn(async move |this, cx| {
 62            let database = database_future.await.map_err(|err| anyhow!(err))?;
 63            database.save_thread(id, thread, folder_paths).await?;
 64            this.update(cx, |this, cx| this.reload(cx))
 65        })
 66    }
 67
 68    pub fn delete_thread(
 69        &mut self,
 70        id: acp::SessionId,
 71        cx: &mut Context<Self>,
 72    ) -> Task<Result<()>> {
 73        let database_future = ThreadsDatabase::connect(cx);
 74        cx.spawn(async move |this, cx| {
 75            let database = database_future.await.map_err(|err| anyhow!(err))?;
 76            database.delete_thread(id.clone()).await?;
 77            this.update(cx, |this, cx| this.reload(cx))
 78        })
 79    }
 80
 81    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 82        let database_future = ThreadsDatabase::connect(cx);
 83        cx.spawn(async move |this, cx| {
 84            let database = database_future.await.map_err(|err| anyhow!(err))?;
 85            database.delete_threads().await?;
 86            this.update(cx, |this, cx| this.reload(cx))
 87        })
 88    }
 89
 90    pub fn reload(&self, cx: &mut Context<Self>) {
 91        let database_connection = ThreadsDatabase::connect(cx);
 92        cx.spawn(async move |this, cx| {
 93            let database = database_connection.await.map_err(|err| anyhow!(err))?;
 94            let all_threads = database.list_threads().await?;
 95            this.update(cx, |this, cx| {
 96                this.threads.clear();
 97                for thread in all_threads {
 98                    if thread.parent_session_id.is_some() {
 99                        continue;
100                    }
101                    this.threads.push(thread);
102                }
103                cx.notify();
104            })
105        })
106        .detach_and_log_err(cx);
107    }
108
109    pub fn is_empty(&self) -> bool {
110        self.threads.is_empty()
111    }
112
113    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
114        self.threads.iter().cloned()
115    }
116
117    pub fn entry_ids(&self) -> impl Iterator<Item = acp::SessionId> + '_ {
118        self.threads.iter().map(|t| t.id.clone())
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use chrono::{DateTime, TimeZone, Utc};
126    use collections::HashMap;
127    use gpui::TestAppContext;
128    use std::sync::Arc;
129
130    fn session_id(value: &str) -> acp::SessionId {
131        acp::SessionId::new(Arc::<str>::from(value))
132    }
133
134    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
135        DbThread {
136            title: title.to_string().into(),
137            messages: Vec::new(),
138            updated_at,
139            detailed_summary: None,
140            initial_project_snapshot: None,
141            cumulative_token_usage: Default::default(),
142            request_token_usage: HashMap::default(),
143            model: None,
144            profile: None,
145            imported: false,
146            subagent_context: None,
147            speed: None,
148            thinking_enabled: false,
149            thinking_effort: None,
150            draft_prompt: None,
151            ui_scroll_position: None,
152        }
153    }
154
155    #[gpui::test]
156    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
157        let thread_store = cx.new(|cx| ThreadStore::new(cx));
158        cx.run_until_parked();
159
160        let older_id = session_id("thread-a");
161        let newer_id = session_id("thread-b");
162
163        let older_thread = make_thread(
164            "Thread A",
165            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
166        );
167        let newer_thread = make_thread(
168            "Thread B",
169            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
170        );
171
172        let save_older = thread_store.update(cx, |store, cx| {
173            store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
174        });
175        save_older.await.unwrap();
176
177        let save_newer = thread_store.update(cx, |store, cx| {
178            store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
179        });
180        save_newer.await.unwrap();
181
182        cx.run_until_parked();
183
184        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
185        assert_eq!(entries.len(), 2);
186        assert_eq!(entries[0].id, newer_id);
187        assert_eq!(entries[1].id, older_id);
188    }
189
190    #[gpui::test]
191    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
192        let thread_store = cx.new(|cx| ThreadStore::new(cx));
193        cx.run_until_parked();
194
195        let thread_id = session_id("thread-a");
196        let thread = make_thread(
197            "Thread A",
198            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
199        );
200
201        let save_task = thread_store.update(cx, |store, cx| {
202            store.save_thread(thread_id, thread, PathList::default(), cx)
203        });
204        save_task.await.unwrap();
205
206        cx.run_until_parked();
207        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
208
209        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
210        delete_task.await.unwrap();
211        cx.run_until_parked();
212
213        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
214    }
215
216    #[gpui::test]
217    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
218        let thread_store = cx.new(|cx| ThreadStore::new(cx));
219        cx.run_until_parked();
220
221        let first_id = session_id("thread-a");
222        let second_id = session_id("thread-b");
223
224        let first_thread = make_thread(
225            "Thread A",
226            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
227        );
228        let second_thread = make_thread(
229            "Thread B",
230            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
231        );
232
233        let save_first = thread_store.update(cx, |store, cx| {
234            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
235        });
236        save_first.await.unwrap();
237        let save_second = thread_store.update(cx, |store, cx| {
238            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
239        });
240        save_second.await.unwrap();
241        cx.run_until_parked();
242
243        let delete_task =
244            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
245        delete_task.await.unwrap();
246        cx.run_until_parked();
247
248        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
249        assert_eq!(entries.len(), 1);
250        assert_eq!(entries[0].id, second_id);
251    }
252
253    #[gpui::test]
254    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
255        let thread_store = cx.new(|cx| ThreadStore::new(cx));
256        cx.run_until_parked();
257
258        let first_id = session_id("thread-a");
259        let second_id = session_id("thread-b");
260
261        let first_thread = make_thread(
262            "Thread A",
263            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
264        );
265        let second_thread = make_thread(
266            "Thread B",
267            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
268        );
269
270        let save_first = thread_store.update(cx, |store, cx| {
271            store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
272        });
273        save_first.await.unwrap();
274        let save_second = thread_store.update(cx, |store, cx| {
275            store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
276        });
277        save_second.await.unwrap();
278        cx.run_until_parked();
279
280        let updated_first = make_thread(
281            "Thread A",
282            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
283        );
284        let update_task = thread_store.update(cx, |store, cx| {
285            store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
286        });
287        update_task.await.unwrap();
288        cx.run_until_parked();
289
290        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
291        assert_eq!(entries.len(), 2);
292        assert_eq!(entries[0].id, first_id);
293        assert_eq!(entries[1].id, second_id);
294    }
295}