Detailed changes
@@ -158,8 +158,10 @@ dependencies = [
"acp_thread",
"agent-client-protocol",
"agent_servers",
+ "agent_settings",
"anyhow",
"assistant_tool",
+ "assistant_tools",
"client",
"clock",
"cloud_llm_client",
@@ -177,6 +179,8 @@ dependencies = [
"language_model",
"language_models",
"log",
+ "lsp",
+ "paths",
"pretty_assertions",
"project",
"prompt_store",
@@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
-agent-client-protocol = { version = "0.0.23" }
+agent-client-protocol = "0.0.23"
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"
@@ -1,18 +1,17 @@
mod connection;
+mod diff;
+
pub use connection::*;
+pub use diff::*;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
use assistant_tool::ActionLog;
-use buffer_diff::BufferDiff;
-use editor::{Bias, MultiBuffer, PathKey};
+use editor::Bias;
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
use itertools::Itertools;
-use language::{
- Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
- text_diff,
-};
+use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
use markdown::Markdown;
use project::{AgentLocation, Project};
use std::collections::HashMap;
@@ -140,7 +139,7 @@ impl AgentThreadEntry {
}
}
- pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
+ pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
if let AgentThreadEntry::ToolCall(call) = self {
itertools::Either::Left(call.diffs())
} else {
@@ -249,7 +248,7 @@ impl ToolCall {
}
}
- pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
+ pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
self.content.iter().filter_map(|content| match content {
ToolCallContent::ContentBlock { .. } => None,
ToolCallContent::Diff { diff } => Some(diff),
@@ -389,7 +388,7 @@ impl ContentBlock {
#[derive(Debug)]
pub enum ToolCallContent {
ContentBlock { content: ContentBlock },
- Diff { diff: Diff },
+ Diff { diff: Entity<Diff> },
}
impl ToolCallContent {
@@ -403,7 +402,7 @@ impl ToolCallContent {
content: ContentBlock::new(content, &language_registry, cx),
},
acp::ToolCallContent::Diff { diff } => Self::Diff {
- diff: Diff::from_acp(diff, language_registry, cx),
+ diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
},
}
}
@@ -411,108 +410,11 @@ impl ToolCallContent {
pub fn to_markdown(&self, cx: &App) -> String {
match self {
Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
- Self::Diff { diff } => diff.to_markdown(cx),
+ Self::Diff { diff } => diff.read(cx).to_markdown(cx),
}
}
}
-#[derive(Debug)]
-pub struct Diff {
- pub multibuffer: Entity<MultiBuffer>,
- pub path: PathBuf,
- _task: Task<Result<()>>,
-}
-
-impl Diff {
- pub fn from_acp(
- diff: acp::Diff,
- language_registry: Arc<LanguageRegistry>,
- cx: &mut App,
- ) -> Self {
- let acp::Diff {
- path,
- old_text,
- new_text,
- } = diff;
-
- let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
-
- let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
- let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
- let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
- let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
-
- let task = cx.spawn({
- let multibuffer = multibuffer.clone();
- let path = path.clone();
- async move |cx| {
- let language = language_registry
- .language_for_file_path(&path)
- .await
- .log_err();
-
- new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
-
- let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
- buffer.set_language(language, cx);
- buffer.snapshot()
- })?;
-
- buffer_diff
- .update(cx, |diff, cx| {
- diff.set_base_text(
- old_buffer_snapshot,
- Some(language_registry),
- new_buffer_snapshot,
- cx,
- )
- })?
- .await?;
-
- multibuffer
- .update(cx, |multibuffer, cx| {
- let hunk_ranges = {
- let buffer = new_buffer.read(cx);
- let diff = buffer_diff.read(cx);
- diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
- .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
- .collect::<Vec<_>>()
- };
-
- multibuffer.set_excerpts_for_path(
- PathKey::for_buffer(&new_buffer, cx),
- new_buffer.clone(),
- hunk_ranges,
- editor::DEFAULT_MULTIBUFFER_CONTEXT,
- cx,
- );
- multibuffer.add_diff(buffer_diff, cx);
- })
- .log_err();
-
- anyhow::Ok(())
- }
- });
-
- Self {
- multibuffer,
- path,
- _task: task,
- }
- }
-
- fn to_markdown(&self, cx: &App) -> String {
- let buffer_text = self
- .multibuffer
- .read(cx)
- .all_buffers()
- .iter()
- .map(|buffer| buffer.read(cx).text())
- .join("\n");
- format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
- }
-}
-
#[derive(Debug, Default)]
pub struct Plan {
pub entries: Vec<PlanEntry>,
@@ -823,6 +725,21 @@ impl AcpThread {
Ok(())
}
+ pub fn set_tool_call_diff(
+ &mut self,
+ tool_call_id: &acp::ToolCallId,
+ diff: Entity<Diff>,
+ cx: &mut Context<Self>,
+ ) -> Result<()> {
+ let (ix, current_call) = self
+ .tool_call_mut(tool_call_id)
+ .context("Tool call not found")?;
+ current_call.content.clear();
+ current_call.content.push(ToolCallContent::Diff { diff });
+ cx.emit(AcpThreadEvent::EntryUpdated(ix));
+ Ok(())
+ }
+
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
let status = ToolCallStatus::Allowed {
@@ -0,0 +1,388 @@
+use agent_client_protocol as acp;
+use anyhow::Result;
+use buffer_diff::{BufferDiff, BufferDiffSnapshot};
+use editor::{MultiBuffer, PathKey};
+use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task};
+use itertools::Itertools;
+use language::{
+ Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _, Point, Rope, TextBuffer,
+};
+use std::{
+ cmp::Reverse,
+ ops::Range,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+use util::ResultExt;
+
+pub enum Diff {
+ Pending(PendingDiff),
+ Finalized(FinalizedDiff),
+}
+
+impl Diff {
+ pub fn from_acp(
+ diff: acp::Diff,
+ language_registry: Arc<LanguageRegistry>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let acp::Diff {
+ path,
+ old_text,
+ new_text,
+ } = diff;
+
+ let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
+
+ let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
+ let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
+ let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
+ let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
+
+ let task = cx.spawn({
+ let multibuffer = multibuffer.clone();
+ let path = path.clone();
+ async move |_, cx| {
+ let language = language_registry
+ .language_for_file_path(&path)
+ .await
+ .log_err();
+
+ new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
+
+ let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
+ buffer.set_language(language, cx);
+ buffer.snapshot()
+ })?;
+
+ buffer_diff
+ .update(cx, |diff, cx| {
+ diff.set_base_text(
+ old_buffer_snapshot,
+ Some(language_registry),
+ new_buffer_snapshot,
+ cx,
+ )
+ })?
+ .await?;
+
+ multibuffer
+ .update(cx, |multibuffer, cx| {
+ let hunk_ranges = {
+ let buffer = new_buffer.read(cx);
+ let diff = buffer_diff.read(cx);
+ diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
+ .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
+ .collect::<Vec<_>>()
+ };
+
+ multibuffer.set_excerpts_for_path(
+ PathKey::for_buffer(&new_buffer, cx),
+ new_buffer.clone(),
+ hunk_ranges,
+ editor::DEFAULT_MULTIBUFFER_CONTEXT,
+ cx,
+ );
+ multibuffer.add_diff(buffer_diff, cx);
+ })
+ .log_err();
+
+ anyhow::Ok(())
+ }
+ });
+
+ Self::Finalized(FinalizedDiff {
+ multibuffer,
+ path,
+ _update_diff: task,
+ })
+ }
+
+ pub fn new(buffer: Entity<Buffer>, cx: &mut Context<Self>) -> Self {
+ let buffer_snapshot = buffer.read(cx).snapshot();
+ let base_text = buffer_snapshot.text();
+ let language_registry = buffer.read(cx).language_registry();
+ let text_snapshot = buffer.read(cx).text_snapshot();
+ let buffer_diff = cx.new(|cx| {
+ let mut diff = BufferDiff::new(&text_snapshot, cx);
+ let _ = diff.set_base_text(
+ buffer_snapshot.clone(),
+ language_registry,
+ text_snapshot,
+ cx,
+ );
+ diff
+ });
+
+ let multibuffer = cx.new(|cx| {
+ let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly);
+ multibuffer.add_diff(buffer_diff.clone(), cx);
+ multibuffer
+ });
+
+ Self::Pending(PendingDiff {
+ multibuffer,
+ base_text: Arc::new(base_text),
+ _subscription: cx.observe(&buffer, |this, _, cx| {
+ if let Diff::Pending(diff) = this {
+ diff.update(cx);
+ }
+ }),
+ buffer,
+ diff: buffer_diff,
+ revealed_ranges: Vec::new(),
+ update_diff: Task::ready(Ok(())),
+ })
+ }
+
+ pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Self>) {
+ if let Self::Pending(diff) = self {
+ diff.reveal_range(range, cx);
+ }
+ }
+
+ pub fn finalize(&mut self, cx: &mut Context<Self>) {
+ if let Self::Pending(diff) = self {
+ *self = Self::Finalized(diff.finalize(cx));
+ }
+ }
+
+ pub fn multibuffer(&self) -> &Entity<MultiBuffer> {
+ match self {
+ Self::Pending(PendingDiff { multibuffer, .. }) => multibuffer,
+ Self::Finalized(FinalizedDiff { multibuffer, .. }) => multibuffer,
+ }
+ }
+
+ pub fn to_markdown(&self, cx: &App) -> String {
+ let buffer_text = self
+ .multibuffer()
+ .read(cx)
+ .all_buffers()
+ .iter()
+ .map(|buffer| buffer.read(cx).text())
+ .join("\n");
+ let path = match self {
+ Diff::Pending(PendingDiff { buffer, .. }) => {
+ buffer.read(cx).file().map(|file| file.path().as_ref())
+ }
+ Diff::Finalized(FinalizedDiff { path, .. }) => Some(path.as_path()),
+ };
+ format!(
+ "Diff: {}\n```\n{}\n```\n",
+ path.unwrap_or(Path::new("untitled")).display(),
+ buffer_text
+ )
+ }
+}
+
+pub struct PendingDiff {
+ multibuffer: Entity<MultiBuffer>,
+ base_text: Arc<String>,
+ buffer: Entity<Buffer>,
+ diff: Entity<BufferDiff>,
+ revealed_ranges: Vec<Range<Anchor>>,
+ _subscription: Subscription,
+ update_diff: Task<Result<()>>,
+}
+
+impl PendingDiff {
+ pub fn update(&mut self, cx: &mut Context<Diff>) {
+ let buffer = self.buffer.clone();
+ let buffer_diff = self.diff.clone();
+ let base_text = self.base_text.clone();
+ self.update_diff = cx.spawn(async move |diff, cx| {
+ let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?;
+ let diff_snapshot = BufferDiff::update_diff(
+ buffer_diff.clone(),
+ text_snapshot.clone(),
+ Some(base_text),
+ false,
+ false,
+ None,
+ None,
+ cx,
+ )
+ .await?;
+ buffer_diff.update(cx, |diff, cx| {
+ diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
+ })?;
+ diff.update(cx, |diff, cx| {
+ if let Diff::Pending(diff) = diff {
+ diff.update_visible_ranges(cx);
+ }
+ })
+ });
+ }
+
+ pub fn reveal_range(&mut self, range: Range<Anchor>, cx: &mut Context<Diff>) {
+ self.revealed_ranges.push(range);
+ self.update_visible_ranges(cx);
+ }
+
+ fn finalize(&self, cx: &mut Context<Diff>) -> FinalizedDiff {
+ let ranges = self.excerpt_ranges(cx);
+ let base_text = self.base_text.clone();
+ let language_registry = self.buffer.read(cx).language_registry().clone();
+
+ let path = self
+ .buffer
+ .read(cx)
+ .file()
+ .map(|file| file.path().as_ref())
+ .unwrap_or(Path::new("untitled"))
+ .into();
+
+ // Replace the buffer in the multibuffer with the snapshot
+ let buffer = cx.new(|cx| {
+ let language = self.buffer.read(cx).language().cloned();
+ let buffer = TextBuffer::new_normalized(
+ 0,
+ cx.entity_id().as_non_zero_u64().into(),
+ self.buffer.read(cx).line_ending(),
+ self.buffer.read(cx).as_rope().clone(),
+ );
+ let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
+ buffer.set_language(language, cx);
+ buffer
+ });
+
+ let buffer_diff = cx.spawn({
+ let buffer = buffer.clone();
+ let language_registry = language_registry.clone();
+ async move |_this, cx| {
+ build_buffer_diff(base_text, &buffer, language_registry, cx).await
+ }
+ });
+
+ let update_diff = cx.spawn(async move |this, cx| {
+ let buffer_diff = buffer_diff.await?;
+ this.update(cx, |this, cx| {
+ this.multibuffer().update(cx, |multibuffer, cx| {
+ let path_key = PathKey::for_buffer(&buffer, cx);
+ multibuffer.clear(cx);
+ multibuffer.set_excerpts_for_path(
+ path_key,
+ buffer,
+ ranges,
+ editor::DEFAULT_MULTIBUFFER_CONTEXT,
+ cx,
+ );
+ multibuffer.add_diff(buffer_diff.clone(), cx);
+ });
+
+ cx.notify();
+ })
+ });
+
+ FinalizedDiff {
+ path,
+ multibuffer: self.multibuffer.clone(),
+ _update_diff: update_diff,
+ }
+ }
+
+ fn update_visible_ranges(&mut self, cx: &mut Context<Diff>) {
+ let ranges = self.excerpt_ranges(cx);
+ self.multibuffer.update(cx, |multibuffer, cx| {
+ multibuffer.set_excerpts_for_path(
+ PathKey::for_buffer(&self.buffer, cx),
+ self.buffer.clone(),
+ ranges,
+ editor::DEFAULT_MULTIBUFFER_CONTEXT,
+ cx,
+ );
+ let end = multibuffer.len(cx);
+ Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
+ });
+ cx.notify();
+ }
+
+ fn excerpt_ranges(&self, cx: &App) -> Vec<Range<Point>> {
+ let buffer = self.buffer.read(cx);
+ let diff = self.diff.read(cx);
+ let mut ranges = diff
+ .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
+ .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
+ .collect::<Vec<_>>();
+ ranges.extend(
+ self.revealed_ranges
+ .iter()
+ .map(|range| range.to_point(&buffer)),
+ );
+ ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
+
+ // Merge adjacent ranges
+ let mut ranges = ranges.into_iter().peekable();
+ let mut merged_ranges = Vec::new();
+ while let Some(mut range) = ranges.next() {
+ while let Some(next_range) = ranges.peek() {
+ if range.end >= next_range.start {
+ range.end = range.end.max(next_range.end);
+ ranges.next();
+ } else {
+ break;
+ }
+ }
+
+ merged_ranges.push(range);
+ }
+ merged_ranges
+ }
+}
+
+pub struct FinalizedDiff {
+ path: PathBuf,
+ multibuffer: Entity<MultiBuffer>,
+ _update_diff: Task<Result<()>>,
+}
+
+async fn build_buffer_diff(
+ old_text: Arc<String>,
+ buffer: &Entity<Buffer>,
+ language_registry: Option<Arc<LanguageRegistry>>,
+ cx: &mut AsyncApp,
+) -> Result<Entity<BufferDiff>> {
+ let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
+
+ let old_text_rope = cx
+ .background_spawn({
+ let old_text = old_text.clone();
+ async move { Rope::from(old_text.as_str()) }
+ })
+ .await;
+ let base_buffer = cx
+ .update(|cx| {
+ Buffer::build_snapshot(
+ old_text_rope,
+ buffer.language().cloned(),
+ language_registry,
+ cx,
+ )
+ })?
+ .await;
+
+ let diff_snapshot = cx
+ .update(|cx| {
+ BufferDiffSnapshot::new_with_base_buffer(
+ buffer.text.clone(),
+ Some(old_text),
+ base_buffer,
+ cx,
+ )
+ })?
+ .await;
+
+ let secondary_diff = cx.new(|cx| {
+ let mut diff = BufferDiff::new(&buffer, cx);
+ diff.set_snapshot(diff_snapshot.clone(), &buffer, cx);
+ diff
+ })?;
+
+ cx.new(|cx| {
+ let mut diff = BufferDiff::new(&buffer.text, cx);
+ diff.set_snapshot(diff_snapshot, &buffer, cx);
+ diff.set_secondary_diff(secondary_diff);
+ diff
+ })
+}
@@ -15,8 +15,10 @@ workspace = true
acp_thread.workspace = true
agent-client-protocol.workspace = true
agent_servers.workspace = true
+agent_settings.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
+assistant_tools.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
fs.workspace = true
@@ -29,6 +31,7 @@ language.workspace = true
language_model.workspace = true
language_models.workspace = true
log.workspace = true
+paths.workspace = true
project.workspace = true
prompt_store.workspace = true
rust-embed.workspace = true
@@ -53,6 +56,7 @@ gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language = { workspace = true, "features" = ["test-support"] }
language_model = { workspace = true, "features" = ["test-support"] }
+lsp = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }
@@ -1,5 +1,5 @@
use crate::{templates::Templates, AgentResponseEvent, Thread};
-use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
+use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
use acp_thread::ModelSelector;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
@@ -412,11 +412,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
anyhow!("No default model configured. Please configure a default model in settings.")
})?;
- let thread = cx.new(|_| {
+ let thread = cx.new(|cx| {
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
thread.add_tool(ThinkingTool);
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
+ thread.add_tool(EditFileTool::new(cx.entity()));
thread
});
@@ -564,6 +565,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
)
})??;
}
+ AgentResponseEvent::ToolCallDiff(tool_call_diff) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.set_tool_call_diff(
+ &tool_call_diff.tool_call_id,
+ tool_call_diff.diff,
+ cx,
+ )
+ })??;
+ }
AgentResponseEvent::Stop(stop_reason) => {
log::debug!("Assistant message complete: {:?}", stop_reason);
return Ok(acp::PromptResponse { stop_reason });
@@ -9,5 +9,6 @@ mod tests;
pub use agent::*;
pub use native_agent_server::NativeAgentServer;
+pub use templates::*;
pub use thread::*;
pub use tools::*;
@@ -1,5 +1,4 @@
use super::*;
-use crate::templates::Templates;
use acp_thread::AgentConnection;
use agent_client_protocol::{self as acp};
use anyhow::Result;
@@ -273,7 +272,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
- output: None
+ output: Some("Allowed".into())
}),
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
@@ -14,6 +14,7 @@ pub struct EchoTool;
impl AgentTool for EchoTool {
type Input = EchoToolInput;
+ type Output = String;
fn name(&self) -> SharedString {
"echo".into()
@@ -48,6 +49,7 @@ pub struct DelayTool;
impl AgentTool for DelayTool {
type Input = DelayToolInput;
+ type Output = String;
fn name(&self) -> SharedString {
"delay".into()
@@ -84,6 +86,7 @@ pub struct ToolRequiringPermission;
impl AgentTool for ToolRequiringPermission {
type Input = ToolRequiringPermissionInput;
+ type Output = String;
fn name(&self) -> SharedString {
"tool_requiring_permission".into()
@@ -99,14 +102,11 @@ impl AgentTool for ToolRequiringPermission {
fn run(
self: Arc<Self>,
- input: Self::Input,
+ _input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
- ) -> Task<Result<String>>
- where
- Self: Sized,
- {
- let auth_check = self.authorize(input, event_stream);
+ ) -> Task<Result<String>> {
+ let auth_check = event_stream.authorize("Authorize?".into());
cx.foreground_executor().spawn(async move {
auth_check.await?;
Ok("Allowed".to_string())
@@ -121,6 +121,7 @@ pub struct InfiniteTool;
impl AgentTool for InfiniteTool {
type Input = InfiniteToolInput;
+ type Output = String;
fn name(&self) -> SharedString {
"infinite".into()
@@ -171,19 +172,20 @@ pub struct WordListTool;
impl AgentTool for WordListTool {
type Input = WordListInput;
+ type Output = String;
fn name(&self) -> SharedString {
"word_list".into()
}
- fn initial_title(&self, _input: Self::Input) -> SharedString {
- "List of random words".into()
- }
-
fn kind(&self) -> acp::ToolKind {
acp::ToolKind::Other
}
+ fn initial_title(&self, _input: Self::Input) -> SharedString {
+ "List of random words".into()
+ }
+
fn run(
self: Arc<Self>,
_input: Self::Input,
@@ -1,4 +1,5 @@
-use crate::templates::{SystemPromptTemplate, Template, Templates};
+use crate::{SystemPromptTemplate, Template, Templates};
+use acp_thread::Diff;
use agent_client_protocol as acp;
use anyhow::{anyhow, Context as _, Result};
use assistant_tool::{adapt_schema_to_format, ActionLog};
@@ -103,6 +104,7 @@ pub enum AgentResponseEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
+ ToolCallDiff(ToolCallDiff),
Stop(acp::StopReason),
}
@@ -113,6 +115,12 @@ pub struct ToolCallAuthorization {
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
+#[derive(Debug)]
+pub struct ToolCallDiff {
+ pub tool_call_id: acp::ToolCallId,
+ pub diff: Entity<acp_thread::Diff>,
+}
+
pub struct Thread {
messages: Vec<AgentMessage>,
completion_mode: CompletionMode,
@@ -125,12 +133,13 @@ pub struct Thread {
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
pub selected_model: Arc<dyn LanguageModel>,
+ project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl Thread {
pub fn new(
- _project: Entity<Project>,
+ project: Entity<Project>,
project_context: Rc<RefCell<ProjectContext>>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
@@ -145,10 +154,19 @@ impl Thread {
project_context,
templates,
selected_model: default_model,
+ project,
action_log,
}
}
+ pub fn project(&self) -> &Entity<Project> {
+ &self.project
+ }
+
+ pub fn action_log(&self) -> &Entity<ActionLog> {
+ &self.action_log
+ }
+
pub fn set_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@@ -315,10 +333,6 @@ impl Thread {
events_rx
}
- pub fn action_log(&self) -> &Entity<ActionLog> {
- &self.action_log
- }
-
pub fn build_system_message(&self) -> AgentMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
@@ -490,15 +504,33 @@ impl Thread {
}));
};
- let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
+ let tool_event_stream =
+ ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
+ tool_event_stream.send_update(acp::ToolCallUpdateFields {
+ status: Some(acp::ToolCallStatus::InProgress),
+ ..Default::default()
+ });
+ let supports_images = self.selected_model.supports_images();
+ let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
Some(cx.foreground_executor().spawn(async move {
- match tool_result.await {
- Ok(tool_output) => LanguageModelToolResult {
+ let tool_result = tool_result.await.and_then(|output| {
+ if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
+ if !supports_images {
+ return Err(anyhow!(
+ "Attempted to read an image, but this model doesn't support it.",
+ ));
+ }
+ }
+ Ok(output)
+ });
+
+ match tool_result {
+ Ok(output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
- content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
- output: None,
+ content: output.llm_output,
+ output: Some(output.raw_output),
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
@@ -511,24 +543,6 @@ impl Thread {
}))
}
- fn run_tool(
- &self,
- tool: Arc<dyn AnyAgentTool>,
- tool_use: LanguageModelToolUse,
- event_stream: AgentResponseEventStream,
- cx: &mut Context<Self>,
- ) -> Task<Result<String>> {
- cx.spawn(async move |_this, cx| {
- let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
- tool_event_stream.send_update(acp::ToolCallUpdateFields {
- status: Some(acp::ToolCallStatus::InProgress),
- ..Default::default()
- });
- cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
- .await
- })
- }
-
fn handle_tool_use_json_parse_error_event(
&mut self,
tool_use_id: LanguageModelToolUseId,
@@ -572,7 +586,7 @@ impl Thread {
self.messages.last_mut().unwrap()
}
- fn build_completion_request(
+ pub(crate) fn build_completion_request(
&self,
completion_intent: CompletionIntent,
cx: &mut App,
@@ -662,6 +676,7 @@ where
Self: 'static + Sized,
{
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
+ type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
fn name(&self) -> SharedString;
@@ -685,23 +700,13 @@ where
schemars::schema_for!(Self::Input)
}
- /// Allows the tool to authorize a given tool call with the user if necessary
- fn authorize(
- &self,
- input: Self::Input,
- event_stream: ToolCallEventStream,
- ) -> impl use<Self> + Future<Output = Result<()>> {
- let json_input = serde_json::json!(&input);
- event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
- }
-
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
- ) -> Task<Result<String>>;
+ ) -> Task<Result<Self::Output>>;
fn erase(self) -> Arc<dyn AnyAgentTool> {
Arc::new(Erased(Arc::new(self)))
@@ -710,6 +715,11 @@ where
pub struct Erased<T>(T);
+pub struct AgentToolOutput {
+ llm_output: LanguageModelToolResultContent,
+ raw_output: serde_json::Value,
+}
+
pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self, cx: &mut App) -> SharedString;
@@ -721,7 +731,7 @@ pub trait AnyAgentTool {
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
- ) -> Task<Result<String>>;
+ ) -> Task<Result<AgentToolOutput>>;
}
impl<T> AnyAgentTool for Erased<Arc<T>>
@@ -756,12 +766,18 @@ where
input: serde_json::Value,
event_stream: ToolCallEventStream,
cx: &mut App,
- ) -> Task<Result<String>> {
- let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
- match parsed_input {
- Ok(input) => self.0.clone().run(input, event_stream, cx),
- Err(error) => Task::ready(Err(anyhow!(error))),
- }
+ ) -> Task<Result<AgentToolOutput>> {
+ cx.spawn(async move |cx| {
+ let input = serde_json::from_value(input)?;
+ let output = cx
+ .update(|cx| self.0.clone().run(input, event_stream, cx))?
+ .await?;
+ let raw_output = serde_json::to_value(&output)?;
+ Ok(AgentToolOutput {
+ llm_output: output.into(),
+ raw_output,
+ })
+ })
}
}
@@ -874,6 +890,12 @@ impl AgentResponseEventStream {
.ok();
}
+ fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
+ self.0
+ .unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
+ .ok();
+ }
+
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
@@ -903,13 +925,41 @@ impl AgentResponseEventStream {
#[derive(Clone)]
pub struct ToolCallEventStream {
tool_use_id: LanguageModelToolUseId,
+ kind: acp::ToolKind,
+ input: serde_json::Value,
stream: AgentResponseEventStream,
}
impl ToolCallEventStream {
- fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
+ #[cfg(test)]
+ pub fn test() -> (Self, ToolCallEventStreamReceiver) {
+ let (events_tx, events_rx) =
+ mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
+
+ let stream = ToolCallEventStream::new(
+ &LanguageModelToolUse {
+ id: "test_id".into(),
+ name: "test_tool".into(),
+ raw_input: String::new(),
+ input: serde_json::Value::Null,
+ is_input_complete: true,
+ },
+ acp::ToolKind::Other,
+ AgentResponseEventStream(events_tx),
+ );
+
+ (stream, ToolCallEventStreamReceiver(events_rx))
+ }
+
+ fn new(
+ tool_use: &LanguageModelToolUse,
+ kind: acp::ToolKind,
+ stream: AgentResponseEventStream,
+ ) -> Self {
Self {
- tool_use_id,
+ tool_use_id: tool_use.id.clone(),
+ kind,
+ input: tool_use.input.clone(),
stream,
}
}
@@ -918,38 +968,52 @@ impl ToolCallEventStream {
self.stream.send_tool_call_update(&self.tool_use_id, fields);
}
- pub fn authorize(
- &self,
- title: String,
- kind: acp::ToolKind,
- input: serde_json::Value,
- ) -> impl use<> + Future<Output = Result<()>> {
- self.stream
- .authorize_tool_call(&self.tool_use_id, title, kind, input)
+ pub fn send_diff(&self, diff: Entity<Diff>) {
+ self.stream.send_tool_call_diff(ToolCallDiff {
+ tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
+ diff,
+ });
+ }
+
+ pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
+ self.stream.authorize_tool_call(
+ &self.tool_use_id,
+ title,
+ self.kind.clone(),
+ self.input.clone(),
+ )
}
}
#[cfg(test)]
-pub struct TestToolCallEventStream {
- stream: ToolCallEventStream,
- _events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
-}
+pub struct ToolCallEventStreamReceiver(
+ mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
+);
#[cfg(test)]
-impl TestToolCallEventStream {
- pub fn new() -> Self {
- let (events_tx, events_rx) =
- mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
+impl ToolCallEventStreamReceiver {
+ pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
+ let event = self.0.next().await;
+ if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
+ auth
+ } else {
+ panic!("Expected ToolCallAuthorization but got: {:?}", event);
+ }
+ }
+}
- let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
+#[cfg(test)]
+impl std::ops::Deref for ToolCallEventStreamReceiver {
+ type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
- Self {
- stream,
- _events_rx: events_rx,
- }
+ fn deref(&self) -> &Self::Target {
+ &self.0
}
+}
- pub fn stream(&self) -> ToolCallEventStream {
- self.stream.clone()
+#[cfg(test)]
+impl std::ops::DerefMut for ToolCallEventStreamReceiver {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
}
}
@@ -1,7 +1,9 @@
+mod edit_file_tool;
mod find_path_tool;
mod read_file_tool;
mod thinking_tool;
+pub use edit_file_tool::*;
pub use find_path_tool::*;
pub use read_file_tool::*;
pub use thinking_tool::*;
@@ -0,0 +1,1361 @@
+use acp_thread::Diff;
+use agent_client_protocol as acp;
+use anyhow::{anyhow, Context as _, Result};
+use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
+use cloud_llm_client::CompletionIntent;
+use collections::HashSet;
+use gpui::{App, AppContext, AsyncApp, Entity, Task};
+use indoc::formatdoc;
+use language::language_settings::{self, FormatOnSave};
+use language_model::LanguageModelToolResultContent;
+use paths;
+use project::lsp_store::{FormatTrigger, LspFormatTarget};
+use project::{Project, ProjectPath};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::Settings;
+use smol::stream::StreamExt as _;
+use std::path::{Path, PathBuf};
+use std::sync::Arc;
+use ui::SharedString;
+use util::ResultExt;
+
+use crate::{AgentTool, Thread, ToolCallEventStream};
+
+/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
+///
+/// Before using this tool:
+///
+/// 1. Use the `read_file` tool to understand the file's contents and context
+///
+/// 2. Verify the directory path is correct (only applicable when creating new files):
+/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct EditFileToolInput {
+ /// A one-line, user-friendly markdown description of the edit. This will be
+ /// shown in the UI and also passed to another model to perform the edit.
+ ///
+ /// Be terse, but also descriptive in what you want to achieve with this
+ /// edit. Avoid generic instructions.
+ ///
+ /// NEVER mention the file path in this description.
+ ///
+ /// <example>Fix API endpoint URLs</example>
+ /// <example>Update copyright year in `page_footer`</example>
+ ///
+ /// Make sure to include this field before all the others in the input object
+ /// so that we can display it immediately.
+ pub display_description: String,
+
+ /// The full path of the file to create or modify in the project.
+ ///
+ /// WARNING: When specifying which file path need changing, you MUST
+ /// start each path with one of the project's root directories.
+ ///
+ /// The following examples assume we have two root directories in the project:
+ /// - /a/b/backend
+ /// - /c/d/frontend
+ ///
+ /// <example>
+ /// `backend/src/main.rs`
+ ///
+ /// Notice how the file path starts with `backend`. Without that, the path
+ /// would be ambiguous and the call would fail!
+ /// </example>
+ ///
+ /// <example>
+ /// `frontend/db.js`
+ /// </example>
+ pub path: PathBuf,
+
+ /// The mode of operation on the file. Possible values:
+ /// - 'edit': Make granular edits to an existing file.
+ /// - 'create': Create a new file if it doesn't exist.
+ /// - 'overwrite': Replace the entire contents of an existing file.
+ ///
+ /// When a file already exists or you just created it, prefer editing
+ /// it as opposed to recreating it from scratch.
+ pub mode: EditFileMode,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+#[serde(rename_all = "lowercase")]
+pub enum EditFileMode {
+ Edit,
+ Create,
+ Overwrite,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct EditFileToolOutput {
+ input_path: PathBuf,
+ project_path: PathBuf,
+ new_text: String,
+ old_text: Arc<String>,
+ diff: String,
+ edit_agent_output: EditAgentOutput,
+}
+
+impl From<EditFileToolOutput> for LanguageModelToolResultContent {
+ fn from(output: EditFileToolOutput) -> Self {
+ if output.diff.is_empty() {
+ "No edits were made.".into()
+ } else {
+ format!(
+ "Edited {}:\n\n```diff\n{}\n```",
+ output.input_path.display(),
+ output.diff
+ )
+ .into()
+ }
+ }
+}
+
+pub struct EditFileTool {
+ thread: Entity<Thread>,
+}
+
+impl EditFileTool {
+ pub fn new(thread: Entity<Thread>) -> Self {
+ Self { thread }
+ }
+
+ fn authorize(
+ &self,
+ input: &EditFileToolInput,
+ event_stream: &ToolCallEventStream,
+ cx: &App,
+ ) -> Task<Result<()>> {
+ if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
+ return Task::ready(Ok(()));
+ }
+
+ // If any path component matches the local settings folder, then this could affect
+ // the editor in ways beyond the project source, so prompt.
+ let local_settings_folder = paths::local_settings_folder_relative_path();
+ let path = Path::new(&input.path);
+ if path
+ .components()
+ .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
+ {
+ return cx.foreground_executor().spawn(
+ event_stream.authorize(format!("{} (local settings)", input.display_description)),
+ );
+ }
+
+ // It's also possible that the global config dir is configured to be inside the project,
+ // so check for that edge case too.
+ if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
+ if canonical_path.starts_with(paths::config_dir()) {
+ return cx.foreground_executor().spawn(
+ event_stream
+ .authorize(format!("{} (global settings)", input.display_description)),
+ );
+ }
+ }
+
+ // Check if path is inside the global config directory
+ // First check if it's already inside project - if not, try to canonicalize
+ let thread = self.thread.read(cx);
+ let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
+
+ // If the path is inside the project, and it's not one of the above edge cases,
+ // then no confirmation is necessary. Otherwise, confirmation is necessary.
+ if project_path.is_some() {
+ Task::ready(Ok(()))
+ } else {
+ cx.foreground_executor()
+ .spawn(event_stream.authorize(input.display_description.clone()))
+ }
+ }
+}
+
+impl AgentTool for EditFileTool {
+ type Input = EditFileToolInput;
+ type Output = EditFileToolOutput;
+
+ fn name(&self) -> SharedString {
+ "edit_file".into()
+ }
+
+ fn kind(&self) -> acp::ToolKind {
+ acp::ToolKind::Edit
+ }
+
+ fn initial_title(&self, input: Self::Input) -> SharedString {
+ input.display_description.into()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ input: Self::Input,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<Self::Output>> {
+ let project = self.thread.read(cx).project().clone();
+ let project_path = match resolve_path(&input, project.clone(), cx) {
+ Ok(path) => path,
+ Err(err) => return Task::ready(Err(anyhow!(err))),
+ };
+
+ let request = self.thread.update(cx, |thread, cx| {
+ thread.build_completion_request(CompletionIntent::ToolResults, cx)
+ });
+ let thread = self.thread.read(cx);
+ let model = thread.selected_model.clone();
+ let action_log = thread.action_log().clone();
+
+ let authorize = self.authorize(&input, &event_stream, cx);
+ cx.spawn(async move |cx: &mut AsyncApp| {
+ authorize.await?;
+
+ let edit_format = EditFormat::from_model(model.clone())?;
+ let edit_agent = EditAgent::new(
+ model,
+ project.clone(),
+ action_log.clone(),
+ // TODO: move edit agent to this crate so we can use our templates
+ assistant_tools::templates::Templates::new(),
+ edit_format,
+ );
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(project_path.clone(), cx)
+ })?
+ .await?;
+
+ let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
+ event_stream.send_diff(diff.clone());
+
+ let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+ let old_text = cx
+ .background_spawn({
+ let old_snapshot = old_snapshot.clone();
+ async move { Arc::new(old_snapshot.text()) }
+ })
+ .await;
+
+
+ let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
+ edit_agent.edit(
+ buffer.clone(),
+ input.display_description.clone(),
+ &request,
+ cx,
+ )
+ } else {
+ edit_agent.overwrite(
+ buffer.clone(),
+ input.display_description.clone(),
+ &request,
+ cx,
+ )
+ };
+
+ let mut hallucinated_old_text = false;
+ let mut ambiguous_ranges = Vec::new();
+ while let Some(event) = events.next().await {
+ match event {
+ EditAgentOutputEvent::Edited => {},
+ EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
+ EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
+ EditAgentOutputEvent::ResolvingEditRange(range) => {
+ diff.update(cx, |card, cx| card.reveal_range(range, cx))?;
+ }
+ }
+ }
+
+ // If format_on_save is enabled, format the buffer
+ let format_on_save_enabled = buffer
+ .read_with(cx, |buffer, cx| {
+ let settings = language_settings::language_settings(
+ buffer.language().map(|l| l.name()),
+ buffer.file(),
+ cx,
+ );
+ settings.format_on_save != FormatOnSave::Off
+ })
+ .unwrap_or(false);
+
+ let edit_agent_output = output.await?;
+
+ if format_on_save_enabled {
+ action_log.update(cx, |log, cx| {
+ log.buffer_edited(buffer.clone(), cx);
+ })?;
+
+ let format_task = project.update(cx, |project, cx| {
+ project.format(
+ HashSet::from_iter([buffer.clone()]),
+ LspFormatTarget::Buffers,
+ false, // Don't push to history since the tool did it.
+ FormatTrigger::Save,
+ cx,
+ )
+ })?;
+ format_task.await.log_err();
+ }
+
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
+ .await?;
+
+ action_log.update(cx, |log, cx| {
+ log.buffer_edited(buffer.clone(), cx);
+ })?;
+
+ let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+ let (new_text, unified_diff) = cx
+ .background_spawn({
+ let new_snapshot = new_snapshot.clone();
+ let old_text = old_text.clone();
+ async move {
+ let new_text = new_snapshot.text();
+ let diff = language::unified_diff(&old_text, &new_text);
+ (new_text, diff)
+ }
+ })
+ .await;
+
+ diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
+
+ let input_path = input.path.display();
+ if unified_diff.is_empty() {
+ anyhow::ensure!(
+ !hallucinated_old_text,
+ formatdoc! {"
+ Some edits were produced but none of them could be applied.
+ Read the relevant sections of {input_path} again so that
+ I can perform the requested edits.
+ "}
+ );
+ anyhow::ensure!(
+ ambiguous_ranges.is_empty(),
+ {
+ let line_numbers = ambiguous_ranges
+ .iter()
+ .map(|range| range.start.to_string())
+ .collect::<Vec<_>>()
+ .join(", ");
+ formatdoc! {"
+ <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
+ relevant sections of {input_path} again and extend <old_text> so
+ that I can perform the requested edits.
+ "}
+ }
+ );
+ }
+
+ Ok(EditFileToolOutput {
+ input_path: input.path,
+ project_path: project_path.path.to_path_buf(),
+ new_text: new_text.clone(),
+ old_text,
+ diff: unified_diff,
+ edit_agent_output,
+ })
+ })
+ }
+}
+
+/// Validate that the file path is valid, meaning:
+///
+/// - For `edit` and `overwrite`, the path must point to an existing file.
+/// - For `create`, the file must not already exist, but it's parent dir must exist.
+fn resolve_path(
+ input: &EditFileToolInput,
+ project: Entity<Project>,
+ cx: &mut App,
+) -> Result<ProjectPath> {
+ let project = project.read(cx);
+
+ match input.mode {
+ EditFileMode::Edit | EditFileMode::Overwrite => {
+ let path = project
+ .find_project_path(&input.path, cx)
+ .context("Can't edit file: path not found")?;
+
+ let entry = project
+ .entry_for_path(&path, cx)
+ .context("Can't edit file: path not found")?;
+
+ anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
+ Ok(path)
+ }
+
+ EditFileMode::Create => {
+ if let Some(path) = project.find_project_path(&input.path, cx) {
+ anyhow::ensure!(
+ project.entry_for_path(&path, cx).is_none(),
+ "Can't create file: file already exists"
+ );
+ }
+
+ let parent_path = input
+ .path
+ .parent()
+ .context("Can't create file: incorrect path")?;
+
+ let parent_project_path = project.find_project_path(&parent_path, cx);
+
+ let parent_entry = parent_project_path
+ .as_ref()
+ .and_then(|path| project.entry_for_path(&path, cx))
+ .context("Can't create file: parent directory doesn't exist")?;
+
+ anyhow::ensure!(
+ parent_entry.is_dir(),
+ "Can't create file: parent is not a directory"
+ );
+
+ let file_name = input
+ .path
+ .file_name()
+ .context("Can't create file: invalid filename")?;
+
+ let new_file_path = parent_project_path.map(|parent| ProjectPath {
+ path: Arc::from(parent.path.join(file_name)),
+ ..parent
+ });
+
+ new_file_path.context("Can't create file")
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::Templates;
+
+ use super::*;
+ use assistant_tool::ActionLog;
+ use client::TelemetrySettings;
+ use fs::Fs;
+ use gpui::{TestAppContext, UpdateGlobal};
+ use language_model::fake_provider::FakeLanguageModel;
+ use serde_json::json;
+ use settings::SettingsStore;
+ use std::rc::Rc;
+ use util::path;
+
+ #[gpui::test]
+ async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({})).await;
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread =
+ cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
+ let result = cx
+ .update(|cx| {
+ let input = EditFileToolInput {
+ display_description: "Some edit".into(),
+ path: "root/nonexistent_file.txt".into(),
+ mode: EditFileMode::Edit,
+ };
+ Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+ })
+ .await;
+ assert_eq!(
+ result.unwrap_err().to_string(),
+ "Can't edit file: path not found"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
+ let mode = &EditFileMode::Create;
+
+ let result = test_resolve_path(mode, "root/new.txt", cx);
+ assert_resolved_path_eq(result.await, "new.txt");
+
+ let result = test_resolve_path(mode, "new.txt", cx);
+ assert_resolved_path_eq(result.await, "new.txt");
+
+ let result = test_resolve_path(mode, "dir/new.txt", cx);
+ assert_resolved_path_eq(result.await, "dir/new.txt");
+
+ let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
+ assert_eq!(
+ result.await.unwrap_err().to_string(),
+ "Can't create file: file already exists"
+ );
+
+ let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
+ assert_eq!(
+ result.await.unwrap_err().to_string(),
+ "Can't create file: parent directory doesn't exist"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
+ let mode = &EditFileMode::Edit;
+
+ let path_with_root = "root/dir/subdir/existing.txt";
+ let path_without_root = "dir/subdir/existing.txt";
+ let result = test_resolve_path(mode, path_with_root, cx);
+ assert_resolved_path_eq(result.await, path_without_root);
+
+ let result = test_resolve_path(mode, path_without_root, cx);
+ assert_resolved_path_eq(result.await, path_without_root);
+
+ let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
+ assert_eq!(
+ result.await.unwrap_err().to_string(),
+ "Can't edit file: path not found"
+ );
+
+ let result = test_resolve_path(mode, "root/dir", cx);
+ assert_eq!(
+ result.await.unwrap_err().to_string(),
+ "Can't edit file: path is a directory"
+ );
+ }
+
+ async fn test_resolve_path(
+ mode: &EditFileMode,
+ path: &str,
+ cx: &mut TestAppContext,
+ ) -> anyhow::Result<ProjectPath> {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "dir": {
+ "subdir": {
+ "existing.txt": "hello"
+ }
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+
+ let input = EditFileToolInput {
+ display_description: "Some edit".into(),
+ path: path.into(),
+ mode: mode.clone(),
+ };
+
+ let result = cx.update(|cx| resolve_path(&input, project, cx));
+ result
+ }
+
+ fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
+ let actual = path
+ .expect("Should return valid path")
+ .path
+ .to_str()
+ .unwrap()
+ .replace("\\", "/"); // Naive Windows paths normalization
+ assert_eq!(actual, expected);
+ }
+
+ #[gpui::test]
+ async fn test_format_on_save(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({"src": {}})).await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+
+ // Set up a Rust language with LSP formatting support
+ let rust_language = Arc::new(language::Language::new(
+ language::LanguageConfig {
+ name: "Rust".into(),
+ matcher: language::LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ None,
+ ));
+
+ // Register the language and fake LSP
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_language);
+
+ let mut fake_language_servers = language_registry.register_fake_lsp(
+ "Rust",
+ language::FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ document_formatting_provider: Some(lsp::OneOf::Left(true)),
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ );
+
+ // Create the file
+ fs.save(
+ path!("/root/src/main.rs").as_ref(),
+ &"initial content".into(),
+ language::LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+
+ // Open the buffer to trigger LSP initialization
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/root/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ // Register the buffer with language servers
+ let _handle = project.update(cx, |project, cx| {
+ project.register_buffer_with_language_servers(&buffer, cx)
+ });
+
+ const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
+ const FORMATTED_CONTENT: &str =
+ "This file was formatted by the fake formatter in the test.\n";
+
+ // Get the fake language server and set up formatting handler
+ let fake_language_server = fake_language_servers.next().await.unwrap();
+ fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
+ |_, _| async move {
+ Ok(Some(vec![lsp::TextEdit {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
+ new_text: FORMATTED_CONTENT.to_string(),
+ }]))
+ }
+ });
+
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|_| {
+ Thread::new(
+ project,
+ Rc::default(),
+ action_log.clone(),
+ Templates::new(),
+ model.clone(),
+ )
+ });
+
+ // First, test with format_on_save enabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+ cx,
+ |settings| {
+ settings.defaults.format_on_save = Some(FormatOnSave::On);
+ settings.defaults.formatter =
+ Some(language::language_settings::SelectedFormatter::Auto);
+ },
+ );
+ });
+ });
+
+ // Have the model stream unformatted content
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = EditFileToolInput {
+ display_description: "Create main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ };
+ Arc::new(EditFileTool {
+ thread: thread.clone(),
+ })
+ .run(input, ToolCallEventStream::test().0, cx)
+ });
+
+ // Stream the unformatted content
+ cx.executor().run_until_parked();
+ model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Read the file to verify it was formatted automatically
+ let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ new_content.replace("\r\n", "\n"),
+ FORMATTED_CONTENT,
+ "Code should be formatted when format_on_save is enabled"
+ );
+
+ let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count());
+
+ assert_eq!(
+ stale_buffer_count, 0,
+ "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
+ This causes the agent to think the file was modified externally when it was just formatted.",
+ stale_buffer_count
+ );
+
+ // Next, test with format_on_save disabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+ cx,
+ |settings| {
+ settings.defaults.format_on_save = Some(FormatOnSave::Off);
+ },
+ );
+ });
+ });
+
+ // Stream unformatted edits again
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = EditFileToolInput {
+ display_description: "Update main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ };
+ Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+ });
+
+ // Stream the unformatted content
+ cx.executor().run_until_parked();
+ model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Verify the file was not formatted
+ let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ new_content.replace("\r\n", "\n"),
+ UNFORMATTED_CONTENT,
+ "Code should not be formatted when format_on_save is disabled"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({"src": {}})).await;
+
+ // Create a simple file with trailing whitespace
+ fs.save(
+ path!("/root/src/main.rs").as_ref(),
+ &"initial content".into(),
+ language::LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|_| {
+ Thread::new(
+ project,
+ Rc::default(),
+ action_log.clone(),
+ Templates::new(),
+ model.clone(),
+ )
+ });
+
+ // First, test with remove_trailing_whitespace_on_save enabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+ cx,
+ |settings| {
+ settings.defaults.remove_trailing_whitespace_on_save = Some(true);
+ },
+ );
+ });
+ });
+
+ const CONTENT_WITH_TRAILING_WHITESPACE: &str =
+ "fn main() { \n println!(\"Hello!\"); \n}\n";
+
+ // Have the model stream content that contains trailing whitespace
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = EditFileToolInput {
+ display_description: "Create main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ };
+ Arc::new(EditFileTool {
+ thread: thread.clone(),
+ })
+ .run(input, ToolCallEventStream::test().0, cx)
+ });
+
+ // Stream the content with trailing whitespace
+ cx.executor().run_until_parked();
+ model.send_last_completion_stream_text_chunk(
+ CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
+ );
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Read the file to verify trailing whitespace was removed automatically
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ fs.load(path!("/root/src/main.rs").as_ref())
+ .await
+ .unwrap()
+ .replace("\r\n", "\n"),
+ "fn main() {\n println!(\"Hello!\");\n}\n",
+ "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
+ );
+
+ // Next, test with remove_trailing_whitespace_on_save disabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+ cx,
+ |settings| {
+ settings.defaults.remove_trailing_whitespace_on_save = Some(false);
+ },
+ );
+ });
+ });
+
+ // Stream edits again with trailing whitespace
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = EditFileToolInput {
+ display_description: "Update main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ };
+ Arc::new(EditFileTool {
+ thread: thread.clone(),
+ })
+ .run(input, ToolCallEventStream::test().0, cx)
+ });
+
+ // Stream the content with trailing whitespace
+ cx.executor().run_until_parked();
+ model.send_last_completion_stream_text_chunk(
+ CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
+ );
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Verify the file still has trailing whitespace
+ // Read the file again - it should still have trailing whitespace
+ let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ final_content.replace("\r\n", "\n"),
+ CONTENT_WITH_TRAILING_WHITESPACE,
+ "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_authorize(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|_| {
+ Thread::new(
+ project,
+ Rc::default(),
+ action_log.clone(),
+ Templates::new(),
+ model.clone(),
+ )
+ });
+ let tool = Arc::new(EditFileTool { thread });
+ fs.insert_tree("/root", json!({})).await;
+
+ // Test 1: Path with .zed component should require confirmation
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let _auth = cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "test 1".into(),
+ path: ".zed/settings.json".into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ });
+
+ let event = stream_rx.expect_tool_authorization().await;
+ assert_eq!(event.tool_call.title, "test 1 (local settings)");
+
+ // Test 2: Path outside project should require confirmation
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let _auth = cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "test 2".into(),
+ path: "/etc/hosts".into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ });
+
+ let event = stream_rx.expect_tool_authorization().await;
+ assert_eq!(event.tool_call.title, "test 2");
+
+ // Test 3: Relative path without .zed should not require confirmation
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "test 3".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+ assert!(stream_rx.try_next().is_err());
+
+ // Test 4: Path with .zed in the middle should require confirmation
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let _auth = cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "test 4".into(),
+ path: "root/.zed/tasks.json".into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ });
+ let event = stream_rx.expect_tool_authorization().await;
+ assert_eq!(event.tool_call.title, "test 4 (local settings)");
+
+ // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.always_allow_tool_actions = true;
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "test 5.1".into(),
+ path: ".zed/settings.json".into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+ assert!(stream_rx.try_next().is_err());
+
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "test 5.2".into(),
+ path: "/etc/hosts".into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+ assert!(stream_rx.try_next().is_err());
+ }
+
+ #[gpui::test]
+ async fn test_authorize_global_config(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/project", json!({})).await;
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|_| {
+ Thread::new(
+ project,
+ Rc::default(),
+ action_log.clone(),
+ Templates::new(),
+ model.clone(),
+ )
+ });
+ let tool = Arc::new(EditFileTool { thread });
+
+ // Test global config paths - these should require confirmation if they exist and are outside the project
+ let test_cases = vec![
+ (
+ "/etc/hosts",
+ true,
+ "System file should require confirmation",
+ ),
+ (
+ "/usr/local/bin/script",
+ true,
+ "System bin file should require confirmation",
+ ),
+ (
+ "project/normal_file.rs",
+ false,
+ "Normal project file should not require confirmation",
+ ),
+ ];
+
+ for (path, should_confirm, description) in test_cases {
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let auth = cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "Edit file".into(),
+ path: path.into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ });
+
+ if should_confirm {
+ stream_rx.expect_tool_authorization().await;
+ } else {
+ auth.await.unwrap();
+ assert!(
+ stream_rx.try_next().is_err(),
+ "Failed for case: {} - path: {} - expected no confirmation but got one",
+ description,
+ path
+ );
+ }
+ }
+ }
+
+ #[gpui::test]
+ async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+
+ // Create multiple worktree directories
+ fs.insert_tree(
+ "/workspace/frontend",
+ json!({
+ "src": {
+ "main.js": "console.log('frontend');"
+ }
+ }),
+ )
+ .await;
+ fs.insert_tree(
+ "/workspace/backend",
+ json!({
+ "src": {
+ "main.rs": "fn main() {}"
+ }
+ }),
+ )
+ .await;
+ fs.insert_tree(
+ "/workspace/shared",
+ json!({
+ ".zed": {
+ "settings.json": "{}"
+ }
+ }),
+ )
+ .await;
+
+ // Create project with multiple worktrees
+ let project = Project::test(
+ fs.clone(),
+ [
+ path!("/workspace/frontend").as_ref(),
+ path!("/workspace/backend").as_ref(),
+ path!("/workspace/shared").as_ref(),
+ ],
+ cx,
+ )
+ .await;
+
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|_| {
+ Thread::new(
+ project.clone(),
+ Rc::default(),
+ action_log.clone(),
+ Templates::new(),
+ model.clone(),
+ )
+ });
+ let tool = Arc::new(EditFileTool { thread });
+
+ // Test files in different worktrees
+ let test_cases = vec![
+ ("frontend/src/main.js", false, "File in first worktree"),
+ ("backend/src/main.rs", false, "File in second worktree"),
+ (
+ "shared/.zed/settings.json",
+ true,
+ ".zed file in third worktree",
+ ),
+ ("/etc/hosts", true, "Absolute path outside all worktrees"),
+ (
+ "../outside/file.txt",
+ true,
+ "Relative path outside worktrees",
+ ),
+ ];
+
+ for (path, should_confirm, description) in test_cases {
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let auth = cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "Edit file".into(),
+ path: path.into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ });
+
+ if should_confirm {
+ stream_rx.expect_tool_authorization().await;
+ } else {
+ auth.await.unwrap();
+ assert!(
+ stream_rx.try_next().is_err(),
+ "Failed for case: {} - path: {} - expected no confirmation but got one",
+ description,
+ path
+ );
+ }
+ }
+ }
+
+ #[gpui::test]
+ async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/project",
+ json!({
+ ".zed": {
+ "settings.json": "{}"
+ },
+ "src": {
+ ".zed": {
+ "local.json": "{}"
+ }
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|_| {
+ Thread::new(
+ project.clone(),
+ Rc::default(),
+ action_log.clone(),
+ Templates::new(),
+ model.clone(),
+ )
+ });
+ let tool = Arc::new(EditFileTool { thread });
+
+ // Test edge cases
+ let test_cases = vec![
+ // Empty path - find_project_path returns Some for empty paths
+ ("", false, "Empty path is treated as project root"),
+ // Root directory
+ ("/", true, "Root directory should be outside project"),
+ // Parent directory references - find_project_path resolves these
+ (
+ "project/../other",
+ false,
+ "Path with .. is resolved by find_project_path",
+ ),
+ (
+ "project/./src/file.rs",
+ false,
+ "Path with . should work normally",
+ ),
+ // Windows-style paths (if on Windows)
+ #[cfg(target_os = "windows")]
+ ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
+ #[cfg(target_os = "windows")]
+ ("project\\src\\main.rs", false, "Windows-style project path"),
+ ];
+
+ for (path, should_confirm, description) in test_cases {
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let auth = cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "Edit file".into(),
+ path: path.into(),
+ mode: EditFileMode::Edit,
+ },
+ &stream_tx,
+ cx,
+ )
+ });
+
+ if should_confirm {
+ stream_rx.expect_tool_authorization().await;
+ } else {
+ auth.await.unwrap();
+ assert!(
+ stream_rx.try_next().is_err(),
+ "Failed for case: {} - path: {} - expected no confirmation but got one",
+ description,
+ path
+ );
+ }
+ }
+ }
+
+ #[gpui::test]
+ async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/project",
+ json!({
+ "existing.txt": "content",
+ ".zed": {
+ "settings.json": "{}"
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|_| {
+ Thread::new(
+ project.clone(),
+ Rc::default(),
+ action_log.clone(),
+ Templates::new(),
+ model.clone(),
+ )
+ });
+ let tool = Arc::new(EditFileTool { thread });
+
+ // Test different EditFileMode values
+ let modes = vec![
+ EditFileMode::Edit,
+ EditFileMode::Create,
+ EditFileMode::Overwrite,
+ ];
+
+ for mode in modes {
+ // Test .zed path with different modes
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let _auth = cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "Edit settings".into(),
+ path: "project/.zed/settings.json".into(),
+ mode: mode.clone(),
+ },
+ &stream_tx,
+ cx,
+ )
+ });
+
+ stream_rx.expect_tool_authorization().await;
+
+ // Test outside path with different modes
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let _auth = cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "Edit file".into(),
+ path: "/outside/file.txt".into(),
+ mode: mode.clone(),
+ },
+ &stream_tx,
+ cx,
+ )
+ });
+
+ stream_rx.expect_tool_authorization().await;
+
+ // Test normal path with different modes
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ cx.update(|cx| {
+ tool.authorize(
+ &EditFileToolInput {
+ display_description: "Edit file".into(),
+ path: "project/normal.txt".into(),
+ mode: mode.clone(),
+ },
+ &stream_tx,
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+ assert!(stream_rx.try_next().is_err());
+ }
+ }
+
+ fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ language::init(cx);
+ TelemetrySettings::register(cx);
+ agent_settings::AgentSettings::register(cx);
+ Project::init_settings(cx);
+ });
+ }
+}
@@ -1,6 +1,8 @@
+use crate::{AgentTool, ToolCallEventStream};
use agent_client_protocol as acp;
use anyhow::{anyhow, Result};
use gpui::{App, AppContext, Entity, SharedString, Task};
+use language_model::LanguageModelToolResultContent;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -8,8 +10,6 @@ use std::fmt::Write;
use std::{cmp, path::PathBuf, sync::Arc};
use util::paths::PathMatcher;
-use crate::{AgentTool, ToolCallEventStream};
-
/// Fast file path pattern matching tool that works with any codebase size
///
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
@@ -39,8 +39,35 @@ pub struct FindPathToolInput {
}
#[derive(Debug, Serialize, Deserialize)]
-struct FindPathToolOutput {
- paths: Vec<PathBuf>,
+pub struct FindPathToolOutput {
+ offset: usize,
+ current_matches_page: Vec<PathBuf>,
+ all_matches_len: usize,
+}
+
+impl From<FindPathToolOutput> for LanguageModelToolResultContent {
+ fn from(output: FindPathToolOutput) -> Self {
+ if output.current_matches_page.is_empty() {
+ "No matches found".into()
+ } else {
+ let mut llm_output = format!("Found {} total matches.", output.all_matches_len);
+ if output.all_matches_len > RESULTS_PER_PAGE {
+ write!(
+ &mut llm_output,
+ "\nShowing results {}-{} (provide 'offset' parameter for more results):",
+ output.offset + 1,
+ output.offset + output.current_matches_page.len()
+ )
+ .unwrap();
+ }
+
+ for mat in output.current_matches_page {
+ write!(&mut llm_output, "\n{}", mat.display()).unwrap();
+ }
+
+ llm_output.into()
+ }
+ }
}
const RESULTS_PER_PAGE: usize = 50;
@@ -57,6 +84,7 @@ impl FindPathTool {
impl AgentTool for FindPathTool {
type Input = FindPathToolInput;
+ type Output = FindPathToolOutput;
fn name(&self) -> SharedString {
"find_path".into()
@@ -75,7 +103,7 @@ impl AgentTool for FindPathTool {
input: Self::Input,
event_stream: ToolCallEventStream,
cx: &mut App,
- ) -> Task<Result<String>> {
+ ) -> Task<Result<FindPathToolOutput>> {
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
cx.background_spawn(async move {
@@ -113,26 +141,11 @@ impl AgentTool for FindPathTool {
..Default::default()
});
- if matches.is_empty() {
- Ok("No matches found".into())
- } else {
- let mut message = format!("Found {} total matches.", matches.len());
- if matches.len() > RESULTS_PER_PAGE {
- write!(
- &mut message,
- "\nShowing results {}-{} (provide 'offset' parameter for more results):",
- input.offset + 1,
- input.offset + paginated_matches.len()
- )
- .unwrap();
- }
-
- for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
- write!(&mut message, "\n{}", mat.display()).unwrap();
- }
-
- Ok(message)
- }
+ Ok(FindPathToolOutput {
+ offset: input.offset,
+ current_matches_page: paginated_matches.to_vec(),
+ all_matches_len: matches.len(),
+ })
})
}
}
@@ -1,10 +1,11 @@
use agent_client_protocol::{self as acp};
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context, Result};
use assistant_tool::{outline, ActionLog};
use gpui::{Entity, Task};
use indoc::formatdoc;
use language::{Anchor, Point};
-use project::{AgentLocation, Project, WorktreeSettings};
+use language_model::{LanguageModelImage, LanguageModelToolResultContent};
+use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
@@ -59,6 +60,7 @@ impl ReadFileTool {
impl AgentTool for ReadFileTool {
type Input = ReadFileToolInput;
+ type Output = LanguageModelToolResultContent;
fn name(&self) -> SharedString {
"read_file".into()
@@ -91,9 +93,9 @@ impl AgentTool for ReadFileTool {
fn run(
self: Arc<Self>,
input: Self::Input,
- event_stream: ToolCallEventStream,
+ _event_stream: ToolCallEventStream,
cx: &mut App,
- ) -> Task<Result<String>> {
+ ) -> Task<Result<LanguageModelToolResultContent>> {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
};
@@ -132,51 +134,27 @@ impl AgentTool for ReadFileTool {
let file_path = input.path.clone();
- event_stream.send_update(acp::ToolCallUpdateFields {
- locations: Some(vec![acp::ToolCallLocation {
- path: project_path.path.to_path_buf(),
- line: input.start_line,
- // TODO (tracked): use full range
- }]),
- ..Default::default()
- });
-
- // TODO (tracked): images
- // if image_store::is_image_file(&self.project, &project_path, cx) {
- // let model = &self.thread.read(cx).selected_model;
-
- // if !model.supports_images() {
- // return Task::ready(Err(anyhow!(
- // "Attempted to read an image, but Zed doesn't currently support sending images to {}.",
- // model.name().0
- // )))
- // .into();
- // }
-
- // return cx.spawn(async move |cx| -> Result<ToolResultOutput> {
- // let image_entity: Entity<ImageItem> = cx
- // .update(|cx| {
- // self.project.update(cx, |project, cx| {
- // project.open_image(project_path.clone(), cx)
- // })
- // })?
- // .await?;
-
- // let image =
- // image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
-
- // let language_model_image = cx
- // .update(|cx| LanguageModelImage::from_image(image, cx))?
- // .await
- // .context("processing image")?;
-
- // Ok(ToolResultOutput {
- // content: ToolResultContent::Image(language_model_image),
- // output: None,
- // })
- // });
- // }
- //
+ if image_store::is_image_file(&self.project, &project_path, cx) {
+ return cx.spawn(async move |cx| {
+ let image_entity: Entity<ImageItem> = cx
+ .update(|cx| {
+ self.project.update(cx, |project, cx| {
+ project.open_image(project_path.clone(), cx)
+ })
+ })?
+ .await?;
+
+ let image =
+ image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
+
+ let language_model_image = cx
+ .update(|cx| LanguageModelImage::from_image(image, cx))?
+ .await
+ .context("processing image")?;
+
+ Ok(language_model_image.into())
+ });
+ }
let project = self.project.clone();
let action_log = self.action_log.clone();
@@ -244,7 +222,7 @@ impl AgentTool for ReadFileTool {
})?;
}
- Ok(result)
+ Ok(result.into())
} else {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
@@ -257,7 +235,7 @@ impl AgentTool for ReadFileTool {
log.buffer_read(buffer, cx);
})?;
- Ok(result)
+ Ok(result.into())
} else {
// File is too big, so return the outline
// and a suggestion to read again with line numbers.
@@ -276,7 +254,8 @@ impl AgentTool for ReadFileTool {
Alternatively, you can fall back to the `grep` tool (if available)
to search the file for specific content."
- })
+ }
+ .into())
}
}
})
@@ -285,8 +264,6 @@ impl AgentTool for ReadFileTool {
#[cfg(test)]
mod test {
- use crate::TestToolCallEventStream;
-
use super::*;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
@@ -304,7 +281,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
- let event_stream = TestToolCallEventStream::new();
+ let (event_stream, _) = ToolCallEventStream::test();
let result = cx
.update(|cx| {
@@ -313,7 +290,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.run(input, event_stream.stream(), cx)
+ tool.run(input, event_stream, cx)
})
.await;
assert_eq!(
@@ -321,6 +298,7 @@ mod test {
"root/nonexistent_file.txt not found"
);
}
+
#[gpui::test]
async fn test_read_small_file(cx: &mut TestAppContext) {
init_test(cx);
@@ -336,7 +314,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
- let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@@ -344,10 +321,10 @@ mod test {
start_line: None,
end_line: None,
};
- tool.run(input, event_stream.stream(), cx)
+ tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
- assert_eq!(result.unwrap(), "This is a small file content");
+ assert_eq!(result.unwrap(), "This is a small file content".into());
}
#[gpui::test]
@@ -367,18 +344,18 @@ mod test {
language_registry.add(Arc::new(rust_lang()));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
- let event_stream = TestToolCallEventStream::new();
- let content = cx
+ let result = cx
.update(|cx| {
let input = ReadFileToolInput {
path: "root/large_file.rs".into(),
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
+ let content = result.to_str().unwrap();
assert_eq!(
content.lines().skip(4).take(6).collect::<Vec<_>>(),
@@ -399,10 +376,11 @@ mod test {
start_line: None,
end_line: None,
};
- tool.run(input, event_stream.stream(), cx)
+ tool.run(input, ToolCallEventStream::test().0, cx)
})
- .await;
- let content = result.unwrap();
+ .await
+ .unwrap();
+ let content = result.to_str().unwrap();
let expected_content = (0..1000)
.flat_map(|i| {
vec![
@@ -438,7 +416,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
- let event_stream = TestToolCallEventStream::new();
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@@ -446,10 +423,10 @@ mod test {
start_line: Some(2),
end_line: Some(4),
};
- tool.run(input, event_stream.stream(), cx)
+ tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
- assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
+ assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into());
}
#[gpui::test]
@@ -467,7 +444,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
- let event_stream = TestToolCallEventStream::new();
// start_line of 0 should be treated as 1
let result = cx
@@ -477,10 +453,10 @@ mod test {
start_line: Some(0),
end_line: Some(2),
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
- assert_eq!(result.unwrap(), "Line 1\nLine 2");
+ assert_eq!(result.unwrap(), "Line 1\nLine 2".into());
// end_line of 0 should result in at least 1 line
let result = cx
@@ -490,10 +466,10 @@ mod test {
start_line: Some(1),
end_line: Some(0),
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
- assert_eq!(result.unwrap(), "Line 1");
+ assert_eq!(result.unwrap(), "Line 1".into());
// when start_line > end_line, should still return at least 1 line
let result = cx
@@ -503,10 +479,10 @@ mod test {
start_line: Some(3),
end_line: Some(2),
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
- assert_eq!(result.unwrap(), "Line 3");
+ assert_eq!(result.unwrap(), "Line 3".into());
}
fn init_test(cx: &mut TestAppContext) {
@@ -612,7 +588,6 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project, action_log));
- let event_stream = TestToolCallEventStream::new();
// Reading a file outside the project worktree should fail
let result = cx
@@ -622,7 +597,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@@ -638,7 +613,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@@ -654,7 +629,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@@ -669,7 +644,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@@ -685,7 +660,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@@ -700,7 +675,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@@ -715,7 +690,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@@ -731,11 +706,11 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(result.is_ok(), "Should be able to read normal files");
- assert_eq!(result.unwrap(), "Normal file content");
+ assert_eq!(result.unwrap(), "Normal file content".into());
// Path traversal attempts with .. should fail
let result = cx
@@ -745,7 +720,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.run(input, event_stream.stream(), cx)
+ tool.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert!(
@@ -826,7 +801,6 @@ mod test {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
- let event_stream = TestToolCallEventStream::new();
// Test reading allowed files in worktree1
let result = cx
@@ -836,12 +810,15 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
- assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }");
+ assert_eq!(
+ result,
+ "fn main() { println!(\"Hello from worktree1\"); }".into()
+ );
// Test reading private file in worktree1 should fail
let result = cx
@@ -851,7 +828,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@@ -872,7 +849,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@@ -893,14 +870,14 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await
.unwrap();
assert_eq!(
result,
- "export function greet() { return 'Hello from worktree2'; }"
+ "export function greet() { return 'Hello from worktree2'; }".into()
);
// Test reading private file in worktree2 should fail
@@ -911,7 +888,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@@ -932,7 +909,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@@ -954,7 +931,7 @@ mod test {
start_line: None,
end_line: None,
};
- tool.clone().run(input, event_stream.stream(), cx)
+ tool.clone().run(input, ToolCallEventStream::test().0, cx)
})
.await;
@@ -20,6 +20,7 @@ pub struct ThinkingTool;
impl AgentTool for ThinkingTool {
type Input = ThinkingToolInput;
+ type Output = String;
fn name(&self) -> SharedString {
"thinking".into()
@@ -42,7 +42,7 @@ use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
use ::acp_thread::{
- AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff,
+ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
};
@@ -732,7 +732,11 @@ impl AcpThreadView {
cx: &App,
) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
- Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
+ Some(
+ entry
+ .diffs()
+ .map(|diff| diff.read(cx).multibuffer().clone()),
+ )
}
fn authenticate(
@@ -1314,10 +1318,9 @@ impl AcpThreadView {
Empty.into_any_element()
}
}
- ToolCallContent::Diff {
- diff: Diff { multibuffer, .. },
- ..
- } => self.render_diff_editor(multibuffer),
+ ToolCallContent::Diff { diff, .. } => {
+ self.render_diff_editor(&diff.read(cx).multibuffer())
+ }
}
}
@@ -2,7 +2,7 @@ mod copy_path_tool;
mod create_directory_tool;
mod delete_path_tool;
mod diagnostics_tool;
-mod edit_agent;
+pub mod edit_agent;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
@@ -14,7 +14,7 @@ mod open_tool;
mod project_notifications_tool;
mod read_file_tool;
mod schema;
-mod templates;
+pub mod templates;
mod terminal_tool;
mod thinking_tool;
mod ui;
@@ -29,7 +29,6 @@ use serde::{Deserialize, Serialize};
use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
-use util::debug_panic;
#[derive(Serialize)]
struct CreateFilePromptTemplate {
@@ -682,11 +681,6 @@ impl EditAgent {
if last_message.content.is_empty() {
conversation.messages.pop();
}
- } else {
- debug_panic!(
- "Last message must be an Assistant tool calling! Got {:?}",
- last_message.content
- );
}
}
@@ -120,8 +120,6 @@ struct PartialInput {
display_description: String,
}
-const DEFAULT_UI_TEXT: &str = "Editing file";
-
impl Tool for EditFileTool {
fn name(&self) -> String {
"edit_file".into()
@@ -211,22 +209,6 @@ impl Tool for EditFileTool {
}
}
- fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
- if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
- let description = input.display_description.trim();
- if !description.is_empty() {
- return description.to_string();
- }
-
- let path = input.path.trim();
- if !path.is_empty() {
- return path.to_string();
- }
- }
-
- DEFAULT_UI_TEXT.to_string()
- }
-
fn run(
self: Arc<Self>,
input: serde_json::Value,
@@ -1370,73 +1352,6 @@ mod tests {
assert_eq!(actual, expected);
}
- #[test]
- fn still_streaming_ui_text_with_path() {
- let input = json!({
- "path": "src/main.rs",
- "display_description": "",
- "old_string": "old code",
- "new_string": "new code"
- });
-
- assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
- }
-
- #[test]
- fn still_streaming_ui_text_with_description() {
- let input = json!({
- "path": "",
- "display_description": "Fix error handling",
- "old_string": "old code",
- "new_string": "new code"
- });
-
- assert_eq!(
- EditFileTool.still_streaming_ui_text(&input),
- "Fix error handling",
- );
- }
-
- #[test]
- fn still_streaming_ui_text_with_path_and_description() {
- let input = json!({
- "path": "src/main.rs",
- "display_description": "Fix error handling",
- "old_string": "old code",
- "new_string": "new code"
- });
-
- assert_eq!(
- EditFileTool.still_streaming_ui_text(&input),
- "Fix error handling",
- );
- }
-
- #[test]
- fn still_streaming_ui_text_no_path_or_description() {
- let input = json!({
- "path": "",
- "display_description": "",
- "old_string": "old code",
- "new_string": "new code"
- });
-
- assert_eq!(
- EditFileTool.still_streaming_ui_text(&input),
- DEFAULT_UI_TEXT,
- );
- }
-
- #[test]
- fn still_streaming_ui_text_with_null() {
- let input = serde_json::Value::Null;
-
- assert_eq!(
- EditFileTool.still_streaming_ui_text(&input),
- DEFAULT_UI_TEXT,
- );
- }
-
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
@@ -297,6 +297,12 @@ impl From<String> for LanguageModelToolResultContent {
}
}
+impl From<LanguageModelImage> for LanguageModelToolResultContent {
+ fn from(image: LanguageModelImage) -> Self {
+ Self::Image(image)
+ }
+}
+
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent {
Text(String),