1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Context, Entity, Task, prelude::*};
5use project::Project;
6use std::rc::Rc;
7
8// TODO: Remove once ACP thread loading is fully handled elsewhere.
9pub fn load_agent_thread(
10 session_id: acp::SessionId,
11 thread_store: Entity<ThreadStore>,
12 project: Entity<Project>,
13 cx: &mut App,
14) -> Task<Result<Entity<crate::Thread>>> {
15 use agent_servers::{AgentServer, AgentServerDelegate};
16
17 let server = Rc::new(crate::NativeAgentServer::new(
18 project.read(cx).fs().clone(),
19 thread_store,
20 ));
21 let delegate = AgentServerDelegate::new(
22 project.read(cx).agent_server_store().clone(),
23 project.clone(),
24 None,
25 None,
26 );
27 let connection = server.connect(None, delegate, cx);
28 cx.spawn(async move |cx| {
29 let (agent, _) = connection.await?;
30 let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
31 cx.update(|cx| agent.load_thread(session_id, cx)).await
32 })
33}
34
35pub struct ThreadStore {
36 threads: Vec<DbThreadMetadata>,
37}
38
39impl ThreadStore {
40 pub fn new(cx: &mut Context<Self>) -> Self {
41 let this = Self {
42 threads: Vec::new(),
43 };
44 this.reload(cx);
45 this
46 }
47
48 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
49 self.threads.iter().find(|thread| &thread.id == session_id)
50 }
51
52 pub fn load_thread(
53 &mut self,
54 id: acp::SessionId,
55 cx: &mut Context<Self>,
56 ) -> Task<Result<Option<DbThread>>> {
57 let database_future = ThreadsDatabase::connect(cx);
58 cx.background_spawn(async move {
59 let database = database_future.await.map_err(|err| anyhow!(err))?;
60 database.load_thread(id).await
61 })
62 }
63
64 pub fn save_thread(
65 &mut self,
66 id: acp::SessionId,
67 thread: crate::DbThread,
68 cx: &mut Context<Self>,
69 ) -> Task<Result<()>> {
70 let database_future = ThreadsDatabase::connect(cx);
71 cx.spawn(async move |this, cx| {
72 let database = database_future.await.map_err(|err| anyhow!(err))?;
73 database.save_thread(id, thread).await?;
74 this.update(cx, |this, cx| this.reload(cx))
75 })
76 }
77
78 pub fn delete_thread(
79 &mut self,
80 id: acp::SessionId,
81 cx: &mut Context<Self>,
82 ) -> Task<Result<()>> {
83 let database_future = ThreadsDatabase::connect(cx);
84 cx.spawn(async move |this, cx| {
85 let database = database_future.await.map_err(|err| anyhow!(err))?;
86 database.delete_thread(id.clone()).await?;
87 this.update(cx, |this, cx| this.reload(cx))
88 })
89 }
90
91 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
92 let database_future = ThreadsDatabase::connect(cx);
93 cx.spawn(async move |this, cx| {
94 let database = database_future.await.map_err(|err| anyhow!(err))?;
95 database.delete_threads().await?;
96 this.update(cx, |this, cx| this.reload(cx))
97 })
98 }
99
100 pub fn reload(&self, cx: &mut Context<Self>) {
101 let database_connection = ThreadsDatabase::connect(cx);
102 cx.spawn(async move |this, cx| {
103 let database = database_connection.await.map_err(|err| anyhow!(err))?;
104 let threads = database.list_threads().await?;
105 this.update(cx, |this, cx| {
106 this.threads = threads;
107 cx.notify();
108 })
109 })
110 .detach_and_log_err(cx);
111 }
112
113 pub fn is_empty(&self) -> bool {
114 self.threads.is_empty()
115 }
116
117 pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
118 self.threads.iter().cloned()
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use chrono::{DateTime, TimeZone, Utc};
126 use collections::HashMap;
127 use gpui::TestAppContext;
128 use std::sync::Arc;
129
130 fn session_id(value: &str) -> acp::SessionId {
131 acp::SessionId::new(Arc::<str>::from(value))
132 }
133
134 fn make_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
135 DbThread {
136 title: title.to_string().into(),
137 messages: Vec::new(),
138 updated_at,
139 detailed_summary: None,
140 initial_project_snapshot: None,
141 cumulative_token_usage: Default::default(),
142 request_token_usage: HashMap::default(),
143 model: None,
144 completion_mode: None,
145 profile: None,
146 imported: false,
147 }
148 }
149
150 #[gpui::test]
151 async fn test_entries_are_sorted_by_updated_at(cx: &mut TestAppContext) {
152 let thread_store = cx.new(|cx| ThreadStore::new(cx));
153 cx.run_until_parked();
154
155 let older_id = session_id("thread-a");
156 let newer_id = session_id("thread-b");
157
158 let older_thread = make_thread(
159 "Thread A",
160 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
161 );
162 let newer_thread = make_thread(
163 "Thread B",
164 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
165 );
166
167 let save_older = thread_store.update(cx, |store, cx| {
168 store.save_thread(older_id.clone(), older_thread, cx)
169 });
170 save_older.await.unwrap();
171
172 let save_newer = thread_store.update(cx, |store, cx| {
173 store.save_thread(newer_id.clone(), newer_thread, cx)
174 });
175 save_newer.await.unwrap();
176
177 cx.run_until_parked();
178
179 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
180 assert_eq!(entries.len(), 2);
181 assert_eq!(entries[0].id, newer_id);
182 assert_eq!(entries[1].id, older_id);
183 }
184
185 #[gpui::test]
186 async fn test_delete_threads_clears_entries(cx: &mut TestAppContext) {
187 let thread_store = cx.new(|cx| ThreadStore::new(cx));
188 cx.run_until_parked();
189
190 let thread_id = session_id("thread-a");
191 let thread = make_thread(
192 "Thread A",
193 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
194 );
195
196 let save_task =
197 thread_store.update(cx, |store, cx| store.save_thread(thread_id, thread, cx));
198 save_task.await.unwrap();
199
200 cx.run_until_parked();
201 assert!(!thread_store.read_with(cx, |store, _cx| store.is_empty()));
202
203 let delete_task = thread_store.update(cx, |store, cx| store.delete_threads(cx));
204 delete_task.await.unwrap();
205 cx.run_until_parked();
206
207 assert!(thread_store.read_with(cx, |store, _cx| store.is_empty()));
208 }
209
210 #[gpui::test]
211 async fn test_delete_thread_removes_only_target(cx: &mut TestAppContext) {
212 let thread_store = cx.new(|cx| ThreadStore::new(cx));
213 cx.run_until_parked();
214
215 let first_id = session_id("thread-a");
216 let second_id = session_id("thread-b");
217
218 let first_thread = make_thread(
219 "Thread A",
220 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
221 );
222 let second_thread = make_thread(
223 "Thread B",
224 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
225 );
226
227 let save_first = thread_store.update(cx, |store, cx| {
228 store.save_thread(first_id.clone(), first_thread, cx)
229 });
230 save_first.await.unwrap();
231 let save_second = thread_store.update(cx, |store, cx| {
232 store.save_thread(second_id.clone(), second_thread, cx)
233 });
234 save_second.await.unwrap();
235 cx.run_until_parked();
236
237 let delete_task =
238 thread_store.update(cx, |store, cx| store.delete_thread(first_id.clone(), cx));
239 delete_task.await.unwrap();
240 cx.run_until_parked();
241
242 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
243 assert_eq!(entries.len(), 1);
244 assert_eq!(entries[0].id, second_id);
245 }
246
247 #[gpui::test]
248 async fn test_save_thread_refreshes_ordering(cx: &mut TestAppContext) {
249 let thread_store = cx.new(|cx| ThreadStore::new(cx));
250 cx.run_until_parked();
251
252 let first_id = session_id("thread-a");
253 let second_id = session_id("thread-b");
254
255 let first_thread = make_thread(
256 "Thread A",
257 Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
258 );
259 let second_thread = make_thread(
260 "Thread B",
261 Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
262 );
263
264 let save_first = thread_store.update(cx, |store, cx| {
265 store.save_thread(first_id.clone(), first_thread, cx)
266 });
267 save_first.await.unwrap();
268 let save_second = thread_store.update(cx, |store, cx| {
269 store.save_thread(second_id.clone(), second_thread, cx)
270 });
271 save_second.await.unwrap();
272 cx.run_until_parked();
273
274 let updated_first = make_thread(
275 "Thread A",
276 Utc.with_ymd_and_hms(2024, 1, 3, 0, 0, 0).unwrap(),
277 );
278 let update_task = thread_store.update(cx, |store, cx| {
279 store.save_thread(first_id.clone(), updated_first, cx)
280 });
281 update_task.await.unwrap();
282 cx.run_until_parked();
283
284 let entries: Vec<_> = thread_store.read_with(cx, |store, _cx| store.entries().collect());
285 assert_eq!(entries.len(), 2);
286 assert_eq!(entries[0].id, first_id);
287 assert_eq!(entries[1].id, second_id);
288 }
289}