From 5259c8692dde8366c2eda6ec250ebdff0e6e0601 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 13:45:55 +0200 Subject: [PATCH] WIP --- crates/agent2/src/agent.rs | 33 +-- crates/agent2/src/db.rs | 8 +- crates/agent2/src/tests/mod.rs | 24 +- crates/agent2/src/thread.rs | 257 ++++++++++++++---- .../src/tools/context_server_registry.rs | 10 + crates/agent2/src/tools/terminal_tool.rs | 4 +- 6 files changed, 250 insertions(+), 86 deletions(-) diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index fa6201a4832250747b10b2e65efbee8f7f6593b9..2a17f23fd0fb04d5372126d7b457f57efa6a8d7d 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,9 +1,9 @@ use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME; use crate::{ - AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, - DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, - MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, - ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool, + EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, + OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, + UserMessageContent, WebSearchTool, templates::Templates, }; use crate::{DbThread, ThreadsDatabase}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; @@ -461,10 +461,7 @@ impl NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, f: impl 'static - + FnOnce( - Entity, - &mut App, - ) -> Result>>, + + FnOnce(Entity, &mut App) -> Result>>, ) -> Task> { let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { agent @@ -488,7 +485,10 @@ impl NativeAgentConnection { log::trace!("Received completion event: {:?}", event); match event { - AgentResponseEvent::Text(text) => { + ThreadEvent::UserMessage(message) => { + todo!() + } + ThreadEvent::AgentText(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block( acp::ContentBlock::Text(acp::TextContent { @@ -500,7 +500,7 @@ impl NativeAgentConnection { ) })?; } - AgentResponseEvent::Thinking(text) => { + ThreadEvent::AgentThinking(text) => { acp_thread.update(cx, |thread, cx| { thread.push_assistant_content_block( acp::ContentBlock::Text(acp::TextContent { @@ -512,7 +512,7 @@ impl NativeAgentConnection { ) })?; } - AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization { + ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { tool_call, options, response, @@ -535,17 +535,17 @@ impl NativeAgentConnection { }) .detach(); } - AgentResponseEvent::ToolCall(tool_call) => { + ThreadEvent::ToolCall(tool_call) => { acp_thread.update(cx, |thread, cx| { thread.upsert_tool_call(tool_call, cx) })??; } - AgentResponseEvent::ToolCallUpdate(update) => { + ThreadEvent::ToolCallUpdate(update) => { acp_thread.update(cx, |thread, cx| { thread.update_tool_call(update, cx) })??; } - AgentResponseEvent::Stop(stop_reason) => { + ThreadEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); } @@ -786,7 +786,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { .into_iter() .map(|thread| AcpThreadMetadata { agent: NATIVE_AGENT_SERVER_NAME.clone(), - id: thread.id, + id: thread.id.into(), title: thread.title, updated_at: thread.updated_at, }) @@ -806,11 +806,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection { session_id: acp::SessionId, cx: &mut App, ) -> Task>> { + let thread_id = session_id.clone().into(); let database = self.0.update(cx, |this, _| this.thread_database.clone()); cx.spawn(async move |cx| { let database = database.await.map_err(|e| anyhow!(e))?; let db_thread = database - .load_thread(session_id.clone()) + .load_thread(thread_id) .await? .context("no such thread found")?; diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index a7240df5c790b4094fa759b0a2cf62e378c30231..5332da276f2beb97255657817342b467b5b6ce7e 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,4 +1,4 @@ -use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; +use crate::{AgentMessage, AgentMessageContent, ThreadId, UserMessage, UserMessageContent}; use agent::thread_store; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; @@ -24,7 +24,7 @@ pub type DbLanguageModel = thread_store::SerializedLanguageModel; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbThreadMetadata { - pub id: acp::SessionId, + pub id: ThreadId, #[serde(alias = "summary")] pub title: SharedString, pub updated_at: DateTime, @@ -323,7 +323,7 @@ impl ThreadsDatabase { for (id, summary, updated_at) in rows { threads.push(DbThreadMetadata { - id: acp::SessionId(id), + id: ThreadId(id), title: summary.into(), updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), }); @@ -333,7 +333,7 @@ impl ThreadsDatabase { }) } - pub fn load_thread(&self, id: acp::SessionId) -> Task>> { + pub fn load_thread(&self, id: ThreadId) -> Task>> { let connection = self.connection.clone(); self.executor.spawn(async move { diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 48a16bf685575ace9360d4285a63702740a6fc86..5cbddafded88e1373200317dd9342686232c7298 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -329,7 +329,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { let mut saw_partial_tool_use = false; while let Some(event) = events.next().await { - if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event { + if let Ok(ThreadEvent::ToolCall(tool_call)) = event { thread.update(cx, |thread, _cx| { // Look for a tool use in the thread's last message let message = thread.last_message().unwrap(); @@ -710,7 +710,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { } async fn expect_tool_call( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCall { let event = events .next() @@ -718,7 +718,7 @@ async fn expect_tool_call( .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCall(tool_call) => return tool_call, + ThreadEvent::ToolCall(tool_call) => return tool_call, event => { panic!("Unexpected event {event:?}"); } @@ -726,7 +726,7 @@ async fn expect_tool_call( } async fn expect_tool_call_update_fields( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> acp::ToolCallUpdate { let event = events .next() @@ -734,7 +734,7 @@ async fn expect_tool_call_update_fields( .expect("no tool call authorization event received") .unwrap(); match event { - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => { return update; } event => { @@ -744,7 +744,7 @@ async fn expect_tool_call_update_fields( } async fn next_tool_call_authorization( - events: &mut UnboundedReceiver>, + events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { loop { let event = events @@ -752,7 +752,7 @@ async fn next_tool_call_authorization( .await .expect("no tool call authorization event received") .unwrap(); - if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event { + if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event { let permission_kinds = tool_call_authorization .options .iter() @@ -912,13 +912,13 @@ async fn test_cancellation(cx: &mut TestAppContext) { let mut echo_completed = false; while let Some(event) = events.next().await { match event.unwrap() { - AgentResponseEvent::ToolCall(tool_call) => { + ThreadEvent::ToolCall(tool_call) => { assert_eq!(tool_call.title, expected_tools.remove(0)); if tool_call.title == "Echo" { echo_id = Some(tool_call.id); } } - AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( acp::ToolCallUpdate { id, fields: @@ -946,7 +946,7 @@ async fn test_cancellation(cx: &mut TestAppContext) { assert!( matches!( last_event, - Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled))) ), "unexpected event {last_event:?}" ); @@ -1386,11 +1386,11 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { } /// Filters out the stop events for asserting against in tests -fn stop_events(result_events: Vec>) -> Vec { +fn stop_events(result_events: Vec>) -> Vec { result_events .into_iter() .filter_map(|event| match event.unwrap() { - AgentResponseEvent::Stop(stop_reason) => Some(stop_reason), + ThreadEvent::Stop(stop_reason) => Some(stop_reason), _ => None, }) .collect() diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index bd6d78b2f07b385ad0ddb355c5b927691ed4d104..87f4803dafddc55a62bcc3eb0110c14a2c715fa8 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,4 +1,4 @@ -use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates}; +use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates}; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; use agent_client_protocol as acp; @@ -30,10 +30,12 @@ use std::{fmt::Write, ops::Range}; use util::{ResultExt, markdown::MarkdownCodeBlock}; use uuid::Uuid; +const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; + #[derive( Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, )] -pub struct ThreadId(Arc); +pub struct ThreadId(pub(crate) Arc); impl ThreadId { pub fn new() -> Self { @@ -53,6 +55,18 @@ impl From<&str> for ThreadId { } } +impl From for ThreadId { + fn from(value: acp::SessionId) -> Self { + Self(value.0) + } +} + +impl From for acp::SessionId { + fn from(value: ThreadId) -> Self { + Self(value.0) + } +} + /// The ID of the user prompt that initiated a request. /// /// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). @@ -313,9 +327,6 @@ impl AgentMessage { AgentMessageContent::RedactedThinking(_) => { markdown.push_str("\n") } - AgentMessageContent::Image(_) => { - markdown.push_str("\n"); - } AgentMessageContent::ToolUse(tool_use) => { markdown.push_str(&format!( "**Tool Use**: {} (ID: {})\n", @@ -386,9 +397,6 @@ impl AgentMessage { AgentMessageContent::ToolUse(value) => { language_model::MessageContent::ToolUse(value.clone()) } - AgentMessageContent::Image(value) => { - language_model::MessageContent::Image(value.clone()) - } }; assistant_message.content.push(chunk); } @@ -432,14 +440,14 @@ pub enum AgentMessageContent { signature: Option, }, RedactedThinking(String), - Image(LanguageModelImage), ToolUse(LanguageModelToolUse), } #[derive(Debug)] -pub enum AgentResponseEvent { - Text(String), - Thinking(String), +pub enum ThreadEvent { + UserMessage(UserMessage), + AgentText(String), + AgentThinking(String), ToolCall(acp::ToolCall), ToolCallUpdate(acp_thread::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), @@ -504,6 +512,121 @@ impl Thread { } } + pub fn from_db( + id: ThreadId, + db_thread: DbThread, + project: Entity, + project_context: Rc>, + context_server_registry: Entity, + action_log: Entity, + templates: Arc, + model: Arc, + cx: &mut Context, + ) -> Self { + let profile_id = db_thread + .profile + .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); + Self { + id, + prompt_id: PromptId::new(), + messages: db_thread.messages, + completion_mode: CompletionMode::Normal, + running_turn: None, + pending_message: None, + tools: BTreeMap::default(), + tool_use_limit_reached: false, + context_server_registry, + profile_id, + project_context, + templates, + model, + project, + action_log, + } + } + + pub fn replay(&self, cx: &mut Context) -> mpsc::UnboundedReceiver> { + let (tx, rx) = mpsc::unbounded(); + let stream = ThreadEventStream(tx); + for message in &self.messages { + match message { + Message::User(user_message) => stream.send_user_message(&user_message), + Message::Agent(assistant_message) => { + for content in &assistant_message.content { + match content { + AgentMessageContent::Text(text) => stream.send_text(text), + AgentMessageContent::Thinking { text, .. } => { + stream.send_thinking(text) + } + AgentMessageContent::RedactedThinking(_) => {} + AgentMessageContent::ToolUse(tool_use) => { + self.replay_tool_call( + tool_use, + assistant_message.tool_results.get(&tool_use.id), + &stream, + cx, + ); + } + } + } + } + Message::Resume => {} + } + } + rx + } + + fn replay_tool_call( + &self, + tool_use: &LanguageModelToolUse, + tool_result: Option<&LanguageModelToolResult>, + stream: &ThreadEventStream, + cx: &mut Context, + ) { + let Some(tool) = self.tools.get(tool_use.name.as_ref()) else { + stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Failed, + content: Vec::new(), + locations: Vec::new(), + raw_input: Some(tool_use.input.clone()), + raw_output: None, + }))) + .ok(); + return; + }; + + let title = tool.initial_title(tool_use.input.clone()); + let kind = tool.kind(); + stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + + if let Some(output) = tool_result + .as_ref() + .and_then(|result| result.output.clone()) + { + let tool_event_stream = ToolCallEventStream::new( + tool_use.id.clone(), + stream.clone(), + Some(self.project.read(cx).fs().clone()), + ); + tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) + .log_err(); + } else { + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + content: Some(vec![TOOL_CANCELED_MESSAGE.into()]), + status: Some(acp::ToolCallStatus::Failed), + ..Default::default() + }, + ); + } + } + pub fn project(&self) -> &Entity { &self.project } @@ -574,7 +697,7 @@ impl Thread { pub fn resume( &mut self, cx: &mut Context, - ) -> Result>> { + ) -> Result>> { anyhow::ensure!( self.tool_use_limit_reached, "can only resume after tool use limit is reached" @@ -595,7 +718,7 @@ impl Thread { id: UserMessageId, content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> + ) -> mpsc::UnboundedReceiver> where T: Into, { @@ -613,15 +736,12 @@ impl Thread { self.run_turn(cx) } - fn run_turn( - &mut self, - cx: &mut Context, - ) -> mpsc::UnboundedReceiver> { + fn run_turn(&mut self, cx: &mut Context) -> mpsc::UnboundedReceiver> { self.cancel(); let model = self.model.clone(); - let (events_tx, events_rx) = mpsc::unbounded::>(); - let event_stream = AgentResponseEventStream(events_tx); + let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = ThreadEventStream(events_tx); let message_ix = self.messages.len().saturating_sub(1); self.tool_use_limit_reached = false; self.running_turn = Some(RunningTurn { @@ -755,7 +875,7 @@ impl Thread { fn handle_streamed_completion_event( &mut self, event: LanguageModelCompletionEvent, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { log::trace!("Handling streamed completion event: {:?}", event); @@ -797,7 +917,7 @@ impl Thread { fn handle_text_event( &mut self, new_text: String, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_text(&new_text); @@ -818,7 +938,7 @@ impl Thread { &mut self, new_text: String, new_signature: Option, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) { event_stream.send_thinking(&new_text); @@ -850,7 +970,7 @@ impl Thread { fn handle_tool_use_event( &mut self, tool_use: LanguageModelToolUse, - event_stream: &AgentResponseEventStream, + event_stream: &ThreadEventStream, cx: &mut Context, ) -> Option> { cx.notify(); @@ -989,9 +1109,7 @@ impl Thread { tool_use_id: tool_use.id.clone(), tool_name: tool_use.name.clone(), is_error: true, - content: LanguageModelToolResultContent::Text( - "Tool canceled by user".into(), - ), + content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()), output: None, }, ); @@ -1143,7 +1261,7 @@ struct RunningTurn { _task: Task<()>, /// The current event stream for the running turn. Used to report a final /// cancellation event if we cancel the turn. - event_stream: AgentResponseEventStream, + event_stream: ThreadEventStream, } impl RunningTurn { @@ -1196,6 +1314,17 @@ where cx: &mut App, ) -> Task>; + /// Emits events for a previous execution of the tool. + fn replay( + &self, + _input: Self::Input, + _output: Self::Output, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } + fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) } @@ -1223,6 +1352,13 @@ pub trait AnyAgentTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Task>; + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()>; } impl AnyAgentTool for Erased> @@ -1274,21 +1410,39 @@ where }) }) } + + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + let input = serde_json::from_value(input)?; + let output = serde_json::from_value(output)?; + self.0.replay(input, output, event_stream, cx) + } } #[derive(Clone)] -struct AgentResponseEventStream(mpsc::UnboundedSender>); +struct ThreadEventStream(mpsc::UnboundedSender>); + +impl ThreadEventStream { + fn send_user_message(&self, message: &UserMessage) { + self.0 + .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone()))) + .ok(); + } -impl AgentResponseEventStream { fn send_text(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string()))) .ok(); } fn send_thinking(&self, text: &str) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string()))) + .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string()))) .ok(); } @@ -1300,7 +1454,7 @@ impl AgentResponseEventStream { input: serde_json::Value, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( + .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call( id, title.to_string(), kind, @@ -1333,7 +1487,7 @@ impl AgentResponseEventStream { fields: acp::ToolCallUpdateFields, ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.to_string().into()), fields, @@ -1347,17 +1501,17 @@ impl AgentResponseEventStream { match reason { StopReason::EndTurn => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn))) .ok(); } StopReason::MaxTokens => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens))) .ok(); } StopReason::Refusal => { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal))) .ok(); } StopReason::ToolUse => {} @@ -1366,7 +1520,7 @@ impl AgentResponseEventStream { fn send_canceled(&self) { self.0 - .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled))) .ok(); } @@ -1378,24 +1532,23 @@ impl AgentResponseEventStream { #[derive(Clone)] pub struct ToolCallEventStream { tool_use_id: LanguageModelToolUseId, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, } impl ToolCallEventStream { #[cfg(test)] pub fn test() -> (Self, ToolCallEventStreamReceiver) { - let (events_tx, events_rx) = mpsc::unbounded::>(); + let (events_tx, events_rx) = mpsc::unbounded::>(); - let stream = - ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None); + let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None); (stream, ToolCallEventStreamReceiver(events_rx)) } fn new( tool_use_id: LanguageModelToolUseId, - stream: AgentResponseEventStream, + stream: ThreadEventStream, fs: Option>, ) -> Self { Self { @@ -1413,7 +1566,7 @@ impl ToolCallEventStream { pub fn update_diff(&self, diff: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateDiff { id: acp::ToolCallId(self.tool_use_id.to_string().into()), diff, @@ -1426,7 +1579,7 @@ impl ToolCallEventStream { pub fn update_terminal(&self, terminal: Entity) { self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( acp_thread::ToolCallUpdateTerminal { id: acp::ToolCallId(self.tool_use_id.to_string().into()), terminal, @@ -1444,7 +1597,7 @@ impl ToolCallEventStream { let (response_tx, response_rx) = oneshot::channel(); self.stream .0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( + .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization( ToolCallAuthorization { tool_call: acp::ToolCallUpdate { id: acp::ToolCallId(self.tool_use_id.to_string().into()), @@ -1494,13 +1647,13 @@ impl ToolCallEventStream { } #[cfg(test)] -pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); +pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); #[cfg(test)] impl ToolCallEventStreamReceiver { pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { + if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event { auth } else { panic!("Expected ToolCallAuthorization but got: {:?}", event); @@ -1509,9 +1662,9 @@ impl ToolCallEventStreamReceiver { pub async fn expect_terminal(&mut self) -> Entity { let event = self.0.next().await; - if let Some(Ok(AgentResponseEvent::ToolCallUpdate( - acp_thread::ToolCallUpdate::UpdateTerminal(update), - ))) = event + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( + update, + )))) = event { update.terminal } else { @@ -1522,7 +1675,7 @@ impl ToolCallEventStreamReceiver { #[cfg(test)] impl std::ops::Deref for ToolCallEventStreamReceiver { - type Target = mpsc::UnboundedReceiver>; + type Target = mpsc::UnboundedReceiver>; fn deref(&self) -> &Self::Target { &self.0 diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent2/src/tools/context_server_registry.rs index db39e9278c250865d63922cd802e1ce9fb1d003f..3b5225317f75dd12785c05d0b44742faebc313a8 100644 --- a/crates/agent2/src/tools/context_server_registry.rs +++ b/crates/agent2/src/tools/context_server_registry.rs @@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool { }) }) } + + fn replay( + &self, + _input: serde_json::Value, + _output: serde_json::Value, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) + } } diff --git a/crates/agent2/src/tools/terminal_tool.rs b/crates/agent2/src/tools/terminal_tool.rs index ecb855ac34d655caefab5ed0bd4f33d60be547f8..6984475d18efacede23057b44c028bc8f062f21b 100644 --- a/crates/agent2/src/tools/terminal_tool.rs +++ b/crates/agent2/src/tools/terminal_tool.rs @@ -319,7 +319,7 @@ mod tests { use theme::ThemeSettings; use util::test::TempTree; - use crate::AgentResponseEvent; + use crate::ThreadEvent; use super::*; @@ -396,7 +396,7 @@ mod tests { }); cx.run_until_parked(); let event = stream_rx.try_next(); - if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event { + if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event { auth.response.send(auth.options[0].id.clone()).unwrap(); }