thread_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Context, Entity, Global, Task, prelude::*};
  5use project::Project;
  6use std::rc::Rc;
  7
  8struct GlobalThreadStore(Entity<ThreadStore>);
  9
 10impl Global for GlobalThreadStore {}
 11
 12// TODO: Remove once ACP thread loading is fully handled elsewhere.
 13pub fn load_agent_thread(
 14    session_id: acp::SessionId,
 15    thread_store: Entity<ThreadStore>,
 16    project: Entity<Project>,
 17    cx: &mut App,
 18) -> Task<Result<Entity<crate::Thread>>> {
 19    use agent_servers::{AgentServer, AgentServerDelegate};
 20
 21    let server = Rc::new(crate::NativeAgentServer::new(
 22        project.read(cx).fs().clone(),
 23        thread_store,
 24    ));
 25    let delegate = AgentServerDelegate::new(
 26        project.read(cx).agent_server_store().clone(),
 27        project.clone(),
 28        None,
 29        None,
 30    );
 31    let connection = server.connect(None, delegate, cx);
 32    cx.spawn(async move |cx| {
 33        let (agent, _) = connection.await?;
 34        let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
 35        cx.update(|cx| agent.load_thread(session_id, cx)).await
 36    })
 37}
 38
 39pub struct ThreadStore {
 40    threads: Vec<DbThreadMetadata>,
 41}
 42
 43impl ThreadStore {
 44    pub fn init_global(cx: &mut App) {
 45        let thread_store = cx.new(|cx| Self::new(cx));
 46        cx.set_global(GlobalThreadStore(thread_store));
 47    }
 48
 49    pub fn global(cx: &App) -> Entity<Self> {
 50        cx.global::<GlobalThreadStore>().0.clone()
 51    }
 52
 53    pub fn new(cx: &mut Context<Self>) -> Self {
 54        let this = Self {
 55            threads: Vec::new(),
 56        };
 57        this.reload(cx);
 58        this
 59    }
 60
 61    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
 62        self.threads.iter().find(|thread| &thread.id == session_id)
 63    }
 64
 65    pub fn load_thread(
 66        &mut self,
 67        id: acp::SessionId,
 68        cx: &mut Context<Self>,
 69    ) -> Task<Result<Option<DbThread>>> {
 70        let database_future = ThreadsDatabase::connect(cx);
 71        cx.background_spawn(async move {
 72            let database = database_future.await.map_err(|err| anyhow!(err))?;
 73            database.load_thread(id).await
 74        })
 75    }
 76
 77    pub fn save_thread(
 78        &mut self,
 79        id: acp::SessionId,
 80        thread: crate::DbThread,
 81        cx: &mut Context<Self>,
 82    ) -> Task<Result<()>> {
 83        let database_future = ThreadsDatabase::connect(cx);
 84        cx.spawn(async move |this, cx| {
 85            let database = database_future.await.map_err(|err| anyhow!(err))?;
 86            database.save_thread(id, thread).await?;
 87            this.update(cx, |this, cx| this.reload(cx))
 88        })
 89    }
 90
 91    pub fn delete_thread(
 92        &mut self,
 93        id: acp::SessionId,
 94        cx: &mut Context<Self>,
 95    ) -> Task<Result<()>> {
 96        let database_future = ThreadsDatabase::connect(cx);
 97        cx.spawn(async move |this, cx| {
 98            let database = database_future.await.map_err(|err| anyhow!(err))?;
 99            database.delete_thread(id.clone()).await?;
100            this.update(cx, |this, cx| this.reload(cx))
101        })
102    }
103
104    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
105        let database_future = ThreadsDatabase::connect(cx);
106        cx.spawn(async move |this, cx| {
107            let database = database_future.await.map_err(|err| anyhow!(err))?;
108            database.delete_threads().await?;
109            this.update(cx, |this, cx| this.reload(cx))
110        })
111    }
112
113    pub fn reload(&self, cx: &mut Context<Self>) {
114        let database_connection = ThreadsDatabase::connect(cx);
115        cx.spawn(async move |this, cx| {
116            let database = database_connection.await.map_err(|err| anyhow!(err))?;
117            let threads = database
118                .list_threads()
119                .await?
120                .into_iter()
121                .filter(|thread| thread.parent_session_id.is_none())
122                .collect::<Vec<_>>();
123            this.update(cx, |this, cx| {
124                this.threads = threads;
125                cx.notify();
126            })
127        })
128        .detach_and_log_err(cx);
129    }
130
131    pub fn is_empty(&self) -> bool {
132        self.threads.is_empty()
133    }
134
135    pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
136        self.threads.iter().cloned()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use chrono::{DateTime, TimeZone, Utc};
144    use collections::HashMap;
145    use gpui::TestAppContext;
146    use std::sync::Arc;
147
148    fn session_id(value: &str) -> acp::SessionId {
149        acp::SessionId::new(Arc::<str>::from(value))
150    }
151
152    fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
153        DbThread {
154            title: title.to_string().into(),
155            messages: Vec::new(),
156            updated_at,
157            detailed_summary: None,
158            initial_project_snapshot: None,
159            cumulative_token_usage: Default::default(),
160            request_token_usage: HashMap::default(),
161            model: None,
162            profile: None,
163            imported: false,
164            subagent_context: None,
165        }
166    }
167
168    #[gpui::test]
169    async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
170        let thread_store = cx.new(|cx| ThreadStore::new(cx));
171        cx.run_until_parked();
172
173        let older_id = session_id("thread-a");
174        let newer_id = session_id("thread-b");
175
176        let older_thread = make_thread(
177            "Thread A",
178            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
179        );
180        let newer_thread = make_thread(
181            "Thread B",
182            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
183        );
184
185        let save_older = thread_store.update(cx, |store, cx| {
186            store.save_thread(older_id.clone(), older_thread, cx)
187        });
188        save_older.await.unwrap();
189
190        let save_newer = thread_store.update(cx, |store, cx| {
191            store.save_thread(newer_id.clone(), newer_thread, cx)
192        });
193        save_newer.await.unwrap();
194
195        cx.run_until_parked();
196
197        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
198        assert_eq!(entries.len(), 2);
199        assert_eq!(entries[0].id, newer_id);
200        assert_eq!(entries[1].id, older_id);
201    }
202
203    #[gpui::test]
204    async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
205        let thread_store = cx.new(|cx| ThreadStore::new(cx));
206        cx.run_until_parked();
207
208        let thread_id = session_id("thread-a");
209        let thread = make_thread(
210            "Thread A",
211            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
212        );
213
214        let save_task =
215            thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
216        save_task.await.unwrap();
217
218        cx.run_until_parked();
219        assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
220
221        let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
222        delete_task.await.unwrap();
223        cx.run_until_parked();
224
225        assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
226    }
227
228    #[gpui::test]
229    async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
230        let thread_store = cx.new(|cx| ThreadStore::new(cx));
231        cx.run_until_parked();
232
233        let first_id = session_id("thread-a");
234        let second_id = session_id("thread-b");
235
236        let first_thread = make_thread(
237            "Thread A",
238            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
239        );
240        let second_thread = make_thread(
241            "Thread B",
242            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
243        );
244
245        let save_first = thread_store.update(cx, |store, cx| {
246            store.save_thread(first_id.clone(), first_thread, cx)
247        });
248        save_first.await.unwrap();
249        let save_second = thread_store.update(cx, |store, cx| {
250            store.save_thread(second_id.clone(), second_thread, cx)
251        });
252        save_second.await.unwrap();
253        cx.run_until_parked();
254
255        let delete_task =
256            thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
257        delete_task.await.unwrap();
258        cx.run_until_parked();
259
260        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
261        assert_eq!(entries.len(), 1);
262        assert_eq!(entries[0].id, second_id);
263    }
264
265    #[gpui::test]
266    async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
267        let thread_store = cx.new(|cx| ThreadStore::new(cx));
268        cx.run_until_parked();
269
270        let first_id = session_id("thread-a");
271        let second_id = session_id("thread-b");
272
273        let first_thread = make_thread(
274            "Thread A",
275            Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
276        );
277        let second_thread = make_thread(
278            "Thread B",
279            Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
280        );
281
282        let save_first = thread_store.update(cx, |store, cx| {
283            store.save_thread(first_id.clone(), first_thread, cx)
284        });
285        save_first.await.unwrap();
286        let save_second = thread_store.update(cx, |store, cx| {
287            store.save_thread(second_id.clone(), second_thread, cx)
288        });
289        save_second.await.unwrap();
290        cx.run_until_parked();
291
292        let updated_first = make_thread(
293            "Thread A",
294            Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
295        );
296        let update_task = thread_store.update(cx, |store, cx| {
297            store.save_thread(first_id.clone(), updated_first, cx)
298        });
299        update_task.await.unwrap();
300        cx.run_until_parked();
301
302        let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
303        assert_eq!(entries.len(), 2);
304        assert_eq!(entries[0].id, first_id);
305        assert_eq!(entries[1].id, second_id);
306    }
307}