xml_edits.rs

  1use anyhow::{Context as _, Result};
  2use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
  3use std::{cmp, ops::Range, path::Path, sync::Arc};
  4
  5pub async fn parse_xml_edits<'a>(
  6    input: &'a str,
  7    get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
  8) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
  9    parse_xml_edits_inner(input, get_buffer)
 10        .await
 11        .with_context(|| format!("Failed to parse XML edits:\n{input}"))
 12}
 13
 14async fn parse_xml_edits_inner<'a>(
 15    mut input: &'a str,
 16    get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
 17) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
 18    let edits_tag = parse_tag(&mut input, "edits")?.context("No edits tag")?;
 19
 20    input = edits_tag.body;
 21
 22    let file_path = edits_tag
 23        .attributes
 24        .trim_start()
 25        .strip_prefix("path")
 26        .context("no file attribute on edits tag")?
 27        .trim_end()
 28        .strip_prefix('=')
 29        .context("no value for path attribute")?
 30        .trim()
 31        .trim_start_matches('"')
 32        .trim_end_matches('"');
 33
 34    let (buffer, context_ranges) = get_buffer(file_path.as_ref())
 35        .with_context(|| format!("no buffer for file {file_path}"))?;
 36
 37    let mut edits = vec![];
 38    while let Some(old_text_tag) = parse_tag(&mut input, "old_text")? {
 39        let new_text_tag =
 40            parse_tag(&mut input, "new_text")?.context("no new_text tag following old_text")?;
 41        let match_range = fuzzy_match_in_ranges(old_text_tag.body, buffer, context_ranges)?;
 42        let old_text = buffer
 43            .text_for_range(match_range.clone())
 44            .collect::<String>();
 45        let edits_within_hunk = language::text_diff(&old_text, &new_text_tag.body);
 46        edits.extend(
 47            edits_within_hunk
 48                .into_iter()
 49                .map(move |(inner_range, inner_text)| {
 50                    (
 51                        buffer.anchor_after(match_range.start + inner_range.start)
 52                            ..buffer.anchor_before(match_range.start + inner_range.end),
 53                        inner_text,
 54                    )
 55                }),
 56        );
 57    }
 58
 59    Ok((buffer, edits))
 60}
 61
 62fn fuzzy_match_in_ranges(
 63    old_text: &str,
 64    buffer: &BufferSnapshot,
 65    context_ranges: &[Range<Anchor>],
 66) -> Result<Range<usize>> {
 67    let mut state = FuzzyMatcher::new(buffer, old_text);
 68    let mut best_match = None;
 69    let mut tie_match_range = None;
 70
 71    for range in context_ranges {
 72        let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
 73        match (best_match_cost, state.match_range(range.to_offset(buffer))) {
 74            (Some(lowest_cost), Some((new_cost, new_range))) => {
 75                if new_cost == lowest_cost {
 76                    tie_match_range = Some(new_range);
 77                } else if new_cost < lowest_cost {
 78                    tie_match_range.take();
 79                    best_match = Some((new_cost, new_range));
 80                }
 81            }
 82            (None, Some(new_match)) => {
 83                best_match = Some(new_match);
 84            }
 85            (None, None) | (Some(_), None) => {}
 86        };
 87    }
 88
 89    if let Some((_, best_match_range)) = best_match {
 90        if let Some(tie_match_range) = tie_match_range {
 91            anyhow::bail!(
 92                "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
 93                best_match_range.clone(),
 94                buffer.text_for_range(best_match_range).collect::<String>(),
 95                tie_match_range.clone(),
 96                buffer.text_for_range(tie_match_range).collect::<String>()
 97            );
 98        }
 99        return Ok(best_match_range);
100    }
101
102    anyhow::bail!(
103        "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
104        old_text,
105        context_ranges
106            .iter()
107            .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
108            .collect::<Vec<String>>()
109            .join("```\n```")
110    );
111}
112
113struct ParsedTag<'a> {
114    attributes: &'a str,
115    body: &'a str,
116}
117
118fn parse_tag<'a>(input: &mut &'a str, tag: &str) -> Result<Option<ParsedTag<'a>>> {
119    let open_tag = format!("<{}", tag);
120    let close_tag = format!("</{}>", tag);
121    let Some(start_ix) = input.find(&open_tag) else {
122        return Ok(None);
123    };
124    let start_ix = start_ix + open_tag.len();
125    let closing_bracket_ix = start_ix
126        + input[start_ix..]
127            .find('>')
128            .with_context(|| format!("missing > after {tag}"))?;
129    let attributes = &input[start_ix..closing_bracket_ix].trim();
130    let end_ix = closing_bracket_ix
131        + input[closing_bracket_ix..]
132            .find(&close_tag)
133            .with_context(|| format!("no `{close_tag}` tag"))?;
134    let body = &input[closing_bracket_ix + '>'.len_utf8()..end_ix];
135    let body = body.strip_prefix('\n').unwrap_or(body);
136    let body = body.strip_suffix('\n').unwrap_or(body);
137    *input = &input[end_ix + close_tag.len()..];
138    Ok(Some(ParsedTag { attributes, body }))
139}
140
141const REPLACEMENT_COST: u32 = 1;
142const INSERTION_COST: u32 = 3;
143const DELETION_COST: u32 = 10;
144
145/// A fuzzy matcher that can process text chunks incrementally
146/// and return the best match found so far at each step.
147struct FuzzyMatcher<'a> {
148    snapshot: &'a BufferSnapshot,
149    query_lines: Vec<&'a str>,
150    matrix: SearchMatrix,
151}
152
153impl<'a> FuzzyMatcher<'a> {
154    fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
155        let query_lines = old_text.lines().collect();
156        Self {
157            snapshot,
158            query_lines,
159            matrix: SearchMatrix::new(0),
160        }
161    }
162
163    fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
164        let point_range = range.to_point(&self.snapshot);
165        let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
166
167        self.matrix
168            .reset(self.query_lines.len() + 1, buffer_line_count + 1);
169        let query_line_count = self.query_lines.len();
170
171        for row in 0..query_line_count {
172            let query_line = self.query_lines[row].trim();
173            let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
174
175            self.matrix.set(
176                row + 1,
177                0,
178                SearchState::new(leading_deletion_cost, SearchDirection::Up),
179            );
180
181            let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
182
183            let mut col = 0;
184            while let Some(buffer_line) = buffer_lines.next() {
185                let buffer_line = buffer_line.trim();
186                let up = SearchState::new(
187                    self.matrix
188                        .get(row, col + 1)
189                        .cost
190                        .saturating_add(DELETION_COST),
191                    SearchDirection::Up,
192                );
193                let left = SearchState::new(
194                    self.matrix
195                        .get(row + 1, col)
196                        .cost
197                        .saturating_add(INSERTION_COST),
198                    SearchDirection::Left,
199                );
200                let diagonal = SearchState::new(
201                    if query_line == buffer_line {
202                        self.matrix.get(row, col).cost
203                    } else if fuzzy_eq(query_line, buffer_line) {
204                        self.matrix.get(row, col).cost + REPLACEMENT_COST
205                    } else {
206                        self.matrix
207                            .get(row, col)
208                            .cost
209                            .saturating_add(DELETION_COST + INSERTION_COST)
210                    },
211                    SearchDirection::Diagonal,
212                );
213                self.matrix
214                    .set(row + 1, col + 1, up.min(left).min(diagonal));
215                col += 1;
216            }
217        }
218
219        // Find all matches with the best cost
220        let mut best_cost = u32::MAX;
221        let mut matches_with_best_cost = Vec::new();
222
223        for col in 1..=buffer_line_count {
224            let cost = self.matrix.get(query_line_count, col).cost;
225            if cost < best_cost {
226                best_cost = cost;
227                matches_with_best_cost.clear();
228                matches_with_best_cost.push(col as u32);
229            } else if cost == best_cost {
230                matches_with_best_cost.push(col as u32);
231            }
232        }
233
234        // Find ranges for the matches
235        for &match_end_col in &matches_with_best_cost {
236            let mut matched_lines = 0;
237            let mut query_row = query_line_count;
238            let mut match_start_col = match_end_col;
239            while query_row > 0 && match_start_col > 0 {
240                let current = self.matrix.get(query_row, match_start_col as usize);
241                match current.direction {
242                    SearchDirection::Diagonal => {
243                        query_row -= 1;
244                        match_start_col -= 1;
245                        matched_lines += 1;
246                    }
247                    SearchDirection::Up => {
248                        query_row -= 1;
249                    }
250                    SearchDirection::Left => {
251                        match_start_col -= 1;
252                    }
253                }
254            }
255
256            let buffer_row_start = match_start_col + point_range.start.row;
257            let buffer_row_end = match_end_col + point_range.start.row;
258
259            let matched_buffer_row_count = buffer_row_end - buffer_row_start;
260            let matched_ratio = matched_lines as f32
261                / (matched_buffer_row_count as f32).max(query_line_count as f32);
262            if matched_ratio >= 0.8 {
263                let buffer_start_ix = self
264                    .snapshot
265                    .point_to_offset(Point::new(buffer_row_start, 0));
266                let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
267                    buffer_row_end - 1,
268                    self.snapshot.line_len(buffer_row_end - 1),
269                ));
270                return Some((best_cost, buffer_start_ix..buffer_end_ix));
271            }
272        }
273
274        None
275    }
276}
277
278fn fuzzy_eq(left: &str, right: &str) -> bool {
279    const THRESHOLD: f64 = 0.8;
280
281    let min_levenshtein = left.len().abs_diff(right.len());
282    let min_normalized_levenshtein =
283        1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
284    if min_normalized_levenshtein < THRESHOLD {
285        return false;
286    }
287
288    strsim::normalized_levenshtein(left, right) >= THRESHOLD
289}
290
291#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
292enum SearchDirection {
293    Up,
294    Left,
295    Diagonal,
296}
297
298#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
299struct SearchState {
300    cost: u32,
301    direction: SearchDirection,
302}
303
304impl SearchState {
305    fn new(cost: u32, direction: SearchDirection) -> Self {
306        Self { cost, direction }
307    }
308}
309
310struct SearchMatrix {
311    cols: usize,
312    rows: usize,
313    data: Vec<SearchState>,
314}
315
316impl SearchMatrix {
317    fn new(cols: usize) -> Self {
318        SearchMatrix {
319            cols,
320            rows: 0,
321            data: Vec::new(),
322        }
323    }
324
325    fn reset(&mut self, rows: usize, cols: usize) {
326        self.rows = rows;
327        self.cols = cols;
328        self.data
329            .fill(SearchState::new(0, SearchDirection::Diagonal));
330        self.data.resize(
331            self.rows * self.cols,
332            SearchState::new(0, SearchDirection::Diagonal),
333        );
334    }
335
336    fn get(&self, row: usize, col: usize) -> SearchState {
337        debug_assert!(row < self.rows);
338        debug_assert!(col < self.cols);
339        self.data[row * self.cols + col]
340    }
341
342    fn set(&mut self, row: usize, col: usize, state: SearchState) {
343        debug_assert!(row < self.rows && col < self.cols);
344        self.data[row * self.cols + col] = state;
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use gpui::TestAppContext;
352    use indoc::indoc;
353    use language::Point;
354    use project::{FakeFs, Project};
355    use serde_json::json;
356    use settings::SettingsStore;
357    use util::path;
358
359    #[test]
360    fn test_parse_tags() {
361        let mut input = indoc! {r#"
362            Prelude
363            <tag attr="foo">
364            tag value
365            </tag>
366            "# };
367        let parsed = parse_tag(&mut input, "tag").unwrap().unwrap();
368        assert_eq!(parsed.attributes, "attr=\"foo\"");
369        assert_eq!(parsed.body, "tag value");
370        assert_eq!(input, "\n");
371    }
372
373    #[gpui::test]
374    async fn test_parse_xml_edits(cx: &mut TestAppContext) {
375        let fs = init_test(cx);
376
377        let buffer_1_text = indoc! {r#"
378            one two three four
379            five six seven eight
380            nine ten eleven twelve
381            thirteen fourteen fifteen
382            sixteen seventeen eighteen
383        "#};
384
385        fs.insert_tree(
386            path!("/root"),
387            json!({
388                "file1": buffer_1_text,
389            }),
390        )
391        .await;
392
393        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
394        let buffer = project
395            .update(cx, |project, cx| {
396                project.open_local_buffer(path!("/root/file1"), cx)
397            })
398            .await
399            .unwrap();
400        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
401
402        let edits = indoc! {r#"
403            <edits path="root/file1">
404            <old_text>
405            nine ten eleven twelve
406            </old_text>
407            <new_text>
408            nine TEN eleven twelve!
409            </new_text>
410            </edits>
411        "#};
412
413        let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
414        let (buffer, edits) = parse_xml_edits(edits, |_path| {
415            Some((&buffer_snapshot, included_ranges.as_slice()))
416        })
417        .await
418        .unwrap();
419
420        let edits = edits
421            .into_iter()
422            .map(|(range, text)| (range.to_point(&buffer), text))
423            .collect::<Vec<_>>();
424        assert_eq!(
425            edits,
426            &[
427                (Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
428                (Point::new(2, 22)..Point::new(2, 22), "!".into())
429            ]
430        );
431    }
432
433    fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
434        cx.update(|cx| {
435            let settings_store = SettingsStore::test(cx);
436            cx.set_global(settings_store);
437        });
438
439        FakeFs::new(cx.background_executor.clone())
440    }
441}