udiff.rs

  1use std::{mem, ops::Range, path::Path, path::PathBuf, sync::Arc};
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use collections::{HashMap, hash_map::Entry};
  5use gpui::{AsyncApp, Entity};
  6use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot, text_diff};
  7use postage::stream::Stream as _;
  8use project::Project;
  9use util::{paths::PathStyle, rel_path::RelPath};
 10use worktree::Worktree;
 11use zeta_prompt::udiff::{
 12    DiffEvent, DiffParser, FileStatus, Hunk, disambiguate_by_line_number, find_context_candidates,
 13};
 14
 15pub use zeta_prompt::udiff::{
 16    DiffLine, HunkLocation, apply_diff_to_string, apply_diff_to_string_with_hunk_offset,
 17    strip_diff_metadata, strip_diff_path_prefix,
 18};
 19
 20#[derive(Clone, Debug)]
 21pub struct OpenedBuffers(HashMap<String, Entity<Buffer>>);
 22
 23impl OpenedBuffers {
 24    pub fn get(&self, path: &str) -> Option<&Entity<Buffer>> {
 25        self.0.get(path)
 26    }
 27
 28    pub fn buffers(&self) -> impl Iterator<Item = &Entity<Buffer>> {
 29        self.0.values()
 30    }
 31}
 32
 33#[must_use]
 34pub async fn apply_diff(
 35    diff_str: &str,
 36    project: &Entity<Project>,
 37    cx: &mut AsyncApp,
 38) -> Result<OpenedBuffers> {
 39    let worktree = project
 40        .read_with(cx, |project, cx| project.visible_worktrees(cx).next())
 41        .context("project has no worktree")?;
 42
 43    let paths: Vec<_> = diff_str
 44        .lines()
 45        .filter_map(|line| {
 46            if let DiffLine::OldPath { path } = DiffLine::parse(line) {
 47                if path != "/dev/null" {
 48                    return Some(PathBuf::from(path.as_ref()));
 49                }
 50            }
 51            None
 52        })
 53        .collect();
 54    refresh_worktree_entries(&worktree, paths.iter().map(|p| p.as_path()), cx).await?;
 55
 56    let mut included_files: HashMap<String, Entity<Buffer>> = HashMap::default();
 57
 58    let mut diff = DiffParser::new(diff_str);
 59    let mut current_file = None;
 60    let mut edits: Vec<(std::ops::Range<Anchor>, Arc<str>)> = vec![];
 61
 62    while let Some(event) = diff.next()? {
 63        match event {
 64            DiffEvent::Hunk { path, hunk, status } => {
 65                if status == FileStatus::Deleted {
 66                    let delete_task = project.update(cx, |project, cx| {
 67                        if let Some(path) = project.find_project_path(path.as_ref(), cx) {
 68                            project.delete_file(path, false, cx)
 69                        } else {
 70                            None
 71                        }
 72                    });
 73
 74                    if let Some(delete_task) = delete_task {
 75                        delete_task.await?;
 76                    };
 77
 78                    continue;
 79                }
 80
 81                let buffer = match current_file {
 82                    None => {
 83                        let buffer = match included_files.entry(path.to_string()) {
 84                            Entry::Occupied(entry) => entry.get().clone(),
 85                            Entry::Vacant(entry) => {
 86                                let buffer: Entity<Buffer> = if status == FileStatus::Created {
 87                                    project
 88                                        .update(cx, |project, cx| {
 89                                            project.create_buffer(None, true, cx)
 90                                        })
 91                                        .await?
 92                                } else {
 93                                    let project_path = project
 94                                        .update(cx, |project, cx| {
 95                                            project.find_project_path(path.as_ref(), cx)
 96                                        })
 97                                        .with_context(|| format!("no such path: {}", path))?;
 98                                    project
 99                                        .update(cx, |project, cx| {
100                                            project.open_buffer(project_path, cx)
101                                        })
102                                        .await?
103                                };
104                                entry.insert(buffer.clone());
105                                buffer
106                            }
107                        };
108                        current_file = Some(buffer);
109                        current_file.as_ref().unwrap()
110                    }
111                    Some(ref current) => current,
112                };
113
114                buffer.read_with(cx, |buffer, _| {
115                    edits.extend(resolve_hunk_edits_in_buffer(
116                        hunk,
117                        buffer,
118                        &[Anchor::min_max_range_for_buffer(buffer.remote_id())],
119                        status,
120                    )?);
121                    anyhow::Ok(())
122                })?;
123            }
124            DiffEvent::FileEnd { renamed_to } => {
125                let buffer = current_file
126                    .take()
127                    .context("Got a FileEnd event before an Hunk event")?;
128
129                if let Some(renamed_to) = renamed_to {
130                    project
131                        .update(cx, |project, cx| {
132                            let new_project_path = project
133                                .find_project_path(Path::new(renamed_to.as_ref()), cx)
134                                .with_context(|| {
135                                    format!("Failed to find worktree for new path: {}", renamed_to)
136                                })?;
137
138                            let project_file = project::File::from_dyn(buffer.read(cx).file())
139                                .expect("Wrong file type");
140
141                            anyhow::Ok(project.rename_entry(
142                                project_file.entry_id.unwrap(),
143                                new_project_path,
144                                cx,
145                            ))
146                        })?
147                        .await?;
148                }
149
150                let edits = mem::take(&mut edits);
151                buffer.update(cx, |buffer, cx| {
152                    buffer.edit(edits, None, cx);
153                });
154            }
155        }
156    }
157
158    Ok(OpenedBuffers(included_files))
159}
160
161pub async fn refresh_worktree_entries(
162    worktree: &Entity<Worktree>,
163    paths: impl IntoIterator<Item = &Path>,
164    cx: &mut AsyncApp,
165) -> Result<()> {
166    let mut rel_paths = Vec::new();
167    for path in paths {
168        if let Ok(rel_path) = RelPath::new(path, PathStyle::Posix) {
169            rel_paths.push(rel_path.into_arc());
170        }
171
172        let path_without_root: PathBuf = path.components().skip(1).collect();
173        if let Ok(rel_path) = RelPath::new(&path_without_root, PathStyle::Posix) {
174            rel_paths.push(rel_path.into_arc());
175        }
176    }
177
178    if !rel_paths.is_empty() {
179        worktree
180            .update(cx, |worktree, _| {
181                worktree
182                    .as_local()
183                    .unwrap()
184                    .refresh_entries_for_paths(rel_paths)
185            })
186            .recv()
187            .await;
188    }
189
190    Ok(())
191}
192
193/// Returns the individual edits that would be applied by a diff to the given content.
194/// Each edit is a tuple of (byte_range_in_content, replacement_text).
195/// Uses sub-line diffing to find the precise character positions of changes.
196/// Returns an empty vec if the hunk context is not found or is ambiguous.
197pub fn edits_for_diff(content: &str, diff_str: &str) -> Result<Vec<(Range<usize>, String)>> {
198    let mut diff = DiffParser::new(diff_str);
199    let mut result = Vec::new();
200
201    while let Some(event) = diff.next()? {
202        match event {
203            DiffEvent::Hunk {
204                mut hunk,
205                path: _,
206                status: _,
207            } => {
208                if hunk.context.is_empty() {
209                    return Ok(Vec::new());
210                }
211
212                let candidates = find_context_candidates(content, &mut hunk);
213
214                let Some(context_offset) =
215                    disambiguate_by_line_number(&candidates, hunk.start_line, &|offset| {
216                        content[..offset].matches('\n').count() as u32
217                    })
218                else {
219                    return Ok(Vec::new());
220                };
221
222                // Use sub-line diffing to find precise edit positions
223                for edit in &hunk.edits {
224                    let old_text = &content
225                        [context_offset + edit.range.start..context_offset + edit.range.end];
226                    let edits_within_hunk = text_diff(old_text, &edit.text);
227                    for (inner_range, inner_text) in edits_within_hunk {
228                        let absolute_start = context_offset + edit.range.start + inner_range.start;
229                        let absolute_end = context_offset + edit.range.start + inner_range.end;
230                        result.push((absolute_start..absolute_end, inner_text.to_string()));
231                    }
232                }
233            }
234            DiffEvent::FileEnd { .. } => {}
235        }
236    }
237
238    Ok(result)
239}
240
241fn resolve_hunk_edits_in_buffer(
242    mut hunk: Hunk,
243    buffer: &TextBufferSnapshot,
244    ranges: &[Range<Anchor>],
245    status: FileStatus,
246) -> Result<impl Iterator<Item = (Range<Anchor>, Arc<str>)>, anyhow::Error> {
247    let context_offset = if status == FileStatus::Created || hunk.context.is_empty() {
248        0
249    } else {
250        let mut candidates: Vec<usize> = Vec::new();
251        for range in ranges {
252            let range = range.to_offset(buffer);
253            let text = buffer.text_for_range(range.clone()).collect::<String>();
254            for ix in find_context_candidates(&text, &mut hunk) {
255                candidates.push(range.start + ix);
256            }
257        }
258
259        disambiguate_by_line_number(&candidates, hunk.start_line, &|offset| {
260            buffer.offset_to_point(offset).row
261        })
262        .ok_or_else(|| {
263            if candidates.is_empty() {
264                anyhow!("Failed to match context:\n\n```\n{}```\n", hunk.context,)
265            } else {
266                anyhow!("Context is not unique enough:\n{}", hunk.context)
267            }
268        })?
269    };
270
271    if let Some(edit) = hunk.edits.iter().find(|edit| edit.range.end > buffer.len()) {
272        return Err(anyhow!("Edit range {:?} exceeds buffer length", edit.range));
273    }
274
275    let iter = hunk.edits.into_iter().flat_map(move |edit| {
276        let old_text = buffer
277            .text_for_range(context_offset + edit.range.start..context_offset + edit.range.end)
278            .collect::<String>();
279        let edits_within_hunk = language::text_diff(&old_text, &edit.text);
280        edits_within_hunk
281            .into_iter()
282            .map(move |(inner_range, inner_text)| {
283                (
284                    buffer.anchor_after(context_offset + edit.range.start + inner_range.start)
285                        ..buffer.anchor_before(context_offset + edit.range.start + inner_range.end),
286                    inner_text,
287                )
288            })
289    });
290    Ok(iter)
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use gpui::TestAppContext;
297    use indoc::indoc;
298    use pretty_assertions::assert_eq;
299    use project::{FakeFs, Project};
300    use serde_json::json;
301    use settings::SettingsStore;
302    use util::path;
303
304    #[test]
305    fn test_line_number_disambiguation() {
306        // Test that line numbers from hunk headers are used to disambiguate
307        // when context before the operation appears multiple times
308        let content = indoc! {"
309            repeated line
310            first unique
311            repeated line
312            second unique
313        "};
314
315        // Context "repeated line" appears twice - line number selects first occurrence
316        let diff = indoc! {"
317            --- a/file.txt
318            +++ b/file.txt
319            @@ -1,2 +1,2 @@
320             repeated line
321            -first unique
322            +REPLACED
323        "};
324
325        let result = edits_for_diff(content, diff).unwrap();
326        assert_eq!(result.len(), 1);
327
328        // The edit should replace "first unique" (after first "repeated line\n" at offset 14)
329        let (range, text) = &result[0];
330        assert_eq!(range.start, 14);
331        assert_eq!(range.end, 26); // "first unique" is 12 bytes
332        assert_eq!(text, "REPLACED");
333    }
334
335    #[test]
336    fn test_line_number_disambiguation_second_match() {
337        // Test disambiguation when the edit should apply to a later occurrence
338        let content = indoc! {"
339            repeated line
340            first unique
341            repeated line
342            second unique
343        "};
344
345        // Context "repeated line" appears twice - line number selects second occurrence
346        let diff = indoc! {"
347            --- a/file.txt
348            +++ b/file.txt
349            @@ -3,2 +3,2 @@
350             repeated line
351            -second unique
352            +REPLACED
353        "};
354
355        let result = edits_for_diff(content, diff).unwrap();
356        assert_eq!(result.len(), 1);
357
358        // The edit should replace "second unique" (after second "repeated line\n")
359        // Offset: "repeated line\n" (14) + "first unique\n" (13) + "repeated line\n" (14) = 41
360        let (range, text) = &result[0];
361        assert_eq!(range.start, 41);
362        assert_eq!(range.end, 54); // "second unique" is 13 bytes
363        assert_eq!(text, "REPLACED");
364    }
365
366    #[gpui::test]
367    async fn test_apply_diff_successful(cx: &mut TestAppContext) {
368        let fs = init_test(cx);
369
370        let buffer_1_text = indoc! {r#"
371            one
372            two
373            three
374            four
375            five
376        "# };
377
378        let buffer_1_text_final = indoc! {r#"
379            3
380            4
381            5
382        "# };
383
384        let buffer_2_text = indoc! {r#"
385            six
386            seven
387            eight
388            nine
389            ten
390        "# };
391
392        let buffer_2_text_final = indoc! {r#"
393            5
394            six
395            seven
396            7.5
397            eight
398            nine
399            ten
400            11
401        "# };
402
403        fs.insert_tree(
404            path!("/root"),
405            json!({
406                "file1": buffer_1_text,
407                "file2": buffer_2_text,
408            }),
409        )
410        .await;
411
412        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
413
414        let diff = indoc! {r#"
415            --- a/file1
416            +++ b/file1
417             one
418             two
419            -three
420            +3
421             four
422             five
423            --- a/file1
424            +++ b/file1
425             3
426            -four
427            -five
428            +4
429            +5
430            --- a/file1
431            +++ b/file1
432            -one
433            -two
434             3
435             4
436            --- a/file2
437            +++ b/file2
438            +5
439             six
440            --- a/file2
441            +++ b/file2
442             seven
443            +7.5
444             eight
445            --- a/file2
446            +++ b/file2
447             ten
448            +11
449        "#};
450
451        let _buffers = apply_diff(diff, &project, &mut cx.to_async())
452            .await
453            .unwrap();
454        let buffer_1 = project
455            .update(cx, |project, cx| {
456                let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
457                project.open_buffer(project_path, cx)
458            })
459            .await
460            .unwrap();
461
462        buffer_1.read_with(cx, |buffer, _cx| {
463            assert_eq!(buffer.text(), buffer_1_text_final);
464        });
465        let buffer_2 = project
466            .update(cx, |project, cx| {
467                let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap();
468                project.open_buffer(project_path, cx)
469            })
470            .await
471            .unwrap();
472
473        buffer_2.read_with(cx, |buffer, _cx| {
474            assert_eq!(buffer.text(), buffer_2_text_final);
475        });
476    }
477
478    #[gpui::test]
479    async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
480        let fs = init_test(cx);
481
482        let start = indoc! {r#"
483            one
484            two
485            three
486            four
487            five
488
489            four
490            five
491        "# };
492
493        let end = indoc! {r#"
494            one
495            two
496            3
497            four
498            5
499
500            four
501            five
502        "# };
503
504        fs.insert_tree(
505            path!("/root"),
506            json!({
507                "file1": start,
508            }),
509        )
510        .await;
511
512        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
513
514        let diff = indoc! {r#"
515            --- a/file1
516            +++ b/file1
517             one
518             two
519            -three
520            +3
521             four
522            -five
523            +5
524        "#};
525
526        let _buffers = apply_diff(diff, &project, &mut cx.to_async())
527            .await
528            .unwrap();
529
530        let buffer_1 = project
531            .update(cx, |project, cx| {
532                let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
533                project.open_buffer(project_path, cx)
534            })
535            .await
536            .unwrap();
537
538        buffer_1.read_with(cx, |buffer, _cx| {
539            assert_eq!(buffer.text(), end);
540        });
541    }
542
543    fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
544        cx.update(|cx| {
545            let settings_store = SettingsStore::test(cx);
546            cx.set_global(settings_store);
547        });
548
549        FakeFs::new(cx.background_executor.clone())
550    }
551
552    #[test]
553    fn test_edits_for_diff() {
554        let content = indoc! {"
555            fn main() {
556                let x = 1;
557                let y = 2;
558                println!(\"{} {}\", x, y);
559            }
560        "};
561
562        let diff = indoc! {"
563            --- a/file.rs
564            +++ b/file.rs
565            @@ -1,5 +1,5 @@
566             fn main() {
567            -    let x = 1;
568            +    let x = 42;
569                 let y = 2;
570                 println!(\"{} {}\", x, y);
571             }
572        "};
573
574        let edits = edits_for_diff(content, diff).unwrap();
575        assert_eq!(edits.len(), 1);
576
577        let (range, replacement) = &edits[0];
578        // With sub-line diffing, the edit should start at "1" (the actual changed character)
579        let expected_start = content.find("let x = 1;").unwrap() + "let x = ".len();
580        assert_eq!(range.start, expected_start);
581        // The deleted text is just "1"
582        assert_eq!(range.end, expected_start + "1".len());
583        // The replacement text
584        assert_eq!(replacement, "42");
585
586        // Verify the cursor would be positioned at the column of "1"
587        let line_start = content[..range.start]
588            .rfind('\n')
589            .map(|p| p + 1)
590            .unwrap_or(0);
591        let cursor_column = range.start - line_start;
592        // "    let x = " is 12 characters, so column 12
593        assert_eq!(cursor_column, "    let x = ".len());
594    }
595
596    #[test]
597    fn test_edits_for_diff_no_trailing_newline() {
598        let content = "foo\nbar\nbaz";
599        let diff = indoc! {"
600            --- a/file.txt
601            +++ b/file.txt
602            @@ -1,3 +1,3 @@
603             foo
604            -bar
605            +qux
606             baz
607        "};
608
609        let result = edits_for_diff(content, diff).unwrap();
610        assert_eq!(result.len(), 1);
611        let (range, text) = &result[0];
612        assert_eq!(&content[range.clone()], "bar");
613        assert_eq!(text, "qux");
614    }
615}