@@ -73,7 +73,7 @@ pub struct StreamingEditFileToolInput {
/// <example>
/// `frontend/db.js`
/// </example>
- pub path: String,
+ pub path: PathBuf,
/// The mode of operation on the file. Possible values:
/// - 'write': Replace the entire contents of the file. If the file doesn't exist, it will be created. Requires 'content' field.
@@ -93,7 +93,7 @@ pub struct StreamingEditFileToolInput {
pub edits: Option<Vec<Edit>>,
}
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum StreamingEditFileMode {
/// Overwrite the file with new content (replacing any existing content).
@@ -187,20 +187,23 @@ impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
}
pub struct StreamingEditFileTool {
+ project: Entity<Project>,
thread: WeakEntity<Thread>,
+ action_log: Entity<ActionLog>,
language_registry: Arc<LanguageRegistry>,
- project: Entity<Project>,
}
impl StreamingEditFileTool {
pub fn new(
project: Entity<Project>,
thread: WeakEntity<Thread>,
+ action_log: Entity<ActionLog>,
language_registry: Arc<LanguageRegistry>,
) -> Self {
Self {
project,
thread,
+ action_log,
language_registry,
}
}
@@ -264,11 +267,11 @@ impl AgentTool for StreamingEditFileTool {
.read(cx)
.short_full_path_for_project_path(&project_path, cx)
})
- .unwrap_or(input.path)
+ .unwrap_or(input.path.to_string_lossy().into_owned())
.into(),
Err(raw_input) => {
- if let Some(input) =
- serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
+ if let Ok(input) =
+ serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input)
{
let path = input.path.unwrap_or_default();
let path = path.trim();
@@ -311,24 +314,37 @@ impl AgentTool for StreamingEditFileTool {
partial = input.recv_partial().fuse() => {
let Some(partial_value) = partial else { break };
if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial_value) {
- if state.is_none() && let Some(path_str) = &parsed.path
- && let Some(display_description) = &parsed.display_description
- && let Some(mode) = parsed.mode.clone() {
- state = Some(
- EditSession::new(
- path_str,
- display_description,
- mode,
- &self,
- &event_stream,
- cx,
- )
- .await?,
- );
+ if state.is_none()
+ && let StreamingEditFileToolPartialInput {
+ path: Some(path),
+ display_description: Some(display_description),
+ mode: Some(mode),
+ ..
+ } = &parsed
+ {
+ match EditSession::new(
+ &PathBuf::from(path),
+ display_description,
+ *mode,
+ &self,
+ &event_stream,
+ cx,
+ )
+ .await
+ {
+ Ok(session) => state = Some(session),
+ Err(e) => {
+ log::error!("Failed to create edit session: {}", e);
+ return Err(e);
+ }
+ }
}
if let Some(state) = &mut state {
- state.process(parsed, &self, &event_stream, cx)?;
+ if let Err(e) = state.process(parsed, &self, &event_stream, cx) {
+ log::error!("Failed to process edit: {}", e);
+ return Err(e);
+ }
}
}
}
@@ -341,22 +357,39 @@ impl AgentTool for StreamingEditFileTool {
input
.recv()
.await
- .map_err(|e| StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")))?;
+ .map_err(|e| {
+ let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}"));
+ log::error!("Failed to receive tool input: {e}");
+ err
+ })?;
let mut state = if let Some(state) = state {
state
} else {
- EditSession::new(
+ match EditSession::new(
&full_input.path,
&full_input.display_description,
- full_input.mode.clone(),
+ full_input.mode,
&self,
&event_stream,
cx,
)
- .await?
+ .await
+ {
+ Ok(session) => session,
+ Err(e) => {
+ log::error!("Failed to create edit session: {}", e);
+ return Err(e);
+ }
+ }
};
- state.finalize(full_input, &self, &event_stream, cx).await
+ match state.finalize(full_input, &self, &event_stream, cx).await {
+ Ok(output) => Ok(output),
+ Err(e) => {
+ log::error!("Failed to finalize edit: {}", e);
+ Err(e)
+ }
+ }
})
}
@@ -442,30 +475,24 @@ impl EditPipeline {
}
}
-/// Compute the `LineIndent` of the first line in a set of query lines.
-fn query_first_line_indent(query_lines: &[String]) -> text::LineIndent {
- let first_line = query_lines.first().map(|s| s.as_str()).unwrap_or("");
- text::LineIndent::from_iter(first_line.chars())
-}
-
impl EditSession {
async fn new(
- path_str: &str,
+ path: &PathBuf,
display_description: &str,
mode: StreamingEditFileMode,
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
) -> Result<Self, StreamingEditFileToolOutput> {
- let path = PathBuf::from(path_str);
let project_path = cx
- .update(|cx| resolve_path(mode.clone(), &path, &tool.project, cx))
+ .update(|cx| resolve_path(mode, &path, &tool.project, cx))
.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx))
else {
return Err(StreamingEditFileToolOutput::error(format!(
- "Worktree at '{path_str}' does not exist"
+ "Worktree at '{}' does not exist",
+ path.to_string_lossy()
)));
};
@@ -483,12 +510,7 @@ impl EditSession {
.await
.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
- let action_log = tool
- .thread
- .read_with(cx, |thread, _cx| thread.action_log().clone())
- .ok();
-
- ensure_buffer_saved(&buffer, &abs_path, tool, action_log.as_ref(), cx)?;
+ ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
event_stream.update_diff(diff.clone());
@@ -500,9 +522,8 @@ impl EditSession {
}
}) as Box<dyn FnOnce()>);
- if let Some(action_log) = &action_log {
- action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
- }
+ tool.action_log
+ .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let old_text = cx
@@ -531,69 +552,31 @@ impl EditSession {
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
) -> Result<StreamingEditFileToolOutput, StreamingEditFileToolOutput> {
- let Self {
- buffer,
- old_text,
- diff,
- abs_path,
- parser,
- pipeline,
- ..
- } = self;
-
- let action_log = tool
- .thread
- .read_with(cx, |thread, _cx| thread.action_log().clone())
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+ let old_text = self.old_text.clone();
match input.mode {
StreamingEditFileMode::Write => {
- action_log.update(cx, |log, cx| {
- log.buffer_created(buffer.clone(), cx);
- });
let content = input.content.ok_or_else(|| {
StreamingEditFileToolOutput::error("'content' field is required for write mode")
})?;
- let events = parser.finalize_content(&content);
- Self::process_events(
- &events,
- buffer,
- diff,
- pipeline,
- abs_path,
- tool,
- event_stream,
- cx,
- )?;
+ let events = self.parser.finalize_content(&content);
+ self.process_events(&events, tool, event_stream, cx)?;
+
+ tool.action_log.update(cx, |log, cx| {
+ log.buffer_created(self.buffer.clone(), cx);
+ });
}
StreamingEditFileMode::Edit => {
let edits = input.edits.ok_or_else(|| {
StreamingEditFileToolOutput::error("'edits' field is required for edit mode")
})?;
-
- let final_edits = edits
- .into_iter()
- .map(|e| Edit {
- old_text: e.old_text,
- new_text: e.new_text,
- })
- .collect::<Vec<_>>();
- let events = parser.finalize_edits(&final_edits);
- Self::process_events(
- &events,
- buffer,
- diff,
- pipeline,
- abs_path,
- tool,
- event_stream,
- cx,
- )?;
+ let events = self.parser.finalize_edits(&edits);
+ self.process_events(&events, tool, event_stream, cx)?;
}
}
- let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
+ let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| {
let settings = language_settings::language_settings(
buffer.language().map(|l| l.name()),
buffer.file(),
@@ -603,13 +586,13 @@ impl EditSession {
});
if format_on_save_enabled {
- action_log.update(cx, |log, cx| {
- log.buffer_edited(buffer.clone(), cx);
+ tool.action_log.update(cx, |log, cx| {
+ log.buffer_edited(self.buffer.clone(), cx);
});
let format_task = tool.project.update(cx, |project, cx| {
project.format(
- HashSet::from_iter([buffer.clone()]),
+ HashSet::from_iter([self.buffer.clone()]),
LspFormatTarget::Buffers,
false,
FormatTrigger::Save,
@@ -624,9 +607,9 @@ impl EditSession {
};
}
- let save_task = tool
- .project
- .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx));
+ let save_task = tool.project.update(cx, |project, cx| {
+ project.save_buffer(self.buffer.clone(), cx)
+ });
futures::select! {
result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; },
_ = event_stream.cancelled_by_user().fuse() => {
@@ -634,11 +617,11 @@ impl EditSession {
}
};
- action_log.update(cx, |log, cx| {
- log.buffer_edited(buffer.clone(), cx);
+ tool.action_log.update(cx, |log, cx| {
+ log.buffer_edited(self.buffer.clone(), cx);
});
- let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let (new_text, unified_diff) = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
@@ -652,7 +635,7 @@ impl EditSession {
.await;
let output = StreamingEditFileToolOutput::Success {
- input_path: PathBuf::from(input.path),
+ input_path: input.path,
new_text,
old_text: old_text.clone(),
diff: unified_diff,
@@ -671,31 +654,13 @@ impl EditSession {
StreamingEditFileMode::Write => {
if let Some(content) = &partial.content {
let events = self.parser.push_content(content);
- Self::process_events(
- &events,
- &self.buffer,
- &self.diff,
- &mut self.pipeline,
- &self.abs_path,
- tool,
- event_stream,
- cx,
- )?;
+ self.process_events(&events, tool, event_stream, cx)?;
}
}
StreamingEditFileMode::Edit => {
if let Some(edits) = partial.edits {
let events = self.parser.push_edits(&edits);
- Self::process_events(
- &events,
- &self.buffer,
- &self.diff,
- &mut self.pipeline,
- &self.abs_path,
- tool,
- event_stream,
- cx,
- )?;
+ self.process_events(&events, tool, event_stream, cx)?;
}
}
}
@@ -703,46 +668,38 @@ impl EditSession {
}
fn process_events(
+ &mut self,
events: &[ToolEditEvent],
- buffer: &Entity<Buffer>,
- diff: &Entity<Diff>,
- pipeline: &mut EditPipeline,
- abs_path: &PathBuf,
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
) -> Result<(), StreamingEditFileToolOutput> {
- let action_log = tool
- .thread
- .read_with(cx, |thread, _cx| thread.action_log().clone())
- .ok();
-
for event in events {
match event {
ToolEditEvent::ContentChunk { chunk } => {
- let (buffer_id, insert_at) = buffer.read_with(cx, |buffer, _cx| {
- let insert_at = if !pipeline.content_written && buffer.len() > 0 {
- 0..buffer.len()
- } else {
- let len = buffer.len();
- len..len
- };
- (buffer.remote_id(), insert_at)
- });
+ let (buffer_id, buffer_len) = self
+ .buffer
+ .read_with(cx, |buffer, _cx| (buffer.remote_id(), buffer.len()));
+ let edit_range = if self.pipeline.content_written {
+ buffer_len..buffer_len
+ } else {
+ 0..buffer_len
+ };
+
agent_edit_buffer(
- buffer,
- [(insert_at, chunk.as_str())],
- action_log.as_ref(),
+ &self.buffer,
+ [(edit_range, chunk.as_str())],
+ &tool.action_log,
cx,
);
cx.update(|cx| {
tool.set_agent_location(
- buffer.downgrade(),
+ self.buffer.downgrade(),
text::Anchor::max_for_buffer(buffer_id),
cx,
);
});
- pipeline.content_written = true;
+ self.pipeline.content_written = true;
}
ToolEditEvent::OldTextChunk {
@@ -750,23 +707,24 @@ impl EditSession {
chunk,
done: false,
} => {
- pipeline.ensure_resolving_old_text(*edit_index, buffer, cx);
+ self.pipeline
+ .ensure_resolving_old_text(*edit_index, &self.buffer, cx);
if let EditPipelineEntry::ResolvingOldText { matcher } =
- &mut pipeline.edits[*edit_index]
+ &mut self.pipeline.edits[*edit_index]
+ && !chunk.is_empty()
{
- if !chunk.is_empty() {
- if let Some(match_range) = matcher.push(chunk, None) {
- let anchor_range = buffer.read_with(cx, |buffer, _cx| {
- buffer.anchor_range_between(match_range.clone())
- });
- diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
-
- cx.update(|cx| {
- let position = buffer.read(cx).anchor_before(match_range.end);
- tool.set_agent_location(buffer.downgrade(), position, cx);
- });
- }
+ if let Some(match_range) = matcher.push(chunk, None) {
+ let anchor_range = self.buffer.read_with(cx, |buffer, _cx| {
+ buffer.anchor_range_between(match_range.clone())
+ });
+ self.diff
+ .update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
+
+ cx.update(|cx| {
+ let position = self.buffer.read(cx).anchor_before(match_range.end);
+ tool.set_agent_location(self.buffer.downgrade(), position, cx);
+ });
}
}
}
@@ -776,10 +734,11 @@ impl EditSession {
chunk,
done: true,
} => {
- pipeline.ensure_resolving_old_text(*edit_index, buffer, cx);
+ self.pipeline
+ .ensure_resolving_old_text(*edit_index, &self.buffer, cx);
let EditPipelineEntry::ResolvingOldText { matcher } =
- &mut pipeline.edits[*edit_index]
+ &mut self.pipeline.edits[*edit_index]
else {
continue;
};
@@ -787,60 +746,47 @@ impl EditSession {
if !chunk.is_empty() {
matcher.push(chunk, None);
}
- let matches = matcher.finish();
-
- if matches.is_empty() {
- return Err(StreamingEditFileToolOutput::error(format!(
- "Could not find matching text for edit at index {}. \
- The old_text did not match any content in the file. \
- Please read the file again to get the current content.",
- edit_index,
- )));
- }
- if matches.len() > 1 {
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let lines = matches
- .iter()
- .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
- .collect::<Vec<_>>()
- .join(", ");
- return Err(StreamingEditFileToolOutput::error(format!(
- "Edit {} matched multiple locations in the file at lines: {}. \
- Please provide more context in old_text to uniquely \
- identify the location.",
- edit_index, lines
- )));
- }
-
- let range = matches.into_iter().next().expect("checked len above");
+ let range = extract_match(matcher.finish(), &self.buffer, edit_index, cx)?;
- let anchor_range = buffer
+ let anchor_range = self
+ .buffer
.read_with(cx, |buffer, _cx| buffer.anchor_range_between(range.clone()));
- diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
+ self.diff
+ .update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let line = snapshot.offset_to_point(range.start).row;
event_stream.update_fields(
- ToolCallUpdateFields::new()
- .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
+ ToolCallUpdateFields::new().locations(vec![
+ ToolCallLocation::new(&self.abs_path).line(Some(line)),
+ ]),
);
let EditPipelineEntry::ResolvingOldText { matcher } =
- &pipeline.edits[*edit_index]
+ &self.pipeline.edits[*edit_index]
else {
continue;
};
let buffer_indent =
snapshot.line_indent_for_row(snapshot.offset_to_point(range.start).row);
- let query_indent = query_first_line_indent(matcher.query_lines());
+ let query_indent = text::LineIndent::from_iter(
+ matcher
+ .query_lines()
+ .first()
+ .map(|s| s.as_str())
+ .unwrap_or("")
+ .chars(),
+ );
let indent_delta = compute_indent_delta(buffer_indent, query_indent);
let old_text_in_buffer =
snapshot.text_for_range(range.clone()).collect::<String>();
- let text_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot());
- pipeline.edits[*edit_index] = EditPipelineEntry::StreamingNewText {
+ let text_snapshot = self
+ .buffer
+ .read_with(cx, |buffer, _cx| buffer.text_snapshot());
+ self.pipeline.edits[*edit_index] = EditPipelineEntry::StreamingNewText {
streaming_diff: StreamingDiff::new(old_text_in_buffer),
edit_cursor: range.start,
reindenter: Reindenter::new(indent_delta),
@@ -848,8 +794,8 @@ impl EditSession {
};
cx.update(|cx| {
- let position = buffer.read(cx).anchor_before(range.end);
- tool.set_agent_location(buffer.downgrade(), position, cx);
+ let position = self.buffer.read(cx).anchor_before(range.end);
+ tool.set_agent_location(self.buffer.downgrade(), position, cx);
});
}
@@ -858,7 +804,7 @@ impl EditSession {
chunk,
done: false,
} => {
- if *edit_index >= pipeline.edits.len() {
+ if *edit_index >= self.pipeline.edits.len() {
continue;
}
let EditPipelineEntry::StreamingNewText {
@@ -867,7 +813,7 @@ impl EditSession {
reindenter,
original_snapshot,
..
- } = &mut pipeline.edits[*edit_index]
+ } = &mut self.pipeline.edits[*edit_index]
else {
continue;
};
@@ -878,18 +824,18 @@ impl EditSession {
}
let char_ops = streaming_diff.push_new(&reindented);
- Self::apply_char_operations(
+ apply_char_operations(
&char_ops,
- buffer,
+ &self.buffer,
original_snapshot,
edit_cursor,
- action_log.as_ref(),
+ &tool.action_log,
cx,
);
let position = original_snapshot.anchor_before(*edit_cursor);
cx.update(|cx| {
- tool.set_agent_location(buffer.downgrade(), position, cx);
+ tool.set_agent_location(self.buffer.downgrade(), position, cx);
});
}
@@ -898,7 +844,7 @@ impl EditSession {
chunk,
done: true,
} => {
- if *edit_index >= pipeline.edits.len() {
+ if *edit_index >= self.pipeline.edits.len() {
continue;
}
@@ -908,7 +854,7 @@ impl EditSession {
mut reindenter,
original_snapshot,
} = std::mem::replace(
- &mut pipeline.edits[*edit_index],
+ &mut self.pipeline.edits[*edit_index],
EditPipelineEntry::Done,
)
else {
@@ -921,64 +867,95 @@ impl EditSession {
if !final_text.is_empty() {
let char_ops = streaming_diff.push_new(&final_text);
- Self::apply_char_operations(
+ apply_char_operations(
&char_ops,
- buffer,
+ &self.buffer,
&original_snapshot,
&mut edit_cursor,
- action_log.as_ref(),
+ &tool.action_log,
cx,
);
}
let remaining_ops = streaming_diff.finish();
- Self::apply_char_operations(
+ apply_char_operations(
&remaining_ops,
- buffer,
+ &self.buffer,
&original_snapshot,
&mut edit_cursor,
- action_log.as_ref(),
+ &tool.action_log,
cx,
);
let position = original_snapshot.anchor_before(edit_cursor);
cx.update(|cx| {
- tool.set_agent_location(buffer.downgrade(), position, cx);
+ tool.set_agent_location(self.buffer.downgrade(), position, cx);
});
}
}
}
Ok(())
}
+}
- fn apply_char_operations(
- ops: &[CharOperation],
- buffer: &Entity<Buffer>,
- snapshot: &text::BufferSnapshot,
- edit_cursor: &mut usize,
- action_log: Option<&Entity<ActionLog>>,
- cx: &mut AsyncApp,
- ) {
- for op in ops {
- match op {
- CharOperation::Insert { text } => {
- let anchor = snapshot.anchor_after(*edit_cursor);
- agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx);
- }
- CharOperation::Delete { bytes } => {
- let delete_end = *edit_cursor + bytes;
- let anchor_range = snapshot.anchor_range_around(*edit_cursor..delete_end);
- agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx);
- *edit_cursor = delete_end;
- }
- CharOperation::Keep { bytes } => {
- *edit_cursor += bytes;
- }
+fn apply_char_operations(
+ ops: &[CharOperation],
+ buffer: &Entity<Buffer>,
+ snapshot: &text::BufferSnapshot,
+ edit_cursor: &mut usize,
+ action_log: &Entity<ActionLog>,
+ cx: &mut AsyncApp,
+) {
+ for op in ops {
+ match op {
+ CharOperation::Insert { text } => {
+ let anchor = snapshot.anchor_after(*edit_cursor);
+ agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx);
+ }
+ CharOperation::Delete { bytes } => {
+ let delete_end = *edit_cursor + bytes;
+ let anchor_range = snapshot.anchor_range_around(*edit_cursor..delete_end);
+ agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx);
+ *edit_cursor = delete_end;
+ }
+ CharOperation::Keep { bytes } => {
+ *edit_cursor += bytes;
}
}
}
}
+fn extract_match(
+ matches: Vec<Range<usize>>,
+ buffer: &Entity<Buffer>,
+ edit_index: &usize,
+ cx: &mut AsyncApp,
+) -> Result<Range<usize>, StreamingEditFileToolOutput> {
+ match matches.len() {
+ 0 => Err(StreamingEditFileToolOutput::error(format!(
+ "Could not find matching text for edit at index {}. \
+ The old_text did not match any content in the file. \
+ Please read the file again to get the current content.",
+ edit_index,
+ ))),
+ 1 => Ok(matches.into_iter().next().unwrap()),
+ _ => {
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let lines = matches
+ .iter()
+ .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
+ .collect::<Vec<_>>()
+ .join(", ");
+ Err(StreamingEditFileToolOutput::error(format!(
+ "Edit {} matched multiple locations in the file at lines: {}. \
+ Please provide more context in old_text to uniquely \
+ identify the location.",
+ edit_index, lines
+ )))
+ }
+ }
+}
+
/// Edits a buffer and reports the edit to the action log in the same effect
/// cycle. This ensures the action log's subscription handler sees the version
/// already updated by `buffer_edited`, so it does not misattribute the agent's
@@ -986,7 +963,7 @@ impl EditSession {
fn agent_edit_buffer<I, S, T>(
buffer: &Entity<Buffer>,
edits: I,
- action_log: Option<&Entity<ActionLog>>,
+ action_log: &Entity<ActionLog>,
cx: &mut AsyncApp,
) where
I: IntoIterator<Item = (Range<S>, T)>,
@@ -997,9 +974,7 @@ fn agent_edit_buffer<I, S, T>(
buffer.update(cx, |buffer, cx| {
buffer.edit(edits, None, cx);
});
- if let Some(action_log) = action_log {
- action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
- }
+ action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
}
@@ -1007,11 +982,11 @@ fn ensure_buffer_saved(
buffer: &Entity<Buffer>,
abs_path: &PathBuf,
tool: &StreamingEditFileTool,
- action_log: Option<&Entity<ActionLog>>,
cx: &mut AsyncApp,
) -> Result<(), StreamingEditFileToolOutput> {
- let last_read_mtime =
- action_log.and_then(|log| log.read_with(cx, |log, _| log.file_read_time(abs_path)));
+ let last_read_mtime = tool
+ .action_log
+ .read_with(cx, |log, _| log.file_read_time(abs_path));
let check_result = tool.thread.read_with(cx, |thread, cx| {
let current = buffer
.read(cx)
@@ -1140,42 +1115,17 @@ mod tests {
#[gpui::test]
async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree("/root", json!({"dir": {}})).await;
- let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
- let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- crate::Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
-
+ let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await;
let result = cx
.update(|cx| {
- let input = StreamingEditFileToolInput {
- display_description: "Create new file".into(),
- path: "root/dir/new_file.txt".into(),
- mode: StreamingEditFileMode::Write,
- content: Some("Hello, World!".into()),
- edits: None,
- };
- Arc::new(StreamingEditFileTool::new(
- project.clone(),
- thread.downgrade(),
- language_registry,
- ))
- .run(
- ToolInput::resolved(input),
+ tool.clone().run(
+ ToolInput::resolved(StreamingEditFileToolInput {
+ display_description: "Create new file".into(),
+ path: "root/dir/new_file.txt".into(),
+ mode: StreamingEditFileMode::Write,
+ content: Some("Hello, World!".into()),
+ edits: None,
+ }),
ToolCallEventStream::test().0,
cx,
)
@@ -1191,43 +1141,18 @@ mod tests {
#[gpui::test]
async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree("/root", json!({"file.txt": "old content"}))
- .await;
- let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
- let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- crate::Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
-
+ let (tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"file.txt": "old content"})).await;
let result = cx
.update(|cx| {
- let input = StreamingEditFileToolInput {
- display_description: "Overwrite file".into(),
- path: "root/file.txt".into(),
- mode: StreamingEditFileMode::Write,
- content: Some("new content".into()),
- edits: None,
- };
- Arc::new(StreamingEditFileTool::new(
- project.clone(),
- thread.downgrade(),
- language_registry,
- ))
- .run(
- ToolInput::resolved(input),
+ tool.clone().run(
+ ToolInput::resolved(StreamingEditFileToolInput {
+ display_description: "Overwrite file".into(),
+ path: "root/file.txt".into(),
+ mode: StreamingEditFileMode::Write,
+ content: Some("new content".into()),
+ edits: None,
+ }),
ToolCallEventStream::test().0,
cx,
)
@@ -1246,51 +1171,21 @@ mod tests {
#[gpui::test]
async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "file.txt": "line 1\nline 2\nline 3\n"
- }),
- )
- .await;
- let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
- let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- crate::Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
-
+ let (tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
let result = cx
.update(|cx| {
- let input = StreamingEditFileToolInput {
- display_description: "Edit lines".into(),
- path: "root/file.txt".into(),
- mode: StreamingEditFileMode::Edit,
- content: None,
- edits: Some(vec![Edit {
- old_text: "line 2".into(),
- new_text: "modified line 2".into(),
- }]),
- };
- Arc::new(StreamingEditFileTool::new(
- project.clone(),
- thread.downgrade(),
- language_registry,
- ))
- .run(
- ToolInput::resolved(input),
+ tool.clone().run(
+ ToolInput::resolved(StreamingEditFileToolInput {
+ display_description: "Edit lines".into(),
+ path: "root/file.txt".into(),
+ mode: StreamingEditFileMode::Edit,
+ content: None,
+ edits: Some(vec![Edit {
+ old_text: "line 2".into(),
+ new_text: "modified line 2".into(),
+ }]),
+ }),
ToolCallEventStream::test().0,
cx,
)
@@ -1305,57 +1200,30 @@ mod tests {
#[gpui::test]
async fn test_streaming_edit_multiple_edits(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
- }),
+ let (tool, _project, _action_log, _fs, _thread) = setup_test(
+ cx,
+ json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
)
.await;
- let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
- let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- crate::Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
-
let result = cx
.update(|cx| {
- let input = StreamingEditFileToolInput {
- display_description: "Edit multiple lines".into(),
- path: "root/file.txt".into(),
- mode: StreamingEditFileMode::Edit,
- content: None,
- edits: Some(vec![
- Edit {
- old_text: "line 5".into(),
- new_text: "modified line 5".into(),
- },
- Edit {
- old_text: "line 1".into(),
- new_text: "modified line 1".into(),
- },
- ]),
- };
- Arc::new(StreamingEditFileTool::new(
- project.clone(),
- thread.downgrade(),
- language_registry,
- ))
- .run(
- ToolInput::resolved(input),
+ tool.clone().run(
+ ToolInput::resolved(StreamingEditFileToolInput {
+ display_description: "Edit multiple lines".into(),
+ path: "root/file.txt".into(),
+ mode: StreamingEditFileMode::Edit,
+ content: None,
+ edits: Some(vec![
+ Edit {
+ old_text: "line 5".into(),
+ new_text: "modified line 5".into(),
+ },
+ Edit {
+ old_text: "line 1".into(),
+ new_text: "modified line 1".into(),
+ },
+ ]),
+ }),
ToolCallEventStream::test().0,
cx,
)