thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use project::Project;
  6use std::rc::Rc;
  7
  8struct GlobalThreadStore(Entity<ThreadStore>);
  9
 10impl Global for GlobalThreadStore {}
 11
 12// TODO: Remove once ACP thread loading is fully handled elsewhere.
 13pub fn load_agent_thread(
 14    session_id: acp::SessionId,
 15    thread_store: Entity<ThreadStore>,
 16    project: Entity<Project>,
 17    cx: &mut App,
 18) -> Task<Result<Entity<crate::Thread>>> {
 19    use agent_servers::{AgentServer, AgentServerDelegate};
 20
 21    let server = Rc::new(crate::NativeAgentServer::new(
 22        project.read(cx).fs().clone(),
 23        thread_store,
 24    ));
 25    let delegate = AgentServerDelegate::new(
 26        project.read(cx).agent_server_store().clone(),
 27        project.clone(),
 28        None,
 29        None,
 30    );
 31    let connection = server.connect(None, delegate, cx);
 32    cx.spawn(async move |cx| {
 33        let (agent, _) = connection.await?;
 34        let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
 35        cx.update(|cx| agent.load_thread(session_id, cx)).await
 36    })
 37}
 38
 39pub struct ThreadStore {
 40    threads: Vec<DbThreadMetadata>,
 41}
 42
 43impl ThreadStore {
 44    pub fn init_global(cx: &mut App) {
 45        let thread_store = cx.new(|cx| Self::new(cx));
 46        cx.set_global(GlobalThreadStore(thread_store));
 47    }
 48
 49    pub fn global(cx: &App) -> Entity<Self> {
 50        cx.global::<GlobalThreadStore>().0.clone()
 51    }
 52
 53    pub fn new(cx: &mut Context<Self>) -> Self {
 54        let this = Self {
 55            threads: Vec::new(),
 56        };
 57        this.reload(cx);
 58        this
 59    }
 60
 61    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 62        self.threads.iter().find(|thread| &thread.id == session_id)
 63    }
 64
 65    pub fn load_thread(
 66        &mut self,
 67        id: acp::SessionId,
 68        cx: &mut Context<Self>,
 69    ) -> Task<Result<Option<DbThread>>> {
 70        let database_future = ThreadsDatabase::connect(cx);
 71        cx.background_spawn(async move {
 72            let database = database_future.await.map_err(|err| anyhow!(err))?;
 73            database.load_thread(id).await
 74        })
 75    }
 76
 77    pub fn save_thread(
 78        &mut self,
 79        id: acp::SessionId,
 80        thread: crate::DbThread,
 81        cx: &mut Context<Self>,
 82    ) -> Task<Result<()>> {
 83        let database_future = ThreadsDatabase::connect(cx);
 84        cx.spawn(async move |this, cx| {
 85            let database = database_future.await.map_err(|err| anyhow!(err))?;
 86            database.save_thread(id, thread).await?;
 87            this.update(cx, |this, cx| this.reload(cx))
 88        })
 89    }
 90
 91    pub fn delete_thread(
 92        &mut self,
 93        id: acp::SessionId,
 94        cx: &mut Context<Self>,
 95    ) -> Task<Result<()>> {
 96        let database_future = ThreadsDatabase::connect(cx);
 97        cx.spawn(async move |this, cx| {
 98            let database = database_future.await.map_err(|err| anyhow!(err))?;
 99            database.delete_thread(id.clone()).await?;
100            this.update(cx, |this, cx| this.reload(cx))
101        })
102    }
103
104    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
105        let database_future = ThreadsDatabase::connect(cx);
106        cx.spawn(async move |this, cx| {
107            let database = database_future.await.map_err(|err| anyhow!(err))?;
108            database.delete_threads().await?;
109            this.update(cx, |this, cx| this.reload(cx))
110        })
111    }
112
113    pub fn reload(&self, cx: &mut Context<Self>) {
114        let database_connection = ThreadsDatabase::connect(cx);
115        cx.spawn(async move |this, cx| {
116            let database = database_connection.await.map_err(|err| anyhow!(err))?;
117            let threads = database
118                .list_threads()
119                .await?
120                .into_iter()
121                .filter(|thread| thread.parent_session_id.is_none())
122                .collect::<Vec<_>>();
123            this.update(cx, |this, cx| {
124                this.threads = threads;
125                cx.notify();
126            })
127        })
128        .detach_and_log_err(cx);
129    }
130
131    pub fn is_empty(&self) -> bool {
132        self.threads.is_empty()
133    }
134
135    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
136        self.threads.iter().cloned()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use chrono::{DateTime, TimeZone, Utc};
144    use collections::HashMap;
145    use gpui::TestAppContext;
146    use std::sync::Arc;
147
148    fn session_id(value: &str) -> acp::SessionId {
149        acp::SessionId::new(Arc::<str>::from(value))
150    }
151
152    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
153        DbThread {
154            title: title.to_string().into(),
155            messages: Vec::new(),
156            updated_at,
157            detailed_summary: None,
158            initial_project_snapshot: None,
159            cumulative_token_usage: Default::default(),
160            request_token_usage: HashMap::default(),
161            model: None,
162            profile: None,
163            imported: false,
164            subagent_context: None,
165            git_worktree_info: None,
166        }
167    }
168
169    #[gpui::test]
170    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
171        let thread_store = cx.new(|cx| ThreadStore::new(cx));
172        cx.run_until_parked();
173
174        let older_id = session_id("thread-a");
175        let newer_id = session_id("thread-b");
176
177        let older_thread = make_thread(
178            "Thread A",
179            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
180        );
181        let newer_thread = make_thread(
182            "Thread B",
183            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
184        );
185
186        let save_older = thread_store.update(cx, |store, cx| {
187            store.save_thread(older_id.clone(), older_thread, cx)
188        });
189        save_older.await.unwrap();
190
191        let save_newer = thread_store.update(cx, |store, cx| {
192            store.save_thread(newer_id.clone(), newer_thread, cx)
193        });
194        save_newer.await.unwrap();
195
196        cx.run_until_parked();
197
198        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
199        assert_eq!(entries.len(), 2);
200        assert_eq!(entries[0].id, newer_id);
201        assert_eq!(entries[1].id, older_id);
202    }
203
204    #[gpui::test]
205    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
206        let thread_store = cx.new(|cx| ThreadStore::new(cx));
207        cx.run_until_parked();
208
209        let thread_id = session_id("thread-a");
210        let thread = make_thread(
211            "Thread A",
212            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
213        );
214
215        let save_task =
216            thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
217        save_task.await.unwrap();
218
219        cx.run_until_parked();
220        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
221
222        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
223        delete_task.await.unwrap();
224        cx.run_until_parked();
225
226        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
227    }
228
229    #[gpui::test]
230    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
231        let thread_store = cx.new(|cx| ThreadStore::new(cx));
232        cx.run_until_parked();
233
234        let first_id = session_id("thread-a");
235        let second_id = session_id("thread-b");
236
237        let first_thread = make_thread(
238            "Thread A",
239            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
240        );
241        let second_thread = make_thread(
242            "Thread B",
243            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
244        );
245
246        let save_first = thread_store.update(cx, |store, cx| {
247            store.save_thread(first_id.clone(), first_thread, cx)
248        });
249        save_first.await.unwrap();
250        let save_second = thread_store.update(cx, |store, cx| {
251            store.save_thread(second_id.clone(), second_thread, cx)
252        });
253        save_second.await.unwrap();
254        cx.run_until_parked();
255
256        let delete_task =
257            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
258        delete_task.await.unwrap();
259        cx.run_until_parked();
260
261        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
262        assert_eq!(entries.len(), 1);
263        assert_eq!(entries[0].id, second_id);
264    }
265
266    #[gpui::test]
267    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
268        let thread_store = cx.new(|cx| ThreadStore::new(cx));
269        cx.run_until_parked();
270
271        let first_id = session_id("thread-a");
272        let second_id = session_id("thread-b");
273
274        let first_thread = make_thread(
275            "Thread A",
276            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
277        );
278        let second_thread = make_thread(
279            "Thread B",
280            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
281        );
282
283        let save_first = thread_store.update(cx, |store, cx| {
284            store.save_thread(first_id.clone(), first_thread, cx)
285        });
286        save_first.await.unwrap();
287        let save_second = thread_store.update(cx, |store, cx| {
288            store.save_thread(second_id.clone(), second_thread, cx)
289        });
290        save_second.await.unwrap();
291        cx.run_until_parked();
292
293        let updated_first = make_thread(
294            "Thread A",
295            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
296        );
297        let update_task = thread_store.update(cx, |store, cx| {
298            store.save_thread(first_id.clone(), updated_first, cx)
299        });
300        update_task.await.unwrap();
301        cx.run_until_parked();
302
303        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
304        assert_eq!(entries.len(), 2);
305        assert_eq!(entries[0].id, first_id);
306        assert_eq!(entries[1].id, second_id);
307    }
308}