thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Task, prelude::*};
  5use project::Project;
  6use std::rc::Rc;
  7
  8// TODO: Remove once ACP thread loading is fully handled elsewhere.
  9pub fn load_agent_thread(
 10    session_id: acp::SessionId,
 11    thread_store: Entity<ThreadStore>,
 12    project: Entity<Project>,
 13    cx: &mut App,
 14) -> Task<Result<Entity<crate::Thread>>> {
 15    use agent_servers::{AgentServer, AgentServerDelegate};
 16
 17    let server = Rc::new(crate::NativeAgentServer::new(
 18        project.read(cx).fs().clone(),
 19        thread_store,
 20    ));
 21    let delegate = AgentServerDelegate::new(
 22        project.read(cx).agent_server_store().clone(),
 23        project.clone(),
 24        None,
 25        None,
 26    );
 27    let connection = server.connect(None, delegate, cx);
 28    cx.spawn(async move |cx| {
 29        let (agent, _) = connection.await?;
 30        let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
 31        cx.update(|cx| agent.load_thread(session_id, cx)).await
 32    })
 33}
 34
 35pub struct ThreadStore {
 36    threads: Vec<DbThreadMetadata>,
 37}
 38
 39impl ThreadStore {
 40    pub fn new(cx: &mut Context<Self>) -> Self {
 41        let this = Self {
 42            threads: Vec::new(),
 43        };
 44        this.reload(cx);
 45        this
 46    }
 47
 48    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 49        self.threads.iter().find(|thread| &thread.id == session_id)
 50    }
 51
 52    pub fn load_thread(
 53        &mut self,
 54        id: acp::SessionId,
 55        cx: &mut Context<Self>,
 56    ) -> Task<Result<Option<DbThread>>> {
 57        let database_future = ThreadsDatabase::connect(cx);
 58        cx.background_spawn(async move {
 59            let database = database_future.await.map_err(|err| anyhow!(err))?;
 60            database.load_thread(id).await
 61        })
 62    }
 63
 64    pub fn save_thread(
 65        &mut self,
 66        id: acp::SessionId,
 67        thread: crate::DbThread,
 68        cx: &mut Context<Self>,
 69    ) -> Task<Result<()>> {
 70        let database_future = ThreadsDatabase::connect(cx);
 71        cx.spawn(async move |this, cx| {
 72            let database = database_future.await.map_err(|err| anyhow!(err))?;
 73            database.save_thread(id, thread).await?;
 74            this.update(cx, |this, cx| this.reload(cx))
 75        })
 76    }
 77
 78    pub fn delete_thread(
 79        &mut self,
 80        id: acp::SessionId,
 81        cx: &mut Context<Self>,
 82    ) -> Task<Result<()>> {
 83        let database_future = ThreadsDatabase::connect(cx);
 84        cx.spawn(async move |this, cx| {
 85            let database = database_future.await.map_err(|err| anyhow!(err))?;
 86            database.delete_thread(id.clone()).await?;
 87            this.update(cx, |this, cx| this.reload(cx))
 88        })
 89    }
 90
 91    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 92        let database_future = ThreadsDatabase::connect(cx);
 93        cx.spawn(async move |this, cx| {
 94            let database = database_future.await.map_err(|err| anyhow!(err))?;
 95            database.delete_threads().await?;
 96            this.update(cx, |this, cx| this.reload(cx))
 97        })
 98    }
 99
