1use anyhow::{anyhow, Context as _, Result};
2use collections::HashMap;
3use editor::ProposedChangesEditor;
4use futures::{future, TryFutureExt as _};
5use gpui::{AppContext, AsyncAppContext, Model, SharedString};
6use language::{AutoindentMode, Buffer, BufferSnapshot};
7use project::{Project, ProjectPath};
8use std::{cmp, ops::Range, path::Path, sync::Arc};
9use text::{AnchorRangeExt as _, Bias, OffsetRangeExt as _, Point};
10
11#[derive(Clone, Debug)]
12pub(crate) struct AssistantPatch {
13 pub range: Range<language::Anchor>,
14 pub title: SharedString,
15 pub edits: Arc<[Result<AssistantEdit>]>,
16 pub status: AssistantPatchStatus,
17}
18
19#[derive(Copy, Clone, Debug, PartialEq, Eq)]
20pub(crate) enum AssistantPatchStatus {
21 Pending,
22 Ready,
23}
24
25#[derive(Clone, Debug, PartialEq, Eq)]
26pub(crate) struct AssistantEdit {
27 pub path: String,
28 pub kind: AssistantEditKind,
29}
30
31#[derive(Clone, Debug, PartialEq, Eq)]
32pub enum AssistantEditKind {
33 Update {
34 old_text: String,
35 new_text: String,
36 description: String,
37 },
38 Create {
39 new_text: String,
40 description: String,
41 },
42 InsertBefore {
43 old_text: String,
44 new_text: String,
45 description: String,
46 },
47 InsertAfter {
48 old_text: String,
49 new_text: String,
50 description: String,
51 },
52 Delete {
53 old_text: String,
54 },
55}
56
57#[derive(Clone, Debug, Eq, PartialEq)]
58pub(crate) struct ResolvedPatch {
59 pub edit_groups: HashMap<Model<Buffer>, Vec<ResolvedEditGroup>>,
60 pub errors: Vec<AssistantPatchResolutionError>,
61}
62
63#[derive(Clone, Debug, Eq, PartialEq)]
64pub struct ResolvedEditGroup {
65 pub context_range: Range<language::Anchor>,
66 pub edits: Vec<ResolvedEdit>,
67}
68
69#[derive(Clone, Debug, Eq, PartialEq)]
70pub struct ResolvedEdit {
71 range: Range<language::Anchor>,
72 new_text: String,
73 description: Option<String>,
74}
75
76#[derive(Clone, Debug, Eq, PartialEq)]
77pub(crate) struct AssistantPatchResolutionError {
78 pub edit_ix: usize,
79 pub message: String,
80}
81
82#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
83enum SearchDirection {
84 Up,
85 Left,
86 Diagonal,
87}
88
89// A measure of the currently quality of an in-progress fuzzy search.
90//
91// Uses 60 bits to store a numeric cost, and 4 bits to store the preceding
92// operation in the search.
93#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
94struct SearchState {
95 score: u32,
96 direction: SearchDirection,
97}
98
99impl SearchState {
100 fn new(score: u32, direction: SearchDirection) -> Self {
101 Self { score, direction }
102 }
103}
104
105impl ResolvedPatch {
106 pub fn apply(&self, editor: &ProposedChangesEditor, cx: &mut AppContext) {
107 for (buffer, groups) in &self.edit_groups {
108 let branch = editor.branch_buffer_for_base(buffer).unwrap();
109 Self::apply_edit_groups(groups, &branch, cx);
110 }
111 editor.recalculate_all_buffer_diffs();
112 }
113
114 fn apply_edit_groups(
115 groups: &Vec<ResolvedEditGroup>,
116 buffer: &Model<Buffer>,
117 cx: &mut AppContext,
118 ) {
119 let mut edits = Vec::new();
120 for group in groups {
121 for suggestion in &group.edits {
122 edits.push((suggestion.range.clone(), suggestion.new_text.clone()));
123 }
124 }
125 buffer.update(cx, |buffer, cx| {
126 buffer.edit(
127 edits,
128 Some(AutoindentMode::Block {
129 original_indent_columns: Vec::new(),
130 }),
131 cx,
132 );
133 });
134 }
135}
136
137impl ResolvedEdit {
138 pub fn try_merge(&mut self, other: &Self, buffer: &text::BufferSnapshot) -> bool {
139 let range = &self.range;
140 let other_range = &other.range;
141
142 // Don't merge if we don't contain the other suggestion.
143 if range.start.cmp(&other_range.start, buffer).is_gt()
144 || range.end.cmp(&other_range.end, buffer).is_lt()
145 {
146 return false;
147 }
148
149 if let Some(description) = &mut self.description {
150 if let Some(other_description) = &other.description {
151 description.push('\n');
152 description.push_str(other_description);
153 }
154 }
155 true
156 }
157}
158
159impl AssistantEdit {
160 pub fn new(
161 path: Option<String>,
162 operation: Option<String>,
163 old_text: Option<String>,
164 new_text: Option<String>,
165 description: Option<String>,
166 ) -> Result<Self> {
167 let path = path.ok_or_else(|| anyhow!("missing path"))?;
168 let operation = operation.ok_or_else(|| anyhow!("missing operation"))?;
169
170 let kind = match operation.as_str() {
171 "update" => AssistantEditKind::Update {
172 old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
173 new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
174 description: description.ok_or_else(|| anyhow!("missing description"))?,
175 },
176 "insert_before" => AssistantEditKind::InsertBefore {
177 old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
178 new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
179 description: description.ok_or_else(|| anyhow!("missing description"))?,
180 },
181 "insert_after" => AssistantEditKind::InsertAfter {
182 old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
183 new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
184 description: description.ok_or_else(|| anyhow!("missing description"))?,
185 },
186 "delete" => AssistantEditKind::Delete {
187 old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
188 },
189 "create" => AssistantEditKind::Create {
190 description: description.ok_or_else(|| anyhow!("missing description"))?,
191 new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
192 },
193 _ => Err(anyhow!("unknown operation {operation:?}"))?,
194 };
195
196 Ok(Self { path, kind })
197 }
198
199 pub async fn resolve(
200 &self,
201 project: Model<Project>,
202 mut cx: AsyncAppContext,
203 ) -> Result<(Model<Buffer>, ResolvedEdit)> {
204 let path = self.path.clone();
205 let kind = self.kind.clone();
206 let buffer = project
207 .update(&mut cx, |project, cx| {
208 let project_path = project
209 .find_project_path(Path::new(&path), cx)
210 .or_else(|| {
211 // If we couldn't find a project path for it, put it in the active worktree
212 // so that when we create the buffer, it can be saved.
213 let worktree = project
214 .active_entry()
215 .and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
216 .or_else(|| project.worktrees(cx).next())?;
217 let worktree = worktree.read(cx);
218
219 Some(ProjectPath {
220 worktree_id: worktree.id(),
221 path: Arc::from(Path::new(&path)),
222 })
223 })
224 .with_context(|| format!("worktree not found for {:?}", path))?;
225 anyhow::Ok(project.open_buffer(project_path, cx))
226 })??
227 .await?;
228
229 let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
230 let suggestion = cx
231 .background_executor()
232 .spawn(async move { kind.resolve(&snapshot) })
233 .await;
234
235 Ok((buffer, suggestion))
236 }
237}
238
239impl AssistantEditKind {
240 fn resolve(self, snapshot: &BufferSnapshot) -> ResolvedEdit {
241 match self {
242 Self::Update {
243 old_text,
244 new_text,
245 description,
246 } => {
247 let range = Self::resolve_location(&snapshot, &old_text);
248 ResolvedEdit {
249 range,
250 new_text,
251 description: Some(description),
252 }
253 }
254 Self::Create {
255 new_text,
256 description,
257 } => ResolvedEdit {
258 range: text::Anchor::MIN..text::Anchor::MAX,
259 description: Some(description),
260 new_text,
261 },
262 Self::InsertBefore {
263 old_text,
264 mut new_text,
265 description,
266 } => {
267 let range = Self::resolve_location(&snapshot, &old_text);
268 new_text.push('\n');
269 ResolvedEdit {
270 range: range.start..range.start,
271 new_text,
272 description: Some(description),
273 }
274 }
275 Self::InsertAfter {
276 old_text,
277 mut new_text,
278 description,
279 } => {
280 let range = Self::resolve_location(&snapshot, &old_text);
281 new_text.insert(0, '\n');
282 ResolvedEdit {
283 range: range.end..range.end,
284 new_text,
285 description: Some(description),
286 }
287 }
288 Self::Delete { old_text } => {
289 let range = Self::resolve_location(&snapshot, &old_text);
290 ResolvedEdit {
291 range,
292 new_text: String::new(),
293 description: None,
294 }
295 }
296 }
297 }
298
299 fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range<text::Anchor> {
300 const INSERTION_COST: u32 = 3;
301 const WHITESPACE_INSERTION_COST: u32 = 1;
302 const DELETION_COST: u32 = 3;
303 const WHITESPACE_DELETION_COST: u32 = 1;
304 const EQUALITY_BONUS: u32 = 5;
305
306 struct Matrix {
307 cols: usize,
308 data: Vec<SearchState>,
309 }
310
311 impl Matrix {
312 fn new(rows: usize, cols: usize) -> Self {
313 Matrix {
314 cols,
315 data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
316 }
317 }
318
319 fn get(&self, row: usize, col: usize) -> SearchState {
320 self.data[row * self.cols + col]
321 }
322
323 fn set(&mut self, row: usize, col: usize, cost: SearchState) {
324 self.data[row * self.cols + col] = cost;
325 }
326 }
327
328 let buffer_len = buffer.len();
329 let query_len = search_query.len();
330 let mut matrix = Matrix::new(query_len + 1, buffer_len + 1);
331
332 for (row, query_byte) in search_query.bytes().enumerate() {
333 for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
334 let deletion_cost = if query_byte.is_ascii_whitespace() {
335 WHITESPACE_DELETION_COST
336 } else {
337 DELETION_COST
338 };
339 let insertion_cost = if buffer_byte.is_ascii_whitespace() {
340 WHITESPACE_INSERTION_COST
341 } else {
342 INSERTION_COST
343 };
344
345 let up = SearchState::new(
346 matrix.get(row, col + 1).score.saturating_sub(deletion_cost),
347 SearchDirection::Up,
348 );
349 let left = SearchState::new(
350 matrix
351 .get(row + 1, col)
352 .score
353 .saturating_sub(insertion_cost),
354 SearchDirection::Left,
355 );
356 let diagonal = SearchState::new(
357 if query_byte == *buffer_byte {
358 matrix.get(row, col).score.saturating_add(EQUALITY_BONUS)
359 } else {
360 matrix
361 .get(row, col)
362 .score
363 .saturating_sub(deletion_cost + insertion_cost)
364 },
365 SearchDirection::Diagonal,
366 );
367 matrix.set(row + 1, col + 1, up.max(left).max(diagonal));
368 }
369 }
370
371 // Traceback to find the best match
372 let mut best_buffer_end = buffer_len;
373 let mut best_score = 0;
374 for col in 1..=buffer_len {
375 let score = matrix.get(query_len, col).score;
376 if score > best_score {
377 best_score = score;
378 best_buffer_end = col;
379 }
380 }
381
382 let mut query_ix = query_len;
383 let mut buffer_ix = best_buffer_end;
384 while query_ix > 0 && buffer_ix > 0 {
385 let current = matrix.get(query_ix, buffer_ix);
386 match current.direction {
387 SearchDirection::Diagonal => {
388 query_ix -= 1;
389 buffer_ix -= 1;
390 }
391 SearchDirection::Up => {
392 query_ix -= 1;
393 }
394 SearchDirection::Left => {
395 buffer_ix -= 1;
396 }
397 }
398 }
399
400 let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
401 start.column = 0;
402 let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
403 if end.column > 0 {
404 end.column = buffer.line_len(end.row);
405 }
406
407 buffer.anchor_after(start)..buffer.anchor_before(end)
408 }
409}
410
411impl AssistantPatch {
412 pub(crate) async fn resolve(
413 &self,
414 project: Model<Project>,
415 cx: &mut AsyncAppContext,
416 ) -> ResolvedPatch {
417 let mut resolve_tasks = Vec::new();
418 for (ix, edit) in self.edits.iter().enumerate() {
419 if let Ok(edit) = edit.as_ref() {
420 resolve_tasks.push(
421 edit.resolve(project.clone(), cx.clone())
422 .map_err(move |error| (ix, error)),
423 );
424 }
425 }
426
427 let edits = future::join_all(resolve_tasks).await;
428 let mut errors = Vec::new();
429 let mut edits_by_buffer = HashMap::default();
430 for entry in edits {
431 match entry {
432 Ok((buffer, edit)) => {
433 edits_by_buffer
434 .entry(buffer)
435 .or_insert_with(Vec::new)
436 .push(edit);
437 }
438 Err((edit_ix, error)) => errors.push(AssistantPatchResolutionError {
439 edit_ix,
440 message: error.to_string(),
441 }),
442 }
443 }
444
445 // Expand the context ranges of each edit and group edits with overlapping context ranges.
446 let mut edit_groups_by_buffer = HashMap::default();
447 for (buffer, edits) in edits_by_buffer {
448 if let Ok(snapshot) = buffer.update(cx, |buffer, _| buffer.text_snapshot()) {
449 edit_groups_by_buffer.insert(buffer, Self::group_edits(edits, &snapshot));
450 }
451 }
452
453 ResolvedPatch {
454 edit_groups: edit_groups_by_buffer,
455 errors,
456 }
457 }
458
459 fn group_edits(
460 mut edits: Vec<ResolvedEdit>,
461 snapshot: &text::BufferSnapshot,
462 ) -> Vec<ResolvedEditGroup> {
463 let mut edit_groups = Vec::<ResolvedEditGroup>::new();
464 // Sort edits by their range so that earlier, larger ranges come first
465 edits.sort_by(|a, b| a.range.cmp(&b.range, &snapshot));
466
467 // Merge overlapping edits
468 edits.dedup_by(|a, b| b.try_merge(a, &snapshot));
469
470 // Create context ranges for each edit
471 for edit in edits {
472 let context_range = {
473 let edit_point_range = edit.range.to_point(&snapshot);
474 let start_row = edit_point_range.start.row.saturating_sub(5);
475 let end_row = cmp::min(edit_point_range.end.row + 5, snapshot.max_point().row);
476 let start = snapshot.anchor_before(Point::new(start_row, 0));
477 let end = snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
478 start..end
479 };
480
481 if let Some(last_group) = edit_groups.last_mut() {
482 if last_group
483 .context_range
484 .end
485 .cmp(&context_range.start, &snapshot)
486 .is_ge()
487 {
488 // Merge with the previous group if context ranges overlap
489 last_group.context_range.end = context_range.end;
490 last_group.edits.push(edit);
491 } else {
492 // Create a new group
493 edit_groups.push(ResolvedEditGroup {
494 context_range,
495 edits: vec![edit],
496 });
497 }
498 } else {
499 // Create the first group
500 edit_groups.push(ResolvedEditGroup {
501 context_range,
502 edits: vec![edit],
503 });
504 }
505 }
506
507 edit_groups
508 }
509
510 pub fn path_count(&self) -> usize {
511 self.paths().count()
512 }
513
514 pub fn paths(&self) -> impl '_ + Iterator<Item = &str> {
515 let mut prev_path = None;
516 self.edits.iter().filter_map(move |edit| {
517 if let Ok(edit) = edit {
518 let path = Some(edit.path.as_str());
519 if path != prev_path {
520 prev_path = path;
521 return path;
522 }
523 }
524 None
525 })
526 }
527}
528
529impl PartialEq for AssistantPatch {
530 fn eq(&self, other: &Self) -> bool {
531 self.range == other.range
532 && self.title == other.title
533 && Arc::ptr_eq(&self.edits, &other.edits)
534 }
535}
536
537impl Eq for AssistantPatch {}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use gpui::{AppContext, Context};
543 use language::{
544 language_settings::AllLanguageSettings, Language, LanguageConfig, LanguageMatcher,
545 };
546 use settings::SettingsStore;
547 use text::{OffsetRangeExt, Point};
548 use ui::BorrowAppContext;
549 use unindent::Unindent as _;
550
551 #[gpui::test]
552 fn test_resolve_location(cx: &mut AppContext) {
553 {
554 let buffer = cx.new_model(|cx| {
555 Buffer::local(
556 concat!(
557 " Lorem\n",
558 " ipsum\n",
559 " dolor sit amet\n",
560 " consecteur",
561 ),
562 cx,
563 )
564 });
565 let snapshot = buffer.read(cx).snapshot();
566 assert_eq!(
567 AssistantEditKind::resolve_location(&snapshot, "ipsum\ndolor").to_point(&snapshot),
568 Point::new(1, 0)..Point::new(2, 18)
569 );
570 }
571
572 {
573 let buffer = cx.new_model(|cx| {
574 Buffer::local(
575 concat!(
576 "fn foo1(a: usize) -> usize {\n",
577 " 40\n",
578 "}\n",
579 "\n",
580 "fn foo2(b: usize) -> usize {\n",
581 " 42\n",
582 "}\n",
583 ),
584 cx,
585 )
586 });
587 let snapshot = buffer.read(cx).snapshot();
588 assert_eq!(
589 AssistantEditKind::resolve_location(&snapshot, "fn foo1(b: usize) {\n40\n}")
590 .to_point(&snapshot),
591 Point::new(0, 0)..Point::new(2, 1)
592 );
593 }
594
595 {
596 let buffer = cx.new_model(|cx| {
597 Buffer::local(
598 concat!(
599 "fn main() {\n",
600 " Foo\n",
601 " .bar()\n",
602 " .baz()\n",
603 " .qux()\n",
604 "}\n",
605 "\n",
606 "fn foo2(b: usize) -> usize {\n",
607 " 42\n",
608 "}\n",
609 ),
610 cx,
611 )
612 });
613 let snapshot = buffer.read(cx).snapshot();
614 assert_eq!(
615 AssistantEditKind::resolve_location(&snapshot, "Foo.bar.baz.qux()")
616 .to_point(&snapshot),
617 Point::new(1, 0)..Point::new(4, 14)
618 );
619 }
620 }
621
622 #[gpui::test]
623 fn test_resolve_edits(cx: &mut AppContext) {
624 let settings_store = SettingsStore::test(cx);
625 cx.set_global(settings_store);
626 language::init(cx);
627 cx.update_global::<SettingsStore, _>(|settings, cx| {
628 settings.update_user_settings::<AllLanguageSettings>(cx, |_| {});
629 });
630
631 assert_edits(
632 "
633 /// A person
634 struct Person {
635 name: String,
636 age: usize,
637 }
638
639 /// A dog
640 struct Dog {
641 weight: f32,
642 }
643
644 impl Person {
645 fn name(&self) -> &str {
646 &self.name
647 }
648 }
649 "
650 .unindent(),
651 vec![
652 AssistantEditKind::Update {
653 old_text: "
654 name: String,
655 "
656 .unindent(),
657 new_text: "
658 first_name: String,
659 last_name: String,
660 "
661 .unindent(),
662 description: "".into(),
663 },
664 AssistantEditKind::Update {
665 old_text: "
666 fn name(&self) -> &str {
667 &self.name
668 }
669 "
670 .unindent(),
671 new_text: "
672 fn name(&self) -> String {
673 format!(\"{} {}\", self.first_name, self.last_name)
674 }
675 "
676 .unindent(),
677 description: "".into(),
678 },
679 ],
680 "
681 /// A person
682 struct Person {
683 first_name: String,
684 last_name: String,
685 age: usize,
686 }
687
688 /// A dog
689 struct Dog {
690 weight: f32,
691 }
692
693 impl Person {
694 fn name(&self) -> String {
695 format!(\"{} {}\", self.first_name, self.last_name)
696 }
697 }
698 "
699 .unindent(),
700 cx,
701 );
702 }
703
704 #[track_caller]
705 fn assert_edits(
706 old_text: String,
707 edits: Vec<AssistantEditKind>,
708 new_text: String,
709 cx: &mut AppContext,
710 ) {
711 let buffer =
712 cx.new_model(|cx| Buffer::local(old_text, cx).with_language(Arc::new(rust_lang()), cx));
713 let snapshot = buffer.read(cx).snapshot();
714 let resolved_edits = edits
715 .into_iter()
716 .map(|kind| kind.resolve(&snapshot))
717 .collect();
718 let edit_groups = AssistantPatch::group_edits(resolved_edits, &snapshot);
719 ResolvedPatch::apply_edit_groups(&edit_groups, &buffer, cx);
720 let actual_new_text = buffer.read(cx).text();
721 pretty_assertions::assert_eq!(actual_new_text, new_text);
722 }
723
724 fn rust_lang() -> Language {
725 Language::new(
726 LanguageConfig {
727 name: "Rust".into(),
728 matcher: LanguageMatcher {
729 path_suffixes: vec!["rs".to_string()],
730 ..Default::default()
731 },
732 ..Default::default()
733 },
734 Some(language::tree_sitter_rust::LANGUAGE.into()),
735 )
736 .with_indents_query(
737 r#"
738 (call_expression) @indent
739 (field_expression) @indent
740 (_ "(" ")" @end) @indent
741 (_ "{" "}" @end) @indent
742 "#,
743 )
744 .unwrap()
745 }
746}