Detailed changes
@@ -191,6 +191,7 @@ version = "0.1.0"
dependencies = [
"acp_thread",
"action_log",
+ "agent",
"agent-client-protocol",
"agent_servers",
"agent_settings",
@@ -208,6 +209,7 @@ dependencies = [
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
+ "git",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
@@ -256,6 +258,7 @@ name = "agent_servers"
version = "0.1.0"
dependencies = [
"acp_thread",
+ "action_log",
"agent-client-protocol",
"agent_settings",
"agentic-coding-protocol",
@@ -537,9 +537,15 @@ impl ToolCallContent {
acp::ToolCallContent::Content { content } => {
Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
}
- acp::ToolCallContent::Diff { diff } => {
- Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
- }
+ acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
+ Diff::finalized(
+ diff.path,
+ diff.old_text,
+ diff.new_text,
+ language_registry,
+ cx,
+ )
+ })),
}
}
@@ -682,6 +688,7 @@ pub struct AcpThread {
#[derive(Debug)]
pub enum AcpThreadEvent {
NewEntry,
+ TitleUpdated,
EntryUpdated(usize),
EntriesRemoved(Range<usize>),
ToolAuthorizationRequired,
@@ -728,11 +735,9 @@ impl AcpThread {
title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>,
project: Entity<Project>,
+ action_log: Entity<ActionLog>,
session_id: acp::SessionId,
- cx: &mut Context<Self>,
) -> Self {
- let action_log = cx.new(|_| ActionLog::new(project.clone()));
-
Self {
action_log,
shared_buffers: Default::default(),
@@ -926,6 +931,12 @@ impl AcpThread {
cx.emit(AcpThreadEvent::NewEntry);
}
+ pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
+ self.title = title;
+ cx.emit(AcpThreadEvent::TitleUpdated);
+ Ok(())
+ }
+
pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
cx.emit(AcpThreadEvent::Retry(status));
}
@@ -1657,7 +1668,7 @@ mod tests {
use super::*;
use anyhow::anyhow;
use futures::{channel::mpsc, future::LocalBoxFuture, select};
- use gpui::{AsyncApp, TestAppContext, WeakEntity};
+ use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
use project::{FakeFs, Fs};
use rand::Rng as _;
@@ -2327,7 +2338,7 @@ mod tests {
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
- cx: &mut gpui::App,
+ cx: &mut App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(
rand::thread_rng()
@@ -2337,8 +2348,16 @@ mod tests {
.collect::<String>()
.into(),
);
- let thread =
- cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let thread = cx.new(|_cx| {
+ AcpThread::new(
+ "Test",
+ self.clone(),
+ project,
+ action_log,
+ session_id.clone(),
+ )
+ });
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
@@ -5,11 +5,12 @@ use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
use language_model::LanguageModelProviderId;
use project::Project;
+use serde::{Deserialize, Serialize};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserMessageId(Arc<str>);
impl UserMessageId {
@@ -208,6 +209,7 @@ impl AgentModelList {
mod test_support {
use std::sync::Arc;
+ use action_log::ActionLog;
use collections::HashMap;
use futures::{channel::oneshot, future::try_join_all};
use gpui::{AppContext as _, WeakEntity};
@@ -295,8 +297,16 @@ mod test_support {
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
- let thread =
- cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let thread = cx.new(|_cx| {
+ AcpThread::new(
+ "Test",
+ self.clone(),
+ project,
+ action_log,
+ session_id.clone(),
+ )
+ });
self.sessions.lock().insert(
session_id,
Session {
@@ -1,4 +1,3 @@
-use agent_client_protocol as acp;
use anyhow::Result;
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{MultiBuffer, PathKey};
@@ -21,17 +20,13 @@ pub enum Diff {
}
impl Diff {
- pub fn from_acp(
- diff: acp::Diff,
+ pub fn finalized(
+ path: PathBuf,
+ old_text: Option<String>,
+ new_text: String,
language_registry: Arc<LanguageRegistry>,
cx: &mut Context<Self>,
) -> Self {
- let acp::Diff {
- path,
- old_text,
- new_text,
- } = diff;
-
let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
@@ -2,6 +2,7 @@ use agent::ThreadId;
use anyhow::{Context as _, Result, bail};
use file_icons::FileIcons;
use prompt_store::{PromptId, UserPromptId};
+use serde::{Deserialize, Serialize};
use std::{
fmt,
ops::Range,
@@ -11,7 +12,7 @@ use std::{
use ui::{App, IconName, SharedString};
use url::Url;
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum MentionUri {
File {
abs_path: PathBuf,
@@ -14,6 +14,7 @@ workspace = true
[dependencies]
acp_thread.workspace = true
action_log.workspace = true
+agent.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
agent_settings.workspace = true
@@ -26,6 +27,7 @@ collections.workspace = true
context_server.workspace = true
fs.workspace = true
futures.workspace = true
+git.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
@@ -59,6 +61,7 @@ which.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
+agent = { workspace = true, "features" = ["test-support"] }
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
@@ -66,6 +69,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
+git = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }
@@ -1,10 +1,11 @@
use crate::{
- AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
- DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
- MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
- ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
+ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
+ EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
+ OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
+ UserMessageContent, WebSearchTool, templates::Templates,
};
use acp_thread::AgentModelSelector;
+use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
@@ -427,18 +428,19 @@ impl NativeAgent {
) {
self.models.refresh_list(cx);
- let default_model = LanguageModelRegistry::read_global(cx)
- .default_model()
- .map(|m| m.model.clone());
+ let registry = LanguageModelRegistry::read_global(cx);
+ let default_model = registry.default_model().map(|m| m.model.clone());
+ let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, cx| {
if thread.model().is_none()
&& let Some(model) = default_model.clone()
{
- thread.set_model(model);
+ thread.set_model(model, cx);
cx.notify();
}
+ thread.set_summarization_model(summarization_model.clone(), cx);
});
}
}
@@ -462,10 +464,7 @@ impl NativeAgentConnection {
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
- + FnOnce(
- Entity<Thread>,
- &mut App,
- ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
+ + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
@@ -489,7 +488,18 @@ impl NativeAgentConnection {
log::trace!("Received completion event: {:?}", event);
match event {
- AgentResponseEvent::Text(text) => {
+ ThreadEvent::UserMessage(message) => {
+ acp_thread.update(cx, |thread, cx| {
+ for content in message.content {
+ thread.push_user_content_block(
+ Some(message.id.clone()),
+ content.into(),
+ cx,
+ );
+ }
+ })?;
+ }
+ ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@@ -501,7 +511,7 @@ impl NativeAgentConnection {
)
})?;
}
- AgentResponseEvent::Thinking(text) => {
+ ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@@ -513,7 +523,7 @@ impl NativeAgentConnection {
)
})?;
}
- AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
+ ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
@@ -536,22 +546,26 @@ impl NativeAgentConnection {
})
.detach();
}
- AgentResponseEvent::ToolCall(tool_call) => {
+ ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})??;
}
- AgentResponseEvent::ToolCallUpdate(update) => {
+ ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
- AgentResponseEvent::Retry(status) => {
+ ThreadEvent::TitleUpdate(title) => {
+ acp_thread
+ .update(cx, |thread, cx| thread.update_title(title, cx))??;
+ }
+ ThreadEvent::Retry(status) => {
acp_thread.update(cx, |thread, cx| {
thread.update_retry_status(status, cx)
})?;
}
- AgentResponseEvent::Stop(stop_reason) => {
+ ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
@@ -604,8 +618,8 @@ impl AgentModelSelector for NativeAgentConnection {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
- thread.update(cx, |thread, _cx| {
- thread.set_model(model.clone());
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
});
update_settings_file::<AgentSettings>(
@@ -665,30 +679,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx.spawn(async move |cx| {
log::debug!("Starting thread creation in async context");
- // Generate session ID
- let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
- log::info!("Created session with ID: {}", session_id);
-
- // Create AcpThread
- let acp_thread = cx.update(|cx| {
- cx.new(|cx| {
- acp_thread::AcpThread::new(
- "agent2",
- self.clone(),
- project.clone(),
- session_id.clone(),
- cx,
- )
- })
- })?;
- let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
-
+ let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
// Create Thread
let thread = agent.update(
cx,
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
// Fetch default model from registry settings
let registry = LanguageModelRegistry::read_global(cx);
+ let language_registry = project.read(cx).languages().clone();
// Log available models for debugging
let available_count = registry.available_models(cx).count();
@@ -699,6 +697,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.models
.model_from_id(&LanguageModels::model_id(&default_model.model))
});
+ let summarization_model = registry.thread_summary_model().map(|c| c.model);
let thread = cx.new(|cx| {
let mut thread = Thread::new(
@@ -708,13 +707,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
action_log.clone(),
agent.templates.clone(),
default_model,
+ summarization_model,
cx,
);
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
- thread.add_tool(EditFileTool::new(cx.entity()));
+ thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
@@ -722,7 +722,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(NowTool);
thread.add_tool(OpenTool::new(project.clone()));
- thread.add_tool(ReadFileTool::new(project.clone(), action_log));
+ thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
thread.add_tool(TerminalTool::new(project.clone(), cx));
thread.add_tool(ThinkingTool);
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
@@ -733,6 +733,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
},
)??;
+ let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
+ log::info!("Created session with ID: {}", session_id);
+ // Create AcpThread
+ let acp_thread = cx.update(|cx| {
+ cx.new(|_cx| {
+ acp_thread::AcpThread::new(
+ "agent2",
+ self.clone(),
+ project.clone(),
+ action_log.clone(),
+ session_id.clone(),
+ )
+ })
+ })?;
+
// Store the session
agent.update(cx, |agent, cx| {
agent.sessions.insert(
@@ -803,7 +818,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
if let Some(agent) = agent.sessions.get(session_id) {
- agent.thread.update(cx, |thread, _cx| thread.cancel());
+ agent.thread.update(cx, |thread, cx| thread.cancel(cx));
}
});
}
@@ -830,7 +845,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
- Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
+ Task::ready(
+ self.0
+ .update(cx, |thread, cx| thread.truncate(message_id, cx)),
+ )
}
}
@@ -345,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let mut saw_partial_tool_use = false;
while let Some(event) = events.next().await {
- if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
+ if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
let message = thread.last_message().unwrap();
@@ -735,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
);
}
-async fn expect_tool_call(
- events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
-) -> acp::ToolCall {
+async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
let event = events
.next()
.await
.expect("no tool call authorization event received")
.unwrap();
match event {
- AgentResponseEvent::ToolCall(tool_call) => return tool_call,
+ ThreadEvent::ToolCall(tool_call) => return tool_call,
event => {
panic!("Unexpected event {event:?}");
}
@@ -752,7 +750,7 @@ async fn expect_tool_call(
}
async fn expect_tool_call_update_fields(
- events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
+ events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
@@ -760,7 +758,7 @@ async fn expect_tool_call_update_fields(
.expect("no tool call authorization event received")
.unwrap();
match event {
- AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
+ ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
return update;
}
event => {
@@ -770,7 +768,7 @@ async fn expect_tool_call_update_fields(
}
async fn next_tool_call_authorization(
- events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
+ events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> ToolCallAuthorization {
loop {
let event = events
@@ -778,7 +776,7 @@ async fn next_tool_call_authorization(
.await
.expect("no tool call authorization event received")
.unwrap();
- if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
+ if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
let permission_kinds = tool_call_authorization
.options
.iter()
@@ -945,13 +943,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
let mut echo_completed = false;
while let Some(event) = events.next().await {
match event.unwrap() {
- AgentResponseEvent::ToolCall(tool_call) => {
+ ThreadEvent::ToolCall(tool_call) => {
assert_eq!(tool_call.title, expected_tools.remove(0));
if tool_call.title == "Echo" {
echo_id = Some(tool_call.id);
}
}
- AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
+ ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
acp::ToolCallUpdate {
id,
fields:
@@ -973,13 +971,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running.
- thread.update(cx, |thread, _cx| thread.cancel());
+ thread.update(cx, |thread, cx| thread.cancel(cx));
let events = events.collect::<Vec<_>>().await;
let last_event = events.last();
assert!(
matches!(
last_event,
- Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
+ Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
),
"unexpected event {last_event:?}"
);
@@ -1161,7 +1159,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
thread
- .update(cx, |thread, _cx| thread.truncate(message_id))
+ .update(cx, |thread, cx| thread.truncate(message_id, cx))
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
@@ -1203,6 +1201,51 @@ async fn test_truncate(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_title_generation(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let summary_model = Arc::new(FakeLanguageModel::default());
+ thread.update(cx, |thread, cx| {
+ thread.set_summarization_model(Some(summary_model.clone()), cx)
+ });
+
+ let send = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("Hey!");
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+ thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
+
+ // Ensure the summary model has been invoked to generate a title.
+ summary_model.send_last_completion_stream_text_chunk("Hello ");
+ summary_model.send_last_completion_stream_text_chunk("world\nG");
+ summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
+ summary_model.end_last_completion_stream();
+ send.collect::<Vec<_>>().await;
+ thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+
+ // Send another message, ensuring no title is generated this time.
+ let send = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello again"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Hey again!");
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+ assert_eq!(summary_model.pending_completions(), Vec::new());
+ send.collect::<Vec<_>>().await;
+ thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+}
+
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init);
@@ -1442,7 +1485,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
- thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@@ -1454,10 +1497,10 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
- AgentResponseEvent::Retry(retry_status) => {
+ ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
- AgentResponseEvent::Stop(..) => break,
+ ThreadEvent::Stop(..) => break,
_ => {}
}
}
@@ -1486,7 +1529,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
- thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@@ -1507,10 +1550,10 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(Ok(event)) = events.next().await {
match event {
- AgentResponseEvent::Retry(retry_status) => {
+ ThreadEvent::Retry(retry_status) => {
retry_events.push(retry_status);
}
- AgentResponseEvent::Stop(..) => break,
+ ThreadEvent::Stop(..) => break,
_ => {}
}
}
@@ -1543,7 +1586,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut events = thread
.update(cx, |thread, cx| {
- thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+ thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx)
})
.unwrap();
@@ -1565,10 +1608,10 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut retry_events = Vec::new();
while let Some(event) = events.next().await {
match event {
- Ok(AgentResponseEvent::Retry(retry_status)) => {
+ Ok(ThreadEvent::Retry(retry_status)) => {
retry_events.push(retry_status);
}
- Ok(AgentResponseEvent::Stop(..)) => break,
+ Ok(ThreadEvent::Stop(..)) => break,
Err(error) => errors.push(error),
_ => {}
}
@@ -1592,11 +1635,11 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
}
/// Filters out the stop events for asserting against in tests
-fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
+fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
- AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
+ ThreadEvent::Stop(stop_reason) => Some(stop_reason),
_ => None,
})
.collect()
@@ -1713,6 +1756,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
action_log,
templates,
Some(model.clone()),
+ None,
cx,
)
});
@@ -1,25 +1,34 @@
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
+use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
use agent_client_protocol as acp;
-use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
+use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
+use chrono::{DateTime, Utc};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
use collections::IndexMap;
use fs::Fs;
use futures::{
+ FutureExt,
channel::{mpsc, oneshot},
+ future::Shared,
stream::FuturesUnordered,
};
+use git::repository::DiffType;
use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+ TokenUsage,
+};
+use project::{
+ Project,
+ git_store::{GitStore, RepositoryState},
};
-use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
@@ -35,28 +44,7 @@ use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
-#[derive(
- Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
-)]
-pub struct ThreadId(Arc<str>);
-
-impl ThreadId {
- pub fn new() -> Self {
- Self(Uuid::new_v4().to_string().into())
- }
-}
-
-impl std::fmt::Display for ThreadId {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-impl From<&str> for ThreadId {
- fn from(value: &str) -> Self {
- Self(value.into())
- }
-}
+const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
/// The ID of the user prompt that initiated a request.
///
@@ -91,7 +79,7 @@ enum RetryStrategy {
},
}
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Message {
User(UserMessage),
Agent(AgentMessage),
@@ -106,6 +94,18 @@ impl Message {
}
}
+ pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
+ match self {
+ Message::User(message) => vec![message.to_request()],
+ Message::Agent(message) => message.to_request(),
+ Message::Resume => vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Continue where you left off".into()],
+ cache: false,
+ }],
+ }
+ }
+
pub fn to_markdown(&self) -> String {
match self {
Message::User(message) => message.to_markdown(),
@@ -113,15 +113,22 @@ impl Message {
Message::Resume => "[resumed after tool use limit was reached]".into(),
}
}
+
+ pub fn role(&self) -> Role {
+ match self {
+ Message::User(_) | Message::Resume => Role::User,
+ Message::Agent(_) => Role::Assistant,
+ }
+ }
}
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserMessage {
pub id: UserMessageId,
pub content: Vec<UserMessageContent>,
}
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UserMessageContent {
Text(String),
Mention { uri: MentionUri, content: String },
@@ -345,9 +352,6 @@ impl AgentMessage {
AgentMessageContent::RedactedThinking(_) => {
markdown.push_str("<redacted_thinking />\n")
}
- AgentMessageContent::Image(_) => {
- markdown.push_str("<image />\n");
- }
AgentMessageContent::ToolUse(tool_use) => {
markdown.push_str(&format!(
"**Tool Use**: {} (ID: {})\n",
@@ -418,9 +422,6 @@ impl AgentMessage {
AgentMessageContent::ToolUse(value) => {
language_model::MessageContent::ToolUse(value.clone())
}
- AgentMessageContent::Image(value) => {
- language_model::MessageContent::Image(value.clone())
- }
};
assistant_message.content.push(chunk);
}
@@ -450,13 +451,13 @@ impl AgentMessage {
}
}
-#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AgentMessage {
pub content: Vec<AgentMessageContent>,
pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
}
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AgentMessageContent {
Text(String),
Thinking {
@@ -464,17 +465,18 @@ pub enum AgentMessageContent {
signature: Option<String>,
},
RedactedThinking(String),
- Image(LanguageModelImage),
ToolUse(LanguageModelToolUse),
}
#[derive(Debug)]
-pub enum AgentResponseEvent {
- Text(String),
- Thinking(String),
+pub enum ThreadEvent {
+ UserMessage(UserMessage),
+ AgentText(String),
+ AgentThinking(String),
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
+ TitleUpdate(SharedString),
Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
}
@@ -487,8 +489,12 @@ pub struct ToolCallAuthorization {
}
pub struct Thread {
- id: ThreadId,
+ id: acp::SessionId,
prompt_id: PromptId,
+ updated_at: DateTime<Utc>,
+ title: Option<SharedString>,
+ #[allow(unused)]
+ summary: DetailedSummaryState,
messages: Vec<Message>,
completion_mode: CompletionMode,
/// Holds the task that handles agent interaction until the end of the turn.
@@ -498,11 +504,18 @@ pub struct Thread {
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool,
+ #[allow(unused)]
+ request_token_usage: Vec<TokenUsage>,
+ #[allow(unused)]
+ cumulative_token_usage: TokenUsage,
+ #[allow(unused)]
+ initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
context_server_registry: Entity<ContextServerRegistry>,
profile_id: AgentProfileId,
project_context: Entity<ProjectContext>,
templates: Arc<Templates>,
model: Option<Arc<dyn LanguageModel>>,
+ summarization_model: Option<Arc<dyn LanguageModel>>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
@@ -515,36 +528,254 @@ impl Thread {
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
model: Option<Arc<dyn LanguageModel>>,
+ summarization_model: Option<Arc<dyn LanguageModel>>,
cx: &mut Context<Self>,
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
Self {
- id: ThreadId::new(),
+ id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
prompt_id: PromptId::new(),
+ updated_at: Utc::now(),
+ title: None,
+ summary: DetailedSummaryState::default(),
messages: Vec::new(),
completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
tool_use_limit_reached: false,
+ request_token_usage: Vec::new(),
+ cumulative_token_usage: TokenUsage::default(),
+ initial_project_snapshot: {
+ let project_snapshot = Self::project_snapshot(project.clone(), cx);
+ cx.foreground_executor()
+ .spawn(async move { Some(project_snapshot.await) })
+ .shared()
+ },
context_server_registry,
profile_id,
project_context,
templates,
model,
+ summarization_model,
project,
action_log,
}
}
- pub fn project(&self) -> &Entity<Project> {
- &self.project
+ pub fn id(&self) -> &acp::SessionId {
+ &self.id
+ }
+
+ pub fn replay(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
+ let (tx, rx) = mpsc::unbounded();
+ let stream = ThreadEventStream(tx);
+ for message in &self.messages {
+ match message {
+ Message::User(user_message) => stream.send_user_message(user_message),
+ Message::Agent(assistant_message) => {
+ for content in &assistant_message.content {
+ match content {
+ AgentMessageContent::Text(text) => stream.send_text(text),
+ AgentMessageContent::Thinking { text, .. } => {
+ stream.send_thinking(text)
+ }
+ AgentMessageContent::RedactedThinking(_) => {}
+ AgentMessageContent::ToolUse(tool_use) => {
+ self.replay_tool_call(
+ tool_use,
+ assistant_message.tool_results.get(&tool_use.id),
+ &stream,
+ cx,
+ );
+ }
+ }
+ }
+ }
+ Message::Resume => {}
+ }
+ }
+ rx
+ }
+
+ fn replay_tool_call(
+ &self,
+ tool_use: &LanguageModelToolUse,
+ tool_result: Option<&LanguageModelToolResult>,
+ stream: &ThreadEventStream,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
+ stream
+ .0
+ .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
+ id: acp::ToolCallId(tool_use.id.to_string().into()),
+ title: tool_use.name.to_string(),
+ kind: acp::ToolKind::Other,
+ status: acp::ToolCallStatus::Failed,
+ content: Vec::new(),
+ locations: Vec::new(),
+ raw_input: Some(tool_use.input.clone()),
+ raw_output: None,
+ })))
+ .ok();
+ return;
+ };
+
+ let title = tool.initial_title(tool_use.input.clone());
+ let kind = tool.kind();
+ stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
+
+ let output = tool_result
+ .as_ref()
+ .and_then(|result| result.output.clone());
+ if let Some(output) = output.clone() {
+ let tool_event_stream = ToolCallEventStream::new(
+ tool_use.id.clone(),
+ stream.clone(),
+ Some(self.project.read(cx).fs().clone()),
+ );
+ tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
+ .log_err();
+ }
+
+ stream.update_tool_call_fields(
+ &tool_use.id,
+ acp::ToolCallUpdateFields {
+ status: Some(acp::ToolCallStatus::Completed),
+ raw_output: output,
+ ..Default::default()
+ },
+ );
+ }
+
+ /// Create a snapshot of the current project state including git information and unsaved buffers.
+ fn project_snapshot(
+ project: Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
+ let git_store = project.read(cx).git_store().clone();
+ let worktree_snapshots: Vec<_> = project
+ .read(cx)
+ .visible_worktrees(cx)
+ .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
+ .collect();
+
+ cx.spawn(async move |_, cx| {
+ let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
+
+ let mut unsaved_buffers = Vec::new();
+ cx.update(|app_cx| {
+ let buffer_store = project.read(app_cx).buffer_store();
+ for buffer_handle in buffer_store.read(app_cx).buffers() {
+ let buffer = buffer_handle.read(app_cx);
+ if buffer.is_dirty()
+ && let Some(file) = buffer.file()
+ {
+ let path = file.path().to_string_lossy().to_string();
+ unsaved_buffers.push(path);
+ }
+ }
+ })
+ .ok();
+
+ Arc::new(ProjectSnapshot {
+ worktree_snapshots,
+ unsaved_buffer_paths: unsaved_buffers,
+ timestamp: Utc::now(),
+ })
+ })
+ }
+
+ fn worktree_snapshot(
+ worktree: Entity<project::Worktree>,
+ git_store: Entity<GitStore>,
+ cx: &App,
+ ) -> Task<agent::thread::WorktreeSnapshot> {
+ cx.spawn(async move |cx| {
+ // Get worktree path and snapshot
+ let worktree_info = cx.update(|app_cx| {
+ let worktree = worktree.read(app_cx);
+ let path = worktree.abs_path().to_string_lossy().to_string();
+ let snapshot = worktree.snapshot();
+ (path, snapshot)
+ });
+
+ let Ok((worktree_path, _snapshot)) = worktree_info else {
+ return WorktreeSnapshot {
+ worktree_path: String::new(),
+ git_state: None,
+ };
+ };
+
+ let git_state = git_store
+ .update(cx, |git_store, cx| {
+ git_store
+ .repositories()
+ .values()
+ .find(|repo| {
+ repo.read(cx)
+ .abs_path_to_repo_path(&worktree.read(cx).abs_path())
+ .is_some()
+ })
+ .cloned()
+ })
+ .ok()
+ .flatten()
+ .map(|repo| {
+ repo.update(cx, |repo, _| {
+ let current_branch =
+ repo.branch.as_ref().map(|branch| branch.name().to_owned());
+ repo.send_job(None, |state, _| async move {
+ let RepositoryState::Local { backend, .. } = state else {
+ return GitState {
+ remote_url: None,
+ head_sha: None,
+ current_branch,
+ diff: None,
+ };
+ };
+
+ let remote_url = backend.remote_url("origin");
+ let head_sha = backend.head_sha().await;
+ let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
+
+ GitState {
+ remote_url,
+ head_sha,
+ current_branch,
+ diff,
+ }
+ })
+ })
+ });
+
+ let git_state = match git_state {
+ Some(git_state) => match git_state.ok() {
+ Some(git_state) => git_state.await.ok(),
+ None => None,
+ },
+ None => None,
+ };
+
+ WorktreeSnapshot {
+ worktree_path,
+ git_state,
+ }
+ })
}
pub fn project_context(&self) -> &Entity<ProjectContext> {
&self.project_context
}
+ pub fn project(&self) -> &Entity<Project> {
+ &self.project
+ }
+
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
@@ -553,16 +784,27 @@ impl Thread {
self.model.as_ref()
}
- pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
+ pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
self.model = Some(model);
+ cx.notify()
+ }
+
+ pub fn set_summarization_model(
+ &mut self,
+ model: Option<Arc<dyn LanguageModel>>,
+ cx: &mut Context<Self>,
+ ) {
+ self.summarization_model = model;
+ cx.notify()
}
pub fn completion_mode(&self) -> CompletionMode {
self.completion_mode
}
- pub fn set_completion_mode(&mut self, mode: CompletionMode) {
+ pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
self.completion_mode = mode;
+ cx.notify()
}
#[cfg(any(test, feature = "test-support"))]
@@ -590,29 +832,29 @@ impl Thread {
self.profile_id = profile_id;
}
- pub fn cancel(&mut self) {
+ pub fn cancel(&mut self, cx: &mut Context<Self>) {
if let Some(running_turn) = self.running_turn.take() {
running_turn.cancel();
}
- self.flush_pending_message();
+ self.flush_pending_message(cx);
}
- pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
- self.cancel();
+ pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
+ self.cancel(cx);
let Some(position) = self.messages.iter().position(
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
) else {
return Err(anyhow!("Message not found"));
};
self.messages.truncate(position);
+ cx.notify();
Ok(())
}
pub fn resume(
&mut self,
cx: &mut Context<Self>,
- ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
- anyhow::ensure!(self.model.is_some(), "Model not set");
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
anyhow::ensure!(
self.tool_use_limit_reached,
"can only resume after tool use limit is reached"
@@ -633,7 +875,7 @@ impl Thread {
id: UserMessageId,
content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>,
- ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
where
T: Into<UserMessageContent>,
{
@@ -656,22 +898,19 @@ impl Thread {
fn run_turn(
&mut self,
cx: &mut Context<Self>,
- ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
- self.cancel();
-
- let model = self
- .model()
- .cloned()
- .context("No language model configured")?;
- let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
- let event_stream = AgentResponseEventStream(events_tx);
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
+ self.cancel(cx);
+
+ let model = self.model.clone().context("No language model configured")?;
+ let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
+ let event_stream = ThreadEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false;
self.running_turn = Some(RunningTurn {
event_stream: event_stream.clone(),
_task: cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
- let turn_result: Result<()> = async {
+ let turn_result: Result<StopReason> = async {
let mut completion_intent = CompletionIntent::UserPrompt;
loop {
log::debug!(
@@ -685,18 +924,27 @@ impl Thread {
log::info!("Calling model.stream_completion");
let mut tool_use_limit_reached = false;
+ let mut refused = false;
+ let mut reached_max_tokens = false;
let mut tool_uses = Self::stream_completion_with_retries(
this.clone(),
model.clone(),
request,
- message_ix,
&event_stream,
&mut tool_use_limit_reached,
+ &mut refused,
+ &mut reached_max_tokens,
cx,
)
.await?;
- let used_tools = tool_uses.is_empty();
+ if refused {
+ return Ok(StopReason::Refusal);
+ } else if reached_max_tokens {
+ return Ok(StopReason::MaxTokens);
+ }
+
+ let end_turn = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
@@ -724,29 +972,42 @@ impl Thread {
log::info!("Tool use limit reached, completing turn");
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
return Err(language_model::ToolUseLimitReachedError.into());
- } else if used_tools {
+ } else if end_turn {
log::info!("No tool uses found, completing turn");
- return Ok(());
+ return Ok(StopReason::EndTurn);
} else {
- this.update(cx, |this, _| this.flush_pending_message())?;
+ this.update(cx, |this, cx| this.flush_pending_message(cx))?;
completion_intent = CompletionIntent::ToolResults;
}
}
}
.await;
+ _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
+
+ match turn_result {
+ Ok(reason) => {
+ log::info!("Turn execution completed: {:?}", reason);
+
+ let update_title = this
+ .update(cx, |this, cx| this.update_title(&event_stream, cx))
+ .ok()
+ .flatten();
+ if let Some(update_title) = update_title {
+ update_title.await.context("update title failed").log_err();
+ }
- if let Err(error) = turn_result {
- log::error!("Turn execution failed: {:?}", error);
- event_stream.send_error(error);
- } else {
- log::info!("Turn execution completed successfully");
+ event_stream.send_stop(reason);
+ if reason == StopReason::Refusal {
+ _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
+ }
+ }
+ Err(error) => {
+ log::error!("Turn execution failed: {:?}", error);
+ event_stream.send_error(error);
+ }
}
- this.update(cx, |this, _| {
- this.flush_pending_message();
- this.running_turn.take();
- })
- .ok();
+ _ = this.update(cx, |this, _| this.running_turn.take());
}),
});
Ok(events_rx)
@@ -756,9 +1017,10 @@ impl Thread {
this: WeakEntity<Self>,
model: Arc<dyn LanguageModel>,
request: LanguageModelRequest,
- message_ix: usize,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
tool_use_limit_reached: &mut bool,
+ refusal: &mut bool,
+ max_tokens_reached: &mut bool,
cx: &mut AsyncApp,
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
log::debug!("Stream completion started successfully");
@@ -774,16 +1036,17 @@ impl Thread {
)) => {
*tool_use_limit_reached = true;
}
- Ok(LanguageModelCompletionEvent::Stop(reason)) => {
- event_stream.send_stop(reason);
- if reason == StopReason::Refusal {
- this.update(cx, |this, _cx| {
- this.flush_pending_message();
- this.messages.truncate(message_ix);
- })?;
- return Ok(tool_uses);
- }
+ Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
+ *refusal = true;
+ return Ok(FuturesUnordered::default());
+ }
+ Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
+ *max_tokens_reached = true;
+ return Ok(FuturesUnordered::default());
}
+ Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::ToolUse | StopReason::EndTurn,
+ )) => break,
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
this.update(cx, |this, cx| {
@@ -843,6 +1106,7 @@ impl Thread {
}
}
}
+
return Ok(tool_uses);
}
}
@@ -870,7 +1134,7 @@ impl Thread {
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
log::trace!("Handling streamed completion event: {:?}", event);
@@ -878,7 +1142,7 @@ impl Thread {
match event {
StartMessage { .. } => {
- self.flush_pending_message();
+ self.flush_pending_message(cx);
self.pending_message = Some(AgentMessage::default());
}
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
@@ -912,7 +1176,7 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) {
event_stream.send_text(&new_text);
@@ -933,7 +1197,7 @@ impl Thread {
&mut self,
new_text: String,
new_signature: Option<String>,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) {
event_stream.send_thinking(&new_text);
@@ -965,7 +1229,7 @@ impl Thread {
fn handle_tool_use_event(
&mut self,
tool_use: LanguageModelToolUse,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
cx.notify();
@@ -1083,11 +1347,85 @@ impl Thread {
}
}
+ pub fn title(&self) -> SharedString {
+ self.title.clone().unwrap_or("New Thread".into())
+ }
+
+ fn update_title(
+ &mut self,
+ event_stream: &ThreadEventStream,
+ cx: &mut Context<Self>,
+ ) -> Option<Task<Result<()>>> {
+ if self.title.is_some() {
+ log::debug!("Skipping title generation because we already have one.");
+ return None;
+ }
+
+ log::info!(
+ "Generating title with model: {:?}",
+ self.summarization_model.as_ref().map(|model| model.name())
+ );
+ let model = self.summarization_model.clone()?;
+ let event_stream = event_stream.clone();
+ let mut request = LanguageModelRequest {
+ intent: Some(CompletionIntent::ThreadSummarization),
+ temperature: AgentSettings::temperature_for_model(&model, cx),
+ ..Default::default()
+ };
+
+ for message in &self.messages {
+ request.messages.extend(message.to_request());
+ }
+
+ request.messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![SUMMARIZE_THREAD_PROMPT.into()],
+ cache: false,
+ });
+ Some(cx.spawn(async move |this, cx| {
+ let mut title = String::new();
+ let mut messages = model.stream_completion(request, cx).await?;
+ while let Some(event) = messages.next().await {
+ let event = event?;
+ let text = match event {
+ LanguageModelCompletionEvent::Text(text) => text,
+ LanguageModelCompletionEvent::StatusUpdate(
+ CompletionRequestStatus::UsageUpdated { .. },
+ ) => {
+ // this.update(cx, |thread, cx| {
+ // thread.update_model_request_usage(amount as u32, limit, cx);
+ // })?;
+ // TODO: handle usage update
+ continue;
+ }
+ _ => continue,
+ };
+
+ let mut lines = text.lines();
+ title.extend(lines.next());
+
+ // Stop if the LLM generated multiple lines.
+ if lines.next().is_some() {
+ break;
+ }
+ }
+
+ log::info!("Setting title: {}", title);
+
+ this.update(cx, |this, cx| {
+ let title = SharedString::from(title);
+ event_stream.send_title_update(title.clone());
+ this.title = Some(title);
+ cx.notify();
+ })
+ }))
+ }
+
fn pending_message(&mut self) -> &mut AgentMessage {
self.pending_message.get_or_insert_default()
}
- fn flush_pending_message(&mut self) {
+ fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
let Some(mut message) = self.pending_message.take() else {
return;
};
@@ -1104,9 +1442,7 @@ impl Thread {
tool_use_id: tool_use.id.clone(),
tool_name: tool_use.name.clone(),
is_error: true,
- content: LanguageModelToolResultContent::Text(
- "Tool canceled by user".into(),
- ),
+ content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
output: None,
},
);
@@ -1114,6 +1450,8 @@ impl Thread {
}
self.messages.push(Message::Agent(message));
+ self.updated_at = Utc::now();
+ cx.notify()
}
pub(crate) fn build_completion_request(
@@ -1205,15 +1543,7 @@ impl Thread {
);
let mut messages = vec![self.build_system_message(cx)];
for message in &self.messages {
- match message {
- Message::User(message) => messages.push(message.to_request()),
- Message::Agent(message) => messages.extend(message.to_request()),
- Message::Resume => messages.push(LanguageModelRequestMessage {
- role: Role::User,
- content: vec!["Continue where you left off".into()],
- cache: false,
- }),
- }
+ messages.extend(message.to_request());
}
if let Some(message) = self.pending_message.as_ref() {
@@ -1367,7 +1697,7 @@ struct RunningTurn {
_task: Task<()>,
/// The current event stream for the running turn. Used to report a final
/// cancellation event if we cancel the turn.
- event_stream: AgentResponseEventStream,
+ event_stream: ThreadEventStream,
}
impl RunningTurn {
@@ -1420,6 +1750,17 @@ where
cx: &mut App,
) -> Task<Result<Self::Output>>;
+ /// Emits events for a previous execution of the tool.
+ fn replay(
+ &self,
+ _input: Self::Input,
+ _output: Self::Output,
+ _event_stream: ToolCallEventStream,
+ _cx: &mut App,
+ ) -> Result<()> {
+ Ok(())
+ }
+
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
}
@@ -1447,6 +1788,13 @@ pub trait AnyAgentTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<AgentToolOutput>>;
+ fn replay(
+ &self,
+ input: serde_json::Value,
+ output: serde_json::Value,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Result<()>;
}
impl<T> AnyAgentTool for Erased<Arc<T>>
@@ -1498,21 +1846,45 @@ where
})
})
}
+
+ fn replay(
+ &self,
+ input: serde_json::Value,
+ output: serde_json::Value,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Result<()> {
+ let input = serde_json::from_value(input)?;
+ let output = serde_json::from_value(output)?;
+ self.0.replay(input, output, event_stream, cx)
+ }
}
#[derive(Clone)]
-struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
+struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
+
+impl ThreadEventStream {
+ fn send_title_update(&self, text: SharedString) {
+ self.0
+ .unbounded_send(Ok(ThreadEvent::TitleUpdate(text)))
+ .ok();
+ }
+
+ fn send_user_message(&self, message: &UserMessage) {
+ self.0
+ .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
+ .ok();
+ }
-impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
+ .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
.ok();
}
fn send_thinking(&self, text: &str) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
+ .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
.ok();
}
@@ -1524,7 +1896,7 @@ impl AgentResponseEventStream {
input: serde_json::Value,
) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
+ .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
id,
title.to_string(),
kind,
@@ -1557,7 +1929,7 @@ impl AgentResponseEventStream {
fields: acp::ToolCallUpdateFields,
) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
+ .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.to_string().into()),
fields,
@@ -1568,26 +1940,24 @@ impl AgentResponseEventStream {
}
fn send_retry(&self, status: acp_thread::RetryStatus) {
- self.0
- .unbounded_send(Ok(AgentResponseEvent::Retry(status)))
- .ok();
+ self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
+ .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
.ok();
}
StopReason::MaxTokens => {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
+ .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
.ok();
}
StopReason::Refusal => {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
+ .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
.ok();
}
StopReason::ToolUse => {}
@@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
})
})
}
+
+ fn replay(
+ &self,
+ _input: serde_json::Value,
+ _output: serde_json::Value,
+ _event_stream: ToolCallEventStream,
+ _cx: &mut App,
+ ) -> Result<()> {
+ Ok(())
+ }
}
@@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
-use gpui::{App, AppContext, AsyncApp, Entity, Task};
+use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use indoc::formatdoc;
-use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
+use language::{LanguageRegistry, ToPoint};
use language_model::LanguageModelToolResultContent;
use paths;
use project::lsp_store::{FormatTrigger, LspFormatTarget};
@@ -98,11 +98,13 @@ pub enum EditFileMode {
#[derive(Debug, Serialize, Deserialize)]
pub struct EditFileToolOutput {
+ #[serde(alias = "original_path")]
input_path: PathBuf,
- project_path: PathBuf,
new_text: String,
old_text: Arc<String>,
+ #[serde(default)]
diff: String,
+ #[serde(alias = "raw_output")]
edit_agent_output: EditAgentOutput,
}
@@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
}
pub struct EditFileTool {
- thread: Entity<Thread>,
+ thread: WeakEntity<Thread>,
+ language_registry: Arc<LanguageRegistry>,
}
impl EditFileTool {
- pub fn new(thread: Entity<Thread>) -> Self {
- Self { thread }
+ pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
+ Self {
+ thread,
+ language_registry,
+ }
}
fn authorize(
@@ -167,8 +173,11 @@ impl EditFileTool {
// Check if path is inside the global config directory
// First check if it's already inside project - if not, try to canonicalize
- let thread = self.thread.read(cx);
- let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
+ let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
+ thread.project().read(cx).find_project_path(&input.path, cx)
+ }) else {
+ return Task::ready(Err(anyhow!("thread was dropped")));
+ };
// If the path is inside the project, and it's not one of the above edge cases,
// then no confirmation is necessary. Otherwise, confirmation is necessary.
@@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
- let project = self.thread.read(cx).project().clone();
+ let Ok(project) = self
+ .thread
+ .read_with(cx, |thread, _cx| thread.project().clone())
+ else {
+ return Task::ready(Err(anyhow!("thread was dropped")));
+ };
let project_path = match resolve_path(&input, project.clone(), cx) {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
@@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
});
}
- let Some(request) = self.thread.update(cx, |thread, cx| {
- thread
- .build_completion_request(CompletionIntent::ToolResults, cx)
- .ok()
- }) else {
- return Task::ready(Err(anyhow!("Failed to build completion request")));
- };
- let thread = self.thread.read(cx);
- let Some(model) = thread.model().cloned() else {
- return Task::ready(Err(anyhow!("No language model configured")));
- };
- let action_log = thread.action_log().clone();
-
let authorize = self.authorize(&input, &event_stream, cx);
cx.spawn(async move |cx: &mut AsyncApp| {
authorize.await?;
+ let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
+ let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
+ (request, thread.model().cloned(), thread.action_log().clone())
+ })?;
+ let request = request?;
+ let model = model.context("No language model configured")?;
+
let edit_format = EditFormat::from_model(model.clone())?;
let edit_agent = EditAgent::new(
model,
@@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
Ok(EditFileToolOutput {
input_path: input.path,
- project_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text,
diff: unified_diff,
@@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
})
})
}
+
+ fn replay(
+ &self,
+ _input: Self::Input,
+ output: Self::Output,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Result<()> {
+ event_stream.update_diff(cx.new(|cx| {
+ Diff::finalized(
+ output.input_path,
+ Some(output.old_text.to_string()),
+ output.new_text,
+ self.language_registry.clone(),
+ cx,
+ )
+ }));
+ Ok(())
+ }
}
/// Validate that the file path is valid, meaning:
@@ -515,6 +541,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -527,6 +554,7 @@ mod tests {
action_log,
Templates::new(),
Some(model),
+ None,
cx,
)
});
@@ -537,7 +565,11 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
- Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+ Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
+ input,
+ ToolCallEventStream::test().0,
+ cx,
+ )
})
.await;
assert_eq!(
@@ -724,6 +756,7 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
+ None,
cx,
)
});
@@ -750,9 +783,10 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool {
- thread: thread.clone(),
- })
+ Arc::new(EditFileTool::new(
+ thread.downgrade(),
+ language_registry.clone(),
+ ))
.run(input, ToolCallEventStream::test().0, cx)
});
@@ -806,7 +840,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+ Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
+ input,
+ ToolCallEventStream::test().0,
+ cx,
+ )
});
// Stream the unformatted content
@@ -850,6 +888,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
@@ -860,6 +899,7 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
+ None,
cx,
)
});
@@ -887,9 +927,10 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool {
- thread: thread.clone(),
- })
+ Arc::new(EditFileTool::new(
+ thread.downgrade(),
+ language_registry.clone(),
+ ))
.run(input, ToolCallEventStream::test().0, cx)
});
@@ -938,10 +979,11 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
- Arc::new(EditFileTool {
- thread: thread.clone(),
- })
- .run(input, ToolCallEventStream::test().0, cx)
+ Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
+ input,
+ ToolCallEventStream::test().0,
+ cx,
+ )
});
// Stream the content with trailing whitespace
@@ -976,6 +1018,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
@@ -986,10 +1029,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
+ None,
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@@ -1111,6 +1155,7 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
@@ -1123,10 +1168,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
+ None,
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@@ -1220,7 +1266,7 @@ mod tests {
cx,
)
.await;
-
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1233,10 +1279,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
+ None,
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test files in different worktrees
let test_cases = vec![
@@ -1302,6 +1349,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1314,10 +1362,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
+ None,
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test edge cases
let test_cases = vec![
@@ -1386,6 +1435,7 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1398,10 +1448,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
+ None,
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
// Test different EditFileMode values
let modes = vec![
@@ -1467,6 +1518,7 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1479,10 +1531,11 @@ mod tests {
action_log.clone(),
Templates::new(),
Some(model.clone()),
+ None,
cx,
)
});
- let tool = Arc::new(EditFileTool { thread });
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
assert_eq!(
tool.initial_title(Err(json!({
@@ -319,7 +319,7 @@ mod tests {
use theme::ThemeSettings;
use util::test::TempTree;
- use crate::AgentResponseEvent;
+ use crate::ThreadEvent;
use super::*;
@@ -396,7 +396,7 @@ mod tests {
});
cx.run_until_parked();
let event = stream_rx.try_next();
- if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event {
+ if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
auth.response.send(auth.options[0].id.clone()).unwrap();
}
@@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
}
};
- let result_text = if response.results.len() == 1 {
- "1 result".to_string()
- } else {
- format!("{} results", response.results.len())
- };
- event_stream.update_fields(acp::ToolCallUpdateFields {
- title: Some(format!("Searched the web: {result_text}")),
- content: Some(
- response
- .results
- .iter()
- .map(|result| acp::ToolCallContent::Content {
- content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
- name: result.title.clone(),
- uri: result.url.clone(),
- title: Some(result.title.clone()),
- description: Some(result.text.clone()),
- mime_type: None,
- annotations: None,
- size: None,
- }),
- })
- .collect(),
- ),
- ..Default::default()
- });
+ emit_update(&response, &event_stream);
Ok(WebSearchToolOutput(response))
})
}
+
+ fn replay(
+ &self,
+ _input: Self::Input,
+ output: Self::Output,
+ event_stream: ToolCallEventStream,
+ _cx: &mut App,
+ ) -> Result<()> {
+ emit_update(&output.0, &event_stream);
+ Ok(())
+ }
+}
+
+fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
+ let result_text = if response.results.len() == 1 {
+ "1 result".to_string()
+ } else {
+ format!("{} results", response.results.len())
+ };
+ event_stream.update_fields(acp::ToolCallUpdateFields {
+ title: Some(format!("Searched the web: {result_text}")),
+ content: Some(
+ response
+ .results
+ .iter()
+ .map(|result| acp::ToolCallContent::Content {
+ content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
+ name: result.title.clone(),
+ uri: result.url.clone(),
+ title: Some(result.title.clone()),
+ description: Some(result.text.clone()),
+ mime_type: None,
+ annotations: None,
+ size: None,
+ }),
+ })
+ .collect(),
+ ),
+ ..Default::default()
+ });
}
@@ -18,6 +18,7 @@ doctest = false
[dependencies]
acp_thread.workspace = true
+action_log.workspace = true
agent-client-protocol.workspace = true
agent_settings.workspace = true
agentic-coding-protocol.workspace = true
@@ -1,4 +1,5 @@
// Translates old acp agents into the new schema
+use action_log::ActionLog;
use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
use anyhow::{Context as _, Result, anyhow};
@@ -443,7 +444,8 @@ impl AgentConnection for AcpConnection {
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
- AcpThread::new(self.name, self.clone(), project, session_id, cx)
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ AcpThread::new(self.name, self.clone(), project, action_log, session_id)
});
current_thread.replace(thread.downgrade());
thread
@@ -1,3 +1,4 @@
+use action_log::ActionLog;
use agent_client_protocol::{self as acp, Agent as _};
use anyhow::anyhow;
use collections::HashMap;
@@ -153,14 +154,14 @@ impl AgentConnection for AcpConnection {
})?;
let session_id = response.session_id;
-
- let thread = cx.new(|cx| {
+ let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
+ let thread = cx.new(|_cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
+ action_log,
session_id.clone(),
- cx,
)
})?;
@@ -1,6 +1,7 @@
mod mcp_server;
pub mod tools;
+use action_log::ActionLog;
use collections::HashMap;
use context_server::listener::McpServerTool;
use language_models::provider::anthropic::AnthropicLanguageModelProvider;
@@ -215,8 +216,15 @@ impl AgentConnection for ClaudeAgentConnection {
}
});
- let thread = cx.new(|cx| {
- AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
+ let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
+ let thread = cx.new(|_cx| {
+ AcpThread::new(
+ "Claude Code",
+ self.clone(),
+ project,
+ action_log,
+ session_id.clone(),
+ )
})?;
thread_tx.send(thread.downgrade())?;
@@ -303,8 +303,13 @@ impl AcpThreadView {
let action_log_subscription =
cx.observe(&action_log, |_, _, cx| cx.notify());
- this.list_state
- .splice(0..0, thread.read(cx).entries().len());
+ let count = thread.read(cx).entries().len();
+ this.list_state.splice(0..0, count);
+ this.entry_view_state.update(cx, |view_state, cx| {
+ for ix in 0..count {
+ view_state.sync_entry(ix, &thread, window, cx);
+ }
+ });
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
@@ -808,6 +813,7 @@ impl AcpThreadView {
self.thread_retry_status.take();
self.thread_state = ThreadState::ServerExited { status: *status };
}
+ AcpThreadEvent::TitleUpdated => {}
}
cx.notify();
}
@@ -2816,12 +2822,15 @@ impl AcpThreadView {
return;
};
- thread.update(cx, |thread, _cx| {
+ thread.update(cx, |thread, cx| {
let current_mode = thread.completion_mode();
- thread.set_completion_mode(match current_mode {
- CompletionMode::Burn => CompletionMode::Normal,
- CompletionMode::Normal => CompletionMode::Burn,
- });
+ thread.set_completion_mode(
+ match current_mode {
+ CompletionMode::Burn => CompletionMode::Normal,
+ CompletionMode::Normal => CompletionMode::Burn,
+ },
+ cx,
+ );
});
}
@@ -3572,8 +3581,9 @@ impl AcpThreadView {
))
.on_click({
cx.listener(move |this, _, _window, cx| {
- thread.update(cx, |thread, _cx| {
- thread.set_completion_mode(CompletionMode::Burn);
+ thread.update(cx, |thread, cx| {
+ thread
+ .set_completion_mode(CompletionMode::Burn, cx);
});
this.resume_chat(cx);
})
@@ -4156,12 +4166,13 @@ pub(crate) mod tests {
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
Task::ready(Ok(cx.new(|cx| {
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
AcpThread::new(
"SaboteurAgentConnection",
self,
project,
+ action_log,
SessionId("test".into()),
- cx,
)
})))
}
@@ -199,24 +199,21 @@ impl AgentDiffPane {
let action_log = thread.action_log(cx).clone();
let mut this = Self {
- _subscriptions: [
- Some(
- cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
- this.update_excerpts(window, cx)
- }),
- ),
+ _subscriptions: vec![
+ cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
+ this.update_excerpts(window, cx)
+ }),
match &thread {
- AgentDiffThread::Native(thread) => {
- Some(cx.subscribe(thread, |this, _thread, event, cx| {
- this.handle_thread_event(event, cx)
- }))
- }
- AgentDiffThread::AcpThread(_) => None,
+ AgentDiffThread::Native(thread) => cx
+ .subscribe(thread, |this, _thread, event, cx| {
+ this.handle_native_thread_event(event, cx)
+ }),
+ AgentDiffThread::AcpThread(thread) => cx
+ .subscribe(thread, |this, _thread, event, cx| {
+ this.handle_acp_thread_event(event, cx)
+ }),
},
- ]
- .into_iter()
- .flatten()
- .collect(),
+ ],
title: SharedString::default(),
multibuffer,
editor,
@@ -324,13 +321,20 @@ impl AgentDiffPane {
}
}
- fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
+ fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
match event {
ThreadEvent::SummaryGenerated => self.update_title(cx),
_ => {}
}
}
+ fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
+ match event {
+ AcpThreadEvent::TitleUpdated => self.update_title(cx),
+ _ => {}
+ }
+ }
+
pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
self.editor.update(cx, |editor, cx| {
@@ -1523,7 +1527,8 @@ impl AgentDiff {
AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {
self.update_reviewing_editors(workspace, window, cx);
}
- AcpThreadEvent::EntriesRemoved(_)
+ AcpThreadEvent::TitleUpdated
+ | AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Retry(_) => {}
}