thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Task, prelude::*};
  5use project::Project;
  6use std::rc::Rc;
  7
  8// TODO: Remove once ACP thread loading is fully handled elsewhere.
  9pub fn load_agent_thread(
 10    session_id: acp::SessionId,
 11    thread_store: Entity<ThreadStore>,
 12    project: Entity<Project>,
 13    cx: &mut App,
 14) -> Task<Result<Entity<crate::Thread>>> {
 15    use agent_servers::{AgentServer, AgentServerDelegate};
 16
 17    let server = Rc::new(crate::NativeAgentServer::new(
 18        project.read(cx).fs().clone(),
 19        thread_store,
 20    ));
 21    let delegate = AgentServerDelegate::new(
 22        project.read(cx).agent_server_store().clone(),
 23        project.clone(),
 24        None,
 25        None,
 26    );
 27    let connection = server.connect(None, delegate, cx);
 28    cx.spawn(async move |cx| {
 29        let (agent, _) = connection.await?;
 30        let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
 31        cx.update(|cx| agent.load_thread(session_id, cx)).await
 32    })
 33}
 34
 35pub struct ThreadStore {
 36    threads: Vec<DbThreadMetadata>,
 37}
 38
 39impl ThreadStore {
 40    pub fn new(cx: &mut Context<Self>) -> Self {
 41        let this = Self {
 42            threads: Vec::new(),
 43        };
 44        this.reload(cx);
 45        this
 46    }
 47
 48    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 49        self.threads.iter().find(|thread| &thread.id == session_id)
 50    }
 51
 52    pub fn load_thread(
 53        &mut self,
 54        id: acp::SessionId,
 55        cx: &mut Context<Self>,
 56    ) -> Task<Result<Option<DbThread>>> {
 57        let database_future = ThreadsDatabase::connect(cx);
 58        cx.background_spawn(async move {
 59            let database = database_future.await.map_err(|err| anyhow!(err))?;
 60            database.load_thread(id).await
 61        })
 62    }
 63
 64    pub fn save_thread(
 65        &mut self,
 66        id: acp::SessionId,
 67        thread: crate::DbThread,
 68        cx: &mut Context<Self>,
 69    ) -> Task<Result<()>> {
 70        let database_future = ThreadsDatabase::connect(cx);
 71        cx.spawn(async move |this, cx| {
 72            let database = database_future.await.map_err(|err| anyhow!(err))?;
 73            database.save_thread(id, thread).await?;
 74            this.update(cx, |this, cx| this.reload(cx))
 75        })
 76    }
 77
 78    pub fn delete_thread(
 79        &mut self,
 80        id: acp::SessionId,
 81        cx: &mut Context<Self>,
 82    ) -> Task<Result<()>> {
 83        let database_future = ThreadsDatabase::connect(cx);
 84        cx.spawn(async move |this, cx| {
 85            let database = database_future.await.map_err(|err| anyhow!(err))?;
 86            database.delete_thread(id.clone()).await?;
 87            this.update(cx, |this, cx| this.reload(cx))
 88        })
 89    }
 90
 91    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 92        let database_future = ThreadsDatabase::connect(cx);
 93        cx.spawn(async move |this, cx| {
 94            let database = database_future.await.map_err(|err| anyhow!(err))?;
 95            database.delete_threads().await?;
 96            this.update(cx, |this, cx| this.reload(cx))
 97        })
 98    }
 99
