resolve_search_block.rs

  1use language::{Anchor, Bias, BufferSnapshot};
  2use std::ops::Range;
  3
  4#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
  5enum SearchDirection {
  6    Up,
  7    Left,
  8    Diagonal,
  9}
 10
 11#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
 12struct SearchState {
 13    cost: u32,
 14    direction: SearchDirection,
 15}
 16
 17impl SearchState {
 18    fn new(cost: u32, direction: SearchDirection) -> Self {
 19        Self { cost, direction }
 20    }
 21}
 22
 23struct SearchMatrix {
 24    cols: usize,
 25    data: Vec<SearchState>,
 26}
 27
 28impl SearchMatrix {
 29    fn new(rows: usize, cols: usize) -> Self {
 30        SearchMatrix {
 31            cols,
 32            data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
 33        }
 34    }
 35
 36    fn get(&self, row: usize, col: usize) -> SearchState {
 37        self.data[row * self.cols + col]
 38    }
 39
 40    fn set(&mut self, row: usize, col: usize, cost: SearchState) {
 41        self.data[row * self.cols + col] = cost;
 42    }
 43}
 44
 45pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
 46    const INSERTION_COST: u32 = 3;
 47    const DELETION_COST: u32 = 10;
 48    const WHITESPACE_INSERTION_COST: u32 = 1;
 49    const WHITESPACE_DELETION_COST: u32 = 1;
 50
 51    let buffer_len = buffer.len();
 52    let query_len = search_query.len();
 53    let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
 54    let mut leading_deletion_cost = 0_u32;
 55    for (row, query_byte) in search_query.bytes().enumerate() {
 56        let deletion_cost = if query_byte.is_ascii_whitespace() {
 57            WHITESPACE_DELETION_COST
 58        } else {
 59            DELETION_COST
 60        };
 61
 62        leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
 63        matrix.set(
 64            row + 1,
 65            0,
 66            SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
 67        );
 68
 69        for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
 70            let insertion_cost = if buffer_byte.is_ascii_whitespace() {
 71                WHITESPACE_INSERTION_COST
 72            } else {
 73                INSERTION_COST
 74            };
 75
 76            let up = SearchState::new(
 77                matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
 78                SearchDirection::Up,
 79            );
 80            let left = SearchState::new(
 81                matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
 82                SearchDirection::Left,
 83            );
 84            let diagonal = SearchState::new(
 85                if query_byte == *buffer_byte {
 86                    matrix.get(row, col).cost
 87                } else {
 88                    matrix
 89                        .get(row, col)
 90                        .cost
 91                        .saturating_add(deletion_cost + insertion_cost)
 92                },
 93                SearchDirection::Diagonal,
 94            );
 95            matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
 96        }
 97    }
 98
 99    // Traceback to find the best match
100    let mut best_buffer_end = buffer_len;
101    let mut best_cost = u32::MAX;
102    for col in 1..=buffer_len {
103        let cost = matrix.get(query_len, col).cost;
104        if cost < best_cost {
105            best_cost = cost;
106            best_buffer_end = col;
107        }
108    }
109
110    let mut query_ix = query_len;
111    let mut buffer_ix = best_buffer_end;
112    while query_ix > 0 && buffer_ix > 0 {
113        let current = matrix.get(query_ix, buffer_ix);
114        match current.direction {
115            SearchDirection::Diagonal => {
116                query_ix -= 1;
117                buffer_ix -= 1;
118            }
119            SearchDirection::Up => {
120                query_ix -= 1;
121            }
122            SearchDirection::Left => {
123                buffer_ix -= 1;
124            }
125        }
126    }
127
128    let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
129    start.column = 0;
130    let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
131    if end.column > 0 {
132        end.column = buffer.line_len(end.row);
133    }
134
135    buffer.anchor_after(start)..buffer.anchor_before(end)
136}
137
138#[cfg(test)]
139mod tests {
140    use crate::edit_files_tool::resolve_search_block::resolve_search_block;
141    use gpui::{prelude::*, App};
142    use language::{Buffer, OffsetRangeExt as _};
143    use unindent::Unindent as _;
144    use util::test::{generate_marked_text, marked_text_ranges};
145
146    #[gpui::test]
147    fn test_resolve_search_block(cx: &mut App) {
148        assert_resolved(
149            concat!(
150                "    Lorem\n",
151                "«    ipsum\n",
152                "    dolor sit amet»\n",
153                "    consecteur",
154            ),
155            "ipsum\ndolor",
156            cx,
157        );
158
159        assert_resolved(
160            &"
161            «fn foo1(a: usize) -> usize {
162                40
163164
165            fn foo2(b: usize) -> usize {
166                42
167            }
168            "
169            .unindent(),
170            "fn foo1(b: usize) {\n40\n}",
171            cx,
172        );
173
174        assert_resolved(
175            &"
176            fn main() {
177            «    Foo
178                    .bar()
179                    .baz()
180                    .qux()»
181            }
182
183            fn foo2(b: usize) -> usize {
184                42
185            }
186            "
187            .unindent(),
188            "Foo.bar.baz.qux()",
189            cx,
190        );
191
192        assert_resolved(
193            &"
194            class Something {
195                one() { return 1; }
196            «    two() { return 2222; }
197                three() { return 333; }
198                four() { return 4444; }
199                five() { return 5555; }
200                six() { return 6666; }
201            »    seven() { return 7; }
202                eight() { return 8; }
203            }
204            "
205            .unindent(),
206            &"
207                two() { return 2222; }
208                four() { return 4444; }
209                five() { return 5555; }
210                six() { return 6666; }
211            "
212            .unindent(),
213            cx,
214        );
215    }
216
217    #[track_caller]
218    fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
219        let (text, _) = marked_text_ranges(text_with_expected_range, false);
220        let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
221        let snapshot = buffer.read(cx).snapshot();
222        let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
223        let text_with_actual_range = generate_marked_text(&text, &[range], false);
224        pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
225    }
226}