Detailed changes
@@ -196,6 +196,7 @@ dependencies = [
"agent_servers",
"agent_settings",
"anyhow",
+ "assistant_context",
"assistant_tool",
"assistant_tools",
"chrono",
@@ -223,6 +224,7 @@ dependencies = [
"log",
"lsp",
"open",
+ "parking_lot",
"paths",
"portable-pty",
"pretty_assertions",
@@ -235,6 +237,7 @@ dependencies = [
"serde_json",
"settings",
"smol",
+ "sqlez",
"task",
"tempfile",
"terminal",
@@ -251,6 +254,7 @@ dependencies = [
"workspace-hack",
"worktree",
"zlog",
+ "zstd",
]
[[package]]
@@ -893,7 +893,7 @@ impl ThreadsDatabase {
let needs_migration_from_heed = mdb_path.exists();
- let connection = if *ZED_STATELESS {
+ let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
Connection::open_memory(Some("THREAD_FALLBACK_DB"))
} else {
Connection::open_file(&sqlite_path.to_string_lossy())
@@ -19,6 +19,7 @@ agent-client-protocol.workspace = true
agent_servers.workspace = true
agent_settings.workspace = true
anyhow.workspace = true
+assistant_context.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
chrono.workspace = true
@@ -39,6 +40,7 @@ language_model.workspace = true
language_models.workspace = true
log.workspace = true
open.workspace = true
+parking_lot.workspace = true
paths.workspace = true
portable-pty.workspace = true
project.workspace = true
@@ -49,6 +51,7 @@ serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
+sqlez.workspace = true
task.workspace = true
terminal.workspace = true
text.workspace = true
@@ -59,6 +62,7 @@ watch.workspace = true
web_search.workspace = true
which.workspace = true
workspace-hack.workspace = true
+zstd.workspace = true
[dev-dependencies]
agent = { workspace = true, "features" = ["test-support"] }
@@ -1,10 +1,9 @@
use crate::{
- ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
- EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
- OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
- UserMessageContent, WebSearchTool, templates::Templates,
+ ContextServerRegistry, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent,
+ templates::Templates,
};
-use acp_thread::AgentModelSelector;
+use crate::{HistoryStore, ThreadsDatabase};
+use acp_thread::{AcpThread, AgentModelSelector};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
@@ -51,7 +50,8 @@ struct Session {
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: WeakEntity<acp_thread::AcpThread>,
- _subscription: Subscription,
+ pending_save: Task<()>,
+ _subscriptions: Vec<Subscription>,
}
pub struct LanguageModels {
@@ -155,6 +155,7 @@ impl LanguageModels {
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
+ history: Entity<HistoryStore>,
/// Shared project context for all threads
project_context: Entity<ProjectContext>,
project_context_needs_refresh: watch::Sender<()>,
@@ -173,6 +174,7 @@ pub struct NativeAgent {
impl NativeAgent {
pub async fn new(
project: Entity<Project>,
+ history: Entity<HistoryStore>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
@@ -200,6 +202,7 @@ impl NativeAgent {
watch::channel(());
Self {
sessions: HashMap::new(),
+ history,
project_context: cx.new(|_| project_context),
project_context_needs_refresh: project_context_needs_refresh_tx,
_maintain_project_context: cx.spawn(async move |this, cx| {
@@ -218,6 +221,55 @@ impl NativeAgent {
})
}
+ fn register_session(
+ &mut self,
+ thread_handle: Entity<Thread>,
+ cx: &mut Context<Self>,
+ ) -> Entity<AcpThread> {
+ let connection = Rc::new(NativeAgentConnection(cx.entity()));
+ let registry = LanguageModelRegistry::read_global(cx);
+ let summarization_model = registry.thread_summary_model().map(|c| c.model);
+
+ thread_handle.update(cx, |thread, cx| {
+ thread.set_summarization_model(summarization_model, cx);
+ thread.add_default_tools(cx)
+ });
+
+ let thread = thread_handle.read(cx);
+ let session_id = thread.id().clone();
+ let title = thread.title();
+ let project = thread.project.clone();
+ let action_log = thread.action_log.clone();
+ let acp_thread = cx.new(|_cx| {
+ acp_thread::AcpThread::new(
+ title,
+ connection,
+ project.clone(),
+ action_log.clone(),
+ session_id.clone(),
+ )
+ });
+ let subscriptions = vec![
+ cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
+ this.sessions.remove(acp_thread.session_id());
+ }),
+ cx.observe(&thread_handle, move |this, thread, cx| {
+ this.save_thread(thread.clone(), cx)
+ }),
+ ];
+
+ self.sessions.insert(
+ session_id,
+ Session {
+ thread: thread_handle,
+ acp_thread: acp_thread.downgrade(),
+ _subscriptions: subscriptions,
+ pending_save: Task::ready(()),
+ },
+ );
+ acp_thread
+ }
+
pub fn models(&self) -> &LanguageModels {
&self.models
}
@@ -444,6 +496,63 @@ impl NativeAgent {
});
}
}
+
+ pub fn open_thread(
+ &mut self,
+ id: acp::SessionId,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ let database_future = ThreadsDatabase::connect(cx);
+ cx.spawn(async move |this, cx| {
+ let database = database_future.await.map_err(|err| anyhow!(err))?;
+ let db_thread = database
+ .load_thread(id.clone())
+ .await?
+ .with_context(|| format!("no thread found with ID: {id:?}"))?;
+
+ let thread = this.update(cx, |this, cx| {
+ let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
+ cx.new(|cx| {
+ Thread::from_db(
+ id.clone(),
+ db_thread,
+ this.project.clone(),
+ this.project_context.clone(),
+ this.context_server_registry.clone(),
+ action_log.clone(),
+ this.templates.clone(),
+ cx,
+ )
+ })
+ })?;
+ let acp_thread =
+ this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
+ let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
+ cx.update(|cx| {
+ NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
+ })?
+ .await?;
+ Ok(acp_thread)
+ })
+ }
+
+ fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
+ let database_future = ThreadsDatabase::connect(cx);
+ let (id, db_thread) =
+ thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
+ let Some(session) = self.sessions.get_mut(&id) else {
+ return;
+ };
+ let history = self.history.clone();
+ session.pending_save = cx.spawn(async move |_, cx| {
+ let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
+ return;
+ };
+ let db_thread = db_thread.await;
+ database.save_thread(id, db_thread).await.log_err();
+ history.update(cx, |history, cx| history.reload(cx)).ok();
+ });
+ }
}
/// Wrapper struct that implements the AgentConnection trait
@@ -476,13 +585,21 @@ impl NativeAgentConnection {
};
log::debug!("Found session for: {}", session_id);
- let mut response_stream = match f(thread, cx) {
+ let response_stream = match f(thread, cx) {
Ok(stream) => stream,
Err(err) => return Task::ready(Err(err)),
};
+ Self::handle_thread_events(response_stream, acp_thread, cx)
+ }
+
+ fn handle_thread_events(
+ mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
+ acp_thread: WeakEntity<AcpThread>,
+ cx: &App,
+ ) -> Task<Result<acp::PromptResponse>> {
cx.spawn(async move |cx| {
// Handle response stream and forward to session.acp_thread
- while let Some(result) = response_stream.next().await {
+ while let Some(result) = events.next().await {
match result {
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
@@ -686,8 +803,6 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
- let language_registry = project.read(cx).languages().clone();
-
// Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
@@ -697,72 +812,23 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.models
.model_from_id(&LanguageModels::model_id(&default_model.model))
});
- let summarization_model = registry.thread_summary_model().map(|c| c.model);
let thread = cx.new(|cx| {
- let mut thread = Thread::new(
+ Thread::new(
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
action_log.clone(),
agent.templates.clone(),
default_model,
- summarization_model,
cx,
- );
- thread.add_tool(CopyPathTool::new(project.clone()));
- thread.add_tool(CreateDirectoryTool::new(project.clone()));
- thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
- thread.add_tool(DiagnosticsTool::new(project.clone()));
- thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
- thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
- thread.add_tool(FindPathTool::new(project.clone()));
- thread.add_tool(GrepTool::new(project.clone()));
- thread.add_tool(ListDirectoryTool::new(project.clone()));
- thread.add_tool(MovePathTool::new(project.clone()));
- thread.add_tool(NowTool);
- thread.add_tool(OpenTool::new(project.clone()));
- thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
- thread.add_tool(TerminalTool::new(project.clone(), cx));
- thread.add_tool(ThinkingTool);
- thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
- thread
+ )
});
Ok(thread)
},
)??;
-
- let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
- log::info!("Created session with ID: {}", session_id);
- // Create AcpThread
- let acp_thread = cx.update(|cx| {
- cx.new(|_cx| {
- acp_thread::AcpThread::new(
- "agent2",
- self.clone(),
- project.clone(),
- action_log.clone(),
- session_id.clone(),
- )
- })
- })?;
-
- // Store the session
- agent.update(cx, |agent, cx| {
- agent.sessions.insert(
- session_id,
- Session {
- thread,
- acp_thread: acp_thread.downgrade(),
- _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
- this.sessions.remove(acp_thread.session_id());
- }),
- },
- );
- })?;
-
- Ok(acp_thread)
+ agent.update(cx, |agent, cx| agent.register_session(thread, cx))
})
}
@@ -887,8 +953,11 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
let agent = NativeAgent::new(
project.clone(),
+ history_store,
Templates::new(),
None,
fs.clone(),
@@ -942,9 +1011,12 @@ mod tests {
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
let connection = NativeAgentConnection(
NativeAgent::new(
project.clone(),
+ history_store,
Templates::new(),
None,
fs.clone(),
@@ -995,9 +1067,13 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [], cx).await;
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
+
// Create the agent and connection
let agent = NativeAgent::new(
project.clone(),
+ history_store,
Templates::new(),
None,
fs.clone(),
@@ -1,4 +1,6 @@
mod agent;
+mod db;
+mod history_store;
mod native_agent_server;
mod templates;
mod thread;
@@ -9,6 +11,8 @@ mod tools;
mod tests;
pub use agent::*;
+pub use db::*;
+pub use history_store::*;
pub use native_agent_server::NativeAgentServer;
pub use templates::*;
pub use thread::*;
@@ -0,0 +1,470 @@
+use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
+use agent::thread_store;
+use agent_client_protocol as acp;
+use agent_settings::{AgentProfileId, CompletionMode};
+use anyhow::{Result, anyhow};
+use chrono::{DateTime, Utc};
+use collections::{HashMap, IndexMap};
+use futures::{FutureExt, future::Shared};
+use gpui::{BackgroundExecutor, Global, Task};
+use indoc::indoc;
+use parking_lot::Mutex;
+use serde::{Deserialize, Serialize};
+use sqlez::{
+ bindable::{Bind, Column},
+ connection::Connection,
+ statement::Statement,
+};
+use std::sync::Arc;
+use ui::{App, SharedString};
+
+pub type DbMessage = crate::Message;
+pub type DbSummary = agent::thread::DetailedSummaryState;
+pub type DbLanguageModel = thread_store::SerializedLanguageModel;
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbThreadMetadata {
+ pub id: acp::SessionId,
+ #[serde(alias = "summary")]
+ pub title: SharedString,
+ pub updated_at: DateTime<Utc>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct DbThread {
+ pub title: SharedString,
+ pub messages: Vec<DbMessage>,
+ pub updated_at: DateTime<Utc>,
+ #[serde(default)]
+ pub summary: DbSummary,
+ #[serde(default)]
+ pub initial_project_snapshot: Option<Arc<agent::thread::ProjectSnapshot>>,
+ #[serde(default)]
+ pub cumulative_token_usage: language_model::TokenUsage,
+ #[serde(default)]
+ pub request_token_usage: Vec<language_model::TokenUsage>,
+ #[serde(default)]
+ pub model: Option<DbLanguageModel>,
+ #[serde(default)]
+ pub completion_mode: Option<CompletionMode>,
+ #[serde(default)]
+ pub profile: Option<AgentProfileId>,
+}
+
+impl DbThread {
+ pub const VERSION: &'static str = "0.3.0";
+
+ pub fn from_json(json: &[u8]) -> Result<Self> {
+ let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
+ match saved_thread_json.get("version") {
+ Some(serde_json::Value::String(version)) => match version.as_str() {
+ Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?),
+ _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
+ },
+ _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?),
+ }
+ }
+
+ fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result<Self> {
+ let mut messages = Vec::new();
+ for msg in thread.messages {
+ let message = match msg.role {
+ language_model::Role::User => {
+ let mut content = Vec::new();
+
+ // Convert segments to content
+ for segment in msg.segments {
+ match segment {
+ thread_store::SerializedMessageSegment::Text { text } => {
+ content.push(UserMessageContent::Text(text));
+ }
+ thread_store::SerializedMessageSegment::Thinking { text, .. } => {
+ // User messages don't have thinking segments, but handle gracefully
+ content.push(UserMessageContent::Text(text));
+ }
+ thread_store::SerializedMessageSegment::RedactedThinking { .. } => {
+ // User messages don't have redacted thinking, skip.
+ }
+ }
+ }
+
+ // If no content was added, add context as text if available
+ if content.is_empty() && !msg.context.is_empty() {
+ content.push(UserMessageContent::Text(msg.context));
+ }
+
+ crate::Message::User(UserMessage {
+ // MessageId from old format can't be meaningfully converted, so generate a new one
+ id: acp_thread::UserMessageId::new(),
+ content,
+ })
+ }
+ language_model::Role::Assistant => {
+ let mut content = Vec::new();
+
+ // Convert segments to content
+ for segment in msg.segments {
+ match segment {
+ thread_store::SerializedMessageSegment::Text { text } => {
+ content.push(AgentMessageContent::Text(text));
+ }
+ thread_store::SerializedMessageSegment::Thinking {
+ text,
+ signature,
+ } => {
+ content.push(AgentMessageContent::Thinking { text, signature });
+ }
+ thread_store::SerializedMessageSegment::RedactedThinking { data } => {
+ content.push(AgentMessageContent::RedactedThinking(data));
+ }
+ }
+ }
+
+ // Convert tool uses
+ let mut tool_names_by_id = HashMap::default();
+ for tool_use in msg.tool_uses {
+ tool_names_by_id.insert(tool_use.id.clone(), tool_use.name.clone());
+ content.push(AgentMessageContent::ToolUse(
+ language_model::LanguageModelToolUse {
+ id: tool_use.id,
+ name: tool_use.name.into(),
+ raw_input: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ input: tool_use.input,
+ is_input_complete: true,
+ },
+ ));
+ }
+
+ // Convert tool results
+ let mut tool_results = IndexMap::default();
+ for tool_result in msg.tool_results {
+ let name = tool_names_by_id
+ .remove(&tool_result.tool_use_id)
+ .unwrap_or_else(|| SharedString::from("unknown"));
+ tool_results.insert(
+ tool_result.tool_use_id.clone(),
+ language_model::LanguageModelToolResult {
+ tool_use_id: tool_result.tool_use_id,
+ tool_name: name.into(),
+ is_error: tool_result.is_error,
+ content: tool_result.content,
+ output: tool_result.output,
+ },
+ );
+ }
+
+ crate::Message::Agent(AgentMessage {
+ content,
+ tool_results,
+ })
+ }
+ language_model::Role::System => {
+ // Skip system messages as they're not supported in the new format
+ continue;
+ }
+ };
+
+ messages.push(message);
+ }
+
+ Ok(Self {
+ title: thread.summary,
+ messages,
+ updated_at: thread.updated_at,
+ summary: thread.detailed_summary_state,
+ initial_project_snapshot: thread.initial_project_snapshot,
+ cumulative_token_usage: thread.cumulative_token_usage,
+ request_token_usage: thread.request_token_usage,
+ model: thread.model,
+ completion_mode: thread.completion_mode,
+ profile: thread.profile,
+ })
+ }
+}
+
+pub static ZED_STATELESS: std::sync::LazyLock<bool> =
+ std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
+
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub enum DataType {
+ #[serde(rename = "json")]
+ Json,
+ #[serde(rename = "zstd")]
+ Zstd,
+}
+
+impl Bind for DataType {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ let value = match self {
+ DataType::Json => "json",
+ DataType::Zstd => "zstd",
+ };
+ value.bind(statement, start_index)
+ }
+}
+
+impl Column for DataType {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let (value, next_index) = String::column(statement, start_index)?;
+ let data_type = match value.as_str() {
+ "json" => DataType::Json,
+ "zstd" => DataType::Zstd,
+ _ => anyhow::bail!("Unknown data type: {}", value),
+ };
+ Ok((data_type, next_index))
+ }
+}
+
+pub(crate) struct ThreadsDatabase {
+ executor: BackgroundExecutor,
+ connection: Arc<Mutex<Connection>>,
+}
+
+struct GlobalThreadsDatabase(Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>);
+
+impl Global for GlobalThreadsDatabase {}
+
+impl ThreadsDatabase {
+ pub fn connect(cx: &mut App) -> Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
+ if cx.has_global::<GlobalThreadsDatabase>() {
+ return cx.global::<GlobalThreadsDatabase>().0.clone();
+ }
+ let executor = cx.background_executor().clone();
+ let task = executor
+ .spawn({
+ let executor = executor.clone();
+ async move {
+ match ThreadsDatabase::new(executor) {
+ Ok(db) => Ok(Arc::new(db)),
+ Err(err) => Err(Arc::new(err)),
+ }
+ }
+ })
+ .shared();
+
+ cx.set_global(GlobalThreadsDatabase(task.clone()));
+ task
+ }
+
+ pub fn new(executor: BackgroundExecutor) -> Result<Self> {
+ let connection = if *ZED_STATELESS || cfg!(any(feature = "test-support", test)) {
+ Connection::open_memory(Some("THREAD_FALLBACK_DB"))
+ } else {
+ let threads_dir = paths::data_dir().join("threads");
+ std::fs::create_dir_all(&threads_dir)?;
+ let sqlite_path = threads_dir.join("threads.db");
+ Connection::open_file(&sqlite_path.to_string_lossy())
+ };
+
+ connection.exec(indoc! {"
+ CREATE TABLE IF NOT EXISTS threads (
+ id TEXT PRIMARY KEY,
+ summary TEXT NOT NULL,
+ updated_at TEXT NOT NULL,
+ data_type TEXT NOT NULL,
+ data BLOB NOT NULL
+ )
+ "})?()
+ .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
+
+ let db = Self {
+ executor: executor.clone(),
+ connection: Arc::new(Mutex::new(connection)),
+ };
+
+ Ok(db)
+ }
+
+ fn save_thread_sync(
+ connection: &Arc<Mutex<Connection>>,
+ id: acp::SessionId,
+ thread: DbThread,
+ ) -> Result<()> {
+ const COMPRESSION_LEVEL: i32 = 3;
+
+ #[derive(Serialize)]
+ struct SerializedThread {
+ #[serde(flatten)]
+ thread: DbThread,
+ version: &'static str,
+ }
+
+ let title = thread.title.to_string();
+ let updated_at = thread.updated_at.to_rfc3339();
+ let json_data = serde_json::to_string(&SerializedThread {
+ thread,
+ version: DbThread::VERSION,
+ })?;
+
+ let connection = connection.lock();
+
+ let compressed = zstd::encode_all(json_data.as_bytes(), COMPRESSION_LEVEL)?;
+ let data_type = DataType::Zstd;
+ let data = compressed;
+
+ let mut insert = connection.exec_bound::<(Arc<str>, String, String, DataType, Vec<u8>)>(indoc! {"
+ INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
+ "})?;
+
+ insert((id.0.clone(), title, updated_at, data_type, data))?;
+
+ Ok(())
+ }
+
+ pub fn list_threads(&self) -> Task<Result<Vec<DbThreadMetadata>>> {
+ let connection = self.connection.clone();
+
+ self.executor.spawn(async move {
+ let connection = connection.lock();
+
+ let mut select =
+ connection.select_bound::<(), (Arc<str>, String, String)>(indoc! {"
+ SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
+ "})?;
+
+ let rows = select(())?;
+ let mut threads = Vec::new();
+
+ for (id, summary, updated_at) in rows {
+ threads.push(DbThreadMetadata {
+ id: acp::SessionId(id),
+ title: summary.into(),
+ updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
+ });
+ }
+
+ Ok(threads)
+ })
+ }
+
+ pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
+ let connection = self.connection.clone();
+
+ self.executor.spawn(async move {
+ let connection = connection.lock();
+ let mut select = connection.select_bound::<Arc<str>, (DataType, Vec<u8>)>(indoc! {"
+ SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
+ "})?;
+
+ let rows = select(id.0)?;
+ if let Some((data_type, data)) = rows.into_iter().next() {
+ let json_data = match data_type {
+ DataType::Zstd => {
+ let decompressed = zstd::decode_all(&data[..])?;
+ String::from_utf8(decompressed)?
+ }
+ DataType::Json => String::from_utf8(data)?,
+ };
+ let thread = DbThread::from_json(json_data.as_bytes())?;
+ Ok(Some(thread))
+ } else {
+ Ok(None)
+ }
+ })
+ }
+
+ pub fn save_thread(&self, id: acp::SessionId, thread: DbThread) -> Task<Result<()>> {
+ let connection = self.connection.clone();
+
+ self.executor
+ .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
+ }
+
+ pub fn delete_thread(&self, id: acp::SessionId) -> Task<Result<()>> {
+ let connection = self.connection.clone();
+
+ self.executor.spawn(async move {
+ let connection = connection.lock();
+
+ let mut delete = connection.exec_bound::<Arc<str>>(indoc! {"
+ DELETE FROM threads WHERE id = ?
+ "})?;
+
+ delete(id.0)?;
+
+ Ok(())
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+ use agent::MessageSegment;
+ use agent::context::LoadedContext;
+ use client::Client;
+ use fs::FakeFs;
+ use gpui::AppContext;
+ use gpui::TestAppContext;
+ use http_client::FakeHttpClient;
+ use language_model::Role;
+ use project::Project;
+ use settings::SettingsStore;
+
+ fn init_test(cx: &mut TestAppContext) {
+ env_logger::try_init().ok();
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ language::init(cx);
+
+ let http_client = FakeHttpClient::with_404_response();
+ let clock = Arc::new(clock::FakeSystemClock::new());
+ let client = Client::new(clock, http_client, cx);
+ agent::init(cx);
+ agent_settings::init(cx);
+ language_model::init(client.clone(), cx);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_retrieving_old_thread(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+
+ // Save a thread using the old agent.
+ let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx));
+ let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
+ thread.update(cx, |thread, cx| {
+ thread.insert_message(
+ Role::User,
+ vec![MessageSegment::Text("Hey!".into())],
+ LoadedContext::default(),
+ vec![],
+ false,
+ cx,
+ );
+ thread.insert_message(
+ Role::Assistant,
+ vec![MessageSegment::Text("How're you doing?".into())],
+ LoadedContext::default(),
+ vec![],
+ false,
+ cx,
+ )
+ });
+ thread_store
+ .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
+ .await
+ .unwrap();
+
+ // Open that same thread using the new agent.
+ let db = cx.update(ThreadsDatabase::connect).await.unwrap();
+ let threads = db.list_threads().await.unwrap();
+ assert_eq!(threads.len(), 1);
+ let thread = db
+ .load_thread(threads[0].id.clone())
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n");
+ assert_eq!(
+ thread.messages[1].to_markdown(),
+ "## Assistant\n\nHow're you doing?\n"
+ );
+ }
+}
@@ -0,0 +1,314 @@
+use crate::{DbThreadMetadata, ThreadsDatabase};
+use agent_client_protocol as acp;
+use anyhow::{Context as _, Result, anyhow};
+use assistant_context::SavedContextMetadata;
+use chrono::{DateTime, Utc};
+use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
+use itertools::Itertools;
+use paths::contexts_dir;
+use serde::{Deserialize, Serialize};
+use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
+use util::ResultExt as _;
+
+const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
+const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
+const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
+
+const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
+
+#[derive(Clone, Debug)]
+pub enum HistoryEntry {
+ AcpThread(DbThreadMetadata),
+ TextThread(SavedContextMetadata),
+}
+
+impl HistoryEntry {
+ pub fn updated_at(&self) -> DateTime<Utc> {
+ match self {
+ HistoryEntry::AcpThread(thread) => thread.updated_at,
+ HistoryEntry::TextThread(context) => context.mtime.to_utc(),
+ }
+ }
+
+ pub fn id(&self) -> HistoryEntryId {
+ match self {
+ HistoryEntry::AcpThread(thread) => HistoryEntryId::AcpThread(thread.id.clone()),
+ HistoryEntry::TextThread(context) => HistoryEntryId::TextThread(context.path.clone()),
+ }
+ }
+
+ pub fn title(&self) -> &SharedString {
+ match self {
+ HistoryEntry::AcpThread(thread) if thread.title.is_empty() => DEFAULT_TITLE,
+ HistoryEntry::AcpThread(thread) => &thread.title,
+ HistoryEntry::TextThread(context) => &context.title,
+ }
+ }
+}
+
+/// Generic identifier for a history entry.
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub enum HistoryEntryId {
+ AcpThread(acp::SessionId),
+ TextThread(Arc<Path>),
+}
+
+#[derive(Serialize, Deserialize)]
+enum SerializedRecentOpen {
+ Thread(String),
+ ContextName(String),
+ /// Old format which stores the full path
+ Context(String),
+}
+
+pub struct HistoryStore {
+ threads: Vec<DbThreadMetadata>,
+ context_store: Entity<assistant_context::ContextStore>,
+ recently_opened_entries: VecDeque<HistoryEntryId>,
+ _subscriptions: Vec<gpui::Subscription>,
+ _save_recently_opened_entries_task: Task<()>,
+}
+
+impl HistoryStore {
+ pub fn new(
+ context_store: Entity<assistant_context::ContextStore>,
+ initial_recent_entries: impl IntoIterator<Item = HistoryEntryId>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let subscriptions = vec![cx.observe(&context_store, |_, _, cx| cx.notify())];
+
+ cx.spawn(async move |this, cx| {
+ let entries = Self::load_recently_opened_entries(cx).await.log_err()?;
+ this.update(cx, |this, _| {
+ this.recently_opened_entries
+ .extend(
+ entries.into_iter().take(
+ MAX_RECENTLY_OPENED_ENTRIES
+ .saturating_sub(this.recently_opened_entries.len()),
+ ),
+ );
+ })
+ .ok()
+ })
+ .detach();
+
+ Self {
+ context_store,
+ recently_opened_entries: initial_recent_entries.into_iter().collect(),
+ threads: Vec::default(),
+ _subscriptions: subscriptions,
+ _save_recently_opened_entries_task: Task::ready(()),
+ }
+ }
+
+ pub fn delete_thread(
+ &mut self,
+ id: acp::SessionId,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ let database_future = ThreadsDatabase::connect(cx);
+ cx.spawn(async move |this, cx| {
+ let database = database_future.await.map_err(|err| anyhow!(err))?;
+ database.delete_thread(id.clone()).await?;
+ this.update(cx, |this, cx| this.reload(cx))
+ })
+ }
+
+ pub fn delete_text_thread(
+ &mut self,
+ path: Arc<Path>,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ self.context_store.update(cx, |context_store, cx| {
+ context_store.delete_local_context(path, cx)
+ })
+ }
+
+ pub fn reload(&self, cx: &mut Context<Self>) {
+ let database_future = ThreadsDatabase::connect(cx);
+ cx.spawn(async move |this, cx| {
+ let threads = database_future
+ .await
+ .map_err(|err| anyhow!(err))?
+ .list_threads()
+ .await?;
+
+ this.update(cx, |this, cx| {
+ this.threads = threads;
+ cx.notify();
+ })
+ })
+ .detach_and_log_err(cx);
+ }
+
+ pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
+ let mut history_entries = Vec::new();
+
+ #[cfg(debug_assertions)]
+ if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
+ return history_entries;
+ }
+
+ history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread));
+ history_entries.extend(
+ self.context_store
+ .read(cx)
+ .unordered_contexts()
+ .cloned()
+ .map(HistoryEntry::TextThread),
+ );
+
+ history_entries.sort_unstable_by_key(|entry| std::cmp::Reverse(entry.updated_at()));
+ history_entries
+ }
+
+ pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
+ self.entries(cx).into_iter().take(limit).collect()
+ }
+
+ pub fn recently_opened_entries(&self, cx: &App) -> Vec<HistoryEntry> {
+ #[cfg(debug_assertions)]
+ if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
+ return Vec::new();
+ }
+
+ let thread_entries = self.threads.iter().flat_map(|thread| {
+ self.recently_opened_entries
+ .iter()
+ .enumerate()
+ .flat_map(|(index, entry)| match entry {
+ HistoryEntryId::AcpThread(id) if &thread.id == id => {
+ Some((index, HistoryEntry::AcpThread(thread.clone())))
+ }
+ _ => None,
+ })
+ });
+
+ let context_entries =
+ self.context_store
+ .read(cx)
+ .unordered_contexts()
+ .flat_map(|context| {
+ self.recently_opened_entries
+ .iter()
+ .enumerate()
+ .flat_map(|(index, entry)| match entry {
+ HistoryEntryId::TextThread(path) if &context.path == path => {
+ Some((index, HistoryEntry::TextThread(context.clone())))
+ }
+ _ => None,
+ })
+ });
+
+ thread_entries
+ .chain(context_entries)
+ // optimization to halt iteration early
+ .take(self.recently_opened_entries.len())
+ .sorted_unstable_by_key(|(index, _)| *index)
+ .map(|(_, entry)| entry)
+ .collect()
+ }
+
+ fn save_recently_opened_entries(&mut self, cx: &mut Context<Self>) {
+ let serialized_entries = self
+ .recently_opened_entries
+ .iter()
+ .filter_map(|entry| match entry {
+ HistoryEntryId::TextThread(path) => path.file_name().map(|file| {
+ SerializedRecentOpen::ContextName(file.to_string_lossy().to_string())
+ }),
+ HistoryEntryId::AcpThread(id) => Some(SerializedRecentOpen::Thread(id.to_string())),
+ })
+ .collect::<Vec<_>>();
+
+ self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
+ cx.background_executor()
+ .timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
+ .await;
+ cx.background_spawn(async move {
+ let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
+ let content = serde_json::to_string(&serialized_entries)?;
+ std::fs::write(path, content)?;
+ anyhow::Ok(())
+ })
+ .await
+ .log_err();
+ });
+ }
+
+ fn load_recently_opened_entries(cx: &AsyncApp) -> Task<Result<Vec<HistoryEntryId>>> {
+ cx.background_spawn(async move {
+ let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
+ let contents = match smol::fs::read_to_string(path).await {
+ Ok(it) => it,
+ Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
+ return Ok(Vec::new());
+ }
+ Err(e) => {
+ return Err(e)
+ .context("deserializing persisted agent panel navigation history");
+ }
+ };
+ let entries = serde_json::from_str::<Vec<SerializedRecentOpen>>(&contents)
+ .context("deserializing persisted agent panel navigation history")?
+ .into_iter()
+ .take(MAX_RECENTLY_OPENED_ENTRIES)
+ .flat_map(|entry| match entry {
+ SerializedRecentOpen::Thread(id) => Some(HistoryEntryId::AcpThread(
+ acp::SessionId(id.as_str().into()),
+ )),
+ SerializedRecentOpen::ContextName(file_name) => Some(
+ HistoryEntryId::TextThread(contexts_dir().join(file_name).into()),
+ ),
+ SerializedRecentOpen::Context(path) => {
+ Path::new(&path).file_name().map(|file_name| {
+ HistoryEntryId::TextThread(contexts_dir().join(file_name).into())
+ })
+ }
+ })
+ .collect::<Vec<_>>();
+ Ok(entries)
+ })
+ }
+
+ pub fn push_recently_opened_entry(&mut self, entry: HistoryEntryId, cx: &mut Context<Self>) {
+ self.recently_opened_entries
+ .retain(|old_entry| old_entry != &entry);
+ self.recently_opened_entries.push_front(entry);
+ self.recently_opened_entries
+ .truncate(MAX_RECENTLY_OPENED_ENTRIES);
+ self.save_recently_opened_entries(cx);
+ }
+
+ pub fn remove_recently_opened_thread(&mut self, id: acp::SessionId, cx: &mut Context<Self>) {
+ self.recently_opened_entries.retain(|entry| match entry {
+ HistoryEntryId::AcpThread(thread_id) if thread_id == &id => false,
+ _ => true,
+ });
+ self.save_recently_opened_entries(cx);
+ }
+
+ pub fn replace_recently_opened_text_thread(
+ &mut self,
+ old_path: &Path,
+ new_path: &Arc<Path>,
+ cx: &mut Context<Self>,
+ ) {
+ for entry in &mut self.recently_opened_entries {
+ match entry {
+ HistoryEntryId::TextThread(path) if path.as_ref() == old_path => {
+ *entry = HistoryEntryId::TextThread(new_path.clone());
+ break;
+ }
+ _ => {}
+ }
+ }
+ self.save_recently_opened_entries(cx);
+ }
+
+ pub fn remove_recently_opened_entry(&mut self, entry: &HistoryEntryId, cx: &mut Context<Self>) {
+ self.recently_opened_entries
+ .retain(|old_entry| old_entry != entry);
+ self.save_recently_opened_entries(cx);
+ }
+}
@@ -7,16 +7,17 @@ use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
-use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
+use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
#[derive(Clone)]
pub struct NativeAgentServer {
fs: Arc<dyn Fs>,
+ history: Entity<HistoryStore>,
}
impl NativeAgentServer {
- pub fn new(fs: Arc<dyn Fs>) -> Self {
- Self { fs }
+ pub fn new(fs: Arc<dyn Fs>, history: Entity<HistoryStore>) -> Self {
+ Self { fs, history }
}
}
@@ -50,6 +51,7 @@ impl AgentServer for NativeAgentServer {
);
let project = project.clone();
let fs = self.fs.clone();
+ let history = self.history.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
@@ -57,7 +59,8 @@ impl AgentServer for NativeAgentServer {
let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity");
- let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
+ let agent =
+ NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);
@@ -1273,10 +1273,13 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, [], cx));
// Create agent and connection
let agent = NativeAgent::new(
project.clone(),
+ history_store,
templates.clone(),
None,
fake_fs.clone(),
@@ -1756,7 +1759,6 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
action_log,
templates,
Some(model.clone()),
- None,
cx,
)
});
@@ -1,4 +1,9 @@
-use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
+use crate::{
+ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
+ DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
+ ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
+ Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
+};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
@@ -17,13 +22,13 @@ use futures::{
stream::FuturesUnordered,
};
use git::repository::DiffType;
-use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
+use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
- LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
- LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
- TokenUsage,
+ LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+ LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
+ LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
};
use project::{
Project,
@@ -516,8 +521,8 @@ pub struct Thread {
templates: Arc<Templates>,
model: Option<Arc<dyn LanguageModel>>,
summarization_model: Option<Arc<dyn LanguageModel>>,
- project: Entity<Project>,
- action_log: Entity<ActionLog>,
+ pub(crate) project: Entity<Project>,
+ pub(crate) action_log: Entity<ActionLog>,
}
impl Thread {
@@ -528,7 +533,6 @@ impl Thread {
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
model: Option<Arc<dyn LanguageModel>>,
- summarization_model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@@ -557,7 +561,7 @@ impl Thread {
project_context,
templates,
model,
- summarization_model,
+ summarization_model: None,
project,
action_log,
}
@@ -652,6 +656,88 @@ impl Thread {
);
}
+ pub fn from_db(
+ id: acp::SessionId,
+ db_thread: DbThread,
+ project: Entity<Project>,
+ project_context: Entity<ProjectContext>,
+ context_server_registry: Entity<ContextServerRegistry>,
+ action_log: Entity<ActionLog>,
+ templates: Arc<Templates>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let profile_id = db_thread
+ .profile
+ .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
+ let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ db_thread
+ .model
+ .and_then(|model| {
+ let model = SelectedModel {
+ provider: model.provider.clone().into(),
+ model: model.model.clone().into(),
+ };
+ registry.select_model(&model, cx)
+ })
+ .or_else(|| registry.default_model())
+ .map(|model| model.model)
+ });
+
+ Self {
+ id,
+ prompt_id: PromptId::new(),
+ title: if db_thread.title.is_empty() {
+ None
+ } else {
+ Some(db_thread.title.clone())
+ },
+ summary: db_thread.summary,
+ messages: db_thread.messages,
+ completion_mode: db_thread.completion_mode.unwrap_or_default(),
+ running_turn: None,
+ pending_message: None,
+ tools: BTreeMap::default(),
+ tool_use_limit_reached: false,
+ request_token_usage: db_thread.request_token_usage.clone(),
+ cumulative_token_usage: db_thread.cumulative_token_usage,
+ initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
+ context_server_registry,
+ profile_id,
+ project_context,
+ templates,
+ model,
+ summarization_model: None,
+ project,
+ action_log,
+ updated_at: db_thread.updated_at,
+ }
+ }
+
+ pub fn to_db(&self, cx: &App) -> Task<DbThread> {
+ let initial_project_snapshot = self.initial_project_snapshot.clone();
+ let mut thread = DbThread {
+ title: self.title.clone().unwrap_or_default(),
+ messages: self.messages.clone(),
+ updated_at: self.updated_at,
+ summary: self.summary.clone(),
+ initial_project_snapshot: None,
+ cumulative_token_usage: self.cumulative_token_usage,
+ request_token_usage: self.request_token_usage.clone(),
+ model: self.model.as_ref().map(|model| DbLanguageModel {
+ provider: model.provider_id().to_string(),
+ model: model.name().0.to_string(),
+ }),
+ completion_mode: Some(self.completion_mode),
+ profile: Some(self.profile_id.clone()),
+ };
+
+ cx.background_spawn(async move {
+ let initial_project_snapshot = initial_project_snapshot.await;
+ thread.initial_project_snapshot = initial_project_snapshot;
+ thread
+ })
+ }
+
/// Create a snapshot of the current project state including git information and unsaved buffers.
fn project_snapshot(
project: Entity<Project>,
@@ -816,6 +902,32 @@ impl Thread {
}
}
+ pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
+ let language_registry = self.project.read(cx).languages().clone();
+ self.add_tool(CopyPathTool::new(self.project.clone()));
+ self.add_tool(CreateDirectoryTool::new(self.project.clone()));
+ self.add_tool(DeletePathTool::new(
+ self.project.clone(),
+ self.action_log.clone(),
+ ));
+ self.add_tool(DiagnosticsTool::new(self.project.clone()));
+ self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
+ self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
+ self.add_tool(FindPathTool::new(self.project.clone()));
+ self.add_tool(GrepTool::new(self.project.clone()));
+ self.add_tool(ListDirectoryTool::new(self.project.clone()));
+ self.add_tool(MovePathTool::new(self.project.clone()));
+ self.add_tool(NowTool);
+ self.add_tool(OpenTool::new(self.project.clone()));
+ self.add_tool(ReadFileTool::new(
+ self.project.clone(),
+ self.action_log.clone(),
+ ));
+ self.add_tool(TerminalTool::new(self.project.clone(), cx));
+ self.add_tool(ThinkingTool);
+ self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
+ }
+
pub fn add_tool(&mut self, tool: impl AgentTool) {
self.tools.insert(tool.name(), tool.erase());
}
@@ -554,7 +554,6 @@ mod tests {
action_log,
Templates::new(),
Some(model),
- None,
cx,
)
});
@@ -756,7 +755,6 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
- None,
cx,
)
});
@@ -899,7 +897,6 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
- None,
cx,
)
});
@@ -1029,7 +1026,6 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
- None,
cx,
)
});
@@ -1168,7 +1164,6 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
- None,
cx,
)
});
@@ -1279,7 +1274,6 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
- None,
cx,
)
});
@@ -1362,7 +1356,6 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
- None,
cx,
)
});
@@ -1448,7 +1441,6 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
- None,
cx,
)
});
@@ -1531,7 +1523,6 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
- None,
cx,
)
});
@@ -3,8 +3,10 @@ mod entry_view_state;
mod message_editor;
mod model_selector;
mod model_selector_popover;
+mod thread_history;
mod thread_view;
pub use model_selector::AcpModelSelector;
pub use model_selector_popover::AcpModelSelectorPopover;
+pub use thread_history::*;
pub use thread_view::AcpThreadView;
@@ -0,0 +1,766 @@
+use crate::RemoveSelectedThread;
+use agent2::{HistoryEntry, HistoryStore};
+use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
+use editor::{Editor, EditorEvent};
+use fuzzy::{StringMatch, StringMatchCandidate};
+use gpui::{
+ App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
+ UniformListScrollHandle, Window, uniform_list,
+};
+use std::{fmt::Display, ops::Range, sync::Arc};
+use time::{OffsetDateTime, UtcOffset};
+use ui::{
+ HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState,
+ Tooltip, prelude::*,
+};
+use util::ResultExt;
+
+pub struct AcpThreadHistory {
+ pub(crate) history_store: Entity<HistoryStore>,
+ scroll_handle: UniformListScrollHandle,
+ selected_index: usize,
+ hovered_index: Option<usize>,
+ search_editor: Entity<Editor>,
+ all_entries: Arc<Vec<HistoryEntry>>,
+ // When the search is empty, we display date separators between history entries
+ // This vector contains an enum of either a separator or an actual entry
+ separated_items: Vec<ListItemType>,
+ // Maps entry indexes to list item indexes
+ separated_item_indexes: Vec<u32>,
+ _separated_items_task: Option<Task<()>>,
+ search_state: SearchState,
+ scrollbar_visibility: bool,
+ scrollbar_state: ScrollbarState,
+ local_timezone: UtcOffset,
+ _subscriptions: Vec<gpui::Subscription>,
+}
+
+enum SearchState {
+ Empty,
+ Searching {
+ query: SharedString,
+ _task: Task<()>,
+ },
+ Searched {
+ query: SharedString,
+ matches: Vec<StringMatch>,
+ },
+}
+
+enum ListItemType {
+ BucketSeparator(TimeBucket),
+ Entry {
+ index: usize,
+ format: EntryTimeFormat,
+ },
+}
+
+pub enum ThreadHistoryEvent {
+ Open(HistoryEntry),
+}
+
+impl EventEmitter<ThreadHistoryEvent> for AcpThreadHistory {}
+
+impl AcpThreadHistory {
+ pub(crate) fn new(
+ history_store: Entity<agent2::HistoryStore>,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let search_editor = cx.new(|cx| {
+ let mut editor = Editor::single_line(window, cx);
+ editor.set_placeholder_text("Search threads...", cx);
+ editor
+ });
+
+ let search_editor_subscription =
+ cx.subscribe(&search_editor, |this, search_editor, event, cx| {
+ if let EditorEvent::BufferEdited = event {
+ let query = search_editor.read(cx).text(cx);
+ this.search(query.into(), cx);
+ }
+ });
+
+ let history_store_subscription = cx.observe(&history_store, |this, _, cx| {
+ this.update_all_entries(cx);
+ });
+
+ let scroll_handle = UniformListScrollHandle::default();
+ let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
+
+ let mut this = Self {
+ history_store,
+ scroll_handle,
+ selected_index: 0,
+ hovered_index: None,
+ search_state: SearchState::Empty,
+ all_entries: Default::default(),
+ separated_items: Default::default(),
+ separated_item_indexes: Default::default(),
+ search_editor,
+ scrollbar_visibility: true,
+ scrollbar_state,
+ local_timezone: UtcOffset::from_whole_seconds(
+ chrono::Local::now().offset().local_minus_utc(),
+ )
+ .unwrap(),
+ _subscriptions: vec![search_editor_subscription, history_store_subscription],
+ _separated_items_task: None,
+ };
+ this.update_all_entries(cx);
+ this
+ }
+
+ fn update_all_entries(&mut self, cx: &mut Context<Self>) {
+ let new_entries: Arc<Vec<HistoryEntry>> = self
+ .history_store
+ .update(cx, |store, cx| store.entries(cx))
+ .into();
+
+ self._separated_items_task.take();
+
+ let mut items = Vec::with_capacity(new_entries.len() + 1);
+ let mut indexes = Vec::with_capacity(new_entries.len() + 1);
+
+ let bg_task = cx.background_spawn(async move {
+ let mut bucket = None;
+ let today = Local::now().naive_local().date();
+
+ for (index, entry) in new_entries.iter().enumerate() {
+ let entry_date = entry
+ .updated_at()
+ .with_timezone(&Local)
+ .naive_local()
+ .date();
+ let entry_bucket = TimeBucket::from_dates(today, entry_date);
+
+ if Some(entry_bucket) != bucket {
+ bucket = Some(entry_bucket);
+ items.push(ListItemType::BucketSeparator(entry_bucket));
+ }
+
+ indexes.push(items.len() as u32);
+ items.push(ListItemType::Entry {
+ index,
+ format: entry_bucket.into(),
+ });
+ }
+ (new_entries, items, indexes)
+ });
+
+ let task = cx.spawn(async move |this, cx| {
+ let (new_entries, items, indexes) = bg_task.await;
+ this.update(cx, |this, cx| {
+ let previously_selected_entry =
+ this.all_entries.get(this.selected_index).map(|e| e.id());
+
+ this.all_entries = new_entries;
+ this.separated_items = items;
+ this.separated_item_indexes = indexes;
+
+ match &this.search_state {
+ SearchState::Empty => {
+ if this.selected_index >= this.all_entries.len() {
+ this.set_selected_entry_index(
+ this.all_entries.len().saturating_sub(1),
+ cx,
+ );
+ } else if let Some(prev_id) = previously_selected_entry
+ && let Some(new_ix) = this
+ .all_entries
+ .iter()
+ .position(|probe| probe.id() == prev_id)
+ {
+ this.set_selected_entry_index(new_ix, cx);
+ }
+ }
+ SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {
+ this.search(query.clone(), cx);
+ }
+ }
+
+ cx.notify();
+ })
+ .log_err();
+ });
+ self._separated_items_task = Some(task);
+ }
+
+ fn search(&mut self, query: SharedString, cx: &mut Context<Self>) {
+ if query.is_empty() {
+ self.search_state = SearchState::Empty;
+ cx.notify();
+ return;
+ }
+
+ let all_entries = self.all_entries.clone();
+
+ let fuzzy_search_task = cx.background_spawn({
+ let query = query.clone();
+ let executor = cx.background_executor().clone();
+ async move {
+ let mut candidates = Vec::with_capacity(all_entries.len());
+
+ for (idx, entry) in all_entries.iter().enumerate() {
+ candidates.push(StringMatchCandidate::new(idx, entry.title()));
+ }
+
+ const MAX_MATCHES: usize = 100;
+
+ fuzzy::match_strings(
+ &candidates,
+ &query,
+ false,
+ true,
+ MAX_MATCHES,
+ &Default::default(),
+ executor,
+ )
+ .await
+ }
+ });
+
+ let task = cx.spawn({
+ let query = query.clone();
+ async move |this, cx| {
+ let matches = fuzzy_search_task.await;
+
+ this.update(cx, |this, cx| {
+ let SearchState::Searching {
+ query: current_query,
+ _task,
+ } = &this.search_state
+ else {
+ return;
+ };
+
+ if &query == current_query {
+ this.search_state = SearchState::Searched {
+ query: query.clone(),
+ matches,
+ };
+
+ this.set_selected_entry_index(0, cx);
+ cx.notify();
+ };
+ })
+ .log_err();
+ }
+ });
+
+ self.search_state = SearchState::Searching { query, _task: task };
+ cx.notify();
+ }
+
+ fn matched_count(&self) -> usize {
+ match &self.search_state {
+ SearchState::Empty => self.all_entries.len(),
+ SearchState::Searching { .. } => 0,
+ SearchState::Searched { matches, .. } => matches.len(),
+ }
+ }
+
+ fn list_item_count(&self) -> usize {
+ match &self.search_state {
+ SearchState::Empty => self.separated_items.len(),
+ SearchState::Searching { .. } => 0,
+ SearchState::Searched { matches, .. } => matches.len(),
+ }
+ }
+
+ fn search_produced_no_matches(&self) -> bool {
+ match &self.search_state {
+ SearchState::Empty => false,
+ SearchState::Searching { .. } => false,
+ SearchState::Searched { matches, .. } => matches.is_empty(),
+ }
+ }
+
+ fn get_match(&self, ix: usize) -> Option<&HistoryEntry> {
+ match &self.search_state {
+ SearchState::Empty => self.all_entries.get(ix),
+ SearchState::Searching { .. } => None,
+ SearchState::Searched { matches, .. } => matches
+ .get(ix)
+ .and_then(|m| self.all_entries.get(m.candidate_id)),
+ }
+ }
+
+ pub fn select_previous(
+ &mut self,
+ _: &menu::SelectPrevious,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let count = self.matched_count();
+ if count > 0 {
+ if self.selected_index == 0 {
+ self.set_selected_entry_index(count - 1, cx);
+ } else {
+ self.set_selected_entry_index(self.selected_index - 1, cx);
+ }
+ }
+ }
+
+ pub fn select_next(
+ &mut self,
+ _: &menu::SelectNext,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let count = self.matched_count();
+ if count > 0 {
+ if self.selected_index == count - 1 {
+ self.set_selected_entry_index(0, cx);
+ } else {
+ self.set_selected_entry_index(self.selected_index + 1, cx);
+ }
+ }
+ }
+
+ fn select_first(
+ &mut self,
+ _: &menu::SelectFirst,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let count = self.matched_count();
+ if count > 0 {
+ self.set_selected_entry_index(0, cx);
+ }
+ }
+
+ fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
+ let count = self.matched_count();
+ if count > 0 {
+ self.set_selected_entry_index(count - 1, cx);
+ }
+ }
+
+ fn set_selected_entry_index(&mut self, entry_index: usize, cx: &mut Context<Self>) {
+ self.selected_index = entry_index;
+
+ let scroll_ix = match self.search_state {
+ SearchState::Empty | SearchState::Searching { .. } => self
+ .separated_item_indexes
+ .get(entry_index)
+ .map(|ix| *ix as usize)
+ .unwrap_or(entry_index + 1),
+ SearchState::Searched { .. } => entry_index,
+ };
+
+ self.scroll_handle
+ .scroll_to_item(scroll_ix, ScrollStrategy::Top);
+
+ cx.notify();
+ }
+
+ fn render_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
+ if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) {
+ return None;
+ }
+
+ Some(
+ div()
+ .occlude()
+ .id("thread-history-scroll")
+ .h_full()
+ .bg(cx.theme().colors().panel_background.opacity(0.8))
+ .border_l_1()
+ .border_color(cx.theme().colors().border_variant)
+ .absolute()
+ .right_1()
+ .top_0()
+ .bottom_0()
+ .w_4()
+ .pl_1()
+ .cursor_default()
+ .on_mouse_move(cx.listener(|_, _, _window, cx| {
+ cx.notify();
+ cx.stop_propagation()
+ }))
+ .on_hover(|_, _window, cx| {
+ cx.stop_propagation();
+ })
+ .on_any_mouse_down(|_, _window, cx| {
+ cx.stop_propagation();
+ })
+ .on_scroll_wheel(cx.listener(|_, _, _window, cx| {
+ cx.notify();
+ }))
+ .children(Scrollbar::vertical(self.scrollbar_state.clone())),
+ )
+ }
+
+ fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
+ self.confirm_entry(self.selected_index, cx);
+ }
+
+ fn confirm_entry(&mut self, ix: usize, cx: &mut Context<Self>) {
+ let Some(entry) = self.get_match(ix) else {
+ return;
+ };
+ cx.emit(ThreadHistoryEvent::Open(entry.clone()));
+ }
+
+ fn remove_selected_thread(
+ &mut self,
+ _: &RemoveSelectedThread,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.remove_thread(self.selected_index, cx)
+ }
+
+ fn remove_thread(&mut self, ix: usize, cx: &mut Context<Self>) {
+ let Some(entry) = self.get_match(ix) else {
+ return;
+ };
+
+ let task = match entry {
+ HistoryEntry::AcpThread(thread) => self
+ .history_store
+ .update(cx, |this, cx| this.delete_thread(thread.id.clone(), cx)),
+ HistoryEntry::TextThread(context) => self.history_store.update(cx, |this, cx| {
+ this.delete_text_thread(context.path.clone(), cx)
+ }),
+ };
+ task.detach_and_log_err(cx);
+ }
+
+ fn list_items(
+ &mut self,
+ range: Range<usize>,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Vec<AnyElement> {
+ match &self.search_state {
+ SearchState::Empty => self
+ .separated_items
+ .get(range)
+ .iter()
+ .flat_map(|items| {
+ items
+ .iter()
+ .map(|item| self.render_list_item(item, vec![], cx))
+ })
+ .collect(),
+ SearchState::Searched { matches, .. } => matches[range]
+ .iter()
+ .filter_map(|m| {
+ let entry = self.all_entries.get(m.candidate_id)?;
+ Some(self.render_history_entry(
+ entry,
+ EntryTimeFormat::DateAndTime,
+ m.candidate_id,
+ m.positions.clone(),
+ cx,
+ ))
+ })
+ .collect(),
+ SearchState::Searching { .. } => {
+ vec![]
+ }
+ }
+ }
+
+ fn render_list_item(
+ &self,
+ item: &ListItemType,
+ highlight_positions: Vec<usize>,
+ cx: &Context<Self>,
+ ) -> AnyElement {
+ match item {
+ ListItemType::Entry { index, format } => match self.all_entries.get(*index) {
+ Some(entry) => self
+ .render_history_entry(entry, *format, *index, highlight_positions, cx)
+ .into_any(),
+ None => Empty.into_any_element(),
+ },
+ ListItemType::BucketSeparator(bucket) => div()
+ .px(DynamicSpacing::Base06.rems(cx))
+ .pt_2()
+ .pb_1()
+ .child(
+ Label::new(bucket.to_string())
+ .size(LabelSize::XSmall)
+ .color(Color::Muted),
+ )
+ .into_any_element(),
+ }
+ }
+
+ fn render_history_entry(
+ &self,
+ entry: &HistoryEntry,
+ format: EntryTimeFormat,
+ list_entry_ix: usize,
+ highlight_positions: Vec<usize>,
+ cx: &Context<Self>,
+ ) -> AnyElement {
+ let selected = list_entry_ix == self.selected_index;
+ let hovered = Some(list_entry_ix) == self.hovered_index;
+ let timestamp = entry.updated_at().timestamp();
+ let thread_timestamp = format.format_timestamp(timestamp, self.local_timezone);
+
+ h_flex()
+ .w_full()
+ .pb_1()
+ .child(
+ ListItem::new(list_entry_ix)
+ .rounded()
+ .toggle_state(selected)
+ .spacing(ListItemSpacing::Sparse)
+ .start_slot(
+ h_flex()
+ .w_full()
+ .gap_2()
+ .justify_between()
+ .child(
+ HighlightedLabel::new(entry.title(), highlight_positions)
+ .size(LabelSize::Small)
+ .truncate(),
+ )
+ .child(
+ Label::new(thread_timestamp)
+ .color(Color::Muted)
+ .size(LabelSize::XSmall),
+ ),
+ )
+ .on_hover(cx.listener(move |this, is_hovered, _window, cx| {
+ if *is_hovered {
+ this.hovered_index = Some(list_entry_ix);
+ } else if this.hovered_index == Some(list_entry_ix) {
+ this.hovered_index = None;
+ }
+
+ cx.notify();
+ }))
+ .end_slot::<IconButton>(if hovered || selected {
+ Some(
+ IconButton::new("delete", IconName::Trash)
+ .shape(IconButtonShape::Square)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Muted)
+ .tooltip(move |window, cx| {
+ Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
+ })
+ .on_click(cx.listener(move |this, _, _, cx| {
+ this.remove_thread(list_entry_ix, cx)
+ })),
+ )
+ } else {
+ None
+ })
+ .on_click(
+ cx.listener(move |this, _, _, cx| this.confirm_entry(list_entry_ix, cx)),
+ ),
+ )
+ .into_any_element()
+ }
+}
+
+impl Focusable for AcpThreadHistory {
+ fn focus_handle(&self, cx: &App) -> FocusHandle {
+ self.search_editor.focus_handle(cx)
+ }
+}
+
+impl Render for AcpThreadHistory {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ v_flex()
+ .key_context("ThreadHistory")
+ .size_full()
+ .on_action(cx.listener(Self::select_previous))
+ .on_action(cx.listener(Self::select_next))
+ .on_action(cx.listener(Self::select_first))
+ .on_action(cx.listener(Self::select_last))
+ .on_action(cx.listener(Self::confirm))
+ .on_action(cx.listener(Self::remove_selected_thread))
+ .when(!self.all_entries.is_empty(), |parent| {
+ parent.child(
+ h_flex()
+ .h(px(41.)) // Match the toolbar perfectly
+ .w_full()
+ .py_1()
+ .px_2()
+ .gap_2()
+ .justify_between()
+ .border_b_1()
+ .border_color(cx.theme().colors().border)
+ .child(
+ Icon::new(IconName::MagnifyingGlass)
+ .color(Color::Muted)
+ .size(IconSize::Small),
+ )
+ .child(self.search_editor.clone()),
+ )
+ })
+ .child({
+ let view = v_flex()
+ .id("list-container")
+ .relative()
+ .overflow_hidden()
+ .flex_grow();
+
+ if self.all_entries.is_empty() {
+ view.justify_center()
+ .child(
+ h_flex().w_full().justify_center().child(
+ Label::new("You don't have any past threads yet.")
+ .size(LabelSize::Small),
+ ),
+ )
+ } else if self.search_produced_no_matches() {
+ view.justify_center().child(
+ h_flex().w_full().justify_center().child(
+ Label::new("No threads match your search.").size(LabelSize::Small),
+ ),
+ )
+ } else {
+ view.pr_5()
+ .child(
+ uniform_list(
+ "thread-history",
+ self.list_item_count(),
+ cx.processor(|this, range: Range<usize>, window, cx| {
+ this.list_items(range, window, cx)
+ }),
+ )
+ .p_1()
+ .track_scroll(self.scroll_handle.clone())
+ .flex_grow(),
+ )
+ .when_some(self.render_scrollbar(cx), |div, scrollbar| {
+ div.child(scrollbar)
+ })
+ }
+ })
+ }
+}
+
+#[derive(Clone, Copy)]
+pub enum EntryTimeFormat {
+ DateAndTime,
+ TimeOnly,
+}
+
+impl EntryTimeFormat {
+ fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String {
+ let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap();
+
+ match self {
+ EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp(
+ timestamp,
+ OffsetDateTime::now_utc(),
+ timezone,
+ time_format::TimestampFormat::EnhancedAbsolute,
+ ),
+ EntryTimeFormat::TimeOnly => time_format::format_time(timestamp),
+ }
+ }
+}
+
+impl From<TimeBucket> for EntryTimeFormat {
+ fn from(bucket: TimeBucket) -> Self {
+ match bucket {
+ TimeBucket::Today => EntryTimeFormat::TimeOnly,
+ TimeBucket::Yesterday => EntryTimeFormat::TimeOnly,
+ TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime,
+ TimeBucket::PastWeek => EntryTimeFormat::DateAndTime,
+ TimeBucket::All => EntryTimeFormat::DateAndTime,
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Clone, Copy, Debug)]
+enum TimeBucket {
+ Today,
+ Yesterday,
+ ThisWeek,
+ PastWeek,
+ All,
+}
+
+impl TimeBucket {
+ fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self {
+ if date == reference {
+ return TimeBucket::Today;
+ }
+
+ if date == reference - TimeDelta::days(1) {
+ return TimeBucket::Yesterday;
+ }
+
+ let week = date.iso_week();
+
+ if reference.iso_week() == week {
+ return TimeBucket::ThisWeek;
+ }
+
+ let last_week = (reference - TimeDelta::days(7)).iso_week();
+
+ if week == last_week {
+ return TimeBucket::PastWeek;
+ }
+
+ TimeBucket::All
+ }
+}
+
+impl Display for TimeBucket {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TimeBucket::Today => write!(f, "Today"),
+ TimeBucket::Yesterday => write!(f, "Yesterday"),
+ TimeBucket::ThisWeek => write!(f, "This Week"),
+ TimeBucket::PastWeek => write!(f, "Past Week"),
+ TimeBucket::All => write!(f, "All"),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use chrono::NaiveDate;
+
+ #[test]
+ fn test_time_bucket_from_dates() {
+ let today = NaiveDate::from_ymd_opt(2023, 1, 15).unwrap();
+
+ let date = today;
+ assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Today);
+
+ let date = NaiveDate::from_ymd_opt(2023, 1, 14).unwrap();
+ assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::Yesterday);
+
+ let date = NaiveDate::from_ymd_opt(2023, 1, 13).unwrap();
+ assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek);
+
+ let date = NaiveDate::from_ymd_opt(2023, 1, 11).unwrap();
+ assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::ThisWeek);
+
+ let date = NaiveDate::from_ymd_opt(2023, 1, 8).unwrap();
+ assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek);
+
+ let date = NaiveDate::from_ymd_opt(2023, 1, 5).unwrap();
+ assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::PastWeek);
+
+ // All: not in this week or last week
+ let date = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap();
+ assert_eq!(TimeBucket::from_dates(today, date), TimeBucket::All);
+
+ // Test year boundary cases
+ let new_year = NaiveDate::from_ymd_opt(2023, 1, 1).unwrap();
+
+ let date = NaiveDate::from_ymd_opt(2022, 12, 31).unwrap();
+ assert_eq!(
+ TimeBucket::from_dates(new_year, date),
+ TimeBucket::Yesterday
+ );
+
+ let date = NaiveDate::from_ymd_opt(2022, 12, 28).unwrap();
+ assert_eq!(TimeBucket::from_dates(new_year, date), TimeBucket::ThisWeek);
+ }
+}
@@ -9,6 +9,7 @@ use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol::{self as acp};
use agent_servers::{AgentServer, ClaudeCode};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, NotifyWhenAgentWaiting};
+use agent2::DbThreadMetadata;
use anyhow::bail;
use audio::{Audio, Sound};
use buffer_diff::BufferDiff;
@@ -155,6 +156,7 @@ enum ThreadState {
impl AcpThreadView {
pub fn new(
agent: Rc<dyn AgentServer>,
+ resume_thread: Option<DbThreadMetadata>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Entity<ThreadStore>,
@@ -203,7 +205,7 @@ impl AcpThreadView {
workspace: workspace.clone(),
project: project.clone(),
entry_view_state,
- thread_state: Self::initial_state(agent, workspace, project, window, cx),
+ thread_state: Self::initial_state(agent, resume_thread, workspace, project, window, cx),
message_editor,
model_selector: None,
profile_selector: None,
@@ -228,6 +230,7 @@ impl AcpThreadView {
fn initial_state(
agent: Rc<dyn AgentServer>,
+ resume_thread: Option<DbThreadMetadata>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
window: &mut Window,
@@ -254,28 +257,27 @@ impl AcpThreadView {
}
};
- // this.update_in(cx, |_this, _window, cx| {
- // let status = connection.exit_status(cx);
- // cx.spawn(async move |this, cx| {
- // let status = status.await.ok();
- // this.update(cx, |this, cx| {
- // this.thread_state = ThreadState::ServerExited { status };
- // cx.notify();
- // })
- // .ok();
- // })
- // .detach();
- // })
- // .ok();
-
- let Some(result) = cx
- .update(|_, cx| {
+ let result = if let Some(native_agent) = connection
+ .clone()
+ .downcast::<agent2::NativeAgentConnection>()
+ && let Some(resume) = resume_thread
+ {
+ cx.update(|_, cx| {
+ native_agent
+ .0
+ .update(cx, |agent, cx| agent.open_thread(resume.id, cx))
+ })
+ .log_err()
+ } else {
+ cx.update(|_, cx| {
connection
.clone()
.new_thread(project.clone(), &root_dir, cx)
})
.log_err()
- else {
+ };
+
+ let Some(result) = result else {
return;
};
@@ -382,6 +384,7 @@ impl AcpThreadView {
this.update(cx, |this, cx| {
this.thread_state = Self::initial_state(
agent.clone(),
+ None,
this.workspace.clone(),
this.project.clone(),
window,
@@ -842,6 +845,7 @@ impl AcpThreadView {
} else {
this.thread_state = Self::initial_state(
agent,
+ None,
this.workspace.clone(),
project.clone(),
window,
@@ -4044,6 +4048,7 @@ pub(crate) mod tests {
cx.new(|cx| {
AcpThreadView::new(
Rc::new(agent),
+ None,
workspace.downgrade(),
project,
thread_store.clone(),
@@ -4248,6 +4253,7 @@ pub(crate) mod tests {
cx.new(|cx| {
AcpThreadView::new(
Rc::new(StubAgentServer::new(connection.as_ref().clone())),
+ None,
workspace.downgrade(),
project.clone(),
thread_store.clone(),
@@ -4,10 +4,11 @@ use std::rc::Rc;
use std::sync::Arc;
use std::time::Duration;
+use agent2::{DbThreadMetadata, HistoryEntry};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use serde::{Deserialize, Serialize};
-use crate::NewExternalAgentThread;
+use crate::acp::{AcpThreadHistory, ThreadHistoryEvent};
use crate::agent_diff::AgentDiffThread;
use crate::{
AddContextServer, AgentDiffPane, ContinueThread, ContinueWithBurnMode,
@@ -28,6 +29,7 @@ use crate::{
thread_history::{HistoryEntryElement, ThreadHistory},
ui::{AgentOnboardingModal, EndTrialUpsell},
};
+use crate::{ExternalAgent, NewExternalAgentThread};
use agent::{
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio,
context_store::ContextStore,
@@ -117,7 +119,7 @@ pub fn init(cx: &mut App) {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
panel.update(cx, |panel, cx| {
- panel.new_external_thread(action.agent, window, cx)
+ panel.external_thread(action.agent, None, window, cx)
});
}
})
@@ -360,6 +362,7 @@ impl ActiveView {
pub fn prompt_editor(
context_editor: Entity<TextThreadEditor>,
history_store: Entity<HistoryStore>,
+ acp_history_store: Entity<agent2::HistoryStore>,
language_registry: Arc<LanguageRegistry>,
window: &mut Window,
cx: &mut App,
@@ -437,6 +440,18 @@ impl ActiveView {
);
}
});
+
+ acp_history_store.update(cx, |history_store, cx| {
+ if let Some(old_path) = old_path {
+ history_store
+ .replace_recently_opened_text_thread(old_path, new_path, cx);
+ } else {
+ history_store.push_recently_opened_entry(
+ agent2::HistoryEntryId::TextThread(new_path.clone()),
+ cx,
+ );
+ }
+ });
}
_ => {}
}
@@ -465,6 +480,8 @@ pub struct AgentPanel {
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
+ acp_history: Entity<AcpThreadHistory>,
+ acp_history_store: Entity<agent2::HistoryStore>,
_default_model_subscription: Subscription,
context_store: Entity<TextThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
@@ -631,6 +648,29 @@ impl AgentPanel {
)
});
+ let acp_history_store =
+ cx.new(|cx| agent2::HistoryStore::new(context_store.clone(), [], cx));
+ let acp_history = cx.new(|cx| AcpThreadHistory::new(acp_history_store.clone(), window, cx));
+ cx.subscribe_in(
+ &acp_history,
+ window,
+ |this, _, event, window, cx| match event {
+ ThreadHistoryEvent::Open(HistoryEntry::AcpThread(thread)) => {
+ this.external_thread(
+ Some(crate::ExternalAgent::NativeAgent),
+ Some(thread.clone()),
+ window,
+ cx,
+ );
+ }
+ ThreadHistoryEvent::Open(HistoryEntry::TextThread(thread)) => {
+ this.open_saved_prompt_editor(thread.path.clone(), window, cx)
+ .detach_and_log_err(cx);
+ }
+ },
+ )
+ .detach();
+
cx.observe(&history_store, |_, _, cx| cx.notify()).detach();
let active_thread = cx.new(|cx| {
@@ -669,6 +709,7 @@ impl AgentPanel {
ActiveView::prompt_editor(
context_editor,
history_store.clone(),
+ acp_history_store.clone(),
language_registry.clone(),
window,
cx,
@@ -685,7 +726,11 @@ impl AgentPanel {
let assistant_navigation_menu =
ContextMenu::build_persistent(window, cx, move |mut menu, _window, cx| {
if let Some(panel) = panel.upgrade() {
- menu = Self::populate_recently_opened_menu_section(menu, panel, cx);
+ if cx.has_flag::<AcpFeatureFlag>() {
+ menu = Self::populate_recently_opened_menu_section_new(menu, panel, cx);
+ } else {
+ menu = Self::populate_recently_opened_menu_section_old(menu, panel, cx);
+ }
}
menu.action("View All", Box::new(OpenHistory))
.end_slot_action(DeleteRecentlyOpenThread.boxed_clone())
@@ -773,6 +818,8 @@ impl AgentPanel {
zoomed: false,
pending_serialization: None,
onboarding,
+ acp_history,
+ acp_history_store,
selected_agent: AgentType::default(),
}
}
@@ -939,6 +986,7 @@ impl AgentPanel {
ActiveView::prompt_editor(
context_editor.clone(),
self.history_store.clone(),
+ self.acp_history_store.clone(),
self.language_registry.clone(),
window,
cx,
@@ -949,9 +997,10 @@ impl AgentPanel {
context_editor.focus_handle(cx).focus(window);
}
- fn new_external_thread(
+ fn external_thread(
&mut self,
agent_choice: Option<crate::ExternalAgent>,
+ resume_thread: Option<DbThreadMetadata>,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -968,6 +1017,7 @@ impl AgentPanel {
let thread_store = self.thread_store.clone();
let text_thread_store = self.context_store.clone();
+ let history = self.acp_history_store.clone();
cx.spawn_in(window, async move |this, cx| {
let ext_agent = match agent_choice {
@@ -1001,7 +1051,7 @@ impl AgentPanel {
}
};
- let server = ext_agent.server(fs);
+ let server = ext_agent.server(fs, history);
this.update_in(cx, |this, window, cx| {
match ext_agent {
@@ -1020,6 +1070,7 @@ impl AgentPanel {
let thread_view = cx.new(|cx| {
crate::acp::AcpThreadView::new(
server,
+ resume_thread,
workspace.clone(),
project,
thread_store.clone(),
@@ -1114,6 +1165,7 @@ impl AgentPanel {
ActiveView::prompt_editor(
editor.clone(),
self.history_store.clone(),
+ self.acp_history_store.clone(),
self.language_registry.clone(),
window,
cx,
@@ -1580,7 +1632,7 @@ impl AgentPanel {
self.focus_handle(cx).focus(window);
}
- fn populate_recently_opened_menu_section(
+ fn populate_recently_opened_menu_section_old(
mut menu: ContextMenu,
panel: Entity<Self>,
cx: &mut Context<ContextMenu>,
@@ -1644,6 +1696,72 @@ impl AgentPanel {
menu
}
+ fn populate_recently_opened_menu_section_new(
+ mut menu: ContextMenu,
+ panel: Entity<Self>,
+ cx: &mut Context<ContextMenu>,
+ ) -> ContextMenu {
+ let entries = panel
+ .read(cx)
+ .acp_history_store
+ .read(cx)
+ .recently_opened_entries(cx);
+
+ if entries.is_empty() {
+ return menu;
+ }
+
+ menu = menu.header("Recently Opened");
+
+ for entry in entries {
+ let title = entry.title().clone();
+
+ menu = menu.entry_with_end_slot_on_hover(
+ title,
+ None,
+ {
+ let panel = panel.downgrade();
+ let entry = entry.clone();
+ move |window, cx| {
+ let entry = entry.clone();
+ panel
+ .update(cx, move |this, cx| match &entry {
+ agent2::HistoryEntry::AcpThread(entry) => this.external_thread(
+ Some(ExternalAgent::NativeAgent),
+ Some(entry.clone()),
+ window,
+ cx,
+ ),
+ agent2::HistoryEntry::TextThread(entry) => this
+ .open_saved_prompt_editor(entry.path.clone(), window, cx)
+ .detach_and_log_err(cx),
+ })
+ .ok();
+ }
+ },
+ IconName::Close,
+ "Close Entry".into(),
+ {
+ let panel = panel.downgrade();
+ let id = entry.id();
+ move |_window, cx| {
+ panel
+ .update(cx, |this, cx| {
+ this.acp_history_store.update(cx, |history_store, cx| {
+ history_store.remove_recently_opened_entry(&id, cx);
+ });
+ })
+ .ok();
+ }
+ },
+ );
+ }
+
+ menu = menu.separator();
+
+ menu
+ }
+
pub fn set_selected_agent(
&mut self,
agent: AgentType,
@@ -1653,8 +1771,8 @@ impl AgentPanel {
if self.selected_agent != agent {
self.selected_agent = agent;
self.serialize(cx);
- self.new_agent_thread(agent, window, cx);
}
+ self.new_agent_thread(agent, window, cx);
}
pub fn selected_agent(&self) -> AgentType {
@@ -1681,13 +1799,13 @@ impl AgentPanel {
window.dispatch_action(NewTextThread.boxed_clone(), cx);
}
AgentType::NativeAgent => {
- self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), window, cx)
+ self.external_thread(Some(crate::ExternalAgent::NativeAgent), None, window, cx)
}
AgentType::Gemini => {
- self.new_external_thread(Some(crate::ExternalAgent::Gemini), window, cx)
+ self.external_thread(Some(crate::ExternalAgent::Gemini), None, window, cx)
}
AgentType::ClaudeCode => {
- self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), window, cx)
+ self.external_thread(Some(crate::ExternalAgent::ClaudeCode), None, window, cx)
}
}
}
@@ -1698,7 +1816,13 @@ impl Focusable for AgentPanel {
match &self.active_view {
ActiveView::Thread { message_editor, .. } => message_editor.focus_handle(cx),
ActiveView::ExternalAgentThread { thread_view, .. } => thread_view.focus_handle(cx),
- ActiveView::History => self.history.focus_handle(cx),
+ ActiveView::History => {
+ if cx.has_flag::<feature_flags::AcpFeatureFlag>() {
+ self.acp_history.focus_handle(cx)
+ } else {
+ self.history.focus_handle(cx)
+ }
+ }
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
ActiveView::Configuration => {
if let Some(configuration) = self.configuration.as_ref() {
@@ -3534,7 +3658,13 @@ impl Render for AgentPanel {
ActiveView::ExternalAgentThread { thread_view, .. } => parent
.child(thread_view.clone())
.child(self.render_drag_target(cx)),
- ActiveView::History => parent.child(self.history.clone()),
+ ActiveView::History => {
+ if cx.has_flag::<feature_flags::AcpFeatureFlag>() {
+ parent.child(self.acp_history.clone())
+ } else {
+ parent.child(self.history.clone())
+ }
+ }
ActiveView::TextThread {
context_editor,
buffer_search_bar,
@@ -156,11 +156,15 @@ enum ExternalAgent {
}
impl ExternalAgent {
- pub fn server(&self, fs: Arc<dyn fs::Fs>) -> Rc<dyn agent_servers::AgentServer> {
+ pub fn server(
+ &self,
+ fs: Arc<dyn fs::Fs>,
+ history: Entity<agent2::HistoryStore>,
+ ) -> Rc<dyn agent_servers::AgentServer> {
match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
- ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)),
+ ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs, history)),
}
}
}