history_store.rs

  1use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
  2use acp_thread::MentionUri;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result, anyhow};
  5use assistant_context::{AssistantContext, SavedContextMetadata};
  6use chrono::{DateTime, Utc};
  7use db::kvp::KEY_VALUE_STORE;
  8use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
  9use itertools::Itertools;
 10use paths::contexts_dir;
 11use project::Project;
 12use serde::{Deserialize, Serialize};
 13use std::{collections::VecDeque, path::Path, rc::Rc, sync::Arc, time::Duration};
 14use ui::ElementId;
 15use util::ResultExt as _;
 16
 17const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
 18const RECENTLY_OPENED_THREADS_KEY: &str = "recent-agent-threads";
 19const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
 20
 21const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
 22
 23//todo: We should remove this function once we support loading all acp thread
 24pub fn load_agent_thread(
 25    session_id: acp::SessionId,
 26    history_store: Entity<HistoryStore>,
 27    project: Entity<Project>,
 28    cx: &mut App,
 29) -> Task<Result<Entity<crate::Thread>>> {
 30    use agent_servers::{AgentServer, AgentServerDelegate};
 31
 32    let server = Rc::new(crate::NativeAgentServer::new(
 33        project.read(cx).fs().clone(),
 34        history_store,
 35    ));
 36    let delegate = AgentServerDelegate::new(
 37        project.read(cx).agent_server_store().clone(),
 38        project.clone(),
 39        None,
 40        None,
 41    );
 42    let connection = server.connect(None, delegate, cx);
 43    cx.spawn(async move |cx| {
 44        let (agent, _) = connection.await?;
 45        let agent = agent.downcast::<crate::NativeAgentConnection>().unwrap();
 46        cx.update(|cx| agent.load_thread(session_id, cx))?.await
 47    })
 48}
 49
 50#[derive(Clone, Debug)]
 51pub enum HistoryEntry {
 52    AcpThread(DbThreadMetadata),
 53    TextThread(SavedContextMetadata),
 54}
 55
 56impl HistoryEntry {
 57    pub fn updated_at(&self) -> DateTime<Utc> {
 58        match self {
 59            HistoryEntry::AcpThread(thread) => thread.updated_at,
 60            HistoryEntry::TextThread(context) => context.mtime.to_utc(),
 61        }
 62    }
 63
 64    pub fn id(&self) -> HistoryEntryId {
 65        match self {
 66            HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()),
 67            HistoryEntry::TextThread(context) => HistoryEntryId::TextThread(context.path.clone()),
 68        }
 69    }
 70
 71    pub fn mention_uri(&self) -> MentionUri {
 72        match self {
 73            HistoryEntry::AcpThread(thread) => MentionUri::Thread {
 74                id: thread.id.clone(),
 75                name: thread.title.to_string(),
 76            },
 77            HistoryEntry::TextThread(context) => MentionUri::TextThread {
 78                path: context.path.as_ref().to_owned(),
 79                name: context.title.to_string(),
 80            },
 81        }
 82    }
 83
 84    pub fn title(&self) -> &SharedString {
 85        match self {
 86            HistoryEntry::AcpThread(thread) => {
 87                if thread.title.is_empty() {
 88                    DEFAULT_TITLE
 89                } else {
 90                    &thread.title
 91                }
 92            }
 93            HistoryEntry::TextThread(context) => &context.title,
 94        }
 95    }
 96}
 97
 98/// Generic identifier for a history entry.
 99#[derive(Clone, PartialEq, Eq, Debug, Hash)]
