1use std::{mem, ops::Range, path::Path, path::PathBuf, sync::Arc};
2
3use anyhow::{Context as _, Result, anyhow};
4use collections::{HashMap, hash_map::Entry};
5use gpui::{AsyncApp, Entity};
6use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot, text_diff};
7use postage::stream::Stream as _;
8use project::Project;
9use util::{paths::PathStyle, rel_path::RelPath};
10use worktree::Worktree;
11use zeta_prompt::udiff::{
12 DiffEvent, DiffParser, FileStatus, Hunk, disambiguate_by_line_number, find_context_candidates,
13};
14
15pub use zeta_prompt::udiff::{
16 DiffLine, HunkLocation, apply_diff_to_string, apply_diff_to_string_with_hunk_offset,
17 strip_diff_metadata, strip_diff_path_prefix,
18};
19
20#[derive(Clone, Debug)]
21pub struct OpenedBuffers(HashMap<String, Entity<Buffer>>);
22
23impl OpenedBuffers {
24 pub fn get(&self, path: &str) -> Option<&Entity<Buffer>> {
25 self.0.get(path)
26 }
27
28 pub fn buffers(&self) -> impl Iterator<Item = &Entity<Buffer>> {
29 self.0.values()
30 }
31}
32
33#[must_use]
34pub async fn apply_diff(
35 diff_str: &str,
36 project: &Entity<Project>,
37 cx: &mut AsyncApp,
38) -> Result<OpenedBuffers> {
39 let worktree = project
40 .read_with(cx, |project, cx| project.visible_worktrees(cx).next())
41 .context("project has no worktree")?;
42
43 let paths: Vec<_> = diff_str
44 .lines()
45 .filter_map(|line| {
46 if let DiffLine::OldPath { path } = DiffLine::parse(line) {
47 if path != "/dev/null" {
48 return Some(PathBuf::from(path.as_ref()));
49 }
50 }
51 None
52 })
53 .collect();
54 refresh_worktree_entries(&worktree, paths.iter().map(|p| p.as_path()), cx).await?;
55
56 let mut included_files: HashMap<String, Entity<Buffer>> = HashMap::default();
57
58 let mut diff = DiffParser::new(diff_str);
59 let mut current_file = None;
60 let mut edits: Vec<(std::ops::Range<Anchor>, Arc<str>)> = vec![];
61
62 while let Some(event) = diff.next()? {
63 match event {
64 DiffEvent::Hunk { path, hunk, status } => {
65 if status == FileStatus::Deleted {
66 let delete_task = project.update(cx, |project, cx| {
67 if let Some(path) = project.find_project_path(path.as_ref(), cx) {
68 project.delete_file(path, false, cx)
69 } else {
70 None
71 }
72 });
73
74 if let Some(delete_task) = delete_task {
75 delete_task.await?;
76 };
77
78 continue;
79 }
80
81 let buffer = match current_file {
82 None => {
83 let buffer = match included_files.entry(path.to_string()) {
84 Entry::Occupied(entry) => entry.get().clone(),
85 Entry::Vacant(entry) => {
86 let buffer: Entity<Buffer> = if status == FileStatus::Created {
87 project
88 .update(cx, |project, cx| {
89 project.create_buffer(None, true, cx)
90 })
91 .await?
92 } else {
93 let project_path = project
94 .update(cx, |project, cx| {
95 project.find_project_path(path.as_ref(), cx)
96 })
97 .with_context(|| format!("no such path: {}", path))?;
98 project
99 .update(cx, |project, cx| {
100 project.open_buffer(project_path, cx)
101 })
102 .await?
103 };
104 entry.insert(buffer.clone());
105 buffer
106 }
107 };
108 current_file = Some(buffer);
109 current_file.as_ref().unwrap()
110 }
111 Some(ref current) => current,
112 };
113
114 buffer.read_with(cx, |buffer, _| {
115 edits.extend(resolve_hunk_edits_in_buffer(
116 hunk,
117 buffer,
118 &[Anchor::min_max_range_for_buffer(buffer.remote_id())],
119 status,
120 )?);
121 anyhow::Ok(())
122 })?;
123 }
124 DiffEvent::FileEnd { renamed_to } => {
125 let buffer = current_file
126 .take()
127 .context("Got a FileEnd event before an Hunk event")?;
128
129 if let Some(renamed_to) = renamed_to {
130 project
131 .update(cx, |project, cx| {
132 let new_project_path = project
133 .find_project_path(Path::new(renamed_to.as_ref()), cx)
134 .with_context(|| {
135 format!("Failed to find worktree for new path: {}", renamed_to)
136 })?;
137
138 let project_file = project::File::from_dyn(buffer.read(cx).file())
139 .expect("Wrong file type");
140
141 anyhow::Ok(project.rename_entry(
142 project_file.entry_id.unwrap(),
143 new_project_path,
144 cx,
145 ))
146 })?
147 .await?;
148 }
149
150 let edits = mem::take(&mut edits);
151 buffer.update(cx, |buffer, cx| {
152 buffer.edit(edits, None, cx);
153 });
154 }
155 }
156 }
157
158 Ok(OpenedBuffers(included_files))
159}
160
161pub async fn refresh_worktree_entries(
162 worktree: &Entity<Worktree>,
163 paths: impl IntoIterator<Item = &Path>,
164 cx: &mut AsyncApp,
165) -> Result<()> {
166 let mut rel_paths = Vec::new();
167 for path in paths {
168 if let Ok(rel_path) = RelPath::new(path, PathStyle::Posix) {
169 rel_paths.push(rel_path.into_arc());
170 }
171
172 let path_without_root: PathBuf = path.components().skip(1).collect();
173 if let Ok(rel_path) = RelPath::new(&path_without_root, PathStyle::Posix) {
174 rel_paths.push(rel_path.into_arc());
175 }
176 }
177
178 if !rel_paths.is_empty() {
179 worktree
180 .update(cx, |worktree, _| {
181 worktree
182 .as_local()
183 .unwrap()
184 .refresh_entries_for_paths(rel_paths)
185 })
186 .recv()
187 .await;
188 }
189
190 Ok(())
191}
192
193/// Returns the individual edits that would be applied by a diff to the given content.
194/// Each edit is a tuple of (byte_range_in_content, replacement_text).
195/// Uses sub-line diffing to find the precise character positions of changes.
196/// Returns an empty vec if the hunk context is not found or is ambiguous.
197pub fn edits_for_diff(content: &str, diff_str: &str) -> Result<Vec<(Range<usize>, String)>> {
198 let mut diff = DiffParser::new(diff_str);
199 let mut result = Vec::new();
200
201 while let Some(event) = diff.next()? {
202 match event {
203 DiffEvent::Hunk {
204 mut hunk,
205 path: _,
206 status: _,
207 } => {
208 if hunk.context.is_empty() {
209 return Ok(Vec::new());
210 }
211
212 let candidates = find_context_candidates(content, &mut hunk);
213
214 let Some(context_offset) =
215 disambiguate_by_line_number(&candidates, hunk.start_line, &|offset| {
216 content[..offset].matches('\n').count() as u32
217 })
218 else {
219 return Ok(Vec::new());
220 };
221
222 // Use sub-line diffing to find precise edit positions
223 for edit in &hunk.edits {
224 let old_text = &content
225 [context_offset + edit.range.start..context_offset + edit.range.end];
226 let edits_within_hunk = text_diff(old_text, &edit.text);
227 for (inner_range, inner_text) in edits_within_hunk {
228 let absolute_start = context_offset + edit.range.start + inner_range.start;
229 let absolute_end = context_offset + edit.range.start + inner_range.end;
230 result.push((absolute_start..absolute_end, inner_text.to_string()));
231 }
232 }
233 }
234 DiffEvent::FileEnd { .. } => {}
235 }
236 }
237
238 Ok(result)
239}
240
241fn resolve_hunk_edits_in_buffer(
242 mut hunk: Hunk,
243 buffer: &TextBufferSnapshot,
244 ranges: &[Range<Anchor>],
245 status: FileStatus,
246) -> Result<impl Iterator<Item = (Range<Anchor>, Arc<str>)>, anyhow::Error> {
247 let context_offset = if status == FileStatus::Created || hunk.context.is_empty() {
248 0
249 } else {
250 let mut candidates: Vec<usize> = Vec::new();
251 for range in ranges {
252 let range = range.to_offset(buffer);
253 let text = buffer.text_for_range(range.clone()).collect::<String>();
254 for ix in find_context_candidates(&text, &mut hunk) {
255 candidates.push(range.start + ix);
256 }
257 }
258
259 disambiguate_by_line_number(&candidates, hunk.start_line, &|offset| {
260 buffer.offset_to_point(offset).row
261 })
262 .ok_or_else(|| {
263 if candidates.is_empty() {
264 anyhow!("Failed to match context:\n\n```\n{}```\n", hunk.context,)
265 } else {
266 anyhow!("Context is not unique enough:\n{}", hunk.context)
267 }
268 })?
269 };
270
271 if let Some(edit) = hunk.edits.iter().find(|edit| edit.range.end > buffer.len()) {
272 return Err(anyhow!("Edit range {:?} exceeds buffer length", edit.range));
273 }
274
275 let iter = hunk.edits.into_iter().flat_map(move |edit| {
276 let old_text = buffer
277 .text_for_range(context_offset + edit.range.start..context_offset + edit.range.end)
278 .collect::<String>();
279 let edits_within_hunk = language::text_diff(&old_text, &edit.text);
280 edits_within_hunk
281 .into_iter()
282 .map(move |(inner_range, inner_text)| {
283 (
284 buffer.anchor_after(context_offset + edit.range.start + inner_range.start)
285 ..buffer.anchor_before(context_offset + edit.range.start + inner_range.end),
286 inner_text,
287 )
288 })
289 });
290 Ok(iter)
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use gpui::TestAppContext;
297 use indoc::indoc;
298 use pretty_assertions::assert_eq;
299 use project::{FakeFs, Project};
300 use serde_json::json;
301 use settings::SettingsStore;
302 use util::path;
303
304 #[test]
305 fn test_line_number_disambiguation() {
306 // Test that line numbers from hunk headers are used to disambiguate
307 // when context before the operation appears multiple times
308 let content = indoc! {"
309 repeated line
310 first unique
311 repeated line
312 second unique
313 "};
314
315 // Context "repeated line" appears twice - line number selects first occurrence
316 let diff = indoc! {"
317 --- a/file.txt
318 +++ b/file.txt
319 @@ -1,2 +1,2 @@
320 repeated line
321 -first unique
322 +REPLACED
323 "};
324
325 let result = edits_for_diff(content, diff).unwrap();
326 assert_eq!(result.len(), 1);
327
328 // The edit should replace "first unique" (after first "repeated line\n" at offset 14)
329 let (range, text) = &result[0];
330 assert_eq!(range.start, 14);
331 assert_eq!(range.end, 26); // "first unique" is 12 bytes
332 assert_eq!(text, "REPLACED");
333 }
334
335 #[test]
336 fn test_line_number_disambiguation_second_match() {
337 // Test disambiguation when the edit should apply to a later occurrence
338 let content = indoc! {"
339 repeated line
340 first unique
341 repeated line
342 second unique
343 "};
344
345 // Context "repeated line" appears twice - line number selects second occurrence
346 let diff = indoc! {"
347 --- a/file.txt
348 +++ b/file.txt
349 @@ -3,2 +3,2 @@
350 repeated line
351 -second unique
352 +REPLACED
353 "};
354
355 let result = edits_for_diff(content, diff).unwrap();
356 assert_eq!(result.len(), 1);
357
358 // The edit should replace "second unique" (after second "repeated line\n")
359 // Offset: "repeated line\n" (14) + "first unique\n" (13) + "repeated line\n" (14) = 41
360 let (range, text) = &result[0];
361 assert_eq!(range.start, 41);
362 assert_eq!(range.end, 54); // "second unique" is 13 bytes
363 assert_eq!(text, "REPLACED");
364 }
365
366 #[gpui::test]
367 async fn test_apply_diff_successful(cx: &mut TestAppContext) {
368 let fs = init_test(cx);
369
370 let buffer_1_text = indoc! {r#"
371 one
372 two
373 three
374 four
375 five
376 "# };
377
378 let buffer_1_text_final = indoc! {r#"
379 3
380 4
381 5
382 "# };
383
384 let buffer_2_text = indoc! {r#"
385 six
386 seven
387 eight
388 nine
389 ten
390 "# };
391
392 let buffer_2_text_final = indoc! {r#"
393 5
394 six
395 seven
396 7.5
397 eight
398 nine
399 ten
400 11
401 "# };
402
403 fs.insert_tree(
404 path!("/root"),
405 json!({
406 "file1": buffer_1_text,
407 "file2": buffer_2_text,
408 }),
409 )
410 .await;
411
412 let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
413
414 let diff = indoc! {r#"
415 --- a/file1
416 +++ b/file1
417 one
418 two
419 -three
420 +3
421 four
422 five
423 --- a/file1
424 +++ b/file1
425 3
426 -four
427 -five
428 +4
429 +5
430 --- a/file1
431 +++ b/file1
432 -one
433 -two
434 3
435 4
436 --- a/file2
437 +++ b/file2
438 +5
439 six
440 --- a/file2
441 +++ b/file2
442 seven
443 +7.5
444 eight
445 --- a/file2
446 +++ b/file2
447 ten
448 +11
449 "#};
450
451 let _buffers = apply_diff(diff, &project, &mut cx.to_async())
452 .await
453 .unwrap();
454 let buffer_1 = project
455 .update(cx, |project, cx| {
456 let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
457 project.open_buffer(project_path, cx)
458 })
459 .await
460 .unwrap();
461
462 buffer_1.read_with(cx, |buffer, _cx| {
463 assert_eq!(buffer.text(), buffer_1_text_final);
464 });
465 let buffer_2 = project
466 .update(cx, |project, cx| {
467 let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap();
468 project.open_buffer(project_path, cx)
469 })
470 .await
471 .unwrap();
472
473 buffer_2.read_with(cx, |buffer, _cx| {
474 assert_eq!(buffer.text(), buffer_2_text_final);
475 });
476 }
477
478 #[gpui::test]
479 async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
480 let fs = init_test(cx);
481
482 let start = indoc! {r#"
483 one
484 two
485 three
486 four
487 five
488
489 four
490 five
491 "# };
492
493 let end = indoc! {r#"
494 one
495 two
496 3
497 four
498 5
499
500 four
501 five
502 "# };
503
504 fs.insert_tree(
505 path!("/root"),
506 json!({
507 "file1": start,
508 }),
509 )
510 .await;
511
512 let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
513
514 let diff = indoc! {r#"
515 --- a/file1
516 +++ b/file1
517 one
518 two
519 -three
520 +3
521 four
522 -five
523 +5
524 "#};
525
526 let _buffers = apply_diff(diff, &project, &mut cx.to_async())
527 .await
528 .unwrap();
529
530 let buffer_1 = project
531 .update(cx, |project, cx| {
532 let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
533 project.open_buffer(project_path, cx)
534 })
535 .await
536 .unwrap();
537
538 buffer_1.read_with(cx, |buffer, _cx| {
539 assert_eq!(buffer.text(), end);
540 });
541 }
542
543 fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
544 cx.update(|cx| {
545 let settings_store = SettingsStore::test(cx);
546 cx.set_global(settings_store);
547 });
548
549 FakeFs::new(cx.background_executor.clone())
550 }
551
552 #[test]
553 fn test_edits_for_diff() {
554 let content = indoc! {"
555 fn main() {
556 let x = 1;
557 let y = 2;
558 println!(\"{} {}\", x, y);
559 }
560 "};
561
562 let diff = indoc! {"
563 --- a/file.rs
564 +++ b/file.rs
565 @@ -1,5 +1,5 @@
566 fn main() {
567 - let x = 1;
568 + let x = 42;
569 let y = 2;
570 println!(\"{} {}\", x, y);
571 }
572 "};
573
574 let edits = edits_for_diff(content, diff).unwrap();
575 assert_eq!(edits.len(), 1);
576
577 let (range, replacement) = &edits[0];
578 // With sub-line diffing, the edit should start at "1" (the actual changed character)
579 let expected_start = content.find("let x = 1;").unwrap() + "let x = ".len();
580 assert_eq!(range.start, expected_start);
581 // The deleted text is just "1"
582 assert_eq!(range.end, expected_start + "1".len());
583 // The replacement text
584 assert_eq!(replacement, "42");
585
586 // Verify the cursor would be positioned at the column of "1"
587 let line_start = content[..range.start]
588 .rfind('\n')
589 .map(|p| p + 1)
590 .unwrap_or(0);
591 let cursor_column = range.start - line_start;
592 // " let x = " is 12 characters, so column 12
593 assert_eq!(cursor_column, " let x = ".len());
594 }
595
596 #[test]
597 fn test_edits_for_diff_no_trailing_newline() {
598 let content = "foo\nbar\nbaz";
599 let diff = indoc! {"
600 --- a/file.txt
601 +++ b/file.txt
602 @@ -1,3 +1,3 @@
603 foo
604 -bar
605 +qux
606 baz
607 "};
608
609 let result = edits_for_diff(content, diff).unwrap();
610 assert_eq!(result.len(), 1);
611 let (range, text) = &result[0];
612 assert_eq!(&content[range.clone()], "bar");
613 assert_eq!(text, "qux");
614 }
615}