1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
2use acp_thread::MentionUri;
3use agent_client_protocol as acp;
4use anyhow::{Context as _, Result, anyhow};
5use assistant_text_thread::{SavedTextThreadMetadata, TextThread};
6use chrono::{DateTime, Utc};
7use db::kvp::KEY_VALUE_STORE;
8use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
9use itertools::Itertools;
10use paths::text_threads_dir;
11use project::Project;
12use serde::{Deserialize, Serialize};
13use std::{collections::VecDeque, path::Path, rc::Rc, sync::Arc, time::Duration};
14use ui::ElementId;
15use util::ResultExt as _;
16
17const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
18const RECENTLY_OPENED_THREADS_KEY: &str = "recent-agent-threads";
19const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
20
21const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
22
23//todo: We should remove this function once we support loading all acp thread
24pub fn load_agent_thread(
25 session_id: acp::SessionId,
26 history_store: Entity<HistoryStore>,
27 project: Entity<Project>,
28 cx: &mut App,
29) -> Task<Result<Entity<crate::Thread>>> {
30 use agent_servers::{AgentServer, AgentServerDelegate};
31
32 let server = Rc::new(crate::NativeAgentServer::new(
33 project.read(cx).fs().clone(),
34 history_store,
35 ));
36 let delegate = AgentServerDelegate::new(
37 project.read(cx).agent_server_store().clone(),
38 project.clone(),
39 None,
40 None,
41 );
42 let connection = server.connect(None, delegate, cx);
43 cx.spawn(async move |cx| {
44 let (agent, _) = connection.await?;
45 let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
46 cx.update(|cx| agent.load_thread(session_id, cx))?.await
47 })
48}
49
50#[derive(Clone, Debug)]
51pub enum HistoryEntry {
52 AcpThread(DbThreadMetadata),
53 TextThread(SavedTextThreadMetadata),
54}
55
56impl HistoryEntry {
57 pub fn updated_at(&self) -> DateTime<Utc> {
58 match self {
59 HistoryEntry::AcpThread(thread) => thread.updated_at,
60 HistoryEntry::TextThread(text_thread) => text_thread.mtime.to_utc(),
61 }
62 }
63
64 pub fn id(&self) -> HistoryEntryId {
65 match self {
66 HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()),
67 HistoryEntry::TextThread(text_thread) => {
68 HistoryEntryId::TextThread(text_thread.path.clone())
69 }
70 }
71 }
72
73 pub fn mention_uri(&self) -> MentionUri {
74 match self {
75 HistoryEntry::AcpThread(thread) => MentionUri::Thread {
76 id: thread.id.clone(),
77 name: thread.title.to_string(),
78 },
79 HistoryEntry::TextThread(text_thread) => MentionUri::TextThread {
80 path: text_thread.path.as_ref().to_owned(),
81 name: text_thread.title.to_string(),
82 },
83 }
84 }
85
86 pub fn title(&self) -> &SharedString {
87 match self {
88 HistoryEntry::AcpThread(thread) => {
89 if thread.title.is_empty() {
90 DEFAULT_TITLE
91 } else {
92 &thread.title
93 }
94 }
95 HistoryEntry::TextThread(text_thread) => &text_thread.title,
96 }
97 }
98}
99
100/// Generic identifier for a history entry.
101#[derive(Clone, PartialEq, Eq, Debug, Hash)]
102pub enum HistoryEntryId {
103 AcpThread(acp::SessionId),
104 TextThread(Arc<Path>),
105}
106
107impl Into<ElementId> for HistoryEntryId {
108 fn into(self) -> ElementId {
109 match self {
110 HistoryEntryId::AcpThread(session_id) => ElementId::Name(session_id.0.into()),
111 HistoryEntryId::TextThread(path) => ElementId::Path(path),
112 }
113 }
114}
115
116#[derive(Serialize, Deserialize, Debug)]
117enum SerializedRecentOpen {
118 AcpThread(String),
119 TextThread(String),
120}
121
122pub struct HistoryStore {
123 threads: Vec<DbThreadMetadata>,
124 entries: Vec<HistoryEntry>,
125 text_thread_store: Entity<assistant_text_thread::TextThreadStore>,
126 recently_opened_entries: VecDeque<HistoryEntryId>,
127 _subscriptions: Vec<gpui::Subscription>,
128 _save_recently_opened_entries_task: Task<()>,
129}
130
131impl HistoryStore {
132 pub fn new(
133 text_thread_store: Entity<assistant_text_thread::TextThreadStore>,
134 cx: &mut Context<Self>,
135 ) -> Self {
136 let subscriptions =
137 vec![cx.observe(&text_thread_store, |this, _, cx| this.update_entries(cx))];
138
139 cx.spawn(async move |this, cx| {
140 let entries = Self::load_recently_opened_entries(cx).await;
141 this.update(cx, |this, cx| {
142 if let Some(entries) = entries.log_err() {
143 this.recently_opened_entries = entries;
144 }
145
146 this.reload(cx);
147 })
148 .ok();
149 })
150 .detach();
151
152 Self {
153 text_thread_store,
154 recently_opened_entries: VecDeque::default(),
155 threads: Vec::default(),
156 entries: Vec::default(),
157 _subscriptions: subscriptions,
158 _save_recently_opened_entries_task: Task::ready(()),
159 }
160 }
161
162 pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
163 self.threads.iter().find(|thread| &thread.id == session_id)
164 }
165
166 pub fn load_thread(
167 &mut self,
168 id: acp::SessionId,
169 cx: &mut Context<Self>,
170 ) -> Task<Result<Option<DbThread>>> {
171 let database_future = ThreadsDatabase::connect(cx);
172 cx.background_spawn(async move {
173 let database = database_future.await.map_err(|err| anyhow!(err))?;
174 database.load_thread(id).await
175 })
176 }
177
178 pub fn save_thread(
179 &mut self,
180 id: acp::SessionId,
181 thread: crate::DbThread,
182 cx: &mut Context<Self>,
183 ) -> Task<Result<()>> {
184 let database_future = ThreadsDatabase::connect(cx);
185 cx.spawn(async move |this, cx| {
186 let database = database_future.await.map_err(|err| anyhow!(err))?;
187 database.save_thread(id, thread).await?;
188 this.update(cx, |this, cx| this.reload(cx))
189 })
190 }
191
192 pub fn delete_thread(
193 &mut self,
194 id: acp::SessionId,
195 cx: &mut Context<Self>,
196 ) -> Task<Result<()>> {
197 let database_future = ThreadsDatabase::connect(cx);
198 cx.spawn(async move |this, cx| {
199 let database = database_future.await.map_err(|err| anyhow!(err))?;
200 database.delete_thread(id.clone()).await?;
201 this.update(cx, |this, cx| this.reload(cx))
202 })
203 }
204
205 pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
206 let database_future = ThreadsDatabase::connect(cx);
207 cx.spawn(async move |this, cx| {
208 let database = database_future.await.map_err(|err| anyhow!(err))?;
209 database.delete_threads().await?;
210 this.update(cx, |this, cx| this.reload(cx))
211 })
212 }
213
214 pub fn delete_text_thread(
215 &mut self,
216 path: Arc<Path>,
217 cx: &mut Context<Self>,
218 ) -> Task<Result<()>> {
219 self.text_thread_store
220 .update(cx, |store, cx| store.delete_local(path, cx))
221 }
222
223 pub fn load_text_thread(
224 &self,
225 path: Arc<Path>,
226 cx: &mut Context<Self>,
227 ) -> Task<Result<Entity<TextThread>>> {
228 self.text_thread_store
229 .update(cx, |store, cx| store.open_local(path, cx))
230 }
231
232 pub fn reload(&self, cx: &mut Context<Self>) {
233 let database_connection = ThreadsDatabase::connect(cx);
234 cx.spawn(async move |this, cx| {
235 let database = database_connection.await;
236 let threads = database.map_err(|err| anyhow!(err))?.list_threads().await?;
237 this.update(cx, |this, cx| {
238 if this.recently_opened_entries.len() < MAX_RECENTLY_OPENED_ENTRIES {
239 for thread in threads
240 .iter()
241 .take(MAX_RECENTLY_OPENED_ENTRIES - this.recently_opened_entries.len())
242 .rev()
243 {
244 this.push_recently_opened_entry(
245 HistoryEntryId::AcpThread(thread.id.clone()),
246 cx,
247 )
248 }
249 }
250 this.threads = threads;
251 this.update_entries(cx);
252 })
253 })
254 .detach_and_log_err(cx);
255 }
256
257 fn update_entries(&mut self, cx: &mut Context<Self>) {
258 #[cfg(debug_assertions)]
259 if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
260 return;
261 }
262 let mut history_entries = Vec::new();
263 history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread));
264 history_entries.extend(
265 self.text_thread_store
266 .read(cx)
267 .unordered_text_threads()
268 .cloned()
269 .map(HistoryEntry::TextThread),
270 );
271
272 history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
273 self.entries = history_entries;
274 cx.notify()
275 }
276
277 pub fn is_empty(&self, _cx: &App) -> bool {
278 self.entries.is_empty()
279 }
280
281 pub fn recently_opened_entries(&self, cx: &App) -> Vec<HistoryEntry> {
282 #[cfg(debug_assertions)]
283 if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
284 return Vec::new();
285 }
286
287 let thread_entries = self.threads.iter().flat_map(|thread| {
288 self.recently_opened_entries
289 .iter()
290 .enumerate()
291 .flat_map(|(index, entry)| match entry {
292 HistoryEntryId::AcpThread(id) if &thread.id == id => {
293 Some((index, HistoryEntry::AcpThread(thread.clone())))
294 }
295 _ => None,
296 })
297 });
298
299 let context_entries = self
300 .text_thread_store
301 .read(cx)
302 .unordered_text_threads()
303 .flat_map(|text_thread| {
304 self.recently_opened_entries
305 .iter()
306 .enumerate()
307 .flat_map(|(index, entry)| match entry {
308 HistoryEntryId::TextThread(path) if &text_thread.path == path => {
309 Some((index, HistoryEntry::TextThread(text_thread.clone())))
310 }
311 _ => None,
312 })
313 });
314
315 thread_entries
316 .chain(context_entries)
317 // optimization to halt iteration early
318 .take(self.recently_opened_entries.len())
319 .sorted_unstable_by_key(|(index, _)| *index)
320 .map(|(_, entry)| entry)
321 .collect()
322 }
323
324 fn save_recently_opened_entries(&mut self, cx: &mut Context<Self>) {
325 let serialized_entries = self
326 .recently_opened_entries
327 .iter()
328 .filter_map(|entry| match entry {
329 HistoryEntryId::TextThread(path) => path.file_name().map(|file| {
330 SerializedRecentOpen::TextThread(file.to_string_lossy().into_owned())
331 }),
332 HistoryEntryId::AcpThread(id) => {
333 Some(SerializedRecentOpen::AcpThread(id.to_string()))
334 }
335 })
336 .collect::<Vec<_>>();
337
338 self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
339 let content = serde_json::to_string(&serialized_entries).unwrap();
340 cx.background_executor()
341 .timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
342 .await;
343
344 if cfg!(any(feature = "test-support", test)) {
345 return;
346 }
347 KEY_VALUE_STORE
348 .write_kvp(RECENTLY_OPENED_THREADS_KEY.to_owned(), content)
349 .await
350 .log_err();
351 });
352 }
353
354 fn load_recently_opened_entries(cx: &AsyncApp) -> Task<Result<VecDeque<HistoryEntryId>>> {
355 cx.background_spawn(async move {
356 if cfg!(any(feature = "test-support", test)) {
357 log::warn!("history store does not persist in tests");
358 return Ok(VecDeque::new());
359 }
360 let json = KEY_VALUE_STORE
361 .read_kvp(RECENTLY_OPENED_THREADS_KEY)?
362 .unwrap_or("[]".to_string());
363 let entries = serde_json::from_str::<Vec<SerializedRecentOpen>>(&json)
364 .context("deserializing persisted agent panel navigation history")?
365 .into_iter()
366 .take(MAX_RECENTLY_OPENED_ENTRIES)
367 .flat_map(|entry| match entry {
368 SerializedRecentOpen::AcpThread(id) => {
369 Some(HistoryEntryId::AcpThread(acp::SessionId::new(id.as_str())))
370 }
371 SerializedRecentOpen::TextThread(file_name) => Some(
372 HistoryEntryId::TextThread(text_threads_dir().join(file_name).into()),
373 ),
374 })
375 .collect();
376 Ok(entries)
377 })
378 }
379
380 pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context<Self>) {
381 self.recently_opened_entries
382 .retain(|old_entry| old_entry != &entry);
383 self.recently_opened_entries.push_front(entry);
384 self.recently_opened_entries
385 .truncate(MAX_RECENTLY_OPENED_ENTRIES);
386 self.save_recently_opened_entries(cx);
387 }
388
389 pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context<Self>) {
390 self.recently_opened_entries.retain(
391 |entry| !matches!(entry, HistoryEntryId::AcpThread(thread_id) if thread_id == &id),
392 );
393 self.save_recently_opened_entries(cx);
394 }
395
396 pub fn replace_recently_opened_text_thread(
397 &mut self,
398 old_path: &Path,
399 new_path: &Arc<Path>,
400 cx: &mut Context<Self>,
401 ) {
402 for entry in &mut self.recently_opened_entries {
403 match entry {
404 HistoryEntryId::TextThread(path) if path.as_ref() == old_path => {
405 *entry = HistoryEntryId::TextThread(new_path.clone());
406 break;
407 }
408 _ => {}
409 }
410 }
411 self.save_recently_opened_entries(cx);
412 }
413
414 pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context<Self>) {
415 self.recently_opened_entries
416 .retain(|old_entry| old_entry != entry);
417 self.save_recently_opened_entries(cx);
418 }
419
420 pub fn entries(&self) -> impl Iterator<Item = HistoryEntry> {
421 self.entries.iter().cloned()
422 }
423}