Cargo.lock 🔗
@@ -158,6 +158,7 @@ dependencies = [
"serde_json",
"settings",
"smol",
+ "strum 0.27.1",
"tempfile",
"ui",
"util",
Ben Brandt created
- **Fix cancellation of tool calls**
- **Make tool_call test more resilient**
- **Fix tool call confirmation test**
Release Notes:
- N/A
Cargo.lock | 1
crates/acp_thread/src/acp_thread.rs | 13 ++
crates/agent_servers/Cargo.toml | 1
crates/agent_servers/src/claude.rs | 14 ++-
crates/agent_servers/src/claude/tools.rs | 99 +++++++++++++++++++++++++
crates/agent_servers/src/e2e_tests.rs | 98 ++++++++++++++++++-------
6 files changed, 188 insertions(+), 38 deletions(-)
@@ -158,6 +158,7 @@ dependencies = [
"serde_json",
"settings",
"smol",
+ "strum 0.27.1",
"tempfile",
"ui",
"util",
@@ -664,7 +664,7 @@ impl AcpThread {
cx: &mut Context<Self>,
) -> Result<ToolCallRequest> {
let project = self.project.read(cx).languages().clone();
- let Some((_, call)) = self.tool_call_mut(tool_call_id) else {
+ let Some((idx, call)) = self.tool_call_mut(tool_call_id) else {
anyhow::bail!("Tool call not found");
};
@@ -675,6 +675,8 @@ impl AcpThread {
respond_tx: tx,
};
+ cx.emit(AcpThreadEvent::EntryUpdated(idx));
+
Ok(ToolCallRequest {
id: tool_call_id,
outcome: rx,
@@ -768,8 +770,13 @@ impl AcpThread {
let language_registry = self.project.read(cx).languages().clone();
let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
- call.content = new_content
- .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx));
+ if let Some(new_content) = new_content {
+ call.content = Some(ToolCallContent::from_acp(
+ new_content,
+ language_registry,
+ cx,
+ ));
+ }
match &mut call.status {
ToolCallStatus::Allowed { status } => {
@@ -33,6 +33,7 @@ serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
+strum.workspace = true
tempfile.workspace = true
ui.workspace = true
util.workspace = true
@@ -281,14 +281,18 @@ impl ClaudeAgentConnection {
} => {
let id = tool_id_map.borrow_mut().remove(&tool_use_id);
if let Some(id) = id {
+ let content = content.to_string();
delegate
.update_tool_call(UpdateToolCallParams {
tool_call_id: id,
status: acp::ToolCallStatus::Finished,
- content: Some(ToolCallContent::Markdown {
- // For now we only include text content
- markdown: content.to_string(),
- }),
+ // Don't unset existing content
+ content: (!content.is_empty()).then_some(
+ ToolCallContent::Markdown {
+ // For now we only include text content
+ markdown: content,
+ },
+ ),
})
.await
.log_err();
@@ -577,7 +581,7 @@ pub(crate) mod tests {
use super::*;
use serde_json::json;
- // crate::common_e2e_tests!(ClaudeCode);
+ crate::common_e2e_tests!(ClaudeCode);
pub fn local_command() -> AgentServerCommand {
AgentServerCommand {
@@ -118,13 +118,106 @@ impl ClaudeTool {
pub fn content(&self) -> Option<acp::ToolCallContent> {
match &self {
- ClaudeTool::Other { input, .. } => Some(acp::ToolCallContent::Markdown {
+ Self::Other { input, .. } => Some(acp::ToolCallContent::Markdown {
markdown: format!(
"```json\n{}```",
serde_json::to_string_pretty(&input).unwrap_or("{}".to_string())
),
}),
- _ => None,
+ Self::Task(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.prompt.clone(),
+ }),
+ Self::NotebookRead(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.notebook_path.display().to_string(),
+ }),
+ Self::NotebookEdit(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.new_source.clone(),
+ }),
+ Self::Terminal(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: format!(
+ "`{}`\n\n{}",
+ params.command,
+ params.description.as_deref().unwrap_or_default()
+ ),
+ }),
+ Self::ReadFile(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.abs_path.display().to_string(),
+ }),
+ Self::Ls(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.path.display().to_string(),
+ }),
+ Self::Glob(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.to_string(),
+ }),
+ Self::Grep(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: format!("`{params}`"),
+ }),
+ Self::WebFetch(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.prompt.clone(),
+ }),
+ Self::WebSearch(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.to_string(),
+ }),
+ Self::TodoWrite(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params
+ .todos
+ .iter()
+ .map(|todo| {
+ format!(
+ "- {} {}: {}",
+ match todo.status {
+ TodoStatus::Completed => "✅",
+ TodoStatus::InProgress => "🚧",
+ TodoStatus::Pending => "⬜",
+ },
+ todo.priority,
+ todo.content
+ )
+ })
+ .join("\n"),
+ }),
+ Self::ExitPlanMode(Some(params)) => Some(acp::ToolCallContent::Markdown {
+ markdown: params.plan.clone(),
+ }),
+ Self::Edit(Some(params)) => Some(acp::ToolCallContent::Diff {
+ diff: acp::Diff {
+ path: params.abs_path.clone(),
+ old_text: Some(params.old_text.clone()),
+ new_text: params.new_text.clone(),
+ },
+ }),
+ Self::Write(Some(params)) => Some(acp::ToolCallContent::Diff {
+ diff: acp::Diff {
+ path: params.file_path.clone(),
+ old_text: None,
+ new_text: params.content.clone(),
+ },
+ }),
+ Self::MultiEdit(Some(params)) => {
+ // todo: show multiple edits in a multibuffer?
+ params.edits.first().map(|edit| acp::ToolCallContent::Diff {
+ diff: acp::Diff {
+ path: params.file_path.clone(),
+ old_text: Some(edit.old_string.clone()),
+ new_text: edit.new_string.clone(),
+ },
+ })
+ }
+ Self::Task(None)
+ | Self::NotebookRead(None)
+ | Self::NotebookEdit(None)
+ | Self::Terminal(None)
+ | Self::ReadFile(None)
+ | Self::Ls(None)
+ | Self::Glob(None)
+ | Self::Grep(None)
+ | Self::WebFetch(None)
+ | Self::WebSearch(None)
+ | Self::TodoWrite(None)
+ | Self::ExitPlanMode(None)
+ | Self::Edit(None)
+ | Self::Write(None)
+ | Self::MultiEdit(None) => None,
}
}
@@ -513,7 +606,7 @@ impl std::fmt::Display for GrepToolParams {
}
}
-#[derive(Deserialize, Serialize, JsonSchema, Debug)]
+#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
#[serde(rename_all = "snake_case")]
pub enum TodoPriority {
High,
@@ -111,18 +111,21 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp
.await
.unwrap();
thread.read_with(cx, |thread, _cx| {
- assert!(matches!(
- &thread.entries()[2],
- AgentThreadEntry::ToolCall(ToolCall {
- status: ToolCallStatus::Allowed { .. },
- ..
- })
- ));
-
- assert!(matches!(
- thread.entries()[3],
- AgentThreadEntry::AssistantMessage(_)
- ));
+ assert!(thread.entries().iter().any(|entry| {
+ matches!(
+ entry,
+ AgentThreadEntry::ToolCall(ToolCall {
+ status: ToolCallStatus::Allowed { .. },
+ ..
+ })
+ )
+ }));
+ assert!(
+ thread
+ .entries()
+ .iter()
+ .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
+ );
});
}
@@ -134,10 +137,26 @@ pub async fn test_tool_call_with_confirmation(
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| {
- thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
+ thread.send_raw(
+ r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
+ cx,
+ )
});
- run_until_first_tool_call(&thread, cx).await;
+ run_until_first_tool_call(
+ &thread,
+ |entry| {
+ matches!(
+ entry,
+ AgentThreadEntry::ToolCall(ToolCall {
+ status: ToolCallStatus::WaitingForConfirmation { .. },
+ ..
+ })
+ )
+ },
+ cx,
+ )
+ .await;
let tool_call_id = thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall {
@@ -148,12 +167,16 @@ pub async fn test_tool_call_with_confirmation(
..
},
..
- }) = &thread.entries()[2]
+ }) = &thread
+ .entries()
+ .iter()
+ .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
+ .unwrap()
else {
panic!();
};
- assert_eq!(root_command, "echo");
+ assert!(root_command.contains("touch"));
*id
});
@@ -161,13 +184,13 @@ pub async fn test_tool_call_with_confirmation(
thread.update(cx, |thread, cx| {
thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
- assert!(matches!(
- &thread.entries()[2],
+ assert!(thread.entries().iter().any(|entry| matches!(
+ entry,
AgentThreadEntry::ToolCall(ToolCall {
status: ToolCallStatus::Allowed { .. },
..
})
- ));
+ )));
});
full_turn.await.unwrap();
@@ -177,15 +200,19 @@ pub async fn test_tool_call_with_confirmation(
content: Some(ToolCallContent::Markdown { markdown }),
status: ToolCallStatus::Allowed { .. },
..
- }) = &thread.entries()[2]
+ }) = thread
+ .entries()
+ .iter()
+ .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
+ .unwrap()
else {
panic!();
};
markdown.read_with(cx, |md, _cx| {
assert!(
- md.source().contains("Hello, world!"),
- r#"Expected '{}' to contain "Hello, world!""#,
+ md.source().contains("Hello"),
+ r#"Expected '{}' to contain "Hello""#,
md.source()
);
});
@@ -198,10 +225,26 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| {
- thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
+ thread.send_raw(
+ r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
+ cx,
+ )
});
- let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
+ let first_tool_call_ix = run_until_first_tool_call(
+ &thread,
+ |entry| {
+ matches!(
+ entry,
+ AgentThreadEntry::ToolCall(ToolCall {
+ status: ToolCallStatus::WaitingForConfirmation { .. },
+ ..
+ })
+ )
+ },
+ cx,
+ )
+ .await;
thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall {
@@ -217,7 +260,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
panic!("{:?}", thread.entries()[1]);
};
- assert_eq!(root_command, "echo");
+ assert!(root_command.contains("touch"));
*id
});
@@ -340,6 +383,7 @@ pub async fn new_test_thread(
pub async fn run_until_first_tool_call(
thread: &Entity<AcpThread>,
+ wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
cx: &mut TestAppContext,
) -> usize {
let (mut tx, mut rx) = mpsc::channel::<usize>(1);
@@ -347,7 +391,7 @@ pub async fn run_until_first_tool_call(
let subscription = cx.update(|cx| {
cx.subscribe(thread, move |thread, _, cx| {
for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
- if matches!(entry, AgentThreadEntry::ToolCall(_)) {
+ if wait_until(entry) {
return tx.try_send(ix).unwrap();
}
}
@@ -357,7 +401,7 @@ pub async fn run_until_first_tool_call(
select! {
// We have to use a smol timer here because
// cx.background_executor().timer isn't real in the test context
- _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
+ _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
panic!("Timeout waiting for tool call")
}
ix = rx.next().fuse() => {