1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream, ToolPermissionDecision,
3 decide_permission_from_settings, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use language::LanguageRegistry;
10use language_model::LanguageModelToolResultContent;
11use paths;
12use project::{Project, ProjectPath};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use settings::Settings;
16use std::ffi::OsStr;
17use std::ops::Range;
18use std::path::{Path, PathBuf};
19use std::sync::Arc;
20use text::BufferSnapshot;
21use ui::SharedString;
22use util::rel_path::RelPath;
23
24const DEFAULT_UI_TEXT: &str = "Editing file";
25
26/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
27///
28/// Before using this tool:
29///
30/// 1. Use the `read_file` tool to understand the file's contents and context
31///
32/// 2. Verify the directory path is correct (only applicable when creating new files):
33/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
34#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
35pub struct StreamingEditFileToolInput {
36 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
37 ///
38 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
39 ///
40 /// NEVER mention the file path in this description.
41 ///
42 /// <example>Fix API endpoint URLs</example>
43 /// <example>Update copyright year in `page_footer`</example>
44 ///
45 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
46 pub display_description: String,
47
48 /// The full path of the file to create or modify in the project.
49 ///
50 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
51 ///
52 /// The following examples assume we have two root directories in the project:
53 /// - /a/b/backend
54 /// - /c/d/frontend
55 ///
56 /// <example>
57 /// `backend/src/main.rs`
58 ///
59 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
60 /// </example>
61 ///
62 /// <example>
63 /// `frontend/db.js`
64 /// </example>
65 pub path: PathBuf,
66
67 /// The mode of operation on the file. Possible values:
68 /// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
69 /// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
70 /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
71 ///
72 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
73 pub mode: StreamingEditFileMode,
74
75 /// The complete content for the new file (required for 'create' and 'overwrite' modes).
76 /// This field should contain the entire file content.
77 #[serde(default, skip_serializing_if = "Option::is_none")]
78 pub content: Option<String>,
79
80 /// List of edit operations to apply sequentially (required for 'edit' mode).
81 /// Each edit finds `old_text` in the file and replaces it with `new_text`.
82 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub edits: Option<Vec<EditOperation>>,
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
87#[serde(rename_all = "snake_case")]
88pub enum StreamingEditFileMode {
89 /// Create a new file if it doesn't exist
90 Create,
91 /// Replace the entire contents of an existing file
92 Overwrite,
93 /// Make granular edits to an existing file
94 Edit,
95}
96
97/// A single edit operation that replaces old text with new text
98#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
99pub struct EditOperation {
100 /// The exact text to find in the file. This will be matched using fuzzy matching
101 /// to handle minor differences in whitespace or formatting.
102 pub old_text: String,
103 /// The text to replace it with
104 pub new_text: String,
105}
106
107#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
108struct StreamingEditFileToolPartialInput {
109 #[serde(default)]
110 path: String,
111 #[serde(default)]
112 display_description: String,
113}
114
115#[derive(Debug, Serialize, Deserialize)]
116pub struct StreamingEditFileToolOutput {
117 #[serde(alias = "original_path")]
118 input_path: PathBuf,
119 new_text: String,
120 old_text: Arc<String>,
121 #[serde(default)]
122 diff: String,
123}
124
125impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
126 fn from(output: StreamingEditFileToolOutput) -> Self {
127 if output.diff.is_empty() {
128 "No edits were made.".into()
129 } else {
130 format!(
131 "Edited {}:\n\n```diff\n{}\n```",
132 output.input_path.display(),
133 output.diff
134 )
135 .into()
136 }
137 }
138}
139
140pub struct StreamingEditFileTool {
141 thread: WeakEntity<Thread>,
142 language_registry: Arc<LanguageRegistry>,
143 project: Entity<Project>,
144 #[allow(dead_code)]
145 templates: Arc<Templates>,
146}
147
148impl StreamingEditFileTool {
149 pub fn new(
150 project: Entity<Project>,
151 thread: WeakEntity<Thread>,
152 language_registry: Arc<LanguageRegistry>,
153 templates: Arc<Templates>,
154 ) -> Self {
155 Self {
156 project,
157 thread,
158 language_registry,
159 templates,
160 }
161 }
162
163 fn authorize(
164 &self,
165 input: &StreamingEditFileToolInput,
166 event_stream: &ToolCallEventStream,
167 cx: &mut App,
168 ) -> Task<Result<()>> {
169 let path_str = input.path.to_string_lossy();
170 let settings = agent_settings::AgentSettings::get_global(cx);
171 let decision = decide_permission_from_settings(Self::name(), &path_str, settings);
172
173 match decision {
174 ToolPermissionDecision::Allow => return Task::ready(Ok(())),
175 ToolPermissionDecision::Deny(reason) => {
176 return Task::ready(Err(anyhow!("{}", reason)));
177 }
178 ToolPermissionDecision::Confirm => {}
179 }
180
181 let local_settings_folder = paths::local_settings_folder_name();
182 let path = Path::new(&input.path);
183 if path.components().any(|component| {
184 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
185 }) {
186 let context = crate::ToolPermissionContext {
187 tool_name: "edit_file".to_string(),
188 input_value: path_str.to_string(),
189 };
190 return event_stream.authorize(
191 format!("{} (local settings)", input.display_description),
192 context,
193 cx,
194 );
195 }
196
197 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
198 && canonical_path.starts_with(paths::config_dir())
199 {
200 let context = crate::ToolPermissionContext {
201 tool_name: "edit_file".to_string(),
202 input_value: path_str.to_string(),
203 };
204 return event_stream.authorize(
205 format!("{} (global settings)", input.display_description),
206 context,
207 cx,
208 );
209 }
210
211 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
212 thread.project().read(cx).find_project_path(&input.path, cx)
213 }) else {
214 return Task::ready(Err(anyhow!("thread was dropped")));
215 };
216
217 if project_path.is_some() {
218 Task::ready(Ok(()))
219 } else {
220 let context = crate::ToolPermissionContext {
221 tool_name: "edit_file".to_string(),
222 input_value: path_str.to_string(),
223 };
224 event_stream.authorize(&input.display_description, context, cx)
225 }
226 }
227}
228
229impl AgentTool for StreamingEditFileTool {
230 type Input = StreamingEditFileToolInput;
231 type Output = StreamingEditFileToolOutput;
232
233 fn name() -> &'static str {
234 "streaming_edit_file"
235 }
236
237 fn kind() -> acp::ToolKind {
238 acp::ToolKind::Edit
239 }
240
241 fn initial_title(
242 &self,
243 input: Result<Self::Input, serde_json::Value>,
244 cx: &mut App,
245 ) -> SharedString {
246 match input {
247 Ok(input) => self
248 .project
249 .read(cx)
250 .find_project_path(&input.path, cx)
251 .and_then(|project_path| {
252 self.project
253 .read(cx)
254 .short_full_path_for_project_path(&project_path, cx)
255 })
256 .unwrap_or(input.path.to_string_lossy().into_owned())
257 .into(),
258 Err(raw_input) => {
259 if let Some(input) =
260 serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
261 {
262 let path = input.path.trim();
263 if !path.is_empty() {
264 return self
265 .project
266 .read(cx)
267 .find_project_path(&input.path, cx)
268 .and_then(|project_path| {
269 self.project
270 .read(cx)
271 .short_full_path_for_project_path(&project_path, cx)
272 })
273 .unwrap_or(input.path)
274 .into();
275 }
276
277 let description = input.display_description.trim();
278 if !description.is_empty() {
279 return description.to_string().into();
280 }
281 }
282
283 DEFAULT_UI_TEXT.into()
284 }
285 }
286 }
287
288 fn run(
289 self: Arc<Self>,
290 input: Self::Input,
291 event_stream: ToolCallEventStream,
292 cx: &mut App,
293 ) -> Task<Result<Self::Output>> {
294 let Ok(project) = self
295 .thread
296 .read_with(cx, |thread, _cx| thread.project().clone())
297 else {
298 return Task::ready(Err(anyhow!("thread was dropped")));
299 };
300
301 let project_path = match resolve_path(&input, project.clone(), cx) {
302 Ok(path) => path,
303 Err(err) => return Task::ready(Err(anyhow!(err))),
304 };
305
306 let abs_path = project.read(cx).absolute_path(&project_path, cx);
307 if let Some(abs_path) = abs_path.clone() {
308 event_stream.update_fields(
309 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
310 );
311 }
312
313 let authorize = self.authorize(&input, &event_stream, cx);
314
315 cx.spawn(async move |cx: &mut AsyncApp| {
316 authorize.await?;
317
318 let buffer = project
319 .update(cx, |project, cx| {
320 project.open_buffer(project_path.clone(), cx)
321 })
322 .await?;
323
324 if let Some(abs_path) = abs_path.as_ref() {
325 let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
326 self.thread.update(cx, |thread, cx| {
327 let last_read = thread.file_read_times.get(abs_path).copied();
328 let current = buffer
329 .read(cx)
330 .file()
331 .and_then(|file| file.disk_state().mtime());
332 let dirty = buffer.read(cx).is_dirty();
333 let has_save = thread.has_tool("save_file");
334 let has_restore = thread.has_tool("restore_file_from_disk");
335 (last_read, current, dirty, has_save, has_restore)
336 })?;
337
338 if is_dirty {
339 let message = match (has_save_tool, has_restore_tool) {
340 (true, true) => {
341 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
342 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
343 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
344 }
345 (true, false) => {
346 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
347 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
348 If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
349 }
350 (false, true) => {
351 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
352 If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
353 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
354 }
355 (false, false) => {
356 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
357 then ask them to save or revert the file manually and inform you when it's ok to proceed."
358 }
359 };
360 anyhow::bail!("{}", message);
361 }
362
363 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
364 if current != last_read {
365 anyhow::bail!(
366 "The file {} has been modified since you last read it. \
367 Please read the file again to get the current state before editing it.",
368 input.path.display()
369 );
370 }
371 }
372 }
373
374 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
375 event_stream.update_diff(diff.clone());
376 let _finalize_diff = util::defer({
377 let diff = diff.downgrade();
378 let mut cx = cx.clone();
379 move || {
380 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
381 }
382 });
383
384 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
385 let old_text = cx
386 .background_spawn({
387 let old_snapshot = old_snapshot.clone();
388 async move { Arc::new(old_snapshot.text()) }
389 })
390 .await;
391
392 let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
393
394 // Edit the buffer and report edits to the action log as part of the
395 // same effect cycle, otherwise the edit will be reported as if the
396 // user made it (due to the buffer subscription in action_log).
397 match input.mode {
398 StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
399 action_log.update(cx, |log, cx| {
400 log.buffer_created(buffer.clone(), cx);
401 });
402 let content = input.content.ok_or_else(|| {
403 anyhow!("'content' field is required for create and overwrite modes")
404 })?;
405 cx.update(|cx| {
406 buffer.update(cx, |buffer, cx| {
407 buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
408 });
409 action_log.update(cx, |log, cx| {
410 log.buffer_edited(buffer.clone(), cx);
411 });
412 });
413 }
414 StreamingEditFileMode::Edit => {
415 action_log.update(cx, |log, cx| {
416 log.buffer_read(buffer.clone(), cx);
417 });
418 let edits = input.edits.ok_or_else(|| {
419 anyhow!("'edits' field is required for edit mode")
420 })?;
421 // apply_edits now handles buffer_edited internally in the same effect cycle
422 apply_edits(&buffer, &action_log, &edits, &diff, &event_stream, &abs_path, cx)?;
423 }
424 }
425
426 project
427 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
428 .await?;
429
430 action_log.update(cx, |log, cx| {
431 log.buffer_edited(buffer.clone(), cx);
432 });
433
434 if let Some(abs_path) = abs_path.as_ref() {
435 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
436 buffer.file().and_then(|file| file.disk_state().mtime())
437 }) {
438 self.thread.update(cx, |thread, _| {
439 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
440 })?;
441 }
442 }
443
444 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
445 let (new_text, unified_diff) = cx
446 .background_spawn({
447 let new_snapshot = new_snapshot.clone();
448 let old_text = old_text.clone();
449 async move {
450 let new_text = new_snapshot.text();
451 let diff = language::unified_diff(&old_text, &new_text);
452 (new_text, diff)
453 }
454 })
455 .await;
456
457 let output = StreamingEditFileToolOutput {
458 input_path: input.path,
459 new_text,
460 old_text,
461 diff: unified_diff,
462 };
463
464 Ok(output)
465 })
466 }
467
468 fn replay(
469 &self,
470 _input: Self::Input,
471 output: Self::Output,
472 event_stream: ToolCallEventStream,
473 cx: &mut App,
474 ) -> Result<()> {
475 event_stream.update_diff(cx.new(|cx| {
476 Diff::finalized(
477 output.input_path.to_string_lossy().into_owned(),
478 Some(output.old_text.to_string()),
479 output.new_text,
480 self.language_registry.clone(),
481 cx,
482 )
483 }));
484 Ok(())
485 }
486}
487
488fn apply_edits(
489 buffer: &Entity<language::Buffer>,
490 action_log: &Entity<action_log::ActionLog>,
491 edits: &[EditOperation],
492 diff: &Entity<Diff>,
493 event_stream: &ToolCallEventStream,
494 abs_path: &Option<PathBuf>,
495 cx: &mut AsyncApp,
496) -> Result<()> {
497 let mut failed_edits = Vec::new();
498 let mut ambiguous_edits = Vec::new();
499 let mut resolved_edits: Vec<(Range<usize>, String)> = Vec::new();
500 let mut first_edit_line: Option<u32> = None;
501
502 // First pass: resolve all edits without applying them
503 for (index, edit) in edits.iter().enumerate() {
504 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
505 let result = resolve_edit(&snapshot, edit);
506
507 match result {
508 Ok(Some((range, new_text))) => {
509 if first_edit_line.is_none() {
510 first_edit_line = Some(snapshot.offset_to_point(range.start).row);
511 }
512 // Reveal the range in the diff view
513 let start_anchor =
514 buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(range.start));
515 let end_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_after(range.end));
516 diff.update(cx, |card, cx| {
517 card.reveal_range(start_anchor..end_anchor, cx)
518 });
519 resolved_edits.push((range, new_text));
520 }
521 Ok(None) => {
522 failed_edits.push(index);
523 }
524 Err(ranges) => {
525 ambiguous_edits.push((index, ranges));
526 }
527 }
528 }
529
530 // Check for errors before applying any edits
531 if !failed_edits.is_empty() {
532 let indices = failed_edits
533 .iter()
534 .map(|i| i.to_string())
535 .collect::<Vec<_>>()
536 .join(", ");
537 anyhow::bail!(
538 "Could not find matching text for edit(s) at index(es): {}. \
539 The old_text did not match any content in the file. \
540 Please read the file again to get the current content.",
541 indices
542 );
543 }
544
545 if !ambiguous_edits.is_empty() {
546 let details: Vec<String> = ambiguous_edits
547 .iter()
548 .map(|(index, ranges)| {
549 let lines = ranges
550 .iter()
551 .map(|r| r.start.to_string())
552 .collect::<Vec<_>>()
553 .join(", ");
554 format!("edit {}: matches at lines {}", index, lines)
555 })
556 .collect();
557 anyhow::bail!(
558 "Some edits matched multiple locations in the file:\n{}. \
559 Please provide more context in old_text to uniquely identify the location.",
560 details.join("\n")
561 );
562 }
563
564 // Emit location for the first edit
565 if let Some(line) = first_edit_line {
566 if let Some(abs_path) = abs_path.clone() {
567 event_stream.update_fields(
568 ToolCallUpdateFields::new()
569 .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
570 );
571 }
572 }
573
574 // Second pass: apply all edits and report to action_log in the same effect cycle.
575 // This prevents the buffer subscription from treating these as user edits.
576 if !resolved_edits.is_empty() {
577 cx.update(|cx| {
578 buffer.update(cx, |buffer, cx| {
579 // Apply edits in reverse order so offsets remain valid
580 let mut edits_sorted: Vec<_> = resolved_edits.into_iter().collect();
581 edits_sorted.sort_by(|a, b| b.0.start.cmp(&a.0.start));
582 for (range, new_text) in edits_sorted {
583 buffer.edit([(range, new_text.as_str())], None, cx);
584 }
585 });
586 action_log.update(cx, |log, cx| {
587 log.buffer_edited(buffer.clone(), cx);
588 });
589 });
590 }
591
592 Ok(())
593}
594
595/// Resolves an edit operation by finding the matching text in the buffer.
596/// Returns Ok(Some((range, new_text))) if a unique match is found,
597/// Ok(None) if no match is found, or Err(ranges) if multiple matches are found.
598fn resolve_edit(
599 snapshot: &BufferSnapshot,
600 edit: &EditOperation,
601) -> std::result::Result<Option<(Range<usize>, String)>, Vec<Range<usize>>> {
602 let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
603 matcher.push(&edit.old_text, None);
604 let matches = matcher.finish();
605
606 if matches.is_empty() {
607 return Ok(None);
608 }
609
610 if matches.len() > 1 {
611 return Err(matches);
612 }
613
614 let match_range = matches.into_iter().next().expect("checked len above");
615 Ok(Some((match_range, edit.new_text.clone())))
616}
617
618fn resolve_path(
619 input: &StreamingEditFileToolInput,
620 project: Entity<Project>,
621 cx: &mut App,
622) -> Result<ProjectPath> {
623 let project = project.read(cx);
624
625 match input.mode {
626 StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
627 let path = project
628 .find_project_path(&input.path, cx)
629 .context("Can't edit file: path not found")?;
630
631 let entry = project
632 .entry_for_path(&path, cx)
633 .context("Can't edit file: path not found")?;
634
635 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
636 Ok(path)
637 }
638
639 StreamingEditFileMode::Create => {
640 if let Some(path) = project.find_project_path(&input.path, cx) {
641 anyhow::ensure!(
642 project.entry_for_path(&path, cx).is_none(),
643 "Can't create file: file already exists"
644 );
645 }
646
647 let parent_path = input
648 .path
649 .parent()
650 .context("Can't create file: incorrect path")?;
651
652 let parent_project_path = project.find_project_path(&parent_path, cx);
653
654 let parent_entry = parent_project_path
655 .as_ref()
656 .and_then(|path| project.entry_for_path(path, cx))
657 .context("Can't create file: parent directory doesn't exist")?;
658
659 anyhow::ensure!(
660 parent_entry.is_dir(),
661 "Can't create file: parent is not a directory"
662 );
663
664 let file_name = input
665 .path
666 .file_name()
667 .and_then(|file_name| file_name.to_str())
668 .and_then(|file_name| RelPath::unix(file_name).ok())
669 .context("Can't create file: invalid filename")?;
670
671 let new_file_path = parent_project_path.map(|parent| ProjectPath {
672 path: parent.path.join(file_name),
673 ..parent
674 });
675
676 new_file_path.context("Can't create file")
677 }
678 }
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684 use crate::{ContextServerRegistry, Templates};
685 use gpui::TestAppContext;
686 use language_model::fake_provider::FakeLanguageModel;
687 use prompt_store::ProjectContext;
688 use serde_json::json;
689 use settings::SettingsStore;
690 use util::path;
691
692 #[gpui::test]
693 async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
694 init_test(cx);
695
696 let fs = project::FakeFs::new(cx.executor());
697 fs.insert_tree("/root", json!({"dir": {}})).await;
698 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
699 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
700 let context_server_registry =
701 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
702 let model = Arc::new(FakeLanguageModel::default());
703 let thread = cx.new(|cx| {
704 crate::Thread::new(
705 project.clone(),
706 cx.new(|_cx| ProjectContext::default()),
707 context_server_registry,
708 Templates::new(),
709 Some(model),
710 cx,
711 )
712 });
713
714 let result = cx
715 .update(|cx| {
716 let input = StreamingEditFileToolInput {
717 display_description: "Create new file".into(),
718 path: "root/dir/new_file.txt".into(),
719 mode: StreamingEditFileMode::Create,
720 content: Some("Hello, World!".into()),
721 edits: None,
722 };
723 Arc::new(StreamingEditFileTool::new(
724 project.clone(),
725 thread.downgrade(),
726 language_registry,
727 Templates::new(),
728 ))
729 .run(input, ToolCallEventStream::test().0, cx)
730 })
731 .await;
732
733 assert!(result.is_ok());
734 let output = result.unwrap();
735 assert_eq!(output.new_text, "Hello, World!");
736 assert!(!output.diff.is_empty());
737 }
738
739 #[gpui::test]
740 async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
741 init_test(cx);
742
743 let fs = project::FakeFs::new(cx.executor());
744 fs.insert_tree("/root", json!({"file.txt": "old content"}))
745 .await;
746 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
747 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
748 let context_server_registry =
749 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
750 let model = Arc::new(FakeLanguageModel::default());
751 let thread = cx.new(|cx| {
752 crate::Thread::new(
753 project.clone(),
754 cx.new(|_cx| ProjectContext::default()),
755 context_server_registry,
756 Templates::new(),
757 Some(model),
758 cx,
759 )
760 });
761
762 let result = cx
763 .update(|cx| {
764 let input = StreamingEditFileToolInput {
765 display_description: "Overwrite file".into(),
766 path: "root/file.txt".into(),
767 mode: StreamingEditFileMode::Overwrite,
768 content: Some("new content".into()),
769 edits: None,
770 };
771 Arc::new(StreamingEditFileTool::new(
772 project.clone(),
773 thread.downgrade(),
774 language_registry,
775 Templates::new(),
776 ))
777 .run(input, ToolCallEventStream::test().0, cx)
778 })
779 .await;
780
781 assert!(result.is_ok());
782 let output = result.unwrap();
783 assert_eq!(output.new_text, "new content");
784 assert_eq!(*output.old_text, "old content");
785 }
786
787 #[gpui::test]
788 async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
789 init_test(cx);
790
791 let fs = project::FakeFs::new(cx.executor());
792 fs.insert_tree(
793 "/root",
794 json!({
795 "file.txt": "line 1\nline 2\nline 3\n"
796 }),
797 )
798 .await;
799 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
800 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
801 let context_server_registry =
802 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
803 let model = Arc::new(FakeLanguageModel::default());
804 let thread = cx.new(|cx| {
805 crate::Thread::new(
806 project.clone(),
807 cx.new(|_cx| ProjectContext::default()),
808 context_server_registry,
809 Templates::new(),
810 Some(model),
811 cx,
812 )
813 });
814
815 let result = cx
816 .update(|cx| {
817 let input = StreamingEditFileToolInput {
818 display_description: "Edit lines".into(),
819 path: "root/file.txt".into(),
820 mode: StreamingEditFileMode::Edit,
821 content: None,
822 edits: Some(vec![EditOperation {
823 old_text: "line 2".into(),
824 new_text: "modified line 2".into(),
825 }]),
826 };
827 Arc::new(StreamingEditFileTool::new(
828 project.clone(),
829 thread.downgrade(),
830 language_registry,
831 Templates::new(),
832 ))
833 .run(input, ToolCallEventStream::test().0, cx)
834 })
835 .await;
836
837 assert!(result.is_ok());
838 let output = result.unwrap();
839 assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
840 }
841
842 #[gpui::test]
843 async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
844 init_test(cx);
845
846 let fs = project::FakeFs::new(cx.executor());
847 fs.insert_tree("/root", json!({})).await;
848 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
849 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
850 let context_server_registry =
851 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
852 let model = Arc::new(FakeLanguageModel::default());
853 let thread = cx.new(|cx| {
854 crate::Thread::new(
855 project.clone(),
856 cx.new(|_cx| ProjectContext::default()),
857 context_server_registry,
858 Templates::new(),
859 Some(model),
860 cx,
861 )
862 });
863
864 let result = cx
865 .update(|cx| {
866 let input = StreamingEditFileToolInput {
867 display_description: "Some edit".into(),
868 path: "root/nonexistent_file.txt".into(),
869 mode: StreamingEditFileMode::Edit,
870 content: None,
871 edits: Some(vec![EditOperation {
872 old_text: "foo".into(),
873 new_text: "bar".into(),
874 }]),
875 };
876 Arc::new(StreamingEditFileTool::new(
877 project,
878 thread.downgrade(),
879 language_registry,
880 Templates::new(),
881 ))
882 .run(input, ToolCallEventStream::test().0, cx)
883 })
884 .await;
885
886 assert_eq!(
887 result.unwrap_err().to_string(),
888 "Can't edit file: path not found"
889 );
890 }
891
892 #[gpui::test]
893 async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
894 init_test(cx);
895
896 let fs = project::FakeFs::new(cx.executor());
897 fs.insert_tree("/root", json!({"file.txt": "hello world"}))
898 .await;
899 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
900 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
901 let context_server_registry =
902 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
903 let model = Arc::new(FakeLanguageModel::default());
904 let thread = cx.new(|cx| {
905 crate::Thread::new(
906 project.clone(),
907 cx.new(|_cx| ProjectContext::default()),
908 context_server_registry,
909 Templates::new(),
910 Some(model),
911 cx,
912 )
913 });
914
915 let result = cx
916 .update(|cx| {
917 let input = StreamingEditFileToolInput {
918 display_description: "Edit file".into(),
919 path: "root/file.txt".into(),
920 mode: StreamingEditFileMode::Edit,
921 content: None,
922 edits: Some(vec![EditOperation {
923 old_text: "nonexistent text that is not in the file".into(),
924 new_text: "replacement".into(),
925 }]),
926 };
927 Arc::new(StreamingEditFileTool::new(
928 project,
929 thread.downgrade(),
930 language_registry,
931 Templates::new(),
932 ))
933 .run(input, ToolCallEventStream::test().0, cx)
934 })
935 .await;
936
937 assert!(result.is_err());
938 assert!(
939 result
940 .unwrap_err()
941 .to_string()
942 .contains("Could not find matching text")
943 );
944 }
945
946 fn init_test(cx: &mut TestAppContext) {
947 cx.update(|cx| {
948 let settings_store = SettingsStore::test(cx);
949 cx.set_global(settings_store);
950 });
951 }
952}