xml_edits.rs

  1use anyhow::{Context as _, Result};
  2use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
  3use std::{cmp, ops::Range, path::Path, sync::Arc};
  4
  5const EDITS_TAG_NAME: &'static str = "edits";
  6const OLD_TEXT_TAG_NAME: &'static str = "old_text";
  7const NEW_TEXT_TAG_NAME: &'static str = "new_text";
  8const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
  9
 10pub async fn parse_xml_edits<'a>(
 11    input: &'a str,
 12    get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
 13) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
 14    parse_xml_edits_inner(input, get_buffer)
 15        .await
 16        .with_context(|| format!("Failed to parse XML edits:\n{input}"))
 17}
 18
 19async fn parse_xml_edits_inner<'a>(
 20    input: &'a str,
 21    get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
 22) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
 23    let xml_edits = extract_xml_replacements(input)?;
 24
 25    let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
 26        .with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
 27
 28    let mut all_edits = vec![];
 29    for (old_text, new_text) in xml_edits.replacements {
 30        let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
 31        let matched_old_text = buffer
 32            .text_for_range(match_range.clone())
 33            .collect::<String>();
 34        let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
 35        all_edits.extend(
 36            edits_within_hunk
 37                .into_iter()
 38                .map(move |(inner_range, inner_text)| {
 39                    (
 40                        buffer.anchor_after(match_range.start + inner_range.start)
 41                            ..buffer.anchor_before(match_range.start + inner_range.end),
 42                        inner_text,
 43                    )
 44                }),
 45        );
 46    }
 47
 48    Ok((buffer, all_edits))
 49}
 50
 51fn fuzzy_match_in_ranges(
 52    old_text: &str,
 53    buffer: &BufferSnapshot,
 54    context_ranges: &[Range<Anchor>],
 55) -> Result<Range<usize>> {
 56    let mut state = FuzzyMatcher::new(buffer, old_text);
 57    let mut best_match = None;
 58    let mut tie_match_range = None;
 59
 60    for range in context_ranges {
 61        let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
 62        match (best_match_cost, state.match_range(range.to_offset(buffer))) {
 63            (Some(lowest_cost), Some((new_cost, new_range))) => {
 64                if new_cost == lowest_cost {
 65                    tie_match_range = Some(new_range);
 66                } else if new_cost < lowest_cost {
 67                    tie_match_range.take();
 68                    best_match = Some((new_cost, new_range));
 69                }
 70            }
 71            (None, Some(new_match)) => {
 72                best_match = Some(new_match);
 73            }
 74            (None, None) | (Some(_), None) => {}
 75        };
 76    }
 77
 78    if let Some((_, best_match_range)) = best_match {
 79        if let Some(tie_match_range) = tie_match_range {
 80            anyhow::bail!(
 81                "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
 82                best_match_range.clone(),
 83                buffer.text_for_range(best_match_range).collect::<String>(),
 84                tie_match_range.clone(),
 85                buffer.text_for_range(tie_match_range).collect::<String>()
 86            );
 87        }
 88        return Ok(best_match_range);
 89    }
 90
 91    anyhow::bail!(
 92        "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
 93        old_text,
 94        context_ranges
 95            .iter()
 96            .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
 97            .collect::<Vec<String>>()
 98            .join("```\n```")
 99    );
100}
101
102#[derive(Debug)]
103struct XmlEdits<'a> {
104    file_path: &'a str,
105    /// Vec of (old_text, new_text) pairs
106    replacements: Vec<(&'a str, &'a str)>,
107}
108
109fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
110    let mut cursor = 0;
111
112    let (edits_body_start, edits_attrs) =
113        find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
114
115    let file_path = edits_attrs
116        .trim_start()
117        .strip_prefix("path")
118        .context("no path attribute on edits tag")?
119        .trim_end()
120        .strip_prefix('=')
121        .context("no value for path attribute")?
122        .trim()
123        .trim_start_matches('"')
124        .trim_end_matches('"');
125
126    cursor = edits_body_start;
127    let mut edits_list = Vec::new();
128
129    while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
130        let old_body_end = find_tag_close(input, &mut cursor)?;
131        let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
132
133        let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
134            .context("no new_text tag following old_text")?;
135        let new_body_end = find_tag_close(input, &mut cursor)?;
136        let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
137
138        edits_list.push((old_text, new_text));
139    }
140
141    Ok(XmlEdits {
142        file_path,
143        replacements: edits_list,
144    })
145}
146
147/// Trims a single leading and trailing newline
148fn trim_surrounding_newlines(input: &str) -> &str {
149    let start = input.strip_prefix('\n').unwrap_or(input);
150    let end = start.strip_suffix('\n').unwrap_or(start);
151    end
152}
153
154fn find_tag_open<'a>(
155    input: &'a str,
156    cursor: &mut usize,
157    expected_tag: &str,
158) -> Result<Option<(usize, &'a str)>> {
159    let mut search_pos = *cursor;
160
161    while search_pos < input.len() {
162        let Some(tag_start) = input[search_pos..].find("<") else {
163            break;
164        };
165        let tag_start = search_pos + tag_start;
166        if !input[tag_start + 1..].starts_with(expected_tag) {
167            search_pos = search_pos + tag_start + 1;
168            continue;
169        };
170
171        let after_tag_name = tag_start + expected_tag.len() + 1;
172        let close_bracket = input[after_tag_name..]
173            .find('>')
174            .with_context(|| format!("missing > after <{}", expected_tag))?;
175        let attrs_end = after_tag_name + close_bracket;
176        let body_start = attrs_end + 1;
177
178        let attributes = input[after_tag_name..attrs_end].trim();
179        *cursor = body_start;
180
181        return Ok(Some((body_start, attributes)));
182    }
183
184    Ok(None)
185}
186
187fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
188    let mut depth = 1;
189    let mut search_pos = *cursor;
190
191    while search_pos < input.len() && depth > 0 {
192        let Some(bracket_offset) = input[search_pos..].find('<') else {
193            break;
194        };
195        let bracket_pos = search_pos + bracket_offset;
196
197        if input[bracket_pos..].starts_with("</")
198            && let Some(close_end) = input[bracket_pos + 2..].find('>')
199        {
200            let close_start = bracket_pos + 2;
201            let tag_name = input[close_start..close_start + close_end].trim();
202
203            if XML_TAGS.contains(&tag_name) {
204                depth -= 1;
205                if depth == 0 {
206                    *cursor = close_start + close_end + 1;
207                    return Ok(bracket_pos);
208                }
209            }
210            search_pos = close_start + close_end + 1;
211            continue;
212        } else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
213            let close_bracket_pos = bracket_pos + close_bracket_offset;
214            let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
215            if XML_TAGS.contains(&tag_name) {
216                depth += 1;
217            }
218        }
219
220        search_pos = bracket_pos + 1;
221    }
222
223    anyhow::bail!("no closing tag found")
224}
225
226const REPLACEMENT_COST: u32 = 1;
227const INSERTION_COST: u32 = 3;
228const DELETION_COST: u32 = 10;
229
230/// A fuzzy matcher that can process text chunks incrementally
231/// and return the best match found so far at each step.
232struct FuzzyMatcher<'a> {
233    snapshot: &'a BufferSnapshot,
234    query_lines: Vec<&'a str>,
235    matrix: SearchMatrix,
236}
237
238impl<'a> FuzzyMatcher<'a> {
239    fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
240        let query_lines = old_text.lines().collect();
241        Self {
242            snapshot,
243            query_lines,
244            matrix: SearchMatrix::new(0),
245        }
246    }
247
248    fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
249        let point_range = range.to_point(&self.snapshot);
250        let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
251
252        self.matrix
253            .reset(self.query_lines.len() + 1, buffer_line_count + 1);
254        let query_line_count = self.query_lines.len();
255
256        for row in 0..query_line_count {
257            let query_line = self.query_lines[row].trim();
258            let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
259
260            self.matrix.set(
261                row + 1,
262                0,
263                SearchState::new(leading_deletion_cost, SearchDirection::Up),
264            );
265
266            let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
267
268            let mut col = 0;
269            while let Some(buffer_line) = buffer_lines.next() {
270                let buffer_line = buffer_line.trim();
271                let up = SearchState::new(
272                    self.matrix
273                        .get(row, col + 1)
274                        .cost
275                        .saturating_add(DELETION_COST),
276                    SearchDirection::Up,
277                );
278                let left = SearchState::new(
279                    self.matrix
280                        .get(row + 1, col)
281                        .cost
282                        .saturating_add(INSERTION_COST),
283                    SearchDirection::Left,
284                );
285                let diagonal = SearchState::new(
286                    if query_line == buffer_line {
287                        self.matrix.get(row, col).cost
288                    } else if fuzzy_eq(query_line, buffer_line) {
289                        self.matrix.get(row, col).cost + REPLACEMENT_COST
290                    } else {
291                        self.matrix
292                            .get(row, col)
293                            .cost
294                            .saturating_add(DELETION_COST + INSERTION_COST)
295                    },
296                    SearchDirection::Diagonal,
297                );
298                self.matrix
299                    .set(row + 1, col + 1, up.min(left).min(diagonal));
300                col += 1;
301            }
302        }
303
304        // Find all matches with the best cost
305        let mut best_cost = u32::MAX;
306        let mut matches_with_best_cost = Vec::new();
307
308        for col in 1..=buffer_line_count {
309            let cost = self.matrix.get(query_line_count, col).cost;
310            if cost < best_cost {
311                best_cost = cost;
312                matches_with_best_cost.clear();
313                matches_with_best_cost.push(col as u32);
314            } else if cost == best_cost {
315                matches_with_best_cost.push(col as u32);
316            }
317        }
318
319        // Find ranges for the matches
320        for &match_end_col in &matches_with_best_cost {
321            let mut matched_lines = 0;
322            let mut query_row = query_line_count;
323            let mut match_start_col = match_end_col;
324            while query_row > 0 && match_start_col > 0 {
325                let current = self.matrix.get(query_row, match_start_col as usize);
326                match current.direction {
327                    SearchDirection::Diagonal => {
328                        query_row -= 1;
329                        match_start_col -= 1;
330                        matched_lines += 1;
331                    }
332                    SearchDirection::Up => {
333                        query_row -= 1;
334                    }
335                    SearchDirection::Left => {
336                        match_start_col -= 1;
337                    }
338                }
339            }
340
341            let buffer_row_start = match_start_col + point_range.start.row;
342            let buffer_row_end = match_end_col + point_range.start.row;
343
344            let matched_buffer_row_count = buffer_row_end - buffer_row_start;
345            let matched_ratio = matched_lines as f32
346                / (matched_buffer_row_count as f32).max(query_line_count as f32);
347            if matched_ratio >= 0.8 {
348                let buffer_start_ix = self
349                    .snapshot
350                    .point_to_offset(Point::new(buffer_row_start, 0));
351                let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
352                    buffer_row_end - 1,
353                    self.snapshot.line_len(buffer_row_end - 1),
354                ));
355                return Some((best_cost, buffer_start_ix..buffer_end_ix));
356            }
357        }
358
359        None
360    }
361}
362
363fn fuzzy_eq(left: &str, right: &str) -> bool {
364    const THRESHOLD: f64 = 0.8;
365
366    let min_levenshtein = left.len().abs_diff(right.len());
367    let min_normalized_levenshtein =
368        1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
369    if min_normalized_levenshtein < THRESHOLD {
370        return false;
371    }
372
373    strsim::normalized_levenshtein(left, right) >= THRESHOLD
374}
375
376#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
377enum SearchDirection {
378    Up,
379    Left,
380    Diagonal,
381}
382
383#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
384struct SearchState {
385    cost: u32,
386    direction: SearchDirection,
387}
388
389impl SearchState {
390    fn new(cost: u32, direction: SearchDirection) -> Self {
391        Self { cost, direction }
392    }
393}
394
395struct SearchMatrix {
396    cols: usize,
397    rows: usize,
398    data: Vec<SearchState>,
399}
400
401impl SearchMatrix {
402    fn new(cols: usize) -> Self {
403        SearchMatrix {
404            cols,
405            rows: 0,
406            data: Vec::new(),
407        }
408    }
409
410    fn reset(&mut self, rows: usize, cols: usize) {
411        self.rows = rows;
412        self.cols = cols;
413        self.data
414            .fill(SearchState::new(0, SearchDirection::Diagonal));
415        self.data.resize(
416            self.rows * self.cols,
417            SearchState::new(0, SearchDirection::Diagonal),
418        );
419    }
420
421    fn get(&self, row: usize, col: usize) -> SearchState {
422        debug_assert!(row < self.rows);
423        debug_assert!(col < self.cols);
424        self.data[row * self.cols + col]
425    }
426
427    fn set(&mut self, row: usize, col: usize, state: SearchState) {
428        debug_assert!(row < self.rows && col < self.cols);
429        self.data[row * self.cols + col] = state;
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use gpui::TestAppContext;
437    use indoc::indoc;
438    use language::Point;
439    use project::{FakeFs, Project};
440    use serde_json::json;
441    use settings::SettingsStore;
442    use util::path;
443
444    #[test]
445    fn test_extract_xml_edits() {
446        let input = indoc! {r#"
447            <edits path="test.rs">
448            <old_text>
449            old content
450            </old_text>
451            <new_text>
452            new content
453            </new_text>
454            </edits>
455        "#};
456
457        let result = extract_xml_replacements(input).unwrap();
458        assert_eq!(result.file_path, "test.rs");
459        assert_eq!(result.replacements.len(), 1);
460        assert_eq!(result.replacements[0].0, "old content");
461        assert_eq!(result.replacements[0].1, "new content");
462    }
463
464    #[test]
465    fn test_extract_xml_edits_with_wrong_closing_tags() {
466        let input = indoc! {r#"
467            <edits path="test.rs">
468            <old_text>
469            old content
470            </new_text>
471            <new_text>
472            new content
473            </old_text>
474            </ edits >
475        "#};
476
477        let result = extract_xml_replacements(input).unwrap();
478        assert_eq!(result.file_path, "test.rs");
479        assert_eq!(result.replacements.len(), 1);
480        assert_eq!(result.replacements[0].0, "old content");
481        assert_eq!(result.replacements[0].1, "new content");
482    }
483
484    #[test]
485    fn test_extract_xml_edits_with_xml_like_content() {
486        let input = indoc! {r#"
487            <edits path="component.tsx">
488            <old_text>
489            <foo><bar></bar></foo>
490            </old_text>
491            <new_text>
492            <foo><bar><baz></baz></bar></foo>
493            </new_text>
494            </edits>
495        "#};
496
497        let result = extract_xml_replacements(input).unwrap();
498        assert_eq!(result.file_path, "component.tsx");
499        assert_eq!(result.replacements.len(), 1);
500        assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
501        assert_eq!(
502            result.replacements[0].1,
503            "<foo><bar><baz></baz></bar></foo>"
504        );
505    }
506
507    #[test]
508    fn test_extract_xml_edits_with_conflicting_content() {
509        let input = indoc! {r#"
510            <edits path="component.tsx">
511            <old_text>
512            <new_text></new_text>
513            </old_text>
514            <new_text>
515            <old_text></old_text>
516            </new_text>
517            </edits>
518        "#};
519
520        let result = extract_xml_replacements(input).unwrap();
521        assert_eq!(result.file_path, "component.tsx");
522        assert_eq!(result.replacements.len(), 1);
523        assert_eq!(result.replacements[0].0, "<new_text></new_text>");
524        assert_eq!(result.replacements[0].1, "<old_text></old_text>");
525    }
526
527    #[test]
528    fn test_extract_xml_edits_multiple_pairs() {
529        let input = indoc! {r#"
530            Some reasoning before edits. Lots of thinking going on here
531
532            <edits path="test.rs">
533            <old_text>
534            first old
535            </old_text>
536            <new_text>
537            first new
538            </new_text>
539            <old_text>
540            second old
541            </edits>
542            <new_text>
543            second new
544            </old_text>
545            </edits>
546        "#};
547
548        let result = extract_xml_replacements(input).unwrap();
549        assert_eq!(result.file_path, "test.rs");
550        assert_eq!(result.replacements.len(), 2);
551        assert_eq!(result.replacements[0].0, "first old");
552        assert_eq!(result.replacements[0].1, "first new");
553        assert_eq!(result.replacements[1].0, "second old");
554        assert_eq!(result.replacements[1].1, "second new");
555    }
556
557    #[test]
558    fn test_extract_xml_edits_unexpected_eof() {
559        let input = indoc! {r#"
560            <edits path="test.rs">
561            <old_text>
562            first old
563            </
564        "#};
565
566        extract_xml_replacements(input).expect_err("Unexpected end of file");
567    }
568
569    #[gpui::test]
570    async fn test_parse_xml_edits(cx: &mut TestAppContext) {
571        let fs = init_test(cx);
572
573        let buffer_1_text = indoc! {r#"
574            one two three four
575            five six seven eight
576            nine ten eleven twelve
577            thirteen fourteen fifteen
578            sixteen seventeen eighteen
579        "#};
580
581        fs.insert_tree(
582            path!("/root"),
583            json!({
584                "file1": buffer_1_text,
585            }),
586        )
587        .await;
588
589        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
590        let buffer = project
591            .update(cx, |project, cx| {
592                project.open_local_buffer(path!("/root/file1"), cx)
593            })
594            .await
595            .unwrap();
596        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
597
598        let edits = indoc! {r#"
599            <edits path="root/file1">
600            <old_text>
601            nine ten eleven twelve
602            </old_text>
603            <new_text>
604            nine TEN eleven twelve!
605            </new_text>
606            </edits>
607        "#};
608
609        let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
610        let (buffer, edits) = parse_xml_edits(edits, |_path| {
611            Some((&buffer_snapshot, included_ranges.as_slice()))
612        })
613        .await
614        .unwrap();
615
616        let edits = edits
617            .into_iter()
618            .map(|(range, text)| (range.to_point(&buffer), text))
619            .collect::<Vec<_>>();
620        assert_eq!(
621            edits,
622            &[
623                (Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
624                (Point::new(2, 22)..Point::new(2, 22), "!".into())
625            ]
626        );
627    }
628
629    fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
630        cx.update(|cx| {
631            let settings_store = SettingsStore::test(cx);
632            cx.set_global(settings_store);
633        });
634
635        FakeFs::new(cx.background_executor.clone())
636    }
637}