1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Context, Entity, Global, Task, prelude::*};
5
6struct GlobalThreadStore(Entity<ThreadStore>);
7
8impl Global for GlobalThreadStore {}
9
10pub struct ThreadStore {
11 threads: Vec<DbThreadMetadata>,
12}
13
14impl ThreadStore {
15 pub fn init_global(cx: &mut App) {
16 let thread_store = cx.new(|cx| Self::new(cx));
17 cx.set_global(GlobalThreadStore(thread_store));
18 }
19
20 pub fn global(cx: &App) -> Entity<Self> {
21 cx.global::<GlobalThreadStore>().0.clone()
22 }
23
24 pub fn new(cx: &mut Context<Self>) -> Self {
25 let this = Self {
26 threads: Vec::new(),
27 };
28 this.reload(cx);
29 this
30 }
31
32 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
33 self.threads.iter().find(|thread| &thread.id == session_id)
34 }
35
36 pub fn load_thread(
37 &mut self,
38 id: acp::SessionId,
39 cx: &mut Context<Self>,
40 ) -> Task<Result<Option<DbThread>>> {
41 let database_future = ThreadsDatabase::connect(cx);
42 cx.background_spawn(async move {
43 let database = database_future.await.map_err(|err| anyhow!(err))?;
44 database.load_thread(id).await
45 })
46 }
47
48 pub fn save_thread(
49 &mut self,
50 id: acp::SessionId,
51 thread: crate::DbThread,
52 cx: &mut Context<Self>,
53 ) -> Task<Result<()>> {
54 let database_future = ThreadsDatabase::connect(cx);
55 cx.spawn(async move |this, cx| {
56 let database = database_future.await.map_err(|err| anyhow!(err))?;
57 database.save_thread(id, thread).await?;
58 this.update(cx, |this, cx| this.reload(cx))
59 })
60 }
61
62 pub fn delete_thread(
63 &mut self,
64 id: acp::SessionId,
65 cx: &mut Context<Self>,
66 ) -> Task<Result<()>> {
67 let database_future = ThreadsDatabase::connect(cx);
68 cx.spawn(async move |this, cx| {
69 let database = database_future.await.map_err(|err| anyhow!(err))?;
70 database.delete_thread(id.clone()).await?;
71 this.update(cx, |this, cx| this.reload(cx))
72 })
73 }
74
75 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
76 let database_future = ThreadsDatabase::connect(cx);
77 cx.spawn(async move |this, cx| {
78 let database = database_future.await.map_err(|err| anyhow!(err))?;
79 database.delete_threads().await?;
80 this.update(cx, |this, cx| this.reload(cx))
81 })
82 }
83
84 pub fn reload(&self, cx: &mut Context<Self>) {
85 let database_connection = ThreadsDatabase::connect(cx);
86 cx.spawn(async move |this, cx| {
87 let database = database_connection.await.map_err(|err| anyhow!(err))?;
88 let threads = database
89 .list_threads()
90 .await?
91 .into_iter()
92 .filter(|thread| thread.parent_session_id.is_none())
93 .collect::<Vec<_>>();
94 this.update(cx, |this, cx| {
95 this.threads = threads;
96 cx.notify();
97 })
98 })
99 .detach_and_log_err(cx);
100 }
101
102 pub fn is_empty(&self) -> bool {
103 self.threads.is_empty()
104 }
105
106 pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
107 self.threads.iter().cloned()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use chrono::{DateTime, TimeZone, Utc};
115 use collections::HashMap;
116 use gpui::TestAppContext;
117 use std::sync::Arc;
118
119 fn session_id(value: &str) -> acp::SessionId {
120 acp::SessionId::new(Arc::<str>::from(value))
121 }
122
123 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
124 DbThread {
125 title: title.to_string().into(),
126 messages: Vec::new(),
127 updated_at,
128 detailed_summary: None,
129 initial_project_snapshot: None,
130 cumulative_token_usage: Default::default(),
131 request_token_usage: HashMap::default(),
132 model: None,
133 profile: None,
134 imported: false,
135 subagent_context: None,
136 }
137 }
138
139 #[gpui::test]
140 async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
141 let thread_store = cx.new(|cx| ThreadStore::new(cx));
142 cx.run_until_parked();
143
144 let older_id = session_id("thread-a");
145 let newer_id = session_id("thread-b");
146
147 let older_thread = make_thread(
148 "Thread A",
149 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
150 );
151 let newer_thread = make_thread(
152 "Thread B",
153 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
154 );
155
156 let save_older = thread_store.update(cx, |store, cx| {
157 store.save_thread(older_id.clone(), older_thread, cx)
158 });
159 save_older.await.unwrap();
160
161 let save_newer = thread_store.update(cx, |store, cx| {
162 store.save_thread(newer_id.clone(), newer_thread, cx)
163 });
164 save_newer.await.unwrap();
165
166 cx.run_until_parked();
167
168 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
169 assert_eq!(entries.len(), 2);
170 assert_eq!(entries[0].id, newer_id);
171 assert_eq!(entries[1].id, older_id);
172 }
173
174 #[gpui::test]
175 async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
176 let thread_store = cx.new(|cx| ThreadStore::new(cx));
177 cx.run_until_parked();
178
179 let thread_id = session_id("thread-a");
180 let thread = make_thread(
181 "Thread A",
182 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
183 );
184
185 let save_task =
186 thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
187 save_task.await.unwrap();
188
189 cx.run_until_parked();
190 assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
191
192 let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
193 delete_task.await.unwrap();
194 cx.run_until_parked();
195
196 assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
197 }
198
199 #[gpui::test]
200 async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
201 let thread_store = cx.new(|cx| ThreadStore::new(cx));
202 cx.run_until_parked();
203
204 let first_id = session_id("thread-a");
205 let second_id = session_id("thread-b");
206
207 let first_thread = make_thread(
208 "Thread A",
209 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
210 );
211 let second_thread = make_thread(
212 "Thread B",
213 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
214 );
215
216 let save_first = thread_store.update(cx, |store, cx| {
217 store.save_thread(first_id.clone(), first_thread, cx)
218 });
219 save_first.await.unwrap();
220 let save_second = thread_store.update(cx, |store, cx| {
221 store.save_thread(second_id.clone(), second_thread, cx)
222 });
223 save_second.await.unwrap();
224 cx.run_until_parked();
225
226 let delete_task =
227 thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
228 delete_task.await.unwrap();
229 cx.run_until_parked();
230
231 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
232 assert_eq!(entries.len(), 1);
233 assert_eq!(entries[0].id, second_id);
234 }
235
236 #[gpui::test]
237 async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
238 let thread_store = cx.new(|cx| ThreadStore::new(cx));
239 cx.run_until_parked();
240
241 let first_id = session_id("thread-a");
242 let second_id = session_id("thread-b");
243
244 let first_thread = make_thread(
245 "Thread A",
246 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
247 );
248 let second_thread = make_thread(
249 "Thread B",
250 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
251 );
252
253 let save_first = thread_store.update(cx, |store, cx| {
254 store.save_thread(first_id.clone(), first_thread, cx)
255 });
256 save_first.await.unwrap();
257 let save_second = thread_store.update(cx, |store, cx| {
258 store.save_thread(second_id.clone(), second_thread, cx)
259 });
260 save_second.await.unwrap();
261 cx.run_until_parked();
262
263 let updated_first = make_thread(
264 "Thread A",
265 Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
266 );
267 let update_task = thread_store.update(cx, |store, cx| {
268 store.save_thread(first_id.clone(), updated_first, cx)
269 });
270 update_task.await.unwrap();
271 cx.run_until_parked();
272
273 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
274 assert_eq!(entries.len(), 2);
275 assert_eq!(entries[0].id, first_id);
276 assert_eq!(entries[1].id, second_id);
277 }
278}