thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use project::Project;
  6use std::rc::Rc;
  7
  8struct GlobalThreadStore(Entity<ThreadStore>);
  9
 10impl Global for GlobalThreadStore {}
 11
 12// TODO: Remove once ACP thread loading is fully handled elsewhere.
 13pub fn load_agent_thread(
 14    session_id: acp::SessionId,
 15    thread_store: Entity<ThreadStore>,
 16    project: Entity<Project>,
 17    cx: &mut App,
 18) -> Task<Result<Entity<crate::Thread>>> {
 19    use agent_servers::{AgentServer, AgentServerDelegate};
 20
 21    let server = Rc::new(crate::NativeAgentServer::new(
 22        project.read(cx).fs().clone(),
 23        thread_store,
 24    ));
 25    let delegate = AgentServerDelegate::new(
 26        project.read(cx).agent_server_store().clone(),
 27        project.clone(),
 28        None,
 29        None,
 30    );
 31    let connection = server.connect(None, delegate, cx);
 32    cx.spawn(async move |cx| {
 33        let (agent, _) = connection.await?;
 34        let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
 35        cx.update(|cx| agent.load_thread(session_id, cx)).await
 36    })
 37}
 38
 39pub struct ThreadStore {
 40    threads: Vec<DbThreadMetadata>,
 41}
 42
 43impl ThreadStore {
 44    pub fn init_global(cx: &mut App) {
 45        let thread_store = cx.new(|cx| Self::new(cx));
 46        cx.set_global(GlobalThreadStore(thread_store));
 47    }
 48
 49    pub fn global(cx: &App) -> Entity<Self> {
 50        cx.global::<GlobalThreadStore>().0.clone()
 51    }
 52
 53    pub fn new(cx: &mut Context<Self>) -> Self {
 54        let this = Self {
 55            threads: Vec::new(),
 56        };
 57        this.reload(cx);
 58        this
 59    }
 60
 61    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 62        self.threads.iter().find(|thread| &thread.id == session_id)
 63    }
 64
 65    pub fn load_thread(
 66        &mut self,
 67        id: acp::SessionId,
 68        cx: &mut Context<Self>,
 69    ) -> Task<Result<Option<DbThread>>> {
 70        let database_future = ThreadsDatabase::connect(cx);
 71        cx.background_spawn(async move {
 72            let database = database_future.await.map_err(|err| anyhow!(err))?;
 73            database.load_thread(id).await
 74        })
 75    }
 76
 77    pub fn save_thread(
 78        &mut self,
 79        id: acp::SessionId,
 80        thread: crate::DbThread,
 81        cx: &mut Context<Self>,
 82    ) -> Task<Result<()>> {
 83        let database_future = ThreadsDatabase::connect(cx);
 84        cx.spawn(async move |this, cx| {
 85            let database = database_future.await.map_err(|err| anyhow!(err))?;
 86            database.save_thread(id, thread).await?;
 87            this.update(cx, |this, cx| this.reload(cx))
 88        })
 89    }
 90
 91    pub fn delete_thread(
 92        &mut self,
 93        id: acp::SessionId,
 94        cx: &mut Context<Self>,
 95    ) -> Task<Result<()>> {
 96        let database_future = ThreadsDatabase::connect(cx);
 97        cx.spawn(async move |this, cx| {
 98            let database = database_future.await.map_err(|err| anyhow!(err))?;
 99            database.delete_thread(id.clone()).await?;
100            this.update(cx, |this, cx| this.reload(cx))
101        })
102    }
103
104    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
105        let database_future = ThreadsDatabase::connect(cx);
106        cx.spawn(async move |this, cx| {
107            let database = database_future.await.map_err(|err| anyhow!(err))?;
108            database.delete_threads().await?;
109            this.update(cx, |this, cx| this.reload(cx))
110        })
111    }
112
113    pub fn reload(&self, cx: &mut Context<Self>) {
114        let database_connection = ThreadsDatabase::connect(cx);
115        cx.spawn(async move |this, cx| {
116            let database = database_connection.await.map_err(|err| anyhow!(err))?;
117            let threads = database.list_threads().await?;
118            this.update(cx, |this, cx| {
119                this.threads = threads;
120                cx.notify();
121            })
122        })
123        .detach_and_log_err(cx);
124    }
125
126    pub fn is_empty(&self) -> bool {
127        self.threads.is_empty()
128    }
129
130    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
131        self.threads.iter().cloned()
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use chrono::{DateTime, TimeZone, Utc};
139    use collections::HashMap;
140    use gpui::TestAppContext;
141    use std::sync::Arc;
142
143    fn session_id(value: &str) -> acp::SessionId {
144        acp::SessionId::new(Arc::<str>::from(value))
145    }
146
147    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
148        DbThread {
149            title: title.to_string().into(),
150            messages: Vec::new(),
151            updated_at,
152            detailed_summary: None,
153            initial_project_snapshot: None,
154            cumulative_token_usage: Default::default(),
155            request_token_usage: HashMap::default(),
156            model: None,
157            profile: None,
158            imported: false,
159        }
160    }
161
162    #[gpui::test]
163    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
164        let thread_store = cx.new(|cx| ThreadStore::new(cx));
165        cx.run_until_parked();
166
167        let older_id = session_id("thread-a");
168        let newer_id = session_id("thread-b");
169
170        let older_thread = make_thread(
171            "Thread A",
172            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
173        );
174        let newer_thread = make_thread(
175            "Thread B",
176            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
177        );
178
179        let save_older = thread_store.update(cx, |store, cx| {
180            store.save_thread(older_id.clone(), older_thread, cx)
181        });
182        save_older.await.unwrap();
183
184        let save_newer = thread_store.update(cx, |store, cx| {
185            store.save_thread(newer_id.clone(), newer_thread, cx)
186        });
187        save_newer.await.unwrap();
188
189        cx.run_until_parked();
190
191        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
192        assert_eq!(entries.len(), 2);
193        assert_eq!(entries[0].id, newer_id);
194        assert_eq!(entries[1].id, older_id);
195    }
196
197    #[gpui::test]
198    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
199        let thread_store = cx.new(|cx| ThreadStore::new(cx));
200        cx.run_until_parked();
201
202        let thread_id = session_id("thread-a");
203        let thread = make_thread(
204            "Thread A",
205            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
206        );
207
208        let save_task =
209            thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
210        save_task.await.unwrap();
211
212        cx.run_until_parked();
213        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
214
215        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
216        delete_task.await.unwrap();
217        cx.run_until_parked();
218
219        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
220    }
221
222    #[gpui::test]
223    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
224        let thread_store = cx.new(|cx| ThreadStore::new(cx));
225        cx.run_until_parked();
226
227        let first_id = session_id("thread-a");
228        let second_id = session_id("thread-b");
229
230        let first_thread = make_thread(
231            "Thread A",
232            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
233        );
234        let second_thread = make_thread(
235            "Thread B",
236            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
237        );
238
239        let save_first = thread_store.update(cx, |store, cx| {
240            store.save_thread(first_id.clone(), first_thread, cx)
241        });
242        save_first.await.unwrap();
243        let save_second = thread_store.update(cx, |store, cx| {
244            store.save_thread(second_id.clone(), second_thread, cx)
245        });
246        save_second.await.unwrap();
247        cx.run_until_parked();
248
249        let delete_task =
250            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
251        delete_task.await.unwrap();
252        cx.run_until_parked();
253
254        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
255        assert_eq!(entries.len(), 1);
256        assert_eq!(entries[0].id, second_id);
257    }
258
259    #[gpui::test]
260    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
261        let thread_store = cx.new(|cx| ThreadStore::new(cx));
262        cx.run_until_parked();
263
264        let first_id = session_id("thread-a");
265        let second_id = session_id("thread-b");
266
267        let first_thread = make_thread(
268            "Thread A",
269            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
270        );
271        let second_thread = make_thread(
272            "Thread B",
273            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
274        );
275
276        let save_first = thread_store.update(cx, |store, cx| {
277            store.save_thread(first_id.clone(), first_thread, cx)
278        });
279        save_first.await.unwrap();
280        let save_second = thread_store.update(cx, |store, cx| {
281            store.save_thread(second_id.clone(), second_thread, cx)
282        });
283        save_second.await.unwrap();
284        cx.run_until_parked();
285
286        let updated_first = make_thread(
287            "Thread A",
288            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
289        );
290        let update_task = thread_store.update(cx, |store, cx| {
291            store.save_thread(first_id.clone(), updated_first, cx)
292        });
293        update_task.await.unwrap();
294        cx.run_until_parked();
295
296        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
297        assert_eq!(entries.len(), 2);
298        assert_eq!(entries[0].id, first_id);
299        assert_eq!(entries[1].id, second_id);
300    }
301}