100    pub fn reload(&self, cx: &mut Context<Self>) {
101        let database_connection = ThreadsDatabase::connect(cx);
102        cx.spawn(async move |this, cx| {
103            let database = database_connection.await.map_err(|err| anyhow!(err))?;
104            let threads = database.list_threads().await?;
105            this.update(cx, |this, cx| {
106                this.threads = threads;
107                cx.notify();
108            })
109        })
110        .detach_and_log_err(cx);
111    }
112
113    pub fn is_empty(&self) -> bool {
114        self.threads.is_empty()
115    }
116
117    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
118        self.threads.iter().cloned()
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use chrono::{DateTime, TimeZone, Utc};
126    use collections::HashMap;
127    use gpui::TestAppContext;
128    use std::sync::Arc;
129
130    fn session_id(value: &str) -> acp::SessionId {
131        acp::SessionId::new(Arc::<str>::from(value))
132    }
133
134    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
135        DbThread {
136            title: title.to_string().into(),
137            messages: Vec::new(),
138            updated_at,
139            detailed_summary: None,
140            initial_project_snapshot: None,
141            cumulative_token_usage: Default::default(),
142            request_token_usage: HashMap::default(),
143            model: None,
144            completion_mode: None,
145            profile: None,
146            imported: false,
147        }
148    }
149
150    #[gpui::test]
151    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
152        let thread_store = cx.new(|cx| ThreadStore::new(cx));
153        cx.run_until_parked();
154
155        let older_id = session_id("thread-a");
156        let newer_id = session_id("thread-b");
157
158        let older_thread = make_thread(
159            "Thread A",
160            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
161        );
162        let newer_thread = make_thread(
163            "Thread B",
164            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
165        );
166
167        let save_older = thread_store.update(cx, |store, cx| {
168            store.save_thread(older_id.clone(), older_thread, cx)
169        });
170        save_older.await.unwrap();
171
172        let save_newer = thread_store.update(cx, |store, cx| {
173            store.save_thread(newer_id.clone(), newer_thread, cx)
174        });
175        save_newer.await.unwrap();
176
177        cx.run_until_parked();
178
179        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
180        assert_eq!(entries.len(), 2);
181        assert_eq!(entries[0].id, newer_id);
182        assert_eq!(entries[1].id, older_id);
183    }
184
185    #[gpui::test]
186    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
187        let thread_store = cx.new(|cx| ThreadStore::new(cx));
188        cx.run_until_parked();
189
190        let thread_id = session_id("thread-a");
191        let thread = make_thread(
192            "Thread A",
193            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
194        );
195
196        let save_task =
197            thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
198        save_task.await.unwrap();
199
200        cx.run_until_parked();
201        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
202
203        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
204        delete_task.await.unwrap();
205        cx.run_until_parked();
206
207        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
208    }
209
210    #[gpui::test]
211    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
212        let thread_store = cx.new(|cx| ThreadStore::new(cx));
213        cx.run_until_parked();
214
215        let first_id = session_id("thread-a");
216        let second_id = session_id("thread-b");
217
218        let first_thread = make_thread(
219            "Thread A",
220            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
221        );
222        let second_thread = make_thread(
223            "Thread B",
224            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
225        );
226
227        let save_first = thread_store.update(cx, |store, cx| {
228            store.save_thread(first_id.clone(), first_thread, cx)
229        });
230        save_first.await.unwrap();
231        let save_second = thread_store.update(cx, |store, cx| {
232            store.save_thread(second_id.clone(), second_thread, cx)
233        });
234        save_second.await.unwrap();
235        cx.run_until_parked();
236
237        let delete_task =
238            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
239        delete_task.await.unwrap();
240        cx.run_until_parked();
241
242        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
243        assert_eq!(entries.len(), 1);
244        assert_eq!(entries[0].id, second_id);
245    }
246
247    #[gpui::test]
248    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
249        let thread_store = cx.new(|cx| ThreadStore::new(cx));
250        cx.run_until_parked();
251
252        let first_id = session_id("thread-a");
253        let second_id = session_id("thread-b");
254
255        let first_thread = make_thread(
256            "Thread A",
257            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
258        );
259        let second_thread = make_thread(
260            "Thread B",
261            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
262        );
263
264        let save_first = thread_store.update(cx, |store, cx| {
265            store.save_thread(first_id.clone(), first_thread, cx)
266        });
267        save_first.await.unwrap();
268        let save_second = thread_store.update(cx, |store, cx| {
269            store.save_thread(second_id.clone(), second_thread, cx)
270        });
271        save_second.await.unwrap();
272        cx.run_until_parked();
273
274        let updated_first = make_thread(
275            "Thread A",
276            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
277        );
278        let update_task = thread_store.update(cx, |store, cx| {
279            store.save_thread(first_id.clone(), updated_first, cx)
280        });
281        update_task.await.unwrap();
282        cx.run_until_parked();
283
284        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
285        assert_eq!(entries.len(), 2);
286        assert_eq!(entries[0].id, first_id);
287        assert_eq!(entries[1].id, second_id);
288    }
289}