1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Context, Entity, Global, Task, prelude::*};
5use util::path_list::PathList;
6
7struct GlobalThreadStore(Entity<ThreadStore>);
8
9impl Global for GlobalThreadStore {}
10
11pub struct ThreadStore {
12 threads: Vec<DbThreadMetadata>,
13}
14
15impl ThreadStore {
16 pub fn init_global(cx: &mut App) {
17 let thread_store = cx.new(|cx| Self::new(cx));
18 cx.set_global(GlobalThreadStore(thread_store));
19 }
20
21 pub fn global(cx: &App) -> Entity<Self> {
22 cx.global::<GlobalThreadStore>().0.clone()
23 }
24
25 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
26 cx.try_global::<GlobalThreadStore>().map(|g| g.0.clone())
27 }
28
29 pub fn new(cx: &mut Context<Self>) -> Self {
30 let this = Self {
31 threads: Vec::new(),
32 };
33 this.reload(cx);
34 this
35 }
36
37 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
38 self.threads.iter().find(|thread| &thread.id == session_id)
39 }
40
41 pub fn load_thread(
42 &mut self,
43 id: acp::SessionId,
44 cx: &mut Context<Self>,
45 ) -> Task<Result<Option<DbThread>>> {
46 let database_future = ThreadsDatabase::connect(cx);
47 cx.background_spawn(async move {
48 let database = database_future.await.map_err(|err| anyhow!(err))?;
49 database.load_thread(id).await
50 })
51 }
52
53 pub fn save_thread(
54 &mut self,
55 id: acp::SessionId,
56 thread: crate::DbThread,
57 folder_paths: PathList,
58 cx: &mut Context<Self>,
59 ) -> Task<Result<()>> {
60 let database_future = ThreadsDatabase::connect(cx);
61 cx.spawn(async move |this, cx| {
62 let database = database_future.await.map_err(|err| anyhow!(err))?;
63 database.save_thread(id, thread, folder_paths).await?;
64 this.update(cx, |this, cx| this.reload(cx))
65 })
66 }
67
68 pub fn delete_thread(
69 &mut self,
70 id: acp::SessionId,
71 cx: &mut Context<Self>,
72 ) -> Task<Result<()>> {
73 let database_future = ThreadsDatabase::connect(cx);
74 cx.spawn(async move |this, cx| {
75 let database = database_future.await.map_err(|err| anyhow!(err))?;
76 database.delete_thread(id.clone()).await?;
77 this.update(cx, |this, cx| this.reload(cx))
78 })
79 }
80
81 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
82 let database_future = ThreadsDatabase::connect(cx);
83 cx.spawn(async move |this, cx| {
84 let database = database_future.await.map_err(|err| anyhow!(err))?;
85 database.delete_threads().await?;
86 this.update(cx, |this, cx| this.reload(cx))
87 })
88 }
89
90 pub fn reload(&self, cx: &mut Context<Self>) {
91 let database_connection = ThreadsDatabase::connect(cx);
92 cx.spawn(async move |this, cx| {
93 let database = database_connection.await.map_err(|err| anyhow!(err))?;
94 let all_threads = database.list_threads().await?;
95 this.update(cx, |this, cx| {
96 this.threads.clear();
97 for thread in all_threads {
98 if thread.parent_session_id.is_some() {
99 continue;
100 }
101 this.threads.push(thread);
102 }
103 cx.notify();
104 })
105 })
106 .detach_and_log_err(cx);
107 }
108
109 pub fn is_empty(&self) -> bool {
110 self.threads.is_empty()
111 }
112
113 pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
114 self.threads.iter().cloned()
115 }
116
117 pub fn entry_ids(&self) -> impl Iterator<Item = acp::SessionId> + '_ {
118 self.threads.iter().map(|t| t.id.clone())
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use chrono::{DateTime, TimeZone, Utc};
126 use collections::HashMap;
127 use gpui::TestAppContext;
128 use std::sync::Arc;
129
130 fn session_id(value: &str) -> acp::SessionId {
131 acp::SessionId::new(Arc::<str>::from(value))
132 }
133
134 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
135 DbThread {
136 title: title.to_string().into(),
137 messages: Vec::new(),
138 updated_at,
139 detailed_summary: None,
140 initial_project_snapshot: None,
141 cumulative_token_usage: Default::default(),
142 request_token_usage: HashMap::default(),
143 model: None,
144 profile: None,
145 imported: false,
146 subagent_context: None,
147 speed: None,
148 thinking_enabled: false,
149 thinking_effort: None,
150 draft_prompt: None,
151 ui_scroll_position: None,
152 }
153 }
154
155 #[gpui::test]
156 async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
157 let thread_store = cx.new(|cx| ThreadStore::new(cx));
158 cx.run_until_parked();
159
160 let older_id = session_id("thread-a");
161 let newer_id = session_id("thread-b");
162
163 let older_thread = make_thread(
164 "Thread A",
165 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
166 );
167 let newer_thread = make_thread(
168 "Thread B",
169 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
170 );
171
172 let save_older = thread_store.update(cx, |store, cx| {
173 store.save_thread(older_id.clone(), older_thread, PathList::default(), cx)
174 });
175 save_older.await.unwrap();
176
177 let save_newer = thread_store.update(cx, |store, cx| {
178 store.save_thread(newer_id.clone(), newer_thread, PathList::default(), cx)
179 });
180 save_newer.await.unwrap();
181
182 cx.run_until_parked();
183
184 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
185 assert_eq!(entries.len(), 2);
186 assert_eq!(entries[0].id, newer_id);
187 assert_eq!(entries[1].id, older_id);
188 }
189
190 #[gpui::test]
191 async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
192 let thread_store = cx.new(|cx| ThreadStore::new(cx));
193 cx.run_until_parked();
194
195 let thread_id = session_id("thread-a");
196 let thread = make_thread(
197 "Thread A",
198 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
199 );
200
201 let save_task = thread_store.update(cx, |store, cx| {
202 store.save_thread(thread_id, thread, PathList::default(), cx)
203 });
204 save_task.await.unwrap();
205
206 cx.run_until_parked();
207 assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
208
209 let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
210 delete_task.await.unwrap();
211 cx.run_until_parked();
212
213 assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
214 }
215
216 #[gpui::test]
217 async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
218 let thread_store = cx.new(|cx| ThreadStore::new(cx));
219 cx.run_until_parked();
220
221 let first_id = session_id("thread-a");
222 let second_id = session_id("thread-b");
223
224 let first_thread = make_thread(
225 "Thread A",
226 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
227 );
228 let second_thread = make_thread(
229 "Thread B",
230 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
231 );
232
233 let save_first = thread_store.update(cx, |store, cx| {
234 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
235 });
236 save_first.await.unwrap();
237 let save_second = thread_store.update(cx, |store, cx| {
238 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
239 });
240 save_second.await.unwrap();
241 cx.run_until_parked();
242
243 let delete_task =
244 thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
245 delete_task.await.unwrap();
246 cx.run_until_parked();
247
248 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
249 assert_eq!(entries.len(), 1);
250 assert_eq!(entries[0].id, second_id);
251 }
252
253 #[gpui::test]
254 async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
255 let thread_store = cx.new(|cx| ThreadStore::new(cx));
256 cx.run_until_parked();
257
258 let first_id = session_id("thread-a");
259 let second_id = session_id("thread-b");
260
261 let first_thread = make_thread(
262 "Thread A",
263 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
264 );
265 let second_thread = make_thread(
266 "Thread B",
267 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
268 );
269
270 let save_first = thread_store.update(cx, |store, cx| {
271 store.save_thread(first_id.clone(), first_thread, PathList::default(), cx)
272 });
273 save_first.await.unwrap();
274 let save_second = thread_store.update(cx, |store, cx| {
275 store.save_thread(second_id.clone(), second_thread, PathList::default(), cx)
276 });
277 save_second.await.unwrap();
278 cx.run_until_parked();
279
280 let updated_first = make_thread(
281 "Thread A",
282 Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
283 );
284 let update_task = thread_store.update(cx, |store, cx| {
285 store.save_thread(first_id.clone(), updated_first, PathList::default(), cx)
286 });
287 update_task.await.unwrap();
288 cx.run_until_parked();
289
290 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
291 assert_eq!(entries.len(), 2);
292 assert_eq!(entries[0].id, first_id);
293 assert_eq!(entries[1].id, second_id);
294 }
295}