1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Context, Entity, Global, Task, prelude::*};
5use project::Project;
6use std::rc::Rc;
7
8struct GlobalThreadStore(Entity<ThreadStore>);
9
10impl Global for GlobalThreadStore {}
11
12// TODO: Remove once ACP thread loading is fully handled elsewhere.
13pub fn load_agent_thread(
14 session_id: acp::SessionId,
15 thread_store: Entity<ThreadStore>,
16 project: Entity<Project>,
17 cx: &mut App,
18) -> Task<Result<Entity<crate::Thread>>> {
19 use agent_servers::{AgentServer, AgentServerDelegate};
20
21 let server = Rc::new(crate::NativeAgentServer::new(
22 project.read(cx).fs().clone(),
23 thread_store,
24 ));
25 let delegate = AgentServerDelegate::new(
26 project.read(cx).agent_server_store().clone(),
27 project.clone(),
28 None,
29 None,
30 );
31 let connection = server.connect(None, delegate, cx);
32 cx.spawn(async move |cx| {
33 let (agent, _) = connection.await?;
34 let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
35 cx.update(|cx| agent.load_thread(session_id, cx)).await
36 })
37}
38
39pub struct ThreadStore {
40 threads: Vec<DbThreadMetadata>,
41}
42
43impl ThreadStore {
44 pub fn init_global(cx: &mut App) {
45 let thread_store = cx.new(|cx| Self::new(cx));
46 cx.set_global(GlobalThreadStore(thread_store));
47 }
48
49 pub fn global(cx: &App) -> Entity<Self> {
50 cx.global::<GlobalThreadStore>().0.clone()
51 }
52
53 pub fn new(cx: &mut Context<Self>) -> Self {
54 let this = Self {
55 threads: Vec::new(),
56 };
57 this.reload(cx);
58 this
59 }
60
61 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
62 self.threads.iter().find(|thread| &thread.id == session_id)
63 }
64
65 pub fn load_thread(
66 &mut self,
67 id: acp::SessionId,
68 cx: &mut Context<Self>,
69 ) -> Task<Result<Option<DbThread>>> {
70 let database_future = ThreadsDatabase::connect(cx);
71 cx.background_spawn(async move {
72 let database = database_future.await.map_err(|err| anyhow!(err))?;
73 database.load_thread(id).await
74 })
75 }
76
77 pub fn save_thread(
78 &mut self,
79 id: acp::SessionId,
80 thread: crate::DbThread,
81 cx: &mut Context<Self>,
82 ) -> Task<Result<()>> {
83 let database_future = ThreadsDatabase::connect(cx);
84 cx.spawn(async move |this, cx| {
85 let database = database_future.await.map_err(|err| anyhow!(err))?;
86 database.save_thread(id, thread).await?;
87 this.update(cx, |this, cx| this.reload(cx))
88 })
89 }
90
91 pub fn delete_thread(
92 &mut self,
93 id: acp::SessionId,
94 cx: &mut Context<Self>,
95 ) -> Task<Result<()>> {
96 let database_future = ThreadsDatabase::connect(cx);
97 cx.spawn(async move |this, cx| {
98 let database = database_future.await.map_err(|err| anyhow!(err))?;
99 database.delete_thread(id.clone()).await?;
100 this.update(cx, |this, cx| this.reload(cx))
101 })
102 }
103
104 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
105 let database_future = ThreadsDatabase::connect(cx);
106 cx.spawn(async move |this, cx| {
107 let database = database_future.await.map_err(|err| anyhow!(err))?;
108 database.delete_threads().await?;
109 this.update(cx, |this, cx| this.reload(cx))
110 })
111 }
112
113 pub fn reload(&self, cx: &mut Context<Self>) {
114 let database_connection = ThreadsDatabase::connect(cx);
115 cx.spawn(async move |this, cx| {
116 let database = database_connection.await.map_err(|err| anyhow!(err))?;
117 let threads = database
118 .list_threads()
119 .await?
120 .into_iter()
121 .filter(|thread| thread.parent_session_id.is_none())
122 .collect::<Vec<_>>();
123 this.update(cx, |this, cx| {
124 this.threads = threads;
125 cx.notify();
126 })
127 })
128 .detach_and_log_err(cx);
129 }
130
131 pub fn is_empty(&self) -> bool {
132 self.threads.is_empty()
133 }
134
135 pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
136 self.threads.iter().cloned()
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use chrono::{DateTime, TimeZone, Utc};
144 use collections::HashMap;
145 use gpui::TestAppContext;
146 use std::sync::Arc;
147
148 fn session_id(value: &str) -> acp::SessionId {
149 acp::SessionId::new(Arc::<str>::from(value))
150 }
151
152 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
153 DbThread {
154 title: title.to_string().into(),
155 messages: Vec::new(),
156 updated_at,
157 detailed_summary: None,
158 initial_project_snapshot: None,
159 cumulative_token_usage: Default::default(),
160 request_token_usage: HashMap::default(),
161 model: None,
162 profile: None,
163 imported: false,
164 subagent_context: None,
165 git_worktree_info: None,
166 }
167 }
168
169 #[gpui::test]
170 async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
171 let thread_store = cx.new(|cx| ThreadStore::new(cx));
172 cx.run_until_parked();
173
174 let older_id = session_id("thread-a");
175 let newer_id = session_id("thread-b");
176
177 let older_thread = make_thread(
178 "Thread A",
179 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
180 );
181 let newer_thread = make_thread(
182 "Thread B",
183 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
184 );
185
186 let save_older = thread_store.update(cx, |store, cx| {
187 store.save_thread(older_id.clone(), older_thread, cx)
188 });
189 save_older.await.unwrap();
190
191 let save_newer = thread_store.update(cx, |store, cx| {
192 store.save_thread(newer_id.clone(), newer_thread, cx)
193 });
194 save_newer.await.unwrap();
195
196 cx.run_until_parked();
197
198 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
199 assert_eq!(entries.len(), 2);
200 assert_eq!(entries[0].id, newer_id);
201 assert_eq!(entries[1].id, older_id);
202 }
203
204 #[gpui::test]
205 async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
206 let thread_store = cx.new(|cx| ThreadStore::new(cx));
207 cx.run_until_parked();
208
209 let thread_id = session_id("thread-a");
210 let thread = make_thread(
211 "Thread A",
212 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
213 );
214
215 let save_task =
216 thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
217 save_task.await.unwrap();
218
219 cx.run_until_parked();
220 assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
221
222 let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
223 delete_task.await.unwrap();
224 cx.run_until_parked();
225
226 assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
227 }
228
229 #[gpui::test]
230 async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
231 let thread_store = cx.new(|cx| ThreadStore::new(cx));
232 cx.run_until_parked();
233
234 let first_id = session_id("thread-a");
235 let second_id = session_id("thread-b");
236
237 let first_thread = make_thread(
238 "Thread A",
239 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
240 );
241 let second_thread = make_thread(
242 "Thread B",
243 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
244 );
245
246 let save_first = thread_store.update(cx, |store, cx| {
247 store.save_thread(first_id.clone(), first_thread, cx)
248 });
249 save_first.await.unwrap();
250 let save_second = thread_store.update(cx, |store, cx| {
251 store.save_thread(second_id.clone(), second_thread, cx)
252 });
253 save_second.await.unwrap();
254 cx.run_until_parked();
255
256 let delete_task =
257 thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
258 delete_task.await.unwrap();
259 cx.run_until_parked();
260
261 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
262 assert_eq!(entries.len(), 1);
263 assert_eq!(entries[0].id, second_id);
264 }
265
266 #[gpui::test]
267 async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
268 let thread_store = cx.new(|cx| ThreadStore::new(cx));
269 cx.run_until_parked();
270
271 let first_id = session_id("thread-a");
272 let second_id = session_id("thread-b");
273
274 let first_thread = make_thread(
275 "Thread A",
276 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
277 );
278 let second_thread = make_thread(
279 "Thread B",
280 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
281 );
282
283 let save_first = thread_store.update(cx, |store, cx| {
284 store.save_thread(first_id.clone(), first_thread, cx)
285 });
286 save_first.await.unwrap();
287 let save_second = thread_store.update(cx, |store, cx| {
288 store.save_thread(second_id.clone(), second_thread, cx)
289 });
290 save_second.await.unwrap();
291 cx.run_until_parked();
292
293 let updated_first = make_thread(
294 "Thread A",
295 Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
296 );
297 let update_task = thread_store.update(cx, |store, cx| {
298 store.save_thread(first_id.clone(), updated_first, cx)
299 });
300 update_task.await.unwrap();
301 cx.run_until_parked();
302
303 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
304 assert_eq!(entries.len(), 2);
305 assert_eq!(entries[0].id, first_id);
306 assert_eq!(entries[1].id, second_id);
307 }
308}