100    pub fn reload(&self, cx: &mut Context<Self>) {
101        let database_connection = ThreadsDatabase::connect(cx);
102        cx.spawn(async move |this, cx| {
103            let database = database_connection.await.map_err(|err| anyhow!(err))?;
104            let threads = database.list_threads().await?;
105            this.update(cx, |this, cx| {
106                this.threads = threads;
107                cx.notify();
108            })
109        })
110        .detach_and_log_err(cx);
111    }
112
113    pub fn is_empty(&self) -> bool {
114        self.threads.is_empty()
115    }
116
117    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
118        self.threads.iter().cloned()
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use chrono::{DateTime, TimeZone, Utc};
126    use collections::HashMap;
127    use gpui::TestAppContext;
128    use std::sync::Arc;
129
130    fn session_id(value: &str) -> acp::SessionId {
131        acp::SessionId::new(Arc::<str>::from(value))
132    }
133
134    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
135        DbThread {
136            title: title.to_string().into(),
137            messages: Vec::new(),
138            updated_at,
139            detailed_summary: None,
140            initial_project_snapshot: None,
141            cumulative_token_usage: Default::default(),
142            request_token_usage: HashMap::default(),
143            model: None,
144            profile: None,
145            imported: false,
146        }
147    }
148
149    #[gpui::test]
150    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
151        let thread_store = cx.new(|cx| ThreadStore::new(cx));
152        cx.run_until_parked();
153
154        let older_id = session_id("thread-a");
155        let newer_id = session_id("thread-b");
156
157        let older_thread = make_thread(
158            "Thread A",
159            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
160        );
161        let newer_thread = make_thread(
162            "Thread B",
163            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
164        );
165
166        let save_older = thread_store.update(cx, |store, cx| {
167            store.save_thread(older_id.clone(), older_thread, cx)
168        });
169        save_older.await.unwrap();
170
171        let save_newer = thread_store.update(cx, |store, cx| {
172            store.save_thread(newer_id.clone(), newer_thread, cx)
173        });
174        save_newer.await.unwrap();
175
176        cx.run_until_parked();
177
178        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
179        assert_eq!(entries.len(), 2);
180        assert_eq!(entries[0].id, newer_id);
181        assert_eq!(entries[1].id, older_id);
182    }
183
184    #[gpui::test]
185    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
186        let thread_store = cx.new(|cx| ThreadStore::new(cx));
187        cx.run_until_parked();
188
189        let thread_id = session_id("thread-a");
190        let thread = make_thread(
191            "Thread A",
192            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
193        );
194
195        let save_task =
196            thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
197        save_task.await.unwrap();
198
199        cx.run_until_parked();
200        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
201
202        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
203        delete_task.await.unwrap();
204        cx.run_until_parked();
205
206        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
207    }
208
209    #[gpui::test]
210    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
211        let thread_store = cx.new(|cx| ThreadStore::new(cx));
212        cx.run_until_parked();
213
214        let first_id = session_id("thread-a");
215        let second_id = session_id("thread-b");
216
217        let first_thread = make_thread(
218            "Thread A",
219            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
220        );
221        let second_thread = make_thread(
222            "Thread B",
223            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
224        );
225
226        let save_first = thread_store.update(cx, |store, cx| {
227            store.save_thread(first_id.clone(), first_thread, cx)
228        });
229        save_first.await.unwrap();
230        let save_second = thread_store.update(cx, |store, cx| {
231            store.save_thread(second_id.clone(), second_thread, cx)
232        });
233        save_second.await.unwrap();
234        cx.run_until_parked();
235
236        let delete_task =
237            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
238        delete_task.await.unwrap();
239        cx.run_until_parked();
240
241        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
242        assert_eq!(entries.len(), 1);
243        assert_eq!(entries[0].id, second_id);
244    }
245
246    #[gpui::test]
247    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
248        let thread_store = cx.new(|cx| ThreadStore::new(cx));
249        cx.run_until_parked();
250
251        let first_id = session_id("thread-a");
252        let second_id = session_id("thread-b");
253
254        let first_thread = make_thread(
255            "Thread A",
256            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
257        );
258        let second_thread = make_thread(
259            "Thread B",
260            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
261        );
262
263        let save_first = thread_store.update(cx, |store, cx| {
264            store.save_thread(first_id.clone(), first_thread, cx)
265        });
266        save_first.await.unwrap();
267        let save_second = thread_store.update(cx, |store, cx| {
268            store.save_thread(second_id.clone(), second_thread, cx)
269        });
270        save_second.await.unwrap();
271        cx.run_until_parked();
272
273        let updated_first = make_thread(
274            "Thread A",
275            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
276        );
277        let update_task = thread_store.update(cx, |store, cx| {
278            store.save_thread(first_id.clone(), updated_first, cx)
279        });
280        update_task.await.unwrap();
281        cx.run_until_parked();
282
283        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
284        assert_eq!(entries.len(), 2);
285        assert_eq!(entries[0].id, first_id);
286        assert_eq!(entries[1].id, second_id);
287    }
288}