Detailed changes
@@ -2,7 +2,6 @@ use crate::{AcpThread, AcpThreadMetadata};
use agent_client_protocol::{self as acp};
use anyhow::Result;
use collections::IndexMap;
-use futures::channel::mpsc::UnboundedReceiver;
use gpui::{Entity, SharedString, Task};
use project::Project;
use serde::{Deserialize, Serialize};
@@ -27,6 +26,8 @@ pub trait AgentConnection {
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>>;
+ // todo!(expose a history trait, and include list_threads and load_thread)
+ // todo!(write a test)
fn list_threads(
&self,
_cx: &mut App,
@@ -5,16 +5,15 @@ use crate::{
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
UserMessageContent, WebSearchTool, templates::Templates,
};
-use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id};
+use crate::{ThreadsDatabase, generate_session_id};
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
-use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
-use futures::future::Shared;
-use futures::{SinkExt, StreamExt, future};
+use futures::channel::mpsc;
+use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
@@ -30,6 +29,7 @@ use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
+use std::time::Duration;
use util::ResultExt;
const RULES_FILE_NAMES: [&'static str; 9] = [
@@ -174,7 +174,7 @@ pub struct NativeAgent {
prompt_store: Option<Entity<PromptStore>>,
thread_database: Arc<ThreadsDatabase>,
history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
- load_history: Task<Result<()>>,
+ load_history: Task<()>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
@@ -212,7 +212,7 @@ impl NativeAgent {
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
watch::channel(());
- let this = Self {
+ let mut this = Self {
sessions: HashMap::new(),
project_context: Rc::new(RefCell::new(project_context)),
project_context_needs_refresh: project_context_needs_refresh_tx,
@@ -229,7 +229,7 @@ impl NativeAgent {
prompt_store,
fs,
history: watch::channel(None).0,
- load_history: Task::ready(Ok(())),
+ load_history: Task::ready(()),
_subscriptions: subscriptions,
};
this.reload_history(cx);
@@ -249,7 +249,7 @@ impl NativeAgent {
Session {
thread: thread.clone(),
acp_thread: acp_thread.downgrade(),
- save_task: Task::ready(()),
+ save_task: Task::ready(Ok(())),
_subscriptions: vec![
cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
@@ -280,24 +280,30 @@ impl NativeAgent {
}
fn reload_history(&mut self, cx: &mut Context<Self>) {
+ dbg!("");
let thread_database = self.thread_database.clone();
self.load_history = cx.spawn(async move |this, cx| {
let results = cx
.background_spawn(async move {
let results = thread_database.list_threads().await?;
- Ok(results
- .into_iter()
- .map(|thread| AcpThreadMetadata {
- agent: NATIVE_AGENT_SERVER_NAME.clone(),
- id: thread.id.into(),
- title: thread.title,
- updated_at: thread.updated_at,
- })
- .collect())
+ dbg!(&results);
+ anyhow::Ok(
+ results
+ .into_iter()
+ .map(|thread| AcpThreadMetadata {
+ agent: NATIVE_AGENT_SERVER_NAME.clone(),
+ id: thread.id.into(),
+ title: thread.title,
+ updated_at: thread.updated_at,
+ })
+ .collect(),
+ )
})
- .await?;
- this.update(cx, |this, cx| this.history.send(Some(results)))?;
- anyhow::Ok(())
+ .await;
+ if let Some(results) = results.log_err() {
+ this.update(cx, |this, _| this.history.send(Some(results)))
+ .ok();
+ }
});
}
@@ -509,10 +515,10 @@ impl NativeAgent {
) {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
- session.thread.update(cx, |thread, _| {
+ session.thread.update(cx, |thread, cx| {
let model_id = LanguageModels::model_id(&thread.model());
if let Some(model) = self.models.model_from_id(&model_id) {
- thread.set_model(model.clone());
+ thread.set_model(model.clone(), cx);
}
});
}
@@ -715,8 +721,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
- thread.update(cx, |thread, _cx| {
- thread.set_model(model.clone());
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
});
update_settings_file::<AgentSettings>(
@@ -867,12 +873,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
session_id: acp::SessionId,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
- let thread_id = ThreadId::from(session_id.clone());
let database = self.0.update(cx, |this, _| this.thread_database.clone());
cx.spawn(async move |cx| {
- let database = database.await.map_err(|e| anyhow!(e))?;
let db_thread = database
- .load_thread(thread_id.clone())
+ .load_thread(session_id.clone())
.await?
.context("no such thread found")?;
@@ -915,7 +919,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let thread = cx.new(|cx| {
let mut thread = Thread::from_db(
- thread_id,
+ session_id,
db_thread,
project.clone(),
agent.project_context.clone(),
@@ -934,7 +938,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Store the session
agent.update(cx, |agent, cx| {
- agent.insert_session(session_id, thread, acp_thread, cx)
+ agent.insert_session(thread.clone(), acp_thread.clone(), cx)
})?;
let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
@@ -995,7 +999,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) {
- agent.thread.update(cx, |thread, _cx| thread.cancel());
+ agent.thread.update(cx, |thread, cx| thread.cancel(cx));
}
});
}
@@ -1022,7 +1026,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
- Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
+ Task::ready(
+ self.0
+ .update(cx, |thread, cx| thread.truncate(message_id, cx)),
+ )
}
}
@@ -1,4 +1,4 @@
-use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent};
+use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
use agent::thread_store;
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, CompletionMode};
@@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbThreadMetadata {
- pub id: ThreadId,
+ pub id: acp::SessionId,
#[serde(alias = "summary")]
pub title: SharedString,
pub updated_at: DateTime<Utc>,
@@ -323,7 +323,7 @@ impl ThreadsDatabase {
for (id, summary, updated_at) in rows {
threads.push(DbThreadMetadata {
- id: ThreadId(id),
+ id: acp::SessionId(id),
title: summary.into(),
updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
});
@@ -333,7 +333,7 @@ impl ThreadsDatabase {
})
}
- pub fn load_thread(&self, id: ThreadId) -> Task<Result<Option<DbThread>>> {
+ pub fn load_thread(&self, id: acp::SessionId) -> Task<Result<Option<DbThread>>> {
let connection = self.connection.clone();
self.executor.spawn(async move {
@@ -1,17 +1,13 @@
use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
-use agent::{ThreadId, thread_store::ThreadStore};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use assistant_context::SavedContextMetadata;
use chrono::{DateTime, Utc};
use collections::HashMap;
-use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*};
-use itertools::Itertools;
-use paths::contexts_dir;
+use gpui::{SharedString, Task, prelude::*};
use serde::{Deserialize, Serialize};
use smol::stream::StreamExt;
-use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration};
-use util::ResultExt as _;
+use std::{path::Path, sync::Arc, time::Duration};
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
@@ -64,16 +60,16 @@ enum SerializedRecentOpen {
}
pub struct AgentHistory {
- entries: HashMap<acp::SessionId, AcpThreadMetadata>,
- _task: Task<Result<()>>,
+ entries: watch::Receiver<Option<Vec<AcpThreadMetadata>>>,
+ _task: Task<()>,
}
pub struct HistoryStore {
- agents: HashMap<AgentServerName, AgentHistory>,
+ agents: HashMap<AgentServerName, AgentHistory>, // todo!() text threads
}
impl HistoryStore {
- pub fn new(cx: &mut Context<Self>) -> Self {
+ pub fn new(_cx: &mut Context<Self>) -> Self {
Self {
agents: HashMap::default(),
}
@@ -88,33 +84,18 @@ impl HistoryStore {
let Some(mut history) = connection.list_threads(cx) else {
return;
};
- let task = cx.spawn(async move |this, cx| {
- while let Some(updated_history) = history.next().await {
- dbg!(&updated_history);
- this.update(cx, |this, cx| {
- for entry in updated_history {
- let agent = this
- .agents
- .get_mut(&entry.agent)
- .context("agent not found")?;
- agent.entries.insert(entry.id.clone(), entry);
- }
- cx.notify();
- anyhow::Ok(())
- })??
- }
- Ok(())
- });
- self.agents.insert(
- agent_name,
- AgentHistory {
- entries: Default::default(),
- _task: task,
- },
- );
+ let history = AgentHistory {
+ entries: history.clone(),
+ _task: cx.spawn(async move |this, cx| {
+ while history.changed().await.is_ok() {
+ this.update(cx, |_, cx| cx.notify()).ok();
+ }
+ }),
+ };
+ self.agents.insert(agent_name.clone(), history);
}
- pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
+ pub fn entries(&mut self, _cx: &mut Context<Self>) -> Vec<HistoryEntry> {
let mut history_entries = Vec::new();
#[cfg(debug_assertions)]
@@ -124,9 +105,8 @@ impl HistoryStore {
history_entries.extend(
self.agents
- .values()
- .flat_map(|agent| agent.entries.values())
- .cloned()
+ .values_mut()
+ .flat_map(|history| history.entries.borrow().clone().unwrap_or_default()) // todo!("surface the loading state?")
.map(HistoryEntry::Thread),
);
// todo!() include the text threads in here.
@@ -135,7 +115,7 @@ impl HistoryStore {
history_entries
}
- pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
+ pub fn recent_entries(&mut self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
self.entries(cx).into_iter().take(limit).collect()
}
}
@@ -938,7 +938,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
- thread.update(cx, |thread, _cx| thread.cancel());
+ thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::<Vec<_>>().await;
let last_event = events.last();
assert!(
@@ -1113,7 +1113,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
thread
- .update(cx, |thread, _cx| thread.truncate(message_id))
+ .update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
@@ -802,16 +802,18 @@ impl Thread {
&self.model
}
- pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
+ pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
self.model = model;
+ cx.notify()
}
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
- pub fn set_completion_mode(&mut self, mode: CompletionMode) {
+ pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
self.completion_mode = mode;
+ cx.notify()
}
#[cfg(any(test, feature = "test-support"))]
@@ -839,21 +841,22 @@ impl Thread {
self.profile_id = profile_id;
}
- pub fn cancel(&mut self) {
+ pub fn cancel(&mut self, cx: &mut Context<Self>) {
if let Some(running_turn) = self.running_turn.take() {
running_turn.cancel();
}
- self.flush_pending_message();
+ self.flush_pending_message(cx);
}
- pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
- self.cancel();
+ pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
+ self.cancel(cx);
let Some(position) = self.messages.iter().position(
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
) else {
return Err(anyhow!("Message not found"));
};
self.messages.truncate(position);
+ cx.notify();
Ok(())
}
@@ -900,7 +903,7 @@ impl Thread {
}
fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
- self.cancel();
+ self.cancel(cx);
let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
@@ -938,8 +941,8 @@ impl Thread {
LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
- this.update(cx, |this, _cx| {
- this.flush_pending_message();
+ this.update(cx, |this, cx| {
+ this.flush_pending_message(cx);
this.messages.truncate(message_ix);
})?;
return Ok(());
@@ -991,7 +994,7 @@ impl Thread {
log::info!("No tool uses found, completing turn");
return Ok(());
} else {
- this.update(cx, |this, _| this.flush_pending_message())?;
+ this.update(cx, |this, cx| this.flush_pending_message(cx))?;
completion_intent = CompletionIntent::ToolResults;
}
}
@@ -1005,8 +1008,8 @@ impl Thread {
log::info!("Turn execution completed successfully");
}
- this.update(cx, |this, _| {
- this.flush_pending_message();
+ this.update(cx, |this, cx| {
+ this.flush_pending_message(cx);
this.running_turn.take();
})
.ok();
@@ -1046,7 +1049,7 @@ impl Thread {
match event {
StartMessage { .. } => {
- self.flush_pending_message();
+ self.flush_pending_message(cx);
self.pending_message = Some(AgentMessage::default());
}
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
@@ -1255,7 +1258,7 @@ impl Thread {
self.pending_message.get_or_insert_default()
}
- fn flush_pending_message(&mut self) {
+ fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
let Some(mut message) = self.pending_message.take() else {
return;
};
@@ -1280,6 +1283,7 @@ impl Thread {
}
self.messages.push(Message::Agent(message));
+ cx.notify()
}
pub(crate) fn build_completion_request(
@@ -2487,12 +2487,15 @@ impl AcpThreadView {
return;
};
- thread.update(cx, |thread, _cx| {
+ thread.update(cx, |thread, cx| {
let current_mode = thread.completion_mode();
- thread.set_completion_mode(match current_mode {
- CompletionMode::Burn => CompletionMode::Normal,
- CompletionMode::Normal => CompletionMode::Burn,
- });
+ thread.set_completion_mode(
+ match current_mode {
+ CompletionMode::Burn => CompletionMode::Normal,
+ CompletionMode::Normal => CompletionMode::Burn,
+ },
+ cx,
+ );
});
}
@@ -3274,8 +3277,8 @@ impl AcpThreadView {
.tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
.on_click({
cx.listener(move |this, _, _window, cx| {
- thread.update(cx, |thread, _cx| {
- thread.set_completion_mode(CompletionMode::Burn);
+ thread.update(cx, |thread, cx| {
+ thread.set_completion_mode(CompletionMode::Burn, cx);
});
this.resume_chat(cx);
})