@@ -22,13 +22,13 @@ use client::UserStore;
use cloud_api_types::Plan;
use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
-use futures::stream;
use futures::{
FutureExt,
channel::{mpsc, oneshot},
future::Shared,
stream::FuturesUnordered,
};
+use futures::{StreamExt, stream};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
@@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file};
-use smol::stream::StreamExt;
use std::{
collections::BTreeMap,
marker::PhantomData,
@@ -2095,7 +2094,7 @@ impl Thread {
this.update(cx, |this, _cx| {
this.pending_message()
.tool_results
- .insert(tool_result.tool_use_id.clone(), tool_result);
+ .insert(tool_result.tool_use_id.clone(), tool_result)
})?;
Ok(())
}
@@ -2195,15 +2194,15 @@ impl Thread {
raw_input,
json_parse_error,
} => {
- return Ok(Some(Task::ready(
- self.handle_tool_use_json_parse_error_event(
- id,
- tool_name,
- raw_input,
- json_parse_error,
- event_stream,
- ),
- )));
+ return Ok(self.handle_tool_use_json_parse_error_event(
+ id,
+ tool_name,
+ raw_input,
+ json_parse_error,
+ event_stream,
+ cancellation_rx,
+ cx,
+ ));
}
UsageUpdate(usage) => {
telemetry::event!(
@@ -2304,12 +2303,12 @@ impl Thread {
if !tool_use.is_input_complete {
if tool.supports_input_streaming() {
let running_turn = self.running_turn.as_mut()?;
- if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) {
+ if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) {
sender.send_partial(tool_use.input);
return None;
}
- let (sender, tool_input) = ToolInputSender::channel();
+ let (mut sender, tool_input) = ToolInputSender::channel();
sender.send_partial(tool_use.input);
running_turn
.streaming_tool_inputs
@@ -2331,13 +2330,13 @@ impl Thread {
}
}
- if let Some(sender) = self
+ if let Some(mut sender) = self
.running_turn
.as_mut()?
.streaming_tool_inputs
.remove(&tool_use.id)
{
- sender.send_final(tool_use.input);
+ sender.send_full(tool_use.input);
return None;
}
@@ -2410,10 +2409,12 @@ impl Thread {
raw_input: Arc<str>,
json_parse_error: String,
event_stream: &ThreadEventStream,
- ) -> LanguageModelToolResult {
+ cancellation_rx: watch::Receiver<bool>,
+ cx: &mut Context<Self>,
+ ) -> Option<Task<LanguageModelToolResult>> {
let tool_use = LanguageModelToolUse {
- id: tool_use_id.clone(),
- name: tool_name.clone(),
+ id: tool_use_id,
+ name: tool_name,
raw_input: raw_input.to_string(),
input: serde_json::json!({}),
is_input_complete: true,
@@ -2426,14 +2427,43 @@ impl Thread {
event_stream,
);
- let tool_output = format!("Error parsing input JSON: {json_parse_error}");
- LanguageModelToolResult {
- tool_use_id,
- tool_name,
- is_error: true,
- content: LanguageModelToolResultContent::Text(tool_output.into()),
- output: Some(serde_json::Value::String(raw_input.to_string())),
+ let tool = self.tool(tool_use.name.as_ref());
+
+ let Some(tool) = tool else {
+ let content = format!("No tool named {} exists", tool_use.name);
+ return Some(Task::ready(LanguageModelToolResult {
+ content: LanguageModelToolResultContent::Text(Arc::from(content)),
+ tool_use_id: tool_use.id,
+ tool_name: tool_use.name,
+ is_error: true,
+ output: None,
+ }));
+ };
+
+ let error_message = format!("Error parsing input JSON: {json_parse_error}");
+
+ if tool.supports_input_streaming()
+ && let Some(mut sender) = self
+ .running_turn
+ .as_mut()?
+ .streaming_tool_inputs
+ .remove(&tool_use.id)
+ {
+ sender.send_invalid_json(error_message);
+ return None;
}
+
+ log::debug!("Running tool {}. Received invalid JSON", tool_use.name);
+ let tool_input = ToolInput::invalid_json(error_message);
+ Some(self.run_tool(
+ tool,
+ tool_input,
+ tool_use.id,
+ tool_use.name,
+ event_stream,
+ cancellation_rx,
+ cx,
+ ))
}
fn send_or_update_tool_use(
@@ -3114,8 +3144,7 @@ impl EventEmitter<TitleUpdated> for Thread {}
/// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams
/// them, followed by the final complete input available through `.recv()`.
pub struct ToolInput<T> {
- partial_rx: mpsc::UnboundedReceiver<serde_json::Value>,
- final_rx: oneshot::Receiver<serde_json::Value>,
+ rx: mpsc::UnboundedReceiver<ToolInputPayload<serde_json::Value>>,
_phantom: PhantomData<T>,
}
@@ -3127,13 +3156,20 @@ impl<T: DeserializeOwned> ToolInput<T> {
}
pub fn ready(value: serde_json::Value) -> Self {
- let (partial_tx, partial_rx) = mpsc::unbounded();
- drop(partial_tx);
- let (final_tx, final_rx) = oneshot::channel();
- final_tx.send(value).ok();
+ let (tx, rx) = mpsc::unbounded();
+ tx.unbounded_send(ToolInputPayload::Full(value)).ok();
Self {
- partial_rx,
- final_rx,
+ rx,
+ _phantom: PhantomData,
+ }
+ }
+
+ pub fn invalid_json(error_message: String) -> Self {
+ let (tx, rx) = mpsc::unbounded();
+ tx.unbounded_send(ToolInputPayload::InvalidJson { error_message })
+ .ok();
+ Self {
+ rx,
_phantom: PhantomData,
}
}
@@ -3147,65 +3183,89 @@ impl<T: DeserializeOwned> ToolInput<T> {
/// Wait for the final deserialized input, ignoring all partial updates.
/// Non-streaming tools can use this to wait until the whole input is available.
pub async fn recv(mut self) -> Result<T> {
- // Drain any remaining partials
- while self.partial_rx.next().await.is_some() {}
+ while let Ok(value) = self.next().await {
+ match value {
+ ToolInputPayload::Full(value) => return Ok(value),
+ ToolInputPayload::Partial(_) => {}
+ ToolInputPayload::InvalidJson { error_message } => {
+ return Err(anyhow!(error_message));
+ }
+ }
+ }
+ Err(anyhow!("tool input was not fully received"))
+ }
+
+ pub async fn next(&mut self) -> Result<ToolInputPayload<T>> {
let value = self
- .final_rx
+ .rx
+ .next()
.await
- .map_err(|_| anyhow!("tool input was not fully received"))?;
- serde_json::from_value(value).map_err(Into::into)
- }
+ .ok_or_else(|| anyhow!("tool input was not fully received"))?;
- /// Returns the next partial JSON snapshot, or `None` when input is complete.
- /// Once this returns `None`, call `recv()` to get the final input.
- pub async fn recv_partial(&mut self) -> Option<serde_json::Value> {
- self.partial_rx.next().await
+ Ok(match value {
+ ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload),
+ ToolInputPayload::Full(payload) => {
+ ToolInputPayload::Full(serde_json::from_value(payload)?)
+ }
+ ToolInputPayload::InvalidJson { error_message } => {
+ ToolInputPayload::InvalidJson { error_message }
+ }
+ })
}
fn cast<U: DeserializeOwned>(self) -> ToolInput<U> {
ToolInput {
- partial_rx: self.partial_rx,
- final_rx: self.final_rx,
+ rx: self.rx,
_phantom: PhantomData,
}
}
}
+pub enum ToolInputPayload<T> {
+ Partial(serde_json::Value),
+ Full(T),
+ InvalidJson { error_message: String },
+}
+
pub struct ToolInputSender {
- partial_tx: mpsc::UnboundedSender<serde_json::Value>,
- final_tx: Option<oneshot::Sender<serde_json::Value>>,
+ has_received_final: bool,
+ tx: mpsc::UnboundedSender<ToolInputPayload<serde_json::Value>>,
}
impl ToolInputSender {
pub(crate) fn channel() -> (Self, ToolInput<serde_json::Value>) {
- let (partial_tx, partial_rx) = mpsc::unbounded();
- let (final_tx, final_rx) = oneshot::channel();
+ let (tx, rx) = mpsc::unbounded();
let sender = Self {
- partial_tx,
- final_tx: Some(final_tx),
+ tx,
+ has_received_final: false,
};
let input = ToolInput {
- partial_rx,
- final_rx,
+ rx,
_phantom: PhantomData,
};
(sender, input)
}
pub(crate) fn has_received_final(&self) -> bool {
- self.final_tx.is_none()
+ self.has_received_final
}
- pub(crate) fn send_partial(&self, value: serde_json::Value) {
- self.partial_tx.unbounded_send(value).ok();
+ pub fn send_partial(&mut self, payload: serde_json::Value) {
+ self.tx
+ .unbounded_send(ToolInputPayload::Partial(payload))
+ .ok();
}
- pub(crate) fn send_final(mut self, value: serde_json::Value) {
- // Close the partial channel so recv_partial() returns None
- self.partial_tx.close_channel();
- if let Some(final_tx) = self.final_tx.take() {
- final_tx.send(value).ok();
- }
+ pub fn send_full(&mut self, payload: serde_json::Value) {
+ self.has_received_final = true;
+ self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok();
+ }
+
+ pub fn send_invalid_json(&mut self, error_message: String) {
+ self.has_received_final = true;
+ self.tx
+ .unbounded_send(ToolInputPayload::InvalidJson { error_message })
+ .ok();
}
}
@@ -4251,68 +4311,78 @@ mod tests {
) {
let (thread, event_stream) = setup_thread_for_test(cx).await;
- cx.update(|cx| {
- thread.update(cx, |thread, _cx| {
- let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
- let tool_name: Arc<str> = Arc::from("test_tool");
- let raw_input: Arc<str> = Arc::from("{invalid json");
- let json_parse_error = "expected value at line 1 column 1".to_string();
-
- // Call the function under test
- let result = thread.handle_tool_use_json_parse_error_event(
- tool_use_id.clone(),
- tool_name.clone(),
- raw_input.clone(),
- json_parse_error,
- &event_stream,
- );
-
- // Verify the result is an error
- assert!(result.is_error);
- assert_eq!(result.tool_use_id, tool_use_id);
- assert_eq!(result.tool_name, tool_name);
- assert!(matches!(
- result.content,
- LanguageModelToolResultContent::Text(_)
- ));
-
- // Verify the tool use was added to the message content
- {
- let last_message = thread.pending_message();
- assert_eq!(
- last_message.content.len(),
- 1,
- "Should have one tool_use in content"
- );
-
- match &last_message.content[0] {
- AgentMessageContent::ToolUse(tool_use) => {
- assert_eq!(tool_use.id, tool_use_id);
- assert_eq!(tool_use.name, tool_name);
- assert_eq!(tool_use.raw_input, raw_input.to_string());
- assert!(tool_use.is_input_complete);
- // Should fall back to empty object for invalid JSON
- assert_eq!(tool_use.input, json!({}));
- }
- _ => panic!("Expected ToolUse content"),
- }
- }
-
- // Insert the tool result (simulating what the caller does)
- thread
- .pending_message()
- .tool_results
- .insert(result.tool_use_id.clone(), result);
+ let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
+ let tool_name: Arc<str> = Arc::from("test_tool");
+ let raw_input: Arc<str> = Arc::from("{invalid json");
+ let json_parse_error = "expected value at line 1 column 1".to_string();
+
+ let (_cancellation_tx, cancellation_rx) = watch::channel(false);
+
+ let result = cx
+ .update(|cx| {
+ thread.update(cx, |thread, cx| {
+ // Call the function under test
+ thread
+ .handle_tool_use_json_parse_error_event(
+ tool_use_id.clone(),
+ tool_name.clone(),
+ raw_input.clone(),
+ json_parse_error,
+ &event_stream,
+ cancellation_rx,
+ cx,
+ )
+ .unwrap()
+ })
+ })
+ .await;
+
+ // Verify the result is an error
+ assert!(result.is_error);
+ assert_eq!(result.tool_use_id, tool_use_id);
+ assert_eq!(result.tool_name, tool_name);
+ assert!(matches!(
+ result.content,
+ LanguageModelToolResultContent::Text(_)
+ ));
- // Verify the tool result was added
+ thread.update(cx, |thread, _cx| {
+ // Verify the tool use was added to the message content
+ {
let last_message = thread.pending_message();
assert_eq!(
- last_message.tool_results.len(),
+ last_message.content.len(),
1,
- "Should have one tool_result"
+ "Should have one tool_use in content"
);
- assert!(last_message.tool_results.contains_key(&tool_use_id));
- });
- });
+
+ match &last_message.content[0] {
+ AgentMessageContent::ToolUse(tool_use) => {
+ assert_eq!(tool_use.id, tool_use_id);
+ assert_eq!(tool_use.name, tool_name);
+ assert_eq!(tool_use.raw_input, raw_input.to_string());
+ assert!(tool_use.is_input_complete);
+ // Should fall back to empty object for invalid JSON
+ assert_eq!(tool_use.input, json!({}));
+ }
+ _ => panic!("Expected ToolUse content"),
+ }
+ }
+
+ // Insert the tool result (simulating what the caller does)
+ thread
+ .pending_message()
+ .tool_results
+ .insert(result.tool_use_id.clone(), result);
+
+ // Verify the tool result was added
+ let last_message = thread.pending_message();
+ assert_eq!(
+ last_message.tool_results.len(),
+ 1,
+ "Should have one tool_result"
+ );
+ assert!(last_message.tool_results.contains_key(&tool_use_id));
+ })
}
}
@@ -2,6 +2,7 @@ use super::edit_file_tool::EditFileTool;
use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
use super::save_file_tool::SaveFileTool;
use super::tool_edit_parser::{ToolEditEvent, ToolEditParser};
+use crate::ToolInputPayload;
use crate::{
AgentTool, Thread, ToolCallEventStream, ToolInput,
edit_agent::{
@@ -12,7 +13,7 @@ use crate::{
use acp_thread::Diff;
use action_log::ActionLog;
use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
-use anyhow::{Context as _, Result};
+use anyhow::Result;
use collections::HashSet;
use futures::FutureExt as _;
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
@@ -188,6 +189,10 @@ pub enum StreamingEditFileToolOutput {
},
Error {
error: String,
+ #[serde(default)]
+ input_path: Option<PathBuf>,
+ #[serde(default)]
+ diff: String,
},
}
@@ -195,6 +200,8 @@ impl StreamingEditFileToolOutput {
pub fn error(error: impl Into<String>) -> Self {
Self::Error {
error: error.into(),
+ input_path: None,
+ diff: String::new(),
}
}
}
@@ -215,7 +222,24 @@ impl std::fmt::Display for StreamingEditFileToolOutput {
)
}
}
- StreamingEditFileToolOutput::Error { error } => write!(f, "{error}"),
+ StreamingEditFileToolOutput::Error {
+ error,
+ diff,
+ input_path,
+ } => {
+ write!(f, "{error}\n")?;
+ if let Some(input_path) = input_path
+ && !diff.is_empty()
+ {
+ write!(
+ f,
+ "Edited {}:\n\n```diff\n{diff}\n```",
+ input_path.display()
+ )
+ } else {
+ write!(f, "No edits were made.")
+ }
+ }
}
}
}
@@ -233,6 +257,14 @@ pub struct StreamingEditFileTool {
language_registry: Arc<LanguageRegistry>,
}
+enum EditSessionResult {
+ Completed(EditSession),
+ Failed {
+ error: String,
+ session: Option<EditSession>,
+ },
+}
+
impl StreamingEditFileTool {
pub fn new(
project: Entity<Project>,
@@ -276,6 +308,158 @@ impl StreamingEditFileTool {
});
}
}
+
+ async fn ensure_buffer_saved(&self, buffer: &Entity<Buffer>, cx: &mut AsyncApp) {
+ let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
+ let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
+ settings.format_on_save != FormatOnSave::Off
+ });
+
+ if format_on_save_enabled {
+ self.project
+ .update(cx, |project, cx| {
+ project.format(
+ HashSet::from_iter([buffer.clone()]),
+ LspFormatTarget::Buffers,
+ false,
+ FormatTrigger::Save,
+ cx,
+ )
+ })
+ .await
+ .log_err();
+ }
+
+ self.project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+ .await
+ .log_err();
+
+ self.action_log.update(cx, |log, cx| {
+ log.buffer_edited(buffer.clone(), cx);
+ });
+ }
+
+ async fn process_streaming_edits(
+ &self,
+ input: &mut ToolInput<StreamingEditFileToolInput>,
+ event_stream: &ToolCallEventStream,
+ cx: &mut AsyncApp,
+ ) -> EditSessionResult {
+ let mut session: Option<EditSession> = None;
+ let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
+
+ loop {
+ futures::select! {
+ payload = input.next().fuse() => {
+ match payload {
+ Ok(payload) => match payload {
+ ToolInputPayload::Partial(partial) => {
+ if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial) {
+ let path_complete = parsed.path.is_some()
+ && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref());
+
+ last_partial = Some(parsed.clone());
+
+ if session.is_none()
+ && path_complete
+ && let StreamingEditFileToolPartialInput {
+ path: Some(path),
+ display_description: Some(display_description),
+ mode: Some(mode),
+ ..
+ } = &parsed
+ {
+ match EditSession::new(
+ PathBuf::from(path),
+ display_description,
+ *mode,
+ self,
+ event_stream,
+ cx,
+ )
+ .await
+ {
+ Ok(created_session) => session = Some(created_session),
+ Err(error) => {
+ log::error!("Failed to create edit session: {}", error);
+ return EditSessionResult::Failed {
+ error,
+ session: None,
+ };
+ }
+ }
+ }
+
+ if let Some(current_session) = &mut session
+ && let Err(error) = current_session.process(parsed, self, event_stream, cx)
+ {
+ log::error!("Failed to process edit: {}", error);
+ return EditSessionResult::Failed { error, session };
+ }
+ }
+ }
+ ToolInputPayload::Full(full_input) => {
+ let mut session = if let Some(session) = session {
+ session
+ } else {
+ match EditSession::new(
+ full_input.path.clone(),
+ &full_input.display_description,
+ full_input.mode,
+ self,
+ event_stream,
+ cx,
+ )
+ .await
+ {
+ Ok(created_session) => created_session,
+ Err(error) => {
+ log::error!("Failed to create edit session: {}", error);
+ return EditSessionResult::Failed {
+ error,
+ session: None,
+ };
+ }
+ }
+ };
+
+ return match session.finalize(full_input, self, event_stream, cx).await {
+ Ok(()) => EditSessionResult::Completed(session),
+ Err(error) => {
+ log::error!("Failed to finalize edit: {}", error);
+ EditSessionResult::Failed {
+ error,
+ session: Some(session),
+ }
+ }
+ };
+ }
+ ToolInputPayload::InvalidJson { error_message } => {
+ log::error!("Received invalid JSON: {error_message}");
+ return EditSessionResult::Failed {
+ error: error_message,
+ session,
+ };
+ }
+ },
+ Err(error) => {
+ return EditSessionResult::Failed {
+ error: format!("Failed to receive tool input: {error}"),
+ session,
+ };
+ }
+ }
+ }
+ _ = event_stream.cancelled_by_user().fuse() => {
+ return EditSessionResult::Failed {
+ error: "Edit cancelled by user".to_string(),
+ session,
+ };
+ }
+ }
+ }
+ }
}
impl AgentTool for StreamingEditFileTool {
@@ -348,94 +532,40 @@ impl AgentTool for StreamingEditFileTool {
cx: &mut App,
) -> Task<Result<Self::Output, Self::Output>> {
cx.spawn(async move |cx: &mut AsyncApp| {
- let mut state: Option<EditSession> = None;
- let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
- loop {
- futures::select! {
- partial = input.recv_partial().fuse() => {
- let Some(partial_value) = partial else { break };
- if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial_value) {
- let path_complete = parsed.path.is_some()
- && parsed.path.as_ref() == last_partial.as_ref().and_then(|p| p.path.as_ref());
-
- last_partial = Some(parsed.clone());
-
- if state.is_none()
- && path_complete
- && let StreamingEditFileToolPartialInput {
- path: Some(path),
- display_description: Some(display_description),
- mode: Some(mode),
- ..
- } = &parsed
- {
- match EditSession::new(
- &PathBuf::from(path),
- display_description,
- *mode,
- &self,
- &event_stream,
- cx,
- )
- .await
- {
- Ok(session) => state = Some(session),
- Err(e) => {
- log::error!("Failed to create edit session: {}", e);
- return Err(e);
- }
- }
- }
-
- if let Some(state) = &mut state {
- if let Err(e) = state.process(parsed, &self, &event_stream, cx) {
- log::error!("Failed to process edit: {}", e);
- return Err(e);
- }
- }
- }
- }
- _ = event_stream.cancelled_by_user().fuse() => {
- return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
- }
- }
- }
- let full_input =
- input
- .recv()
- .await
- .map_err(|e| {
- let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}"));
- log::error!("Failed to receive tool input: {e}");
- err
- })?;
-
- let mut state = if let Some(state) = state {
- state
- } else {
- match EditSession::new(
- &full_input.path,
- &full_input.display_description,
- full_input.mode,
- &self,
- &event_stream,
- cx,
- )
+ match self
+ .process_streaming_edits(&mut input, &event_stream, cx)
.await
- {
- Ok(session) => session,
- Err(e) => {
- log::error!("Failed to create edit session: {}", e);
- return Err(e);
- }
+ {
+ EditSessionResult::Completed(session) => {
+ self.ensure_buffer_saved(&session.buffer, cx).await;
+ let (new_text, diff) = session.compute_new_text_and_diff(cx).await;
+ Ok(StreamingEditFileToolOutput::Success {
+ old_text: session.old_text.clone(),
+ new_text,
+ input_path: session.input_path,
+ diff,
+ })
}
- };
- match state.finalize(full_input, &self, &event_stream, cx).await {
- Ok(output) => Ok(output),
- Err(e) => {
- log::error!("Failed to finalize edit: {}", e);
- Err(e)
+ EditSessionResult::Failed {
+ error,
+ session: Some(session),
+ } => {
+ self.ensure_buffer_saved(&session.buffer, cx).await;
+ let (_new_text, diff) = session.compute_new_text_and_diff(cx).await;
+ Err(StreamingEditFileToolOutput::Error {
+ error,
+ input_path: Some(session.input_path),
+ diff,
+ })
}
+ EditSessionResult::Failed {
+ error,
+ session: None,
+ } => Err(StreamingEditFileToolOutput::Error {
+ error,
+ input_path: None,
+ diff: String::new(),
+ }),
}
})
}
@@ -472,6 +602,7 @@ impl AgentTool for StreamingEditFileTool {
pub struct EditSession {
abs_path: PathBuf,
+ input_path: PathBuf,
buffer: Entity<Buffer>,
old_text: Arc<String>,
diff: Entity<Diff>,
@@ -518,23 +649,21 @@ impl EditPipeline {
impl EditSession {
async fn new(
- path: &PathBuf,
+ path: PathBuf,
display_description: &str,
mode: StreamingEditFileMode,
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
- ) -> Result<Self, StreamingEditFileToolOutput> {
- let project_path = cx
- .update(|cx| resolve_path(mode, &path, &tool.project, cx))
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+ ) -> Result<Self, String> {
+ let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?;
let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx))
else {
- return Err(StreamingEditFileToolOutput::error(format!(
+ return Err(format!(
"Worktree at '{}' does not exist",
path.to_string_lossy()
- )));
+ ));
};
event_stream.update_fields(
@@ -543,13 +672,13 @@ impl EditSession {
cx.update(|cx| tool.authorize(&path, &display_description, event_stream, cx))
.await
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+ .map_err(|e| e.to_string())?;
let buffer = tool
.project
.update(cx, |project, cx| project.open_buffer(project_path, cx))
.await
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+ .map_err(|e| e.to_string())?;
ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
@@ -578,6 +707,7 @@ impl EditSession {
Ok(Self {
abs_path,
+ input_path: path,
buffer,
old_text,
diff,
@@ -594,22 +724,20 @@ impl EditSession {
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
- ) -> Result<StreamingEditFileToolOutput, StreamingEditFileToolOutput> {
- let old_text = self.old_text.clone();
-
+ ) -> Result<(), String> {
match input.mode {
StreamingEditFileMode::Write => {
- let content = input.content.ok_or_else(|| {
- StreamingEditFileToolOutput::error("'content' field is required for write mode")
- })?;
+ let content = input
+ .content
+ .ok_or_else(|| "'content' field is required for write mode".to_string())?;
let events = self.parser.finalize_content(&content);
self.process_events(&events, tool, event_stream, cx)?;
}
StreamingEditFileMode::Edit => {
- let edits = input.edits.ok_or_else(|| {
- StreamingEditFileToolOutput::error("'edits' field is required for edit mode")
- })?;
+ let edits = input
+ .edits
+ .ok_or_else(|| "'edits' field is required for edit mode".to_string())?;
let events = self.parser.finalize_edits(&edits);
self.process_events(&events, tool, event_stream, cx)?;
@@ -625,53 +753,15 @@ impl EditSession {
}
}
}
+ Ok(())
+ }
- let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| {
- let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
- settings.format_on_save != FormatOnSave::Off
- });
-
- if format_on_save_enabled {
- tool.action_log.update(cx, |log, cx| {
- log.buffer_edited(self.buffer.clone(), cx);
- });
-
- let format_task = tool.project.update(cx, |project, cx| {
- project.format(
- HashSet::from_iter([self.buffer.clone()]),
- LspFormatTarget::Buffers,
- false,
- FormatTrigger::Save,
- cx,
- )
- });
- futures::select! {
- result = format_task.fuse() => { result.log_err(); },
- _ = event_stream.cancelled_by_user().fuse() => {
- return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
- }
- };
- }
-
- let save_task = tool.project.update(cx, |project, cx| {
- project.save_buffer(self.buffer.clone(), cx)
- });
- futures::select! {
- result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; },
- _ = event_stream.cancelled_by_user().fuse() => {
- return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
- }
- };
-
- tool.action_log.update(cx, |log, cx| {
- log.buffer_edited(self.buffer.clone(), cx);
- });
-
+ async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) {
let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let (new_text, unified_diff) = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
- let old_text = old_text.clone();
+ let old_text = self.old_text.clone();
async move {
let new_text = new_snapshot.text();
let diff = language::unified_diff(&old_text, &new_text);
@@ -679,14 +769,7 @@ impl EditSession {
}
})
.await;
-
- let output = StreamingEditFileToolOutput::Success {
- input_path: input.path,
- new_text,
- old_text: old_text.clone(),
- diff: unified_diff,
- };
- Ok(output)
+ (new_text, unified_diff)
}
fn process(
@@ -695,7 +778,7 @@ impl EditSession {
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
- ) -> Result<(), StreamingEditFileToolOutput> {
+ ) -> Result<(), String> {
match &self.mode {
StreamingEditFileMode::Write => {
if let Some(content) = &partial.content {
@@ -719,7 +802,7 @@ impl EditSession {
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
- ) -> Result<(), StreamingEditFileToolOutput> {
+ ) -> Result<(), String> {
for event in events {
match event {
ToolEditEvent::ContentChunk { chunk } => {
@@ -969,14 +1052,14 @@ fn extract_match(
buffer: &Entity<Buffer>,
edit_index: &usize,
cx: &mut AsyncApp,
-) -> Result<Range<usize>, StreamingEditFileToolOutput> {
+) -> Result<Range<usize>, String> {
match matches.len() {
- 0 => Err(StreamingEditFileToolOutput::error(format!(
+ 0 => Err(format!(
"Could not find matching text for edit at index {}. \
The old_text did not match any content in the file. \
Please read the file again to get the current content.",
edit_index,
- ))),
+ )),
1 => Ok(matches.into_iter().next().unwrap()),
_ => {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
@@ -985,12 +1068,12 @@ fn extract_match(
.map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
.collect::<Vec<_>>()
.join(", ");
- Err(StreamingEditFileToolOutput::error(format!(
+ Err(format!(
"Edit {} matched multiple locations in the file at lines: {}. \
Please provide more context in old_text to uniquely \
identify the location.",
edit_index, lines
- )))
+ ))
}
}
}
@@ -1022,7 +1105,7 @@ fn ensure_buffer_saved(
abs_path: &PathBuf,
tool: &StreamingEditFileTool,
cx: &mut AsyncApp,
-) -> Result<(), StreamingEditFileToolOutput> {
+) -> Result<(), String> {
let last_read_mtime = tool
.action_log
.read_with(cx, |log, _| log.file_read_time(abs_path));
@@ -1063,15 +1146,14 @@ fn ensure_buffer_saved(
then ask them to save or revert the file manually and inform you when it's ok to proceed."
}
};
- return Err(StreamingEditFileToolOutput::error(message));
+ return Err(message.to_string());
}
if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
if current != last_read {
- return Err(StreamingEditFileToolOutput::error(
- "The file has been modified since you last read it. \
- Please read the file again to get the current state before editing it.",
- ));
+ return Err("The file has been modified since you last read it. \
+ Please read the file again to get the current state before editing it."
+ .to_string());
}
}
@@ -1083,56 +1165,63 @@ fn resolve_path(
path: &PathBuf,
project: &Entity<Project>,
cx: &mut App,
-) -> Result<ProjectPath> {
+) -> Result<ProjectPath, String> {
let project = project.read(cx);
match mode {
StreamingEditFileMode::Edit => {
let path = project
.find_project_path(&path, cx)
- .context("Can't edit file: path not found")?;
+ .ok_or_else(|| "Can't edit file: path not found".to_string())?;
let entry = project
.entry_for_path(&path, cx)
- .context("Can't edit file: path not found")?;
+ .ok_or_else(|| "Can't edit file: path not found".to_string())?;
- anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
- Ok(path)
+ if entry.is_file() {
+ Ok(path)
+ } else {
+ Err("Can't edit file: path is a directory".to_string())
+ }
}
StreamingEditFileMode::Write => {
if let Some(path) = project.find_project_path(&path, cx)
&& let Some(entry) = project.entry_for_path(&path, cx)
{
- anyhow::ensure!(entry.is_file(), "Can't write to file: path is a directory");
- return Ok(path);
+ if entry.is_file() {
+ return Ok(path);
+ } else {
+ return Err("Can't write to file: path is a directory".to_string());
+ }
}
- let parent_path = path.parent().context("Can't create file: incorrect path")?;
+ let parent_path = path
+ .parent()
+ .ok_or_else(|| "Can't create file: incorrect path".to_string())?;
let parent_project_path = project.find_project_path(&parent_path, cx);
let parent_entry = parent_project_path
.as_ref()
.and_then(|path| project.entry_for_path(path, cx))
- .context("Can't create file: parent directory doesn't exist")?;
+ .ok_or_else(|| "Can't create file: parent directory doesn't exist")?;
- anyhow::ensure!(
- parent_entry.is_dir(),
- "Can't create file: parent is not a directory"
- );
+ if !parent_entry.is_dir() {
+ return Err("Can't create file: parent is not a directory".to_string());
+ }
let file_name = path
.file_name()
.and_then(|file_name| file_name.to_str())
.and_then(|file_name| RelPath::unix(file_name).ok())
- .context("Can't create file: invalid filename")?;
+ .ok_or_else(|| "Can't create file: invalid filename".to_string())?;
let new_file_path = parent_project_path.map(|parent| ProjectPath {
path: parent.path.join(file_name),
..parent
});
- new_file_path.context("Can't create file")
+ new_file_path.ok_or_else(|| "Can't create file".to_string())
}
}
}
@@ -1382,10 +1471,17 @@ mod tests {
})
.await;
- let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+ let StreamingEditFileToolOutput::Error {
+ error,
+ diff,
+ input_path,
+ } = result.unwrap_err()
+ else {
panic!("expected error");
};
assert_eq!(error, "Can't edit file: path not found");
+ assert!(diff.is_empty());
+ assert_eq!(input_path, None);
}
#[gpui::test]
@@ -1411,7 +1507,7 @@ mod tests {
})
.await;
- let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+ let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else {
panic!("expected error");
};
assert!(
@@ -1424,7 +1520,7 @@ mod tests {
async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1447,7 +1543,7 @@ mod tests {
cx.run_until_parked();
// Now send the final complete input
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1465,7 +1561,7 @@ mod tests {
async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "hello world"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1485,7 +1581,7 @@ mod tests {
cx.run_until_parked();
// Send final
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Overwrite file",
"path": "root/file.txt",
"mode": "write",
@@ -1503,7 +1599,7 @@ mod tests {
async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "hello world"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver, mut cancellation_tx) =
ToolCallEventStream::test_with_cancellation();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1521,7 +1617,7 @@ mod tests {
drop(sender);
let result = task.await;
- let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+ let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else {
panic!("expected error");
};
assert!(
@@ -1537,7 +1633,7 @@ mod tests {
json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
)
.await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1578,7 +1674,7 @@ mod tests {
cx.run_until_parked();
// Send final complete input
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit multiple lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1601,7 +1697,7 @@ mod tests {
#[gpui::test]
async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1625,7 +1721,7 @@ mod tests {
cx.run_until_parked();
// Final with full content
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Create new file",
"path": "root/dir/new_file.txt",
"mode": "write",
@@ -1643,12 +1739,12 @@ mod tests {
async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
// Send final immediately with no partials (simulates non-streaming path)
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1669,7 +1765,7 @@ mod tests {
json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
)
.await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1739,7 +1835,7 @@ mod tests {
);
// Send final complete input
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit multiple lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1767,7 +1863,7 @@ mod tests {
async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) {
let (tool, project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1835,7 +1931,7 @@ mod tests {
assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nCCC\nddd\nEEEeee\n"));
// Send final
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit three lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1857,7 +1953,7 @@ mod tests {
async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) {
let (tool, project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1893,16 +1989,17 @@ mod tests {
}));
cx.run_until_parked();
- // Verify edit 1 was applied
- let buffer_text = project.update(cx, |project, cx| {
+ let buffer = project.update(cx, |project, cx| {
let pp = project
.find_project_path(&PathBuf::from("root/file.txt"), cx)
.unwrap();
- project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text())
+ project.get_open_buffer(&pp, cx).unwrap()
});
+
+ // Verify edit 1 was applied
+ let buffer_text = buffer.read_with(cx, |buffer, _cx| buffer.text());
assert_eq!(
- buffer_text.as_deref(),
- Some("MODIFIED\nline 2\nline 3\n"),
+ buffer_text, "MODIFIED\nline 2\nline 3\n",
"First edit should be applied even though second edit will fail"
);
@@ -1925,20 +2022,32 @@ mod tests {
drop(sender);
let result = task.await;
- let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+ let StreamingEditFileToolOutput::Error {
+ error,
+ diff,
+ input_path,
+ } = result.unwrap_err()
+ else {
panic!("expected error");
};
+
assert!(
error.contains("Could not find matching text for edit at index 1"),
"Expected error about edit 1 failing, got: {error}"
);
+ // Ensure that first edit was applied successfully and that we saved the buffer
+ assert_eq!(input_path, Some(PathBuf::from("root/file.txt")));
+ assert_eq!(
+ diff,
+ "@@ -1,3 +1,3 @@\n-line 1\n+MODIFIED\n line 2\n line 3\n"
+ );
}
#[gpui::test]
async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) {
let (tool, project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "hello world\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1975,7 +2084,7 @@ mod tests {
);
// Send final — the edit is applied during finalization
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Single edit",
"path": "root/file.txt",
"mode": "edit",
@@ -1993,7 +2102,7 @@ mod tests {
async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
- let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+ let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
ToolInput::test();
let (event_stream, _event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2020,7 +2129,7 @@ mod tests {
cx.run_until_parked();
// Send the final complete input
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit lines",
"path": "root/file.txt",
"mode": "edit",
@@ -2038,7 +2147,7 @@ mod tests {
async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "hello world\n"})).await;
- let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+ let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
ToolInput::test();
let (event_stream, _event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2064,7 +2173,7 @@ mod tests {
// Create a channel and send multiple partials before a final, then use
// ToolInput::resolved-style immediate delivery to confirm recv() works
// when partials are already buffered.
- let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+ let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
ToolInput::test();
let (event_stream, _event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2077,7 +2186,7 @@ mod tests {
"path": "root/dir/new.txt",
"mode": "write"
}));
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Create",
"path": "root/dir/new.txt",
"mode": "write",
@@ -2109,13 +2218,13 @@ mod tests {
let result = test_resolve_path(&mode, "root/dir/subdir", cx);
assert_eq!(
- result.await.unwrap_err().to_string(),
+ result.await.unwrap_err(),
"Can't write to file: path is a directory"
);
let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx);
assert_eq!(
- result.await.unwrap_err().to_string(),
+ result.await.unwrap_err(),
"Can't create file: parent directory doesn't exist"
);
}
@@ -2133,14 +2242,11 @@ mod tests {
assert_resolved_path_eq(result.await, rel_path(path_without_root));
let result = test_resolve_path(&mode, "root/nonexistent.txt", cx);
- assert_eq!(
- result.await.unwrap_err().to_string(),
- "Can't edit file: path not found"
- );
+ assert_eq!(result.await.unwrap_err(), "Can't edit file: path not found");
let result = test_resolve_path(&mode, "root/dir", cx);
assert_eq!(
- result.await.unwrap_err().to_string(),
+ result.await.unwrap_err(),
"Can't edit file: path is a directory"
);
}