@@ -1,9 +1,9 @@
use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
use crate::{
- AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
- DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
- MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
- ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
+ ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
+ EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
+ OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
+ UserMessageContent, WebSearchTool, templates::Templates,
};
use crate::{DbThread, ThreadsDatabase};
use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
@@ -461,10 +461,7 @@ impl NativeAgentConnection {
session_id: acp::SessionId,
cx: &mut App,
f: impl 'static
- + FnOnce(
- Entity<Thread>,
- &mut App,
- ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
+ + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
) -> Task<Result<acp::PromptResponse>> {
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
agent
@@ -488,7 +485,10 @@ impl NativeAgentConnection {
log::trace!("Received completion event: {:?}", event);
match event {
- AgentResponseEvent::Text(text) => {
+ ThreadEvent::UserMessage(message) => {
+ todo!()
+ }
+ ThreadEvent::AgentText(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@@ -500,7 +500,7 @@ impl NativeAgentConnection {
)
})?;
}
- AgentResponseEvent::Thinking(text) => {
+ ThreadEvent::AgentThinking(text) => {
acp_thread.update(cx, |thread, cx| {
thread.push_assistant_content_block(
acp::ContentBlock::Text(acp::TextContent {
@@ -512,7 +512,7 @@ impl NativeAgentConnection {
)
})?;
}
- AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
+ ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
tool_call,
options,
response,
@@ -535,17 +535,17 @@ impl NativeAgentConnection {
})
.detach();
}
- AgentResponseEvent::ToolCall(tool_call) => {
+ ThreadEvent::ToolCall(tool_call) => {
acp_thread.update(cx, |thread, cx| {
thread.upsert_tool_call(tool_call, cx)
})??;
}
- AgentResponseEvent::ToolCallUpdate(update) => {
+ ThreadEvent::ToolCallUpdate(update) => {
acp_thread.update(cx, |thread, cx| {
thread.update_tool_call(update, cx)
})??;
}
- AgentResponseEvent::Stop(stop_reason) => {
+ ThreadEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
}
@@ -786,7 +786,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.into_iter()
.map(|thread| AcpThreadMetadata {
agent: NATIVE_AGENT_SERVER_NAME.clone(),
- id: thread.id,
+ id: thread.id.into(),
title: thread.title,
updated_at: thread.updated_at,
})
@@ -806,11 +806,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
session_id: acp::SessionId,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
+ let thread_id = session_id.clone().into();
let database = self.0.update(cx, |this, _| this.thread_database.clone());
cx.spawn(async move |cx| {
let database = database.await.map_err(|e| anyhow!(e))?;
let db_thread = database
- .load_thread(session_id.clone())
+ .load_thread(thread_id)
.await?
.context("no such thread found")?;
@@ -1,4 +1,4 @@
-use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
+use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent_client_protocol as acp;
@@ -30,10 +30,12 @@ use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
+const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
+
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
)]
-pub struct ThreadId(Arc<str>);
+pub struct ThreadId(pub(crate) Arc<str>);
impl ThreadId {
pub fn new() -> Self {
@@ -53,6 +55,18 @@ impl From<&str> for ThreadId {
}
}
+impl From<acp::SessionId> for ThreadId {
+ fn from(value: acp::SessionId) -> Self {
+ Self(value.0)
+ }
+}
+
+impl From<ThreadId> for acp::SessionId {
+ fn from(value: ThreadId) -> Self {
+ Self(value.0)
+ }
+}
+
/// The ID of the user prompt that initiated a request.
///
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
@@ -313,9 +327,6 @@ impl AgentMessage {
AgentMessageContent::RedactedThinking(_) => {
markdown.push_str("<redacted_thinking />\n")
}
- AgentMessageContent::Image(_) => {
- markdown.push_str("<image />\n");
- }
AgentMessageContent::ToolUse(tool_use) => {
markdown.push_str(&format!(
"**Tool Use**: {} (ID: {})\n",
@@ -386,9 +397,6 @@ impl AgentMessage {
AgentMessageContent::ToolUse(value) => {
language_model::MessageContent::ToolUse(value.clone())
}
- AgentMessageContent::Image(value) => {
- language_model::MessageContent::Image(value.clone())
- }
};
assistant_message.content.push(chunk);
}
@@ -432,14 +440,14 @@ pub enum AgentMessageContent {
signature: Option<String>,
},
RedactedThinking(String),
- Image(LanguageModelImage),
ToolUse(LanguageModelToolUse),
}
#[derive(Debug)]
-pub enum AgentResponseEvent {
- Text(String),
- Thinking(String),
+pub enum ThreadEvent {
+ UserMessage(UserMessage),
+ AgentText(String),
+ AgentThinking(String),
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
@@ -504,6 +512,121 @@ impl Thread {
}
}
+ pub fn from_db(
+ id: ThreadId,
+ db_thread: DbThread,
+ project: Entity<Project>,
+ project_context: Rc<RefCell<ProjectContext>>,
+ context_server_registry: Entity<ContextServerRegistry>,
+ action_log: Entity<ActionLog>,
+ templates: Arc<Templates>,
+ model: Arc<dyn LanguageModel>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let profile_id = db_thread
+ .profile
+ .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
+ Self {
+ id,
+ prompt_id: PromptId::new(),
+ messages: db_thread.messages,
+ completion_mode: CompletionMode::Normal,
+ running_turn: None,
+ pending_message: None,
+ tools: BTreeMap::default(),
+ tool_use_limit_reached: false,
+ context_server_registry,
+ profile_id,
+ project_context,
+ templates,
+ model,
+ project,
+ action_log,
+ }
+ }
+
+ pub fn replay(&self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
+ let (tx, rx) = mpsc::unbounded();
+ let stream = ThreadEventStream(tx);
+ for message in &self.messages {
+ match message {
+ Message::User(user_message) => stream.send_user_message(&user_message),
+ Message::Agent(assistant_message) => {
+ for content in &assistant_message.content {
+ match content {
+ AgentMessageContent::Text(text) => stream.send_text(text),
+ AgentMessageContent::Thinking { text, .. } => {
+ stream.send_thinking(text)
+ }
+ AgentMessageContent::RedactedThinking(_) => {}
+ AgentMessageContent::ToolUse(tool_use) => {
+ self.replay_tool_call(
+ tool_use,
+ assistant_message.tool_results.get(&tool_use.id),
+ &stream,
+ cx,
+ );
+ }
+ }
+ }
+ }
+ Message::Resume => {}
+ }
+ }
+ rx
+ }
+
+ fn replay_tool_call(
+ &self,
+ tool_use: &LanguageModelToolUse,
+ tool_result: Option<&LanguageModelToolResult>,
+ stream: &ThreadEventStream,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
+ stream
+ .0
+ .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
+ id: acp::ToolCallId(tool_use.id.to_string().into()),
+ title: tool_use.name.to_string(),
+ kind: acp::ToolKind::Other,
+ status: acp::ToolCallStatus::Failed,
+ content: Vec::new(),
+ locations: Vec::new(),
+ raw_input: Some(tool_use.input.clone()),
+ raw_output: None,
+ })))
+ .ok();
+ return;
+ };
+
+ let title = tool.initial_title(tool_use.input.clone());
+ let kind = tool.kind();
+ stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
+
+ if let Some(output) = tool_result
+ .as_ref()
+ .and_then(|result| result.output.clone())
+ {
+ let tool_event_stream = ToolCallEventStream::new(
+ tool_use.id.clone(),
+ stream.clone(),
+ Some(self.project.read(cx).fs().clone()),
+ );
+ tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
+ .log_err();
+ } else {
+ stream.update_tool_call_fields(
+ &tool_use.id,
+ acp::ToolCallUpdateFields {
+ content: Some(vec![TOOL_CANCELED_MESSAGE.into()]),
+ status: Some(acp::ToolCallStatus::Failed),
+ ..Default::default()
+ },
+ );
+ }
+ }
+
pub fn project(&self) -> &Entity<Project> {
&self.project
}
@@ -574,7 +697,7 @@ impl Thread {
pub fn resume(
&mut self,
cx: &mut Context<Self>,
- ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
+ ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
anyhow::ensure!(
self.tool_use_limit_reached,
"can only resume after tool use limit is reached"
@@ -595,7 +718,7 @@ impl Thread {
id: UserMessageId,
content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>,
- ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
+ ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
where
T: Into<UserMessageContent>,
{
@@ -613,15 +736,12 @@ impl Thread {
self.run_turn(cx)
}
- fn run_turn(
- &mut self,
- cx: &mut Context<Self>,
- ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
+ fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
self.cancel();
let model = self.model.clone();
- let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
- let event_stream = AgentResponseEventStream(events_tx);
+ let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
+ let event_stream = ThreadEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false;
self.running_turn = Some(RunningTurn {
@@ -755,7 +875,7 @@ impl Thread {
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
log::trace!("Handling streamed completion event: {:?}", event);
@@ -797,7 +917,7 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) {
event_stream.send_text(&new_text);
@@ -818,7 +938,7 @@ impl Thread {
&mut self,
new_text: String,
new_signature: Option<String>,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) {
event_stream.send_thinking(&new_text);
@@ -850,7 +970,7 @@ impl Thread {
fn handle_tool_use_event(
&mut self,
tool_use: LanguageModelToolUse,
- event_stream: &AgentResponseEventStream,
+ event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
cx.notify();
@@ -989,9 +1109,7 @@ impl Thread {
tool_use_id: tool_use.id.clone(),
tool_name: tool_use.name.clone(),
is_error: true,
- content: LanguageModelToolResultContent::Text(
- "Tool canceled by user".into(),
- ),
+ content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
output: None,
},
);
@@ -1143,7 +1261,7 @@ struct RunningTurn {
_task: Task<()>,
/// The current event stream for the running turn. Used to report a final
/// cancellation event if we cancel the turn.
- event_stream: AgentResponseEventStream,
+ event_stream: ThreadEventStream,
}
impl RunningTurn {
@@ -1196,6 +1314,17 @@ where
cx: &mut App,
) -> Task<Result<Self::Output>>;
+ /// Emits events for a previous execution of the tool.
+ fn replay(
+ &self,
+ _input: Self::Input,
+ _output: Self::Output,
+ _event_stream: ToolCallEventStream,
+ _cx: &mut App,
+ ) -> Result<()> {
+ Ok(())
+ }
+
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
}
@@ -1223,6 +1352,13 @@ pub trait AnyAgentTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<AgentToolOutput>>;
+ fn replay(
+ &self,
+ input: serde_json::Value,
+ output: serde_json::Value,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Result<()>;
}
impl<T> AnyAgentTool for Erased<Arc<T>>
@@ -1274,21 +1410,39 @@ where
})
})
}
+
+ fn replay(
+ &self,
+ input: serde_json::Value,
+ output: serde_json::Value,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Result<()> {
+ let input = serde_json::from_value(input)?;
+ let output = serde_json::from_value(output)?;
+ self.0.replay(input, output, event_stream, cx)
+ }
}
#[derive(Clone)]
-struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
+struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
+
+impl ThreadEventStream {
+ fn send_user_message(&self, message: &UserMessage) {
+ self.0
+ .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
+ .ok();
+ }
-impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
+ .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
.ok();
}
fn send_thinking(&self, text: &str) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
+ .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
.ok();
}
@@ -1300,7 +1454,7 @@ impl AgentResponseEventStream {
input: serde_json::Value,
) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
+ .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
id,
title.to_string(),
kind,
@@ -1333,7 +1487,7 @@ impl AgentResponseEventStream {
fields: acp::ToolCallUpdateFields,
) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
+ .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.to_string().into()),
fields,
@@ -1347,17 +1501,17 @@ impl AgentResponseEventStream {
match reason {
StopReason::EndTurn => {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
+ .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
.ok();
}
StopReason::MaxTokens => {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
+ .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
.ok();
}
StopReason::Refusal => {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
+ .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
.ok();
}
StopReason::ToolUse => {}
@@ -1366,7 +1520,7 @@ impl AgentResponseEventStream {
fn send_canceled(&self) {
self.0
- .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
+ .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
.ok();
}
@@ -1378,24 +1532,23 @@ impl AgentResponseEventStream {
#[derive(Clone)]
pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId,
- stream: AgentResponseEventStream,
+ stream: ThreadEventStream,
fs: Option<Arc<dyn Fs>>,
}
impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
- let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
+ let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
- let stream =
- ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None);
+ let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
(stream, ToolCallEventStreamReceiver(events_rx))
}
fn new(
tool_use_id: LanguageModelToolUseId,
- stream: AgentResponseEventStream,
+ stream: ThreadEventStream,
fs: Option<Arc<dyn Fs>>,
) -> Self {
Self {
@@ -1413,7 +1566,7 @@ impl ToolCallEventStream {
pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
self.stream
.0
- .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
+ .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
acp_thread::ToolCallUpdateDiff {
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
diff,
@@ -1426,7 +1579,7 @@ impl ToolCallEventStream {
pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
self.stream
.0
- .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
+ .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
acp_thread::ToolCallUpdateTerminal {
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
terminal,
@@ -1444,7 +1597,7 @@ impl ToolCallEventStream {
let (response_tx, response_rx) = oneshot::channel();
self.stream
.0
- .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
+ .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
ToolCallAuthorization {
tool_call: acp::ToolCallUpdate {
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
@@ -1494,13 +1647,13 @@ impl ToolCallEventStream {
}
#[cfg(test)]
-pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
+pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
#[cfg(test)]
impl ToolCallEventStreamReceiver {
pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
let event = self.0.next().await;
- if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
+ if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
auth
} else {
panic!("Expected ToolCallAuthorization but got: {:?}", event);
@@ -1509,9 +1662,9 @@ impl ToolCallEventStreamReceiver {
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
let event = self.0.next().await;
- if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
- acp_thread::ToolCallUpdate::UpdateTerminal(update),
- ))) = event
+ if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
+ update,
+ )))) = event
{
update.terminal
} else {
@@ -1522,7 +1675,7 @@ impl ToolCallEventStreamReceiver {
#[cfg(test)]
impl std::ops::Deref for ToolCallEventStreamReceiver {
- type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
+ type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
fn deref(&self) -> &Self::Target {
&self.0