1use super::edit_file_tool::EditFileTool;
2use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
3use super::save_file_tool::SaveFileTool;
4use crate::{
5 AgentTool, Templates, Thread, ToolCallEventStream, ToolPermissionDecision,
6 decide_permission_from_settings, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
7};
8use acp_thread::Diff;
9use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
10use anyhow::{Context as _, Result, anyhow};
11use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
12use language::LanguageRegistry;
13use language_model::LanguageModelToolResultContent;
14use paths;
15use project::{Project, ProjectPath};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use std::ffi::OsStr;
20use std::ops::Range;
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23use text::BufferSnapshot;
24use ui::SharedString;
25use util::rel_path::RelPath;
26
27const DEFAULT_UI_TEXT: &str = "Editing file";
28
29/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
30///
31/// Before using this tool:
32///
33/// 1. Use the `read_file` tool to understand the file's contents and context
34///
35/// 2. Verify the directory path is correct (only applicable when creating new files):
36/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
37#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
38pub struct StreamingEditFileToolInput {
39 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
40 ///
41 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
42 ///
43 /// NEVER mention the file path in this description.
44 ///
45 /// <example>Fix API endpoint URLs</example>
46 /// <example>Update copyright year in `page_footer`</example>
47 ///
48 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
49 pub display_description: String,
50
51 /// The full path of the file to create or modify in the project.
52 ///
53 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
54 ///
55 /// The following examples assume we have two root directories in the project:
56 /// - /a/b/backend
57 /// - /c/d/frontend
58 ///
59 /// <example>
60 /// `backend/src/main.rs`
61 ///
62 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
63 /// </example>
64 ///
65 /// <example>
66 /// `frontend/db.js`
67 /// </example>
68 pub path: PathBuf,
69
70 /// The mode of operation on the file. Possible values:
71 /// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
72 /// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
73 /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
74 ///
75 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
76 pub mode: StreamingEditFileMode,
77
78 /// The complete content for the new file (required for 'create' and 'overwrite' modes).
79 /// This field should contain the entire file content.
80 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub content: Option<String>,
82
83 /// List of edit operations to apply sequentially (required for 'edit' mode).
84 /// Each edit finds `old_text` in the file and replaces it with `new_text`.
85 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub edits: Option<Vec<EditOperation>>,
87}
88
89#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
90#[serde(rename_all = "snake_case")]
91pub enum StreamingEditFileMode {
92 /// Create a new file if it doesn't exist
93 Create,
94 /// Replace the entire contents of an existing file
95 Overwrite,
96 /// Make granular edits to an existing file
97 Edit,
98}
99
100/// A single edit operation that replaces old text with new text
101#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
102pub struct EditOperation {
103 /// The exact text to find in the file. This will be matched using fuzzy matching
104 /// to handle minor differences in whitespace or formatting.
105 pub old_text: String,
106 /// The text to replace it with
107 pub new_text: String,
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
111struct StreamingEditFileToolPartialInput {
112 #[serde(default)]
113 path: String,
114 #[serde(default)]
115 display_description: String,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119pub struct StreamingEditFileToolOutput {
120 #[serde(alias = "original_path")]
121 input_path: PathBuf,
122 new_text: String,
123 old_text: Arc<String>,
124 #[serde(default)]
125 diff: String,
126}
127
128impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
129 fn from(output: StreamingEditFileToolOutput) -> Self {
130 if output.diff.is_empty() {
131 "No edits were made.".into()
132 } else {
133 format!(
134 "Edited {}:\n\n```diff\n{}\n```",
135 output.input_path.display(),
136 output.diff
137 )
138 .into()
139 }
140 }
141}
142
143pub struct StreamingEditFileTool {
144 thread: WeakEntity<Thread>,
145 language_registry: Arc<LanguageRegistry>,
146 project: Entity<Project>,
147 #[allow(dead_code)]
148 templates: Arc<Templates>,
149}
150
151impl StreamingEditFileTool {
152 pub fn new(
153 project: Entity<Project>,
154 thread: WeakEntity<Thread>,
155 language_registry: Arc<LanguageRegistry>,
156 templates: Arc<Templates>,
157 ) -> Self {
158 Self {
159 project,
160 thread,
161 language_registry,
162 templates,
163 }
164 }
165
166 fn authorize(
167 &self,
168 input: &StreamingEditFileToolInput,
169 event_stream: &ToolCallEventStream,
170 cx: &mut App,
171 ) -> Task<Result<()>> {
172 let path_str = input.path.to_string_lossy();
173 let settings = agent_settings::AgentSettings::get_global(cx);
174 let decision = decide_permission_from_settings(Self::NAME, &path_str, settings);
175
176 match decision {
177 ToolPermissionDecision::Allow => return Task::ready(Ok(())),
178 ToolPermissionDecision::Deny(reason) => {
179 return Task::ready(Err(anyhow!("{}", reason)));
180 }
181 ToolPermissionDecision::Confirm => {}
182 }
183
184 let local_settings_folder = paths::local_settings_folder_name();
185 let path = Path::new(&input.path);
186 if path.components().any(|component| {
187 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
188 }) {
189 let context = crate::ToolPermissionContext {
190 tool_name: EditFileTool::NAME.to_string(),
191 input_value: path_str.to_string(),
192 };
193 return event_stream.authorize(
194 format!("{} (local settings)", input.display_description),
195 context,
196 cx,
197 );
198 }
199
200 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
201 && canonical_path.starts_with(paths::config_dir())
202 {
203 let context = crate::ToolPermissionContext {
204 tool_name: EditFileTool::NAME.to_string(),
205 input_value: path_str.to_string(),
206 };
207 return event_stream.authorize(
208 format!("{} (global settings)", input.display_description),
209 context,
210 cx,
211 );
212 }
213
214 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
215 thread.project().read(cx).find_project_path(&input.path, cx)
216 }) else {
217 return Task::ready(Err(anyhow!("thread was dropped")));
218 };
219
220 if project_path.is_some() {
221 Task::ready(Ok(()))
222 } else {
223 let context = crate::ToolPermissionContext {
224 tool_name: EditFileTool::NAME.to_string(),
225 input_value: path_str.to_string(),
226 };
227 event_stream.authorize(&input.display_description, context, cx)
228 }
229 }
230}
231
232impl AgentTool for StreamingEditFileTool {
233 type Input = StreamingEditFileToolInput;
234 type Output = StreamingEditFileToolOutput;
235
236 const NAME: &'static str = "streaming_edit_file";
237
238 fn kind() -> acp::ToolKind {
239 acp::ToolKind::Edit
240 }
241
242 fn initial_title(
243 &self,
244 input: Result<Self::Input, serde_json::Value>,
245 cx: &mut App,
246 ) -> SharedString {
247 match input {
248 Ok(input) => self
249 .project
250 .read(cx)
251 .find_project_path(&input.path, cx)
252 .and_then(|project_path| {
253 self.project
254 .read(cx)
255 .short_full_path_for_project_path(&project_path, cx)
256 })
257 .unwrap_or(input.path.to_string_lossy().into_owned())
258 .into(),
259 Err(raw_input) => {
260 if let Some(input) =
261 serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
262 {
263 let path = input.path.trim();
264 if !path.is_empty() {
265 return self
266 .project
267 .read(cx)
268 .find_project_path(&input.path, cx)
269 .and_then(|project_path| {
270 self.project
271 .read(cx)
272 .short_full_path_for_project_path(&project_path, cx)
273 })
274 .unwrap_or(input.path)
275 .into();
276 }
277
278 let description = input.display_description.trim();
279 if !description.is_empty() {
280 return description.to_string().into();
281 }
282 }
283
284 DEFAULT_UI_TEXT.into()
285 }
286 }
287 }
288
289 fn run(
290 self: Arc<Self>,
291 input: Self::Input,
292 event_stream: ToolCallEventStream,
293 cx: &mut App,
294 ) -> Task<Result<Self::Output>> {
295 let Ok(project) = self
296 .thread
297 .read_with(cx, |thread, _cx| thread.project().clone())
298 else {
299 return Task::ready(Err(anyhow!("thread was dropped")));
300 };
301
302 let project_path = match resolve_path(&input, project.clone(), cx) {
303 Ok(path) => path,
304 Err(err) => return Task::ready(Err(anyhow!(err))),
305 };
306
307 let abs_path = project.read(cx).absolute_path(&project_path, cx);
308 if let Some(abs_path) = abs_path.clone() {
309 event_stream.update_fields(
310 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
311 );
312 }
313
314 let authorize = self.authorize(&input, &event_stream, cx);
315
316 cx.spawn(async move |cx: &mut AsyncApp| {
317 authorize.await?;
318
319 let buffer = project
320 .update(cx, |project, cx| {
321 project.open_buffer(project_path.clone(), cx)
322 })
323 .await?;
324
325 if let Some(abs_path) = abs_path.as_ref() {
326 let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
327 self.thread.update(cx, |thread, cx| {
328 let last_read = thread.file_read_times.get(abs_path).copied();
329 let current = buffer
330 .read(cx)
331 .file()
332 .and_then(|file| file.disk_state().mtime());
333 let dirty = buffer.read(cx).is_dirty();
334 let has_save = thread.has_tool(SaveFileTool::NAME);
335 let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
336 (last_read, current, dirty, has_save, has_restore)
337 })?;
338
339 if is_dirty {
340 let message = match (has_save_tool, has_restore_tool) {
341 (true, true) => {
342 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
343 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
344 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
345 }
346 (true, false) => {
347 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
348 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
349 If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
350 }
351 (false, true) => {
352 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
353 If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
354 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
355 }
356 (false, false) => {
357 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
358 then ask them to save or revert the file manually and inform you when it's ok to proceed."
359 }
360 };
361 anyhow::bail!("{}", message);
362 }
363
364 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
365 if current != last_read {
366 anyhow::bail!(
367 "The file {} has been modified since you last read it. \
368 Please read the file again to get the current state before editing it.",
369 input.path.display()
370 );
371 }
372 }
373 }
374
375 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
376 event_stream.update_diff(diff.clone());
377 let _finalize_diff = util::defer({
378 let diff = diff.downgrade();
379 let mut cx = cx.clone();
380 move || {
381 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
382 }
383 });
384
385 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
386 let old_text = cx
387 .background_spawn({
388 let old_snapshot = old_snapshot.clone();
389 async move { Arc::new(old_snapshot.text()) }
390 })
391 .await;
392
393 let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
394
395 // Edit the buffer and report edits to the action log as part of the
396 // same effect cycle, otherwise the edit will be reported as if the
397 // user made it (due to the buffer subscription in action_log).
398 match input.mode {
399 StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
400 action_log.update(cx, |log, cx| {
401 log.buffer_created(buffer.clone(), cx);
402 });
403 let content = input.content.ok_or_else(|| {
404 anyhow!("'content' field is required for create and overwrite modes")
405 })?;
406 cx.update(|cx| {
407 buffer.update(cx, |buffer, cx| {
408 buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
409 });
410 action_log.update(cx, |log, cx| {
411 log.buffer_edited(buffer.clone(), cx);
412 });
413 });
414 }
415 StreamingEditFileMode::Edit => {
416 action_log.update(cx, |log, cx| {
417 log.buffer_read(buffer.clone(), cx);
418 });
419 let edits = input.edits.ok_or_else(|| {
420 anyhow!("'edits' field is required for edit mode")
421 })?;
422 // apply_edits now handles buffer_edited internally in the same effect cycle
423 apply_edits(&buffer, &action_log, &edits, &diff, &event_stream, &abs_path, cx)?;
424 }
425 }
426
427 project
428 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
429 .await?;
430
431 action_log.update(cx, |log, cx| {
432 log.buffer_edited(buffer.clone(), cx);
433 });
434
435 if let Some(abs_path) = abs_path.as_ref() {
436 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
437 buffer.file().and_then(|file| file.disk_state().mtime())
438 }) {
439 self.thread.update(cx, |thread, _| {
440 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
441 })?;
442 }
443 }
444
445 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
446 let (new_text, unified_diff) = cx
447 .background_spawn({
448 let new_snapshot = new_snapshot.clone();
449 let old_text = old_text.clone();
450 async move {
451 let new_text = new_snapshot.text();
452 let diff = language::unified_diff(&old_text, &new_text);
453 (new_text, diff)
454 }
455 })
456 .await;
457
458 let output = StreamingEditFileToolOutput {
459 input_path: input.path,
460 new_text,
461 old_text,
462 diff: unified_diff,
463 };
464
465 Ok(output)
466 })
467 }
468
469 fn replay(
470 &self,
471 _input: Self::Input,
472 output: Self::Output,
473 event_stream: ToolCallEventStream,
474 cx: &mut App,
475 ) -> Result<()> {
476 event_stream.update_diff(cx.new(|cx| {
477 Diff::finalized(
478 output.input_path.to_string_lossy().into_owned(),
479 Some(output.old_text.to_string()),
480 output.new_text,
481 self.language_registry.clone(),
482 cx,
483 )
484 }));
485 Ok(())
486 }
487}
488
489fn apply_edits(
490 buffer: &Entity<language::Buffer>,
491 action_log: &Entity<action_log::ActionLog>,
492 edits: &[EditOperation],
493 diff: &Entity<Diff>,
494 event_stream: &ToolCallEventStream,
495 abs_path: &Option<PathBuf>,
496 cx: &mut AsyncApp,
497) -> Result<()> {
498 let mut failed_edits = Vec::new();
499 let mut ambiguous_edits = Vec::new();
500 let mut resolved_edits: Vec<(Range<usize>, String)> = Vec::new();
501 let mut first_edit_line: Option<u32> = None;
502
503 // First pass: resolve all edits without applying them
504 for (index, edit) in edits.iter().enumerate() {
505 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
506 let result = resolve_edit(&snapshot, edit);
507
508 match result {
509 Ok(Some((range, new_text))) => {
510 if first_edit_line.is_none() {
511 first_edit_line = Some(snapshot.offset_to_point(range.start).row);
512 }
513 // Reveal the range in the diff view
514 let start_anchor =
515 buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(range.start));
516 let end_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_after(range.end));
517 diff.update(cx, |card, cx| {
518 card.reveal_range(start_anchor..end_anchor, cx)
519 });
520 resolved_edits.push((range, new_text));
521 }
522 Ok(None) => {
523 failed_edits.push(index);
524 }
525 Err(ranges) => {
526 ambiguous_edits.push((index, ranges));
527 }
528 }
529 }
530
531 // Check for errors before applying any edits
532 if !failed_edits.is_empty() {
533 let indices = failed_edits
534 .iter()
535 .map(|i| i.to_string())
536 .collect::<Vec<_>>()
537 .join(", ");
538 anyhow::bail!(
539 "Could not find matching text for edit(s) at index(es): {}. \
540 The old_text did not match any content in the file. \
541 Please read the file again to get the current content.",
542 indices
543 );
544 }
545
546 if !ambiguous_edits.is_empty() {
547 let details: Vec<String> = ambiguous_edits
548 .iter()
549 .map(|(index, ranges)| {
550 let lines = ranges
551 .iter()
552 .map(|r| r.start.to_string())
553 .collect::<Vec<_>>()
554 .join(", ");
555 format!("edit {}: matches at lines {}", index, lines)
556 })
557 .collect();
558 anyhow::bail!(
559 "Some edits matched multiple locations in the file:\n{}. \
560 Please provide more context in old_text to uniquely identify the location.",
561 details.join("\n")
562 );
563 }
564
565 // Emit location for the first edit
566 if let Some(line) = first_edit_line {
567 if let Some(abs_path) = abs_path.clone() {
568 event_stream.update_fields(
569 ToolCallUpdateFields::new()
570 .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
571 );
572 }
573 }
574
575 // Second pass: apply all edits and report to action_log in the same effect cycle.
576 // This prevents the buffer subscription from treating these as user edits.
577 if !resolved_edits.is_empty() {
578 cx.update(|cx| {
579 buffer.update(cx, |buffer, cx| {
580 // Apply edits in reverse order so offsets remain valid
581 let mut edits_sorted: Vec<_> = resolved_edits.into_iter().collect();
582 edits_sorted.sort_by(|a, b| b.0.start.cmp(&a.0.start));
583 for (range, new_text) in edits_sorted {
584 buffer.edit([(range, new_text.as_str())], None, cx);
585 }
586 });
587 action_log.update(cx, |log, cx| {
588 log.buffer_edited(buffer.clone(), cx);
589 });
590 });
591 }
592
593 Ok(())
594}
595
596/// Resolves an edit operation by finding the matching text in the buffer.
597/// Returns Ok(Some((range, new_text))) if a unique match is found,
598/// Ok(None) if no match is found, or Err(ranges) if multiple matches are found.
599fn resolve_edit(
600 snapshot: &BufferSnapshot,
601 edit: &EditOperation,
602) -> std::result::Result<Option<(Range<usize>, String)>, Vec<Range<usize>>> {
603 let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
604 matcher.push(&edit.old_text, None);
605 let matches = matcher.finish();
606
607 if matches.is_empty() {
608 return Ok(None);
609 }
610
611 if matches.len() > 1 {
612 return Err(matches);
613 }
614
615 let match_range = matches.into_iter().next().expect("checked len above");
616 Ok(Some((match_range, edit.new_text.clone())))
617}
618
619fn resolve_path(
620 input: &StreamingEditFileToolInput,
621 project: Entity<Project>,
622 cx: &mut App,
623) -> Result<ProjectPath> {
624 let project = project.read(cx);
625
626 match input.mode {
627 StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
628 let path = project
629 .find_project_path(&input.path, cx)
630 .context("Can't edit file: path not found")?;
631
632 let entry = project
633 .entry_for_path(&path, cx)
634 .context("Can't edit file: path not found")?;
635
636 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
637 Ok(path)
638 }
639
640 StreamingEditFileMode::Create => {
641 if let Some(path) = project.find_project_path(&input.path, cx) {
642 anyhow::ensure!(
643 project.entry_for_path(&path, cx).is_none(),
644 "Can't create file: file already exists"
645 );
646 }
647
648 let parent_path = input
649 .path
650 .parent()
651 .context("Can't create file: incorrect path")?;
652
653 let parent_project_path = project.find_project_path(&parent_path, cx);
654
655 let parent_entry = parent_project_path
656 .as_ref()
657 .and_then(|path| project.entry_for_path(path, cx))
658 .context("Can't create file: parent directory doesn't exist")?;
659
660 anyhow::ensure!(
661 parent_entry.is_dir(),
662 "Can't create file: parent is not a directory"
663 );
664
665 let file_name = input
666 .path
667 .file_name()
668 .and_then(|file_name| file_name.to_str())
669 .and_then(|file_name| RelPath::unix(file_name).ok())
670 .context("Can't create file: invalid filename")?;
671
672 let new_file_path = parent_project_path.map(|parent| ProjectPath {
673 path: parent.path.join(file_name),
674 ..parent
675 });
676
677 new_file_path.context("Can't create file")
678 }
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685 use crate::{ContextServerRegistry, Templates};
686 use gpui::TestAppContext;
687 use language_model::fake_provider::FakeLanguageModel;
688 use prompt_store::ProjectContext;
689 use serde_json::json;
690 use settings::SettingsStore;
691 use util::path;
692
693 #[gpui::test]
694 async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
695 init_test(cx);
696
697 let fs = project::FakeFs::new(cx.executor());
698 fs.insert_tree("/root", json!({"dir": {}})).await;
699 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
700 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
701 let context_server_registry =
702 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
703 let model = Arc::new(FakeLanguageModel::default());
704 let thread = cx.new(|cx| {
705 crate::Thread::new(
706 project.clone(),
707 cx.new(|_cx| ProjectContext::default()),
708 context_server_registry,
709 Templates::new(),
710 Some(model),
711 cx,
712 )
713 });
714
715 let result = cx
716 .update(|cx| {
717 let input = StreamingEditFileToolInput {
718 display_description: "Create new file".into(),
719 path: "root/dir/new_file.txt".into(),
720 mode: StreamingEditFileMode::Create,
721 content: Some("Hello, World!".into()),
722 edits: None,
723 };
724 Arc::new(StreamingEditFileTool::new(
725 project.clone(),
726 thread.downgrade(),
727 language_registry,
728 Templates::new(),
729 ))
730 .run(input, ToolCallEventStream::test().0, cx)
731 })
732 .await;
733
734 assert!(result.is_ok());
735 let output = result.unwrap();
736 assert_eq!(output.new_text, "Hello, World!");
737 assert!(!output.diff.is_empty());
738 }
739
740 #[gpui::test]
741 async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
742 init_test(cx);
743
744 let fs = project::FakeFs::new(cx.executor());
745 fs.insert_tree("/root", json!({"file.txt": "old content"}))
746 .await;
747 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
748 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
749 let context_server_registry =
750 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
751 let model = Arc::new(FakeLanguageModel::default());
752 let thread = cx.new(|cx| {
753 crate::Thread::new(
754 project.clone(),
755 cx.new(|_cx| ProjectContext::default()),
756 context_server_registry,
757 Templates::new(),
758 Some(model),
759 cx,
760 )
761 });
762
763 let result = cx
764 .update(|cx| {
765 let input = StreamingEditFileToolInput {
766 display_description: "Overwrite file".into(),
767 path: "root/file.txt".into(),
768 mode: StreamingEditFileMode::Overwrite,
769 content: Some("new content".into()),
770 edits: None,
771 };
772 Arc::new(StreamingEditFileTool::new(
773 project.clone(),
774 thread.downgrade(),
775 language_registry,
776 Templates::new(),
777 ))
778 .run(input, ToolCallEventStream::test().0, cx)
779 })
780 .await;
781
782 assert!(result.is_ok());
783 let output = result.unwrap();
784 assert_eq!(output.new_text, "new content");
785 assert_eq!(*output.old_text, "old content");
786 }
787
788 #[gpui::test]
789 async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
790 init_test(cx);
791
792 let fs = project::FakeFs::new(cx.executor());
793 fs.insert_tree(
794 "/root",
795 json!({
796 "file.txt": "line 1\nline 2\nline 3\n"
797 }),
798 )
799 .await;
800 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
801 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
802 let context_server_registry =
803 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
804 let model = Arc::new(FakeLanguageModel::default());
805 let thread = cx.new(|cx| {
806 crate::Thread::new(
807 project.clone(),
808 cx.new(|_cx| ProjectContext::default()),
809 context_server_registry,
810 Templates::new(),
811 Some(model),
812 cx,
813 )
814 });
815
816 let result = cx
817 .update(|cx| {
818 let input = StreamingEditFileToolInput {
819 display_description: "Edit lines".into(),
820 path: "root/file.txt".into(),
821 mode: StreamingEditFileMode::Edit,
822 content: None,
823 edits: Some(vec![EditOperation {
824 old_text: "line 2".into(),
825 new_text: "modified line 2".into(),
826 }]),
827 };
828 Arc::new(StreamingEditFileTool::new(
829 project.clone(),
830 thread.downgrade(),
831 language_registry,
832 Templates::new(),
833 ))
834 .run(input, ToolCallEventStream::test().0, cx)
835 })
836 .await;
837
838 assert!(result.is_ok());
839 let output = result.unwrap();
840 assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
841 }
842
843 #[gpui::test]
844 async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
845 init_test(cx);
846
847 let fs = project::FakeFs::new(cx.executor());
848 fs.insert_tree("/root", json!({})).await;
849 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
850 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
851 let context_server_registry =
852 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
853 let model = Arc::new(FakeLanguageModel::default());
854 let thread = cx.new(|cx| {
855 crate::Thread::new(
856 project.clone(),
857 cx.new(|_cx| ProjectContext::default()),
858 context_server_registry,
859 Templates::new(),
860 Some(model),
861 cx,
862 )
863 });
864
865 let result = cx
866 .update(|cx| {
867 let input = StreamingEditFileToolInput {
868 display_description: "Some edit".into(),
869 path: "root/nonexistent_file.txt".into(),
870 mode: StreamingEditFileMode::Edit,
871 content: None,
872 edits: Some(vec![EditOperation {
873 old_text: "foo".into(),
874 new_text: "bar".into(),
875 }]),
876 };
877 Arc::new(StreamingEditFileTool::new(
878 project,
879 thread.downgrade(),
880 language_registry,
881 Templates::new(),
882 ))
883 .run(input, ToolCallEventStream::test().0, cx)
884 })
885 .await;
886
887 assert_eq!(
888 result.unwrap_err().to_string(),
889 "Can't edit file: path not found"
890 );
891 }
892
893 #[gpui::test]
894 async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
895 init_test(cx);
896
897 let fs = project::FakeFs::new(cx.executor());
898 fs.insert_tree("/root", json!({"file.txt": "hello world"}))
899 .await;
900 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
901 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
902 let context_server_registry =
903 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
904 let model = Arc::new(FakeLanguageModel::default());
905 let thread = cx.new(|cx| {
906 crate::Thread::new(
907 project.clone(),
908 cx.new(|_cx| ProjectContext::default()),
909 context_server_registry,
910 Templates::new(),
911 Some(model),
912 cx,
913 )
914 });
915
916 let result = cx
917 .update(|cx| {
918 let input = StreamingEditFileToolInput {
919 display_description: "Edit file".into(),
920 path: "root/file.txt".into(),
921 mode: StreamingEditFileMode::Edit,
922 content: None,
923 edits: Some(vec![EditOperation {
924 old_text: "nonexistent text that is not in the file".into(),
925 new_text: "replacement".into(),
926 }]),
927 };
928 Arc::new(StreamingEditFileTool::new(
929 project,
930 thread.downgrade(),
931 language_registry,
932 Templates::new(),
933 ))
934 .run(input, ToolCallEventStream::test().0, cx)
935 })
936 .await;
937
938 assert!(result.is_err());
939 assert!(
940 result
941 .unwrap_err()
942 .to_string()
943 .contains("Could not find matching text")
944 );
945 }
946
947 fn init_test(cx: &mut TestAppContext) {
948 cx.update(|cx| {
949 let settings_store = SettingsStore::test(cx);
950 cx.set_global(settings_store);
951 });
952 }
953}