history_store.rs

  1use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
  2use agent_client_protocol as acp;
  3use anyhow::{Context as _, Result};
  4use assistant_context::SavedContextMetadata;
  5use chrono::{DateTime, Utc};
  6use collections::HashMap;
  7use gpui::{SharedString, Task, prelude::*};
  8use serde::{Deserialize, Serialize};
  9use smol::stream::StreamExt;
 10use std::{path::Path, sync::Arc, time::Duration};
 11
 12const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
 13const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
 14const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
 15
 16// todo!(put this in the UI)
 17#[derive(Clone, Debug)]
 18pub enum HistoryEntry {
 19    AcpThread(AcpThreadMetadata),
 20    TextThread(SavedContextMetadata),
 21}
 22
 23impl HistoryEntry {
 24    pub fn updated_at(&self) -> DateTime<Utc> {
 25        match self {
 26            HistoryEntry::AcpThread(thread) => thread.updated_at,
 27            HistoryEntry::TextThread(context) => context.mtime.to_utc(),
 28        }
 29    }
 30
 31    pub fn id(&self) -> HistoryEntryId {
 32        match self {
 33            HistoryEntry::AcpThread(thread) => {
 34                HistoryEntryId::Thread(thread.agent.clone(), thread.id.clone())
 35            }
 36            HistoryEntry::TextThread(context) => HistoryEntryId::Context(context.path.clone()),
 37        }
 38    }
 39
 40    pub fn title(&self) -> &SharedString {
 41        match self {
 42            HistoryEntry::AcpThread(thread) => &thread.title,
 43            HistoryEntry::TextThread(context) => &context.title,
 44        }
 45    }
 46}
 47
 48/// Generic identifier for a history entry.
 49#[derive(Clone, PartialEq, Eq, Debug)]
 50pub enum HistoryEntryId {
 51    Thread(AgentServerName, acp::SessionId),
 52    Context(Arc<Path>),
 53}
 54
 55#[derive(Serialize, Deserialize)]
 56enum SerializedRecentOpen {
 57    Thread(String),
 58    ContextName(String),
 59    /// Old format which stores the full path
 60    Context(String),
 61}
 62
 63pub struct AgentHistory {
 64    entries: watch::Receiver<Option<Vec<AcpThreadMetadata>>>,
 65    _task: Task<()>,
 66}
 67
 68pub struct HistoryStore {
 69    agents: HashMap<AgentServerName, AgentHistory>, // todo!() text threads
 70}
 71
 72impl HistoryStore {
 73    pub fn new(_cx: &mut Context<Self>) -> Self {
 74        Self {
 75            agents: HashMap::default(),
 76        }
 77    }
 78
 79    pub fn register_agent(
 80        &mut self,
 81        agent_name: AgentServerName,
 82        connection: &dyn AgentConnection,
 83        cx: &mut Context<Self>,
 84    ) {
 85        let Some(mut history) = connection.list_threads(cx) else {
 86            return;
 87        };
 88        let history = AgentHistory {
 89            entries: history.clone(),
 90            _task: cx.spawn(async move |this, cx| {
 91                while history.changed().await.is_ok() {
 92                    this.update(cx, |_, cx| cx.notify()).ok();
 93                }
 94            }),
 95        };
 96        self.agents.insert(agent_name.clone(), history);
 97    }
 98
 99    pub fn entries(&mut self, _cx: &mut Context<Self>) -> Vec<HistoryEntry> {
100        let mut history_entries = Vec::new();
101
102        #[cfg(debug_assertions)]
103        if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
104            return history_entries;
105        }
106
107        history_entries.extend(
108            self.agents
109                .values_mut()
110                .flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?")
111                .map(HistoryEntry::AcpThread),
112        );
113        // todo!() include the text threads in here.
114
115        history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
116        history_entries
117    }
118
119    pub fn recent_entries(&mut self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
120        self.entries(cx).into_iter().take(limit).collect()
121    }
122}