history_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use acp_thread::MentionUri;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result, anyhow};
  5use assistant_text_thread::{SavedTextThreadMetadata, TextThread};
  6use chrono::{DateTime, Utc};
  7use db::kvp::KEY_VALUE_STORE;
  8use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
  9use itertools::Itertools;
 10use paths::text_threads_dir;
 11use project::Project;
 12use serde::{Deserialize, Serialize};
 13use std::{collections::VecDeque, path::Path, rc::Rc, sync::Arc, time::Duration};
 14use ui::ElementId;
 15use util::ResultExt as _;
 16
 17const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
 18const RECENTLY_OPENED_THREADS_KEY: &str = "recent-agent-threads";
 19const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
 20
 21const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
 22
 23//todo: We should remove this function once we support loading all acp thread
 24pub fn load_agent_thread(
 25    session_id: acp::SessionId,
 26    history_store: Entity<HistoryStore>,
 27    project: Entity<Project>,
 28    cx: &mut App,
 29) -> Task<Result<Entity<crate::Thread>>> {
 30    use agent_servers::{AgentServer, AgentServerDelegate};
 31
 32    let server = Rc::new(crate::NativeAgentServer::new(
 33        project.read(cx).fs().clone(),
 34        history_store,
 35    ));
 36    let delegate = AgentServerDelegate::new(
 37        project.read(cx).agent_server_store().clone(),
 38        project.clone(),
 39        None,
 40        None,
 41    );
 42    let connection = server.connect(None, delegate, cx);
 43    cx.spawn(async move |cx| {
 44        let (agent, _) = connection.await?;
 45        let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
 46        cx.update(|cx| agent.load_thread(session_id, cx))?.await
 47    })
 48}
 49
 50#[derive(Clone, Debug)]
 51pub enum HistoryEntry {
 52    AcpThread(DbThreadMetadata),
 53    TextThread(SavedTextThreadMetadata),
 54}
 55
 56impl HistoryEntry {
 57    pub fn updated_at(&self) -> DateTime<Utc> {
 58        match self {
 59            HistoryEntry::AcpThread(thread) => thread.updated_at,
 60            HistoryEntry::TextThread(text_thread) => text_thread.mtime.to_utc(),
 61        }
 62    }
 63
 64    pub fn id(&self) -> HistoryEntryId {
 65        match self {
 66            HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()),
 67            HistoryEntry::TextThread(text_thread) => {
 68                HistoryEntryId::TextThread(text_thread.path.clone())
 69            }
 70        }
 71    }
 72
 73    pub fn mention_uri(&self) -> MentionUri {
 74        match self {
 75            HistoryEntry::AcpThread(thread) => MentionUri::Thread {
 76                id: thread.id.clone(),
 77                name: thread.title.to_string(),
 78            },
 79            HistoryEntry::TextThread(text_thread) => MentionUri::TextThread {
 80                path: text_thread.path.as_ref().to_owned(),
 81                name: text_thread.title.to_string(),
 82            },
 83        }
 84    }
 85
 86    pub fn title(&self) -> &SharedString {
 87        match self {
 88            HistoryEntry::AcpThread(thread) => {
 89                if thread.title.is_empty() {
 90                    DEFAULT_TITLE
 91                } else {
 92                    &thread.title
 93                }
 94            }
 95            HistoryEntry::TextThread(text_thread) => &text_thread.title,
 96        }
 97    }
 98}
 99
