1use language::{Anchor, Bias, BufferSnapshot};
2use std::ops::Range;
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
5enum SearchDirection {
6 Up,
7 Left,
8 Diagonal,
9}
10
11#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
12struct SearchState {
13 cost: u32,
14 direction: SearchDirection,
15}
16
17impl SearchState {
18 fn new(cost: u32, direction: SearchDirection) -> Self {
19 Self { cost, direction }
20 }
21}
22
23struct SearchMatrix {
24 cols: usize,
25 data: Vec<SearchState>,
26}
27
28impl SearchMatrix {
29 fn new(rows: usize, cols: usize) -> Self {
30 SearchMatrix {
31 cols,
32 data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
33 }
34 }
35
36 fn get(&self, row: usize, col: usize) -> SearchState {
37 self.data[row * self.cols + col]
38 }
39
40 fn set(&mut self, row: usize, col: usize, cost: SearchState) {
41 self.data[row * self.cols + col] = cost;
42 }
43}
44
45pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
46 const INSERTION_COST: u32 = 3;
47 const DELETION_COST: u32 = 10;
48 const WHITESPACE_INSERTION_COST: u32 = 1;
49 const WHITESPACE_DELETION_COST: u32 = 1;
50
51 let buffer_len = buffer.len();
52 let query_len = search_query.len();
53 let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
54 let mut leading_deletion_cost = 0_u32;
55 for (row, query_byte) in search_query.bytes().enumerate() {
56 let deletion_cost = if query_byte.is_ascii_whitespace() {
57 WHITESPACE_DELETION_COST
58 } else {
59 DELETION_COST
60 };
61
62 leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
63 matrix.set(
64 row + 1,
65 0,
66 SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
67 );
68
69 for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
70 let insertion_cost = if buffer_byte.is_ascii_whitespace() {
71 WHITESPACE_INSERTION_COST
72 } else {
73 INSERTION_COST
74 };
75
76 let up = SearchState::new(
77 matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
78 SearchDirection::Up,
79 );
80 let left = SearchState::new(
81 matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
82 SearchDirection::Left,
83 );
84 let diagonal = SearchState::new(
85 if query_byte == *buffer_byte {
86 matrix.get(row, col).cost
87 } else {
88 matrix
89 .get(row, col)
90 .cost
91 .saturating_add(deletion_cost + insertion_cost)
92 },
93 SearchDirection::Diagonal,
94 );
95 matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
96 }
97 }
98
99 // Traceback to find the best match
100 let mut best_buffer_end = buffer_len;
101 let mut best_cost = u32::MAX;
102 for col in 1..=buffer_len {
103 let cost = matrix.get(query_len, col).cost;
104 if cost < best_cost {
105 best_cost = cost;
106 best_buffer_end = col;
107 }
108 }
109
110 let mut query_ix = query_len;
111 let mut buffer_ix = best_buffer_end;
112 while query_ix > 0 && buffer_ix > 0 {
113 let current = matrix.get(query_ix, buffer_ix);
114 match current.direction {
115 SearchDirection::Diagonal => {
116 query_ix -= 1;
117 buffer_ix -= 1;
118 }
119 SearchDirection::Up => {
120 query_ix -= 1;
121 }
122 SearchDirection::Left => {
123 buffer_ix -= 1;
124 }
125 }
126 }
127
128 let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
129 start.column = 0;
130 let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
131 if end.column > 0 {
132 end.column = buffer.line_len(end.row);
133 }
134
135 buffer.anchor_after(start)..buffer.anchor_before(end)
136}
137
138#[cfg(test)]
139mod tests {
140 use crate::edit_files_tool::resolve_search_block::resolve_search_block;
141 use gpui::{prelude::*, App};
142 use language::{Buffer, OffsetRangeExt as _};
143 use unindent::Unindent as _;
144 use util::test::{generate_marked_text, marked_text_ranges};
145
146 #[gpui::test]
147 fn test_resolve_search_block(cx: &mut App) {
148 assert_resolved(
149 concat!(
150 " Lorem\n",
151 "« ipsum\n",
152 " dolor sit amet»\n",
153 " consecteur",
154 ),
155 "ipsum\ndolor",
156 cx,
157 );
158
159 assert_resolved(
160 &"
161 «fn foo1(a: usize) -> usize {
162 40
163 }»
164
165 fn foo2(b: usize) -> usize {
166 42
167 }
168 "
169 .unindent(),
170 "fn foo1(b: usize) {\n40\n}",
171 cx,
172 );
173
174 assert_resolved(
175 &"
176 fn main() {
177 « Foo
178 .bar()
179 .baz()
180 .qux()»
181 }
182
183 fn foo2(b: usize) -> usize {
184 42
185 }
186 "
187 .unindent(),
188 "Foo.bar.baz.qux()",
189 cx,
190 );
191
192 assert_resolved(
193 &"
194 class Something {
195 one() { return 1; }
196 « two() { return 2222; }
197 three() { return 333; }
198 four() { return 4444; }
199 five() { return 5555; }
200 six() { return 6666; }
201 » seven() { return 7; }
202 eight() { return 8; }
203 }
204 "
205 .unindent(),
206 &"
207 two() { return 2222; }
208 four() { return 4444; }
209 five() { return 5555; }
210 six() { return 6666; }
211 "
212 .unindent(),
213 cx,
214 );
215 }
216
217 #[track_caller]
218 fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
219 let (text, _) = marked_text_ranges(text_with_expected_range, false);
220 let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
221 let snapshot = buffer.read(cx).snapshot();
222 let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
223 let text_with_actual_range = generate_marked_text(&text, &[range], false);
224 pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
225 }
226}