1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream, ToolPermissionDecision,
3 decide_permission_from_settings, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use language::{Anchor, LanguageRegistry, ToPoint};
10use language_model::LanguageModelToolResultContent;
11use paths;
12use project::{Project, ProjectPath};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use settings::Settings;
16use std::ffi::OsStr;
17use std::ops::Range;
18use std::path::{Path, PathBuf};
19use std::sync::Arc;
20use text::BufferSnapshot;
21use ui::SharedString;
22use util::rel_path::RelPath;
23
24const DEFAULT_UI_TEXT: &str = "Editing file";
25
26/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
27///
28/// Before using this tool:
29///
30/// 1. Use the `read_file` tool to understand the file's contents and context
31///
32/// 2. Verify the directory path is correct (only applicable when creating new files):
33/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
34#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
35pub struct StreamingEditFileToolInput {
36 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
37 ///
38 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
39 ///
40 /// NEVER mention the file path in this description.
41 ///
42 /// <example>Fix API endpoint URLs</example>
43 /// <example>Update copyright year in `page_footer`</example>
44 ///
45 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
46 pub display_description: String,
47
48 /// The full path of the file to create or modify in the project.
49 ///
50 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
51 ///
52 /// The following examples assume we have two root directories in the project:
53 /// - /a/b/backend
54 /// - /c/d/frontend
55 ///
56 /// <example>
57 /// `backend/src/main.rs`
58 ///
59 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
60 /// </example>
61 ///
62 /// <example>
63 /// `frontend/db.js`
64 /// </example>
65 pub path: PathBuf,
66
67 /// The mode of operation on the file. Possible values:
68 /// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
69 /// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
70 /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
71 ///
72 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
73 pub mode: StreamingEditFileMode,
74
75 /// The complete content for the new file (required for 'create' and 'overwrite' modes).
76 /// This field should contain the entire file content.
77 #[serde(default, skip_serializing_if = "Option::is_none")]
78 pub content: Option<String>,
79
80 /// List of edit operations to apply sequentially (required for 'edit' mode).
81 /// Each edit finds `old_text` in the file and replaces it with `new_text`.
82 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub edits: Option<Vec<EditOperation>>,
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
87#[serde(rename_all = "snake_case")]
88pub enum StreamingEditFileMode {
89 /// Create a new file if it doesn't exist
90 Create,
91 /// Replace the entire contents of an existing file
92 Overwrite,
93 /// Make granular edits to an existing file
94 Edit,
95}
96
97/// A single edit operation that replaces old text with new text
98#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
99pub struct EditOperation {
100 /// The exact text to find in the file. This will be matched using fuzzy matching
101 /// to handle minor differences in whitespace or formatting.
102 pub old_text: String,
103 /// The text to replace it with
104 pub new_text: String,
105}
106
107#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
108struct StreamingEditFileToolPartialInput {
109 #[serde(default)]
110 path: String,
111 #[serde(default)]
112 display_description: String,
113}
114
115#[derive(Debug, Serialize, Deserialize)]
116pub struct StreamingEditFileToolOutput {
117 #[serde(alias = "original_path")]
118 input_path: PathBuf,
119 new_text: String,
120 old_text: Arc<String>,
121 #[serde(default)]
122 diff: String,
123}
124
125impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
126 fn from(output: StreamingEditFileToolOutput) -> Self {
127 if output.diff.is_empty() {
128 "No edits were made.".into()
129 } else {
130 format!(
131 "Edited {}:\n\n```diff\n{}\n```",
132 output.input_path.display(),
133 output.diff
134 )
135 .into()
136 }
137 }
138}
139
140pub struct StreamingEditFileTool {
141 thread: WeakEntity<Thread>,
142 language_registry: Arc<LanguageRegistry>,
143 project: Entity<Project>,
144 #[allow(dead_code)]
145 templates: Arc<Templates>,
146}
147
148impl StreamingEditFileTool {
149 pub fn new(
150 project: Entity<Project>,
151 thread: WeakEntity<Thread>,
152 language_registry: Arc<LanguageRegistry>,
153 templates: Arc<Templates>,
154 ) -> Self {
155 Self {
156 project,
157 thread,
158 language_registry,
159 templates,
160 }
161 }
162
163 fn authorize(
164 &self,
165 input: &StreamingEditFileToolInput,
166 event_stream: &ToolCallEventStream,
167 cx: &mut App,
168 ) -> Task<Result<()>> {
169 let path_str = input.path.to_string_lossy();
170 let settings = agent_settings::AgentSettings::get_global(cx);
171 let decision = decide_permission_from_settings(Self::name(), &path_str, settings);
172
173 match decision {
174 ToolPermissionDecision::Allow => return Task::ready(Ok(())),
175 ToolPermissionDecision::Deny(reason) => {
176 return Task::ready(Err(anyhow!("{}", reason)));
177 }
178 ToolPermissionDecision::Confirm => {}
179 }
180
181 let local_settings_folder = paths::local_settings_folder_name();
182 let path = Path::new(&input.path);
183 if path.components().any(|component| {
184 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
185 }) {
186 let context = crate::ToolPermissionContext {
187 tool_name: "edit_file".to_string(),
188 input_value: path_str.to_string(),
189 };
190 return event_stream.authorize(
191 format!("{} (local settings)", input.display_description),
192 context,
193 cx,
194 );
195 }
196
197 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
198 && canonical_path.starts_with(paths::config_dir())
199 {
200 let context = crate::ToolPermissionContext {
201 tool_name: "edit_file".to_string(),
202 input_value: path_str.to_string(),
203 };
204 return event_stream.authorize(
205 format!("{} (global settings)", input.display_description),
206 context,
207 cx,
208 );
209 }
210
211 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
212 thread.project().read(cx).find_project_path(&input.path, cx)
213 }) else {
214 return Task::ready(Err(anyhow!("thread was dropped")));
215 };
216
217 if project_path.is_some() {
218 Task::ready(Ok(()))
219 } else {
220 let context = crate::ToolPermissionContext {
221 tool_name: "edit_file".to_string(),
222 input_value: path_str.to_string(),
223 };
224 event_stream.authorize(&input.display_description, context, cx)
225 }
226 }
227}
228
229impl AgentTool for StreamingEditFileTool {
230 type Input = StreamingEditFileToolInput;
231 type Output = StreamingEditFileToolOutput;
232
233 fn name() -> &'static str {
234 "streaming_edit_file"
235 }
236
237 fn kind() -> acp::ToolKind {
238 acp::ToolKind::Edit
239 }
240
241 fn initial_title(
242 &self,
243 input: Result<Self::Input, serde_json::Value>,
244 cx: &mut App,
245 ) -> SharedString {
246 match input {
247 Ok(input) => self
248 .project
249 .read(cx)
250 .find_project_path(&input.path, cx)
251 .and_then(|project_path| {
252 self.project
253 .read(cx)
254 .short_full_path_for_project_path(&project_path, cx)
255 })
256 .unwrap_or(input.path.to_string_lossy().into_owned())
257 .into(),
258 Err(raw_input) => {
259 if let Some(input) =
260 serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
261 {
262 let path = input.path.trim();
263 if !path.is_empty() {
264 return self
265 .project
266 .read(cx)
267 .find_project_path(&input.path, cx)
268 .and_then(|project_path| {
269 self.project
270 .read(cx)
271 .short_full_path_for_project_path(&project_path, cx)
272 })
273 .unwrap_or(input.path)
274 .into();
275 }
276
277 let description = input.display_description.trim();
278 if !description.is_empty() {
279 return description.to_string().into();
280 }
281 }
282
283 DEFAULT_UI_TEXT.into()
284 }
285 }
286 }
287
288 fn run(
289 self: Arc<Self>,
290 input: Self::Input,
291 event_stream: ToolCallEventStream,
292 cx: &mut App,
293 ) -> Task<Result<Self::Output>> {
294 let Ok(project) = self
295 .thread
296 .read_with(cx, |thread, _cx| thread.project().clone())
297 else {
298 return Task::ready(Err(anyhow!("thread was dropped")));
299 };
300
301 let project_path = match resolve_path(&input, project.clone(), cx) {
302 Ok(path) => path,
303 Err(err) => return Task::ready(Err(anyhow!(err))),
304 };
305
306 let abs_path = project.read(cx).absolute_path(&project_path, cx);
307 if let Some(abs_path) = abs_path.clone() {
308 event_stream.update_fields(
309 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
310 );
311 }
312
313 let authorize = self.authorize(&input, &event_stream, cx);
314
315 cx.spawn(async move |cx: &mut AsyncApp| {
316 authorize.await?;
317
318 let buffer = project
319 .update(cx, |project, cx| {
320 project.open_buffer(project_path.clone(), cx)
321 })
322 .await?;
323
324 if let Some(abs_path) = abs_path.as_ref() {
325 let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
326 self.thread.update(cx, |thread, cx| {
327 let last_read = thread.file_read_times.get(abs_path).copied();
328 let current = buffer
329 .read(cx)
330 .file()
331 .and_then(|file| file.disk_state().mtime());
332 let dirty = buffer.read(cx).is_dirty();
333 let has_save = thread.has_tool("save_file");
334 let has_restore = thread.has_tool("restore_file_from_disk");
335 (last_read, current, dirty, has_save, has_restore)
336 })?;
337
338 if is_dirty {
339 let message = match (has_save_tool, has_restore_tool) {
340 (true, true) => {
341 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
342 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
343 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
344 }
345 (true, false) => {
346 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
347 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
348 If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
349 }
350 (false, true) => {
351 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
352 If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
353 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
354 }
355 (false, false) => {
356 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
357 then ask them to save or revert the file manually and inform you when it's ok to proceed."
358 }
359 };
360 anyhow::bail!("{}", message);
361 }
362
363 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
364 if current != last_read {
365 anyhow::bail!(
366 "The file {} has been modified since you last read it. \
367 Please read the file again to get the current state before editing it.",
368 input.path.display()
369 );
370 }
371 }
372 }
373
374 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
375 event_stream.update_diff(diff.clone());
376 let _finalize_diff = util::defer({
377 let diff = diff.downgrade();
378 let mut cx = cx.clone();
379 move || {
380 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
381 }
382 });
383
384 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
385 let old_text = cx
386 .background_spawn({
387 let old_snapshot = old_snapshot.clone();
388 async move { Arc::new(old_snapshot.text()) }
389 })
390 .await;
391
392 match input.mode {
393 StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
394 let content = input.content.ok_or_else(|| {
395 anyhow!("'content' field is required for create and overwrite modes")
396 })?;
397 buffer.update(cx, |buffer, cx| {
398 buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
399 });
400 }
401 StreamingEditFileMode::Edit => {
402 let edits = input.edits.ok_or_else(|| {
403 anyhow!("'edits' field is required for edit mode")
404 })?;
405 apply_edits(&buffer, &edits, &diff, &event_stream, &abs_path, cx)?;
406 }
407 }
408
409 let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
410
411 action_log.update(cx, |log, cx| {
412 log.buffer_edited(buffer.clone(), cx);
413 });
414
415 project
416 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
417 .await?;
418
419 action_log.update(cx, |log, cx| {
420 log.buffer_edited(buffer.clone(), cx);
421 });
422
423 if let Some(abs_path) = abs_path.as_ref() {
424 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
425 buffer.file().and_then(|file| file.disk_state().mtime())
426 }) {
427 self.thread.update(cx, |thread, _| {
428 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
429 })?;
430 }
431 }
432
433 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
434 let (new_text, unified_diff) = cx
435 .background_spawn({
436 let new_snapshot = new_snapshot.clone();
437 let old_text = old_text.clone();
438 async move {
439 let new_text = new_snapshot.text();
440 let diff = language::unified_diff(&old_text, &new_text);
441 (new_text, diff)
442 }
443 })
444 .await;
445
446 let output = StreamingEditFileToolOutput {
447 input_path: input.path,
448 new_text,
449 old_text,
450 diff: unified_diff,
451 };
452
453 Ok(output)
454 })
455 }
456
457 fn replay(
458 &self,
459 _input: Self::Input,
460 output: Self::Output,
461 event_stream: ToolCallEventStream,
462 cx: &mut App,
463 ) -> Result<()> {
464 event_stream.update_diff(cx.new(|cx| {
465 Diff::finalized(
466 output.input_path.to_string_lossy().into_owned(),
467 Some(output.old_text.to_string()),
468 output.new_text,
469 self.language_registry.clone(),
470 cx,
471 )
472 }));
473 Ok(())
474 }
475}
476
477fn apply_edits(
478 buffer: &Entity<language::Buffer>,
479 edits: &[EditOperation],
480 diff: &Entity<Diff>,
481 event_stream: &ToolCallEventStream,
482 abs_path: &Option<PathBuf>,
483 cx: &mut AsyncApp,
484) -> Result<()> {
485 let mut emitted_location = false;
486 let mut failed_edits = Vec::new();
487 let mut ambiguous_edits = Vec::new();
488
489 for (index, edit) in edits.iter().enumerate() {
490 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
491 let result = apply_single_edit(buffer, &snapshot, edit, diff, cx);
492
493 match result {
494 Ok(Some(range)) => {
495 if !emitted_location {
496 let line = buffer.update(cx, |buffer, _cx| {
497 range.start.to_point(&buffer.snapshot()).row
498 });
499 if let Some(abs_path) = abs_path.clone() {
500 event_stream.update_fields(
501 ToolCallUpdateFields::new()
502 .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
503 );
504 }
505 emitted_location = true;
506 }
507 }
508 Ok(None) => {
509 failed_edits.push(index);
510 }
511 Err(ranges) => {
512 ambiguous_edits.push((index, ranges));
513 }
514 }
515 }
516
517 if !failed_edits.is_empty() {
518 let indices = failed_edits
519 .iter()
520 .map(|i| i.to_string())
521 .collect::<Vec<_>>()
522 .join(", ");
523 anyhow::bail!(
524 "Could not find matching text for edit(s) at index(es): {}. \
525 The old_text did not match any content in the file. \
526 Please read the file again to get the current content.",
527 indices
528 );
529 }
530
531 if !ambiguous_edits.is_empty() {
532 let details: Vec<String> = ambiguous_edits
533 .iter()
534 .map(|(index, ranges)| {
535 let lines = ranges
536 .iter()
537 .map(|r| r.start.to_string())
538 .collect::<Vec<_>>()
539 .join(", ");
540 format!("edit {}: matches at lines {}", index, lines)
541 })
542 .collect();
543 anyhow::bail!(
544 "Some edits matched multiple locations in the file:\n{}. \
545 Please provide more context in old_text to uniquely identify the location.",
546 details.join("\n")
547 );
548 }
549
550 Ok(())
551}
552
553fn apply_single_edit(
554 buffer: &Entity<language::Buffer>,
555 snapshot: &BufferSnapshot,
556 edit: &EditOperation,
557 diff: &Entity<Diff>,
558 cx: &mut AsyncApp,
559) -> std::result::Result<Option<Range<Anchor>>, Vec<Range<usize>>> {
560 let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
561 matcher.push(&edit.old_text, None);
562 let matches = matcher.finish();
563
564 if matches.is_empty() {
565 return Ok(None);
566 }
567
568 if matches.len() > 1 {
569 return Err(matches);
570 }
571
572 let match_range = matches.into_iter().next().expect("checked len above");
573
574 let start_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(match_range.start));
575 let end_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_after(match_range.end));
576
577 diff.update(cx, |card, cx| {
578 card.reveal_range(start_anchor..end_anchor, cx)
579 });
580
581 buffer.update(cx, |buffer, cx| {
582 buffer.edit([(match_range.clone(), edit.new_text.as_str())], None, cx);
583 });
584
585 let new_end = buffer.read_with(cx, |buffer, _cx| {
586 buffer.anchor_after(match_range.start + edit.new_text.len())
587 });
588
589 Ok(Some(start_anchor..new_end))
590}
591
592fn resolve_path(
593 input: &StreamingEditFileToolInput,
594 project: Entity<Project>,
595 cx: &mut App,
596) -> Result<ProjectPath> {
597 let project = project.read(cx);
598
599 match input.mode {
600 StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
601 let path = project
602 .find_project_path(&input.path, cx)
603 .context("Can't edit file: path not found")?;
604
605 let entry = project
606 .entry_for_path(&path, cx)
607 .context("Can't edit file: path not found")?;
608
609 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
610 Ok(path)
611 }
612
613 StreamingEditFileMode::Create => {
614 if let Some(path) = project.find_project_path(&input.path, cx) {
615 anyhow::ensure!(
616 project.entry_for_path(&path, cx).is_none(),
617 "Can't create file: file already exists"
618 );
619 }
620
621 let parent_path = input
622 .path
623 .parent()
624 .context("Can't create file: incorrect path")?;
625
626 let parent_project_path = project.find_project_path(&parent_path, cx);
627
628 let parent_entry = parent_project_path
629 .as_ref()
630 .and_then(|path| project.entry_for_path(path, cx))
631 .context("Can't create file: parent directory doesn't exist")?;
632
633 anyhow::ensure!(
634 parent_entry.is_dir(),
635 "Can't create file: parent is not a directory"
636 );
637
638 let file_name = input
639 .path
640 .file_name()
641 .and_then(|file_name| file_name.to_str())
642 .and_then(|file_name| RelPath::unix(file_name).ok())
643 .context("Can't create file: invalid filename")?;
644
645 let new_file_path = parent_project_path.map(|parent| ProjectPath {
646 path: parent.path.join(file_name),
647 ..parent
648 });
649
650 new_file_path.context("Can't create file")
651 }
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658 use crate::{ContextServerRegistry, Templates};
659 use gpui::TestAppContext;
660 use language_model::fake_provider::FakeLanguageModel;
661 use prompt_store::ProjectContext;
662 use serde_json::json;
663 use settings::SettingsStore;
664 use util::path;
665
666 #[gpui::test]
667 async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
668 init_test(cx);
669
670 let fs = project::FakeFs::new(cx.executor());
671 fs.insert_tree("/root", json!({"dir": {}})).await;
672 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
673 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
674 let context_server_registry =
675 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
676 let model = Arc::new(FakeLanguageModel::default());
677 let thread = cx.new(|cx| {
678 crate::Thread::new(
679 project.clone(),
680 cx.new(|_cx| ProjectContext::default()),
681 context_server_registry,
682 Templates::new(),
683 Some(model),
684 cx,
685 )
686 });
687
688 let result = cx
689 .update(|cx| {
690 let input = StreamingEditFileToolInput {
691 display_description: "Create new file".into(),
692 path: "root/dir/new_file.txt".into(),
693 mode: StreamingEditFileMode::Create,
694 content: Some("Hello, World!".into()),
695 edits: None,
696 };
697 Arc::new(StreamingEditFileTool::new(
698 project.clone(),
699 thread.downgrade(),
700 language_registry,
701 Templates::new(),
702 ))
703 .run(input, ToolCallEventStream::test().0, cx)
704 })
705 .await;
706
707 assert!(result.is_ok());
708 let output = result.unwrap();
709 assert_eq!(output.new_text, "Hello, World!");
710 assert!(!output.diff.is_empty());
711 }
712
713 #[gpui::test]
714 async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
715 init_test(cx);
716
717 let fs = project::FakeFs::new(cx.executor());
718 fs.insert_tree("/root", json!({"file.txt": "old content"}))
719 .await;
720 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
721 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
722 let context_server_registry =
723 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
724 let model = Arc::new(FakeLanguageModel::default());
725 let thread = cx.new(|cx| {
726 crate::Thread::new(
727 project.clone(),
728 cx.new(|_cx| ProjectContext::default()),
729 context_server_registry,
730 Templates::new(),
731 Some(model),
732 cx,
733 )
734 });
735
736 let result = cx
737 .update(|cx| {
738 let input = StreamingEditFileToolInput {
739 display_description: "Overwrite file".into(),
740 path: "root/file.txt".into(),
741 mode: StreamingEditFileMode::Overwrite,
742 content: Some("new content".into()),
743 edits: None,
744 };
745 Arc::new(StreamingEditFileTool::new(
746 project.clone(),
747 thread.downgrade(),
748 language_registry,
749 Templates::new(),
750 ))
751 .run(input, ToolCallEventStream::test().0, cx)
752 })
753 .await;
754
755 assert!(result.is_ok());
756 let output = result.unwrap();
757 assert_eq!(output.new_text, "new content");
758 assert_eq!(*output.old_text, "old content");
759 }
760
761 #[gpui::test]
762 async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
763 init_test(cx);
764
765 let fs = project::FakeFs::new(cx.executor());
766 fs.insert_tree(
767 "/root",
768 json!({
769 "file.txt": "line 1\nline 2\nline 3\n"
770 }),
771 )
772 .await;
773 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
774 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
775 let context_server_registry =
776 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
777 let model = Arc::new(FakeLanguageModel::default());
778 let thread = cx.new(|cx| {
779 crate::Thread::new(
780 project.clone(),
781 cx.new(|_cx| ProjectContext::default()),
782 context_server_registry,
783 Templates::new(),
784 Some(model),
785 cx,
786 )
787 });
788
789 let result = cx
790 .update(|cx| {
791 let input = StreamingEditFileToolInput {
792 display_description: "Edit lines".into(),
793 path: "root/file.txt".into(),
794 mode: StreamingEditFileMode::Edit,
795 content: None,
796 edits: Some(vec![EditOperation {
797 old_text: "line 2".into(),
798 new_text: "modified line 2".into(),
799 }]),
800 };
801 Arc::new(StreamingEditFileTool::new(
802 project.clone(),
803 thread.downgrade(),
804 language_registry,
805 Templates::new(),
806 ))
807 .run(input, ToolCallEventStream::test().0, cx)
808 })
809 .await;
810
811 assert!(result.is_ok());
812 let output = result.unwrap();
813 assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
814 }
815
816 #[gpui::test]
817 async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
818 init_test(cx);
819
820 let fs = project::FakeFs::new(cx.executor());
821 fs.insert_tree("/root", json!({})).await;
822 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
823 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
824 let context_server_registry =
825 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
826 let model = Arc::new(FakeLanguageModel::default());
827 let thread = cx.new(|cx| {
828 crate::Thread::new(
829 project.clone(),
830 cx.new(|_cx| ProjectContext::default()),
831 context_server_registry,
832 Templates::new(),
833 Some(model),
834 cx,
835 )
836 });
837
838 let result = cx
839 .update(|cx| {
840 let input = StreamingEditFileToolInput {
841 display_description: "Some edit".into(),
842 path: "root/nonexistent_file.txt".into(),
843 mode: StreamingEditFileMode::Edit,
844 content: None,
845 edits: Some(vec![EditOperation {
846 old_text: "foo".into(),
847 new_text: "bar".into(),
848 }]),
849 };
850 Arc::new(StreamingEditFileTool::new(
851 project,
852 thread.downgrade(),
853 language_registry,
854 Templates::new(),
855 ))
856 .run(input, ToolCallEventStream::test().0, cx)
857 })
858 .await;
859
860 assert_eq!(
861 result.unwrap_err().to_string(),
862 "Can't edit file: path not found"
863 );
864 }
865
866 #[gpui::test]
867 async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
868 init_test(cx);
869
870 let fs = project::FakeFs::new(cx.executor());
871 fs.insert_tree("/root", json!({"file.txt": "hello world"}))
872 .await;
873 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
874 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
875 let context_server_registry =
876 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
877 let model = Arc::new(FakeLanguageModel::default());
878 let thread = cx.new(|cx| {
879 crate::Thread::new(
880 project.clone(),
881 cx.new(|_cx| ProjectContext::default()),
882 context_server_registry,
883 Templates::new(),
884 Some(model),
885 cx,
886 )
887 });
888
889 let result = cx
890 .update(|cx| {
891 let input = StreamingEditFileToolInput {
892 display_description: "Edit file".into(),
893 path: "root/file.txt".into(),
894 mode: StreamingEditFileMode::Edit,
895 content: None,
896 edits: Some(vec![EditOperation {
897 old_text: "nonexistent text that is not in the file".into(),
898 new_text: "replacement".into(),
899 }]),
900 };
901 Arc::new(StreamingEditFileTool::new(
902 project,
903 thread.downgrade(),
904 language_registry,
905 Templates::new(),
906 ))
907 .run(input, ToolCallEventStream::test().0, cx)
908 })
909 .await;
910
911 assert!(result.is_err());
912 assert!(
913 result
914 .unwrap_err()
915 .to_string()
916 .contains("Could not find matching text")
917 );
918 }
919
920 fn init_test(cx: &mut TestAppContext) {
921 cx.update(|cx| {
922 let settings_store = SettingsStore::test(cx);
923 cx.set_global(settings_store);
924 });
925 }
926}