100/// Generic identifier for a history entry.
101#[derive(Clone, PartialEq, Eq, Debug, Hash)]
102pub enum HistoryEntryId {
103    AcpThread(acp::SessionId),
104    TextThread(Arc<Path>),
105}
106
107impl Into<ElementId> for HistoryEntryId {
108    fn into(self) -> ElementId {
109        match self {
110            HistoryEntryId::AcpThread(session_id) => ElementId::Name(session_id.0.into()),
111            HistoryEntryId::TextThread(path) => ElementId::Path(path),
112        }
113    }
114}
115
116#[derive(Serialize, Deserialize, Debug)]
117enum SerializedRecentOpen {
118    AcpThread(String),
119    TextThread(String),
120}
121
122pub struct HistoryStore {
123    threads: Vec<DbThreadMetadata>,
124    entries: Vec<HistoryEntry>,
125    text_thread_store: Entity<assistant_text_thread::TextThreadStore>,
126    recently_opened_entries: VecDeque<HistoryEntryId>,
127    _subscriptions: Vec<gpui::Subscription>,
128    _save_recently_opened_entries_task: Task<()>,
129}
130
131impl HistoryStore {
132    pub fn new(
133        text_thread_store: Entity<assistant_text_thread::TextThreadStore>,
134        cx: &mut Context<Self>,
135    ) -> Self {
136        let subscriptions =
137            vec![cx.observe(&text_thread_store, |this, _, cx| this.update_entries(cx))];
138
139        cx.spawn(async move |this, cx| {
140            let entries = Self::load_recently_opened_entries(cx).await;
141            this.update(cx, |this, cx| {
142                if let Some(entries) = entries.log_err() {
143                    this.recently_opened_entries = entries;
144                }
145
146                this.reload(cx);
147            })
148            .ok();
149        })
150        .detach();
151
152        Self {
153            text_thread_store,
154            recently_opened_entries: VecDeque::default(),
155            threads: Vec::default(),
156            entries: Vec::default(),
157            _subscriptions: subscriptions,
158            _save_recently_opened_entries_task: Task::ready(()),
159        }
160    }
161
162    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
163        self.threads.iter().find(|thread| &thread.id == session_id)
164    }
165
166    pub fn load_thread(
167        &mut self,
168        id: acp::SessionId,
169        cx: &mut Context<Self>,
170    ) -> Task<Result<Option<DbThread>>> {
171        let database_future = ThreadsDatabase::connect(cx);
172        cx.background_spawn(async move {
173            let database = database_future.await.map_err(|err| anyhow!(err))?;
174            database.load_thread(id).await
175        })
176    }
177
178    pub fn save_thread(
179        &mut self,
180        id: acp::SessionId,
181        thread: crate::DbThread,
182        cx: &mut Context<Self>,
183    ) -> Task<Result<()>> {
184        let database_future = ThreadsDatabase::connect(cx);
185        cx.spawn(async move |this, cx| {
186            let database = database_future.await.map_err(|err| anyhow!(err))?;
187            database.save_thread(id, thread).await?;
188            this.update(cx, |this, cx| this.reload(cx))
189        })
190    }
191
192    pub fn delete_thread(
193        &mut self,
194        id: acp::SessionId,
195        cx: &mut Context<Self>,
196    ) -> Task<Result<()>> {
197        let database_future = ThreadsDatabase::connect(cx);
198        cx.spawn(async move |this, cx| {
199            let database = database_future.await.map_err(|err| anyhow!(err))?;
200            database.delete_thread(id.clone()).await?;
201            this.update(cx, |this, cx| this.reload(cx))
202        })
203    }
204
205    pub fn delete_threads(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
206        let database_future = ThreadsDatabase::connect(cx);
207        cx.spawn(async move |this, cx| {
208            let database = database_future.await.map_err(|err| anyhow!(err))?;
209            database.delete_threads().await?;
210            this.update(cx, |this, cx| this.reload(cx))
211        })
212    }
213
214    pub fn delete_text_thread(
215        &mut self,
216        path: Arc<Path>,
217        cx: &mut Context<Self>,
218    ) -> Task<Result<()>> {
219        self.text_thread_store
220            .update(cx, |store, cx| store.delete_local(path, cx))
221    }
222
223    pub fn load_text_thread(
224        &self,
225        path: Arc<Path>,
226        cx: &mut Context<Self>,
227    ) -> Task<Result<Entity<TextThread>>> {
228        self.text_thread_store
229            .update(cx, |store, cx| store.open_local(path, cx))
230    }
231
232    pub fn reload(&self, cx: &mut Context<Self>) {
233        let database_connection = ThreadsDatabase::connect(cx);
234        cx.spawn(async move |this, cx| {
235            let database = database_connection.await;
236            let threads = database.map_err(|err| anyhow!(err))?.list_threads().await?;
237            this.update(cx, |this, cx| {
238                if this.recently_opened_entries.len() < MAX_RECENTLY_OPENED_ENTRIES {
239                    for thread in threads
240                        .iter()
241                        .take(MAX_RECENTLY_OPENED_ENTRIES - this.recently_opened_entries.len())
242                        .rev()
243                    {
244                        this.push_recently_opened_entry(
245                            HistoryEntryId::AcpThread(thread.id.clone()),
246                            cx,
247                        )
248                    }
249                }
250                this.threads = threads;
251                this.update_entries(cx);
252            })
253        })
254        .detach_and_log_err(cx);
255    }
256
257    fn update_entries(&mut self, cx: &mut Context<Self>) {
258        #[cfg(debug_assertions)]
259        if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
260            return;
261        }
262        let mut history_entries = Vec::new();
263        history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread));
264        history_entries.extend(
265            self.text_thread_store
266                .read(cx)
267                .unordered_text_threads()
268                .cloned()
269                .map(HistoryEntry::TextThread),
270        );
271
272        history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
273        self.entries = history_entries;
274        cx.notify()
275    }
276
277    pub fn is_empty(&self, _cx: &App) -> bool {
278        self.entries.is_empty()
279    }
280
281    pub fn recently_opened_entries(&self, cx: &App) -> Vec<HistoryEntry> {
282        #[cfg(debug_assertions)]
283        if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
284            return Vec::new();
285        }
286
287        let thread_entries = self.threads.iter().flat_map(|thread| {
288            self.recently_opened_entries
289                .iter()
290                .enumerate()
291                .flat_map(|(index, entry)| match entry {
292                    HistoryEntryId::AcpThread(id) if &thread.id == id => {
293                        Some((index, HistoryEntry::AcpThread(thread.clone())))
294                    }
295                    _ => None,
296                })
297        });
298
299        let context_entries = self
300            .text_thread_store
301            .read(cx)
302            .unordered_text_threads()
303            .flat_map(|text_thread| {
304                self.recently_opened_entries
305                    .iter()
306                    .enumerate()
307                    .flat_map(|(index, entry)| match entry {
308                        HistoryEntryId::TextThread(path) if &text_thread.path == path => {
309                            Some((index, HistoryEntry::TextThread(text_thread.clone())))
310                        }
311                        _ => None,
312                    })
313            });
314
315        thread_entries
316            .chain(context_entries)
317            // optimization to halt iteration early
318            .take(self.recently_opened_entries.len())
319            .sorted_unstable_by_key(|(index, _)| *index)
320            .map(|(_, entry)| entry)
321            .collect()
322    }
323
324    fn save_recently_opened_entries(&mut self, cx: &mut Context<Self>) {
325        let serialized_entries = self
326            .recently_opened_entries
327            .iter()
328            .filter_map(|entry| match entry {
329                HistoryEntryId::TextThread(path) => path.file_name().map(|file| {
330                    SerializedRecentOpen::TextThread(file.to_string_lossy().into_owned())
331                }),
332                HistoryEntryId::AcpThread(id) => {
333                    Some(SerializedRecentOpen::AcpThread(id.to_string()))
334                }
335            })
336            .collect::<Vec<_>>();
337
338        self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
339            let content = serde_json::to_string(&serialized_entries).unwrap();
340            cx.background_executor()
341                .timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
342                .await;
343
344            if cfg!(any(feature = "test-support", test)) {
345                return;
346            }
347            KEY_VALUE_STORE
348                .write_kvp(RECENTLY_OPENED_THREADS_KEY.to_owned(), content)
349                .await
350                .log_err();
351        });
352    }
353
354    fn load_recently_opened_entries(cx: &AsyncApp) -> Task<Result<VecDeque<HistoryEntryId>>> {
355        cx.background_spawn(async move {
356            if cfg!(any(feature = "test-support", test)) {
357                log::warn!("history store does not persist in tests");
358                return Ok(VecDeque::new());
359            }
360            let json = KEY_VALUE_STORE
361                .read_kvp(RECENTLY_OPENED_THREADS_KEY)?
362                .unwrap_or("[]".to_string());
363            let entries = serde_json::from_str::<Vec<SerializedRecentOpen>>(&json)
364                .context("deserializing persisted agent panel navigation history")?
365                .into_iter()
366                .take(MAX_RECENTLY_OPENED_ENTRIES)
367                .flat_map(|entry| match entry {
368                    SerializedRecentOpen::AcpThread(id) => {
369                        Some(HistoryEntryId::AcpThread(acp::SessionId::new(id.as_str())))
370                    }
371                    SerializedRecentOpen::TextThread(file_name) => Some(
372                        HistoryEntryId::TextThread(text_threads_dir().join(file_name).into()),
373                    ),
374                })
375                .collect();
376            Ok(entries)
377        })
378    }
379
380    pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context<Self>) {
381        self.recently_opened_entries
382            .retain(|old_entry| old_entry != &entry);
383        self.recently_opened_entries.push_front(entry);
384        self.recently_opened_entries
385            .truncate(MAX_RECENTLY_OPENED_ENTRIES);
386        self.save_recently_opened_entries(cx);
387    }
388
389    pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context<Self>) {
390        self.recently_opened_entries.retain(
391            |entry| !matches!(entry, HistoryEntryId::AcpThread(thread_id) if thread_id == &id),
392        );
393        self.save_recently_opened_entries(cx);
394    }
395
396    pub fn replace_recently_opened_text_thread(
397        &mut self,
398        old_path: &Path,
399        new_path: &Arc<Path>,
400        cx: &mut Context<Self>,
401    ) {
402        for entry in &mut self.recently_opened_entries {
403            match entry {
404                HistoryEntryId::TextThread(path) if path.as_ref() == old_path => {
405                    *entry = HistoryEntryId::TextThread(new_path.clone());
406                    break;
407                }
408                _ => {}
409            }
410        }
411        self.save_recently_opened_entries(cx);
412    }
413
414    pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context<Self>) {
415        self.recently_opened_entries
416            .retain(|old_entry| old_entry != entry);
417        self.save_recently_opened_entries(cx);
418    }
419
420    pub fn entries(&self) -> impl Iterator<Item = HistoryEntry> {
421        self.entries.iter().cloned()
422    }
423}