100pub enum HistoryEntryId {
101    AcpThread(acp::SessionId),
102    TextThread(Arc<Path>),
103}
104
105impl Into<ElementId> for HistoryEntryId {
106    fn into(self) -> ElementId {
107        match self {
108            HistoryEntryId::AcpThread(session_id) => ElementId::Name(session_id.0.into()),
109            HistoryEntryId::TextThread(path) => ElementId::Path(path),
110        }
111    }
112}
113
114#[derive(Serialize, Deserialize, Debug)]
115enum SerializedRecentOpen {
116    AcpThread(String),
117    TextThread(String),
118}
119
120pub struct HistoryStore {
121    threads: Vec<DbThreadMetadata>,
122    entries: Vec<HistoryEntry>,
123    text_thread_store: Entity<assistant_context::ContextStore>,
124    recently_opened_entries: VecDeque<HistoryEntryId>,
125    _subscriptions: Vec<gpui::Subscription>,
126    _save_recently_opened_entries_task: Task<()>,
127}
128
129impl HistoryStore {
130    pub fn new(
131        text_thread_store: Entity<assistant_context::ContextStore>,
132        cx: &mut Context<Self>,
133    ) -> Self {
134        let subscriptions =
135            vec![cx.observe(&text_thread_store, |this, _, cx| this.update_entries(cx))];
136
137        cx.spawn(async move |this, cx| {
138            let entries = Self::load_recently_opened_entries(cx).await;
139            this.update(cx, |this, cx| {
140                if let Some(entries) = entries.log_err() {
141                    this.recently_opened_entries = entries;
142                }
143
144                this.reload(cx);
145            })
146            .ok();
147        })
148        .detach();
149
150        Self {
151            text_thread_store,
152            recently_opened_entries: VecDeque::default(),
153            threads: Vec::default(),
154            entries: Vec::default(),
155            _subscriptions: subscriptions,
156            _save_recently_opened_entries_task: Task::ready(()),
157        }
158    }
159
160    pub fn thread_from_session_id(&self, session_id: &acp::SessionId) -> Option<&DbThreadMetadata> {
161        self.threads.iter().find(|thread| &thread.id == session_id)
162    }
163
164    pub fn load_thread(
165        &mut self,
166        id: acp::SessionId,
167        cx: &mut Context<Self>,
168    ) -> Task<Result<Option<DbThread>>> {
169        let database_future = ThreadsDatabase::connect(cx);
170        cx.background_spawn(async move {
171            let database = database_future.await.map_err(|err| anyhow!(err))?;
172            database.load_thread(id).await
173        })
174    }
175
176    pub fn delete_thread(
177        &mut self,
178        id: acp::SessionId,
179        cx: &mut Context<Self>,
180    ) -> Task<Result<()>> {
181        let database_future = ThreadsDatabase::connect(cx);
182        cx.spawn(async move |this, cx| {
183            let database = database_future.await.map_err(|err| anyhow!(err))?;
184            database.delete_thread(id.clone()).await?;
185            this.update(cx, |this, cx| this.reload(cx))
186        })
187    }
188
189    pub fn delete_text_thread(
190        &mut self,
191        path: Arc<Path>,
192        cx: &mut Context<Self>,
193    ) -> Task<Result<()>> {
194        self.text_thread_store
195            .update(cx, |store, cx| store.delete_local_context(path, cx))
196    }
197
198    pub fn load_text_thread(
199        &self,
200        path: Arc<Path>,
201        cx: &mut Context<Self>,
202    ) -> Task<Result<Entity<AssistantContext>>> {
203        self.text_thread_store
204            .update(cx, |store, cx| store.open_local_context(path, cx))
205    }
206
207    pub fn reload(&self, cx: &mut Context<Self>) {
208        let database_future = ThreadsDatabase::connect(cx);
209        cx.spawn(async move |this, cx| {
210            let threads = database_future
211                .await
212                .map_err(|err| anyhow!(err))?
213                .list_threads()
214                .await?;
215
216            this.update(cx, |this, cx| {
217                if this.recently_opened_entries.len() < MAX_RECENTLY_OPENED_ENTRIES {
218                    for thread in threads
219                        .iter()
220                        .take(MAX_RECENTLY_OPENED_ENTRIES - this.recently_opened_entries.len())
221                        .rev()
222                    {
223                        this.push_recently_opened_entry(
224                            HistoryEntryId::AcpThread(thread.id.clone()),
225                            cx,
226                        )
227                    }
228                }
229                this.threads = threads;
230                this.update_entries(cx);
231            })
232        })
233        .detach_and_log_err(cx);
234    }
235
236    fn update_entries(&mut self, cx: &mut Context<Self>) {
237        #[cfg(debug_assertions)]
238        if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
239            return;
240        }
241        let mut history_entries = Vec::new();
242        history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread));
243        history_entries.extend(
244            self.text_thread_store
245                .read(cx)
246                .unordered_contexts()
247                .cloned()
248                .map(HistoryEntry::TextThread),
249        );
250
251        history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
252        self.entries = history_entries;
253        cx.notify()
254    }
255
256    pub fn is_empty(&self, _cx: &App) -> bool {
257        self.entries.is_empty()
258    }
259
260    pub fn recently_opened_entries(&self, cx: &App) -> Vec<HistoryEntry> {
261        #[cfg(debug_assertions)]
262        if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
263            return Vec::new();
264        }
265
266        let thread_entries = self.threads.iter().flat_map(|thread| {
267            self.recently_opened_entries
268                .iter()
269                .enumerate()
270                .flat_map(|(index, entry)| match entry {
271                    HistoryEntryId::AcpThread(id) if &thread.id == id => {
272                        Some((index, HistoryEntry::AcpThread(thread.clone())))
273                    }
274                    _ => None,
275                })
276        });
277
278        let context_entries = self
279            .text_thread_store
280            .read(cx)
281            .unordered_contexts()
282            .flat_map(|context| {
283                self.recently_opened_entries
284                    .iter()
285                    .enumerate()
286                    .flat_map(|(index, entry)| match entry {
287                        HistoryEntryId::TextThread(path) if &context.path == path => {
288                            Some((index, HistoryEntry::TextThread(context.clone())))
289                        }
290                        _ => None,
291                    })
292            });
293
294        thread_entries
295            .chain(context_entries)
296            // optimization to halt iteration early
297            .take(self.recently_opened_entries.len())
298            .sorted_unstable_by_key(|(index, _)| *index)
299            .map(|(_, entry)| entry)
300            .collect()
301    }
302
303    fn save_recently_opened_entries(&mut self, cx: &mut Context<Self>) {
304        let serialized_entries = self
305            .recently_opened_entries
306            .iter()
307            .filter_map(|entry| match entry {
308                HistoryEntryId::TextThread(path) => path.file_name().map(|file| {
309                    SerializedRecentOpen::TextThread(file.to_string_lossy().into_owned())
310                }),
311                HistoryEntryId::AcpThread(id) => {
312                    Some(SerializedRecentOpen::AcpThread(id.to_string()))
313                }
314            })
315            .collect::<Vec<_>>();
316
317        self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
318            let content = serde_json::to_string(&serialized_entries).unwrap();
319            cx.background_executor()
320                .timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
321                .await;
322
323            if cfg!(any(feature = "test-support", test)) {
324                return;
325            }
326            KEY_VALUE_STORE
327                .write_kvp(RECENTLY_OPENED_THREADS_KEY.to_owned(), content)
328                .await
329                .log_err();
330        });
331    }
332
333    fn load_recently_opened_entries(cx: &AsyncApp) -> Task<Result<VecDeque<HistoryEntryId>>> {
334        cx.background_spawn(async move {
335            if cfg!(any(feature = "test-support", test)) {
336                anyhow::bail!("history store does not persist in tests");
337            }
338            let json = KEY_VALUE_STORE
339                .read_kvp(RECENTLY_OPENED_THREADS_KEY)?
340                .unwrap_or("[]".to_string());
341            let entries = serde_json::from_str::<Vec<SerializedRecentOpen>>(&json)
342                .context("deserializing persisted agent panel navigation history")?
343                .into_iter()
344                .take(MAX_RECENTLY_OPENED_ENTRIES)
345                .flat_map(|entry| match entry {
346                    SerializedRecentOpen::AcpThread(id) => Some(HistoryEntryId::AcpThread(
347                        acp::SessionId(id.as_str().into()),
348                    )),
349                    SerializedRecentOpen::TextThread(file_name) => Some(
350                        HistoryEntryId::TextThread(contexts_dir().join(file_name).into()),
351                    ),
352                })
353                .collect();
354            Ok(entries)
355        })
356    }
357
358    pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context<Self>) {
359        self.recently_opened_entries
360            .retain(|old_entry| old_entry != &entry);
361        self.recently_opened_entries.push_front(entry);
362        self.recently_opened_entries
363            .truncate(MAX_RECENTLY_OPENED_ENTRIES);
364        self.save_recently_opened_entries(cx);
365    }
366
367    pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context<Self>) {
368        self.recently_opened_entries.retain(
369            |entry| !matches!(entry, HistoryEntryId::AcpThread(thread_id) if thread_id == &id),
370        );
371        self.save_recently_opened_entries(cx);
372    }
373
374    pub fn replace_recently_opened_text_thread(
375        &mut self,
376        old_path: &Path,
377        new_path: &Arc<Path>,
378        cx: &mut Context<Self>,
379    ) {
380        for entry in &mut self.recently_opened_entries {
381            match entry {
382                HistoryEntryId::TextThread(path) if path.as_ref() == old_path => {
383                    *entry = HistoryEntryId::TextThread(new_path.clone());
384                    break;
385                }
386                _ => {}
387            }
388        }
389        self.save_recently_opened_entries(cx);
390    }
391
392    pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context<Self>) {
393        self.recently_opened_entries
394            .retain(|old_entry| old_entry != entry);
395        self.save_recently_opened_entries(cx);
396    }
397
398    pub fn entries(&self) -> impl Iterator<Item = HistoryEntry> {
399        self.entries.iter().cloned()
400    }
401}