patch.rs

  1use anyhow::{anyhow, Context as _, Result};
  2use collections::HashMap;
  3use editor::ProposedChangesEditor;
  4use futures::{future, TryFutureExt as _};
  5use gpui::{AppContext, AsyncAppContext, Model, SharedString};
  6use language::{AutoindentMode, Buffer, BufferSnapshot};
  7use project::{Project, ProjectPath};
  8use std::{cmp, ops::Range, path::Path, sync::Arc};
  9use text::{AnchorRangeExt as _, Bias, OffsetRangeExt as _, Point};
 10
 11#[derive(Clone, Debug)]
 12pub(crate) struct AssistantPatch {
 13    pub range: Range<language::Anchor>,
 14    pub title: SharedString,
 15    pub edits: Arc<[Result<AssistantEdit>]>,
 16    pub status: AssistantPatchStatus,
 17}
 18
 19#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 20pub(crate) enum AssistantPatchStatus {
 21    Pending,
 22    Ready,
 23}
 24
 25#[derive(Clone, Debug, PartialEq, Eq)]
 26pub(crate) struct AssistantEdit {
 27    pub path: String,
 28    pub kind: AssistantEditKind,
 29}
 30
 31#[derive(Clone, Debug, PartialEq, Eq)]
 32pub enum AssistantEditKind {
 33    Update {
 34        old_text: String,
 35        new_text: String,
 36        description: String,
 37    },
 38    Create {
 39        new_text: String,
 40        description: String,
 41    },
 42    InsertBefore {
 43        old_text: String,
 44        new_text: String,
 45        description: String,
 46    },
 47    InsertAfter {
 48        old_text: String,
 49        new_text: String,
 50        description: String,
 51    },
 52    Delete {
 53        old_text: String,
 54    },
 55}
 56
 57#[derive(Clone, Debug, Eq, PartialEq)]
 58pub(crate) struct ResolvedPatch {
 59    pub edit_groups: HashMap<Model<Buffer>, Vec<ResolvedEditGroup>>,
 60    pub errors: Vec<AssistantPatchResolutionError>,
 61}
 62
 63#[derive(Clone, Debug, Eq, PartialEq)]
 64pub struct ResolvedEditGroup {
 65    pub context_range: Range<language::Anchor>,
 66    pub edits: Vec<ResolvedEdit>,
 67}
 68
 69#[derive(Clone, Debug, Eq, PartialEq)]
 70pub struct ResolvedEdit {
 71    range: Range<language::Anchor>,
 72    new_text: String,
 73    description: Option<String>,
 74}
 75
 76#[derive(Clone, Debug, Eq, PartialEq)]
 77pub(crate) struct AssistantPatchResolutionError {
 78    pub edit_ix: usize,
 79    pub message: String,
 80}
 81
 82#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
 83enum SearchDirection {
 84    Up,
 85    Left,
 86    Diagonal,
 87}
 88
 89// A measure of the currently quality of an in-progress fuzzy search.
 90//
 91// Uses 60 bits to store a numeric cost, and 4 bits to store the preceding
 92// operation in the search.
 93#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
 94struct SearchState {
 95    score: u32,
 96    direction: SearchDirection,
 97}
 98
 99impl SearchState {
100    fn new(score: u32, direction: SearchDirection) -> Self {
101        Self { score, direction }
102    }
103}
104
105impl ResolvedPatch {
106    pub fn apply(&self, editor: &ProposedChangesEditor, cx: &mut AppContext) {
107        for (buffer, groups) in &self.edit_groups {
108            let branch = editor.branch_buffer_for_base(buffer).unwrap();
109            Self::apply_edit_groups(groups, &branch, cx);
110        }
111        editor.recalculate_all_buffer_diffs();
112    }
113
114    fn apply_edit_groups(
115        groups: &Vec<ResolvedEditGroup>,
116        buffer: &Model<Buffer>,
117        cx: &mut AppContext,
118    ) {
119        let mut edits = Vec::new();
120        for group in groups {
121            for suggestion in &group.edits {
122                edits.push((suggestion.range.clone(), suggestion.new_text.clone()));
123            }
124        }
125        buffer.update(cx, |buffer, cx| {
126            buffer.edit(
127                edits,
128                Some(AutoindentMode::Block {
129                    original_indent_columns: Vec::new(),
130                }),
131                cx,
132            );
133        });
134    }
135}
136
137impl ResolvedEdit {
138    pub fn try_merge(&mut self, other: &Self, buffer: &text::BufferSnapshot) -> bool {
139        let range = &self.range;
140        let other_range = &other.range;
141
142        // Don't merge if we don't contain the other suggestion.
143        if range.start.cmp(&other_range.start, buffer).is_gt()
144            || range.end.cmp(&other_range.end, buffer).is_lt()
145        {
146            return false;
147        }
148
149        if let Some(description) = &mut self.description {
150            if let Some(other_description) = &other.description {
151                description.push('\n');
152                description.push_str(other_description);
153            }
154        }
155        true
156    }
157}
158
159impl AssistantEdit {
160    pub fn new(
161        path: Option<String>,
162        operation: Option<String>,
163        old_text: Option<String>,
164        new_text: Option<String>,
165        description: Option<String>,
166    ) -> Result<Self> {
167        let path = path.ok_or_else(|| anyhow!("missing path"))?;
168        let operation = operation.ok_or_else(|| anyhow!("missing operation"))?;
169
170        let kind = match operation.as_str() {
171            "update" => AssistantEditKind::Update {
172                old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
173                new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
174                description: description.ok_or_else(|| anyhow!("missing description"))?,
175            },
176            "insert_before" => AssistantEditKind::InsertBefore {
177                old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
178                new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
179                description: description.ok_or_else(|| anyhow!("missing description"))?,
180            },
181            "insert_after" => AssistantEditKind::InsertAfter {
182                old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
183                new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
184                description: description.ok_or_else(|| anyhow!("missing description"))?,
185            },
186            "delete" => AssistantEditKind::Delete {
187                old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
188            },
189            "create" => AssistantEditKind::Create {
190                description: description.ok_or_else(|| anyhow!("missing description"))?,
191                new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
192            },
193            _ => Err(anyhow!("unknown operation {operation:?}"))?,
194        };
195
196        Ok(Self { path, kind })
197    }
198
199    pub async fn resolve(
200        &self,
201        project: Model<Project>,
202        mut cx: AsyncAppContext,
203    ) -> Result<(Model<Buffer>, ResolvedEdit)> {
204        let path = self.path.clone();
205        let kind = self.kind.clone();
206        let buffer = project
207            .update(&mut cx, |project, cx| {
208                let project_path = project
209                    .find_project_path(Path::new(&path), cx)
210                    .or_else(|| {
211                        // If we couldn't find a project path for it, put it in the active worktree
212                        // so that when we create the buffer, it can be saved.
213                        let worktree = project
214                            .active_entry()
215                            .and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
216                            .or_else(|| project.worktrees(cx).next())?;
217                        let worktree = worktree.read(cx);
218
219                        Some(ProjectPath {
220                            worktree_id: worktree.id(),
221                            path: Arc::from(Path::new(&path)),
222                        })
223                    })
224                    .with_context(|| format!("worktree not found for {:?}", path))?;
225                anyhow::Ok(project.open_buffer(project_path, cx))
226            })??
227            .await?;
228
229        let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
230        let suggestion = cx
231            .background_executor()
232            .spawn(async move { kind.resolve(&snapshot) })
233            .await;
234
235        Ok((buffer, suggestion))
236    }
237}
238
239impl AssistantEditKind {
240    fn resolve(self, snapshot: &BufferSnapshot) -> ResolvedEdit {
241        match self {
242            Self::Update {
243                old_text,
244                new_text,
245                description,
246            } => {
247                let range = Self::resolve_location(&snapshot, &old_text);
248                ResolvedEdit {
249                    range,
250                    new_text,
251                    description: Some(description),
252                }
253            }
254            Self::Create {
255                new_text,
256                description,
257            } => ResolvedEdit {
258                range: text::Anchor::MIN..text::Anchor::MAX,
259                description: Some(description),
260                new_text,
261            },
262            Self::InsertBefore {
263                old_text,
264                mut new_text,
265                description,
266            } => {
267                let range = Self::resolve_location(&snapshot, &old_text);
268                new_text.push('\n');
269                ResolvedEdit {
270                    range: range.start..range.start,
271                    new_text,
272                    description: Some(description),
273                }
274            }
275            Self::InsertAfter {
276                old_text,
277                mut new_text,
278                description,
279            } => {
280                let range = Self::resolve_location(&snapshot, &old_text);
281                new_text.insert(0, '\n');
282                ResolvedEdit {
283                    range: range.end..range.end,
284                    new_text,
285                    description: Some(description),
286                }
287            }
288            Self::Delete { old_text } => {
289                let range = Self::resolve_location(&snapshot, &old_text);
290                ResolvedEdit {
291                    range,
292                    new_text: String::new(),
293                    description: None,
294                }
295            }
296        }
297    }
298
299    fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range<text::Anchor> {
300        const INSERTION_COST: u32 = 3;
301        const WHITESPACE_INSERTION_COST: u32 = 1;
302        const DELETION_COST: u32 = 3;
303        const WHITESPACE_DELETION_COST: u32 = 1;
304        const EQUALITY_BONUS: u32 = 5;
305
306        struct Matrix {
307            cols: usize,
308            data: Vec<SearchState>,
309        }
310
311        impl Matrix {
312            fn new(rows: usize, cols: usize) -> Self {
313                Matrix {
314                    cols,
315                    data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
316                }
317            }
318
319            fn get(&self, row: usize, col: usize) -> SearchState {
320                self.data[row * self.cols + col]
321            }
322
323            fn set(&mut self, row: usize, col: usize, cost: SearchState) {
324                self.data[row * self.cols + col] = cost;
325            }
326        }
327
328        let buffer_len = buffer.len();
329        let query_len = search_query.len();
330        let mut matrix = Matrix::new(query_len + 1, buffer_len + 1);
331
332        for (row, query_byte) in search_query.bytes().enumerate() {
333            for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
334                let deletion_cost = if query_byte.is_ascii_whitespace() {
335                    WHITESPACE_DELETION_COST
336                } else {
337                    DELETION_COST
338                };
339                let insertion_cost = if buffer_byte.is_ascii_whitespace() {
340                    WHITESPACE_INSERTION_COST
341                } else {
342                    INSERTION_COST
343                };
344
345                let up = SearchState::new(
346                    matrix.get(row, col + 1).score.saturating_sub(deletion_cost),
347                    SearchDirection::Up,
348                );
349                let left = SearchState::new(
350                    matrix
351                        .get(row + 1, col)
352                        .score
353                        .saturating_sub(insertion_cost),
354                    SearchDirection::Left,
355                );
356                let diagonal = SearchState::new(
357                    if query_byte == *buffer_byte {
358                        matrix.get(row, col).score.saturating_add(EQUALITY_BONUS)
359                    } else {
360                        matrix
361                            .get(row, col)
362                            .score
363                            .saturating_sub(deletion_cost + insertion_cost)
364                    },
365                    SearchDirection::Diagonal,
366                );
367                matrix.set(row + 1, col + 1, up.max(left).max(diagonal));
368            }
369        }
370
371        // Traceback to find the best match
372        let mut best_buffer_end = buffer_len;
373        let mut best_score = 0;
374        for col in 1..=buffer_len {
375            let score = matrix.get(query_len, col).score;
376            if score > best_score {
377                best_score = score;
378                best_buffer_end = col;
379            }
380        }
381
382        let mut query_ix = query_len;
383        let mut buffer_ix = best_buffer_end;
384        while query_ix > 0 && buffer_ix > 0 {
385            let current = matrix.get(query_ix, buffer_ix);
386            match current.direction {
387                SearchDirection::Diagonal => {
388                    query_ix -= 1;
389                    buffer_ix -= 1;
390                }
391                SearchDirection::Up => {
392                    query_ix -= 1;
393                }
394                SearchDirection::Left => {
395                    buffer_ix -= 1;
396                }
397            }
398        }
399
400        let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
401        start.column = 0;
402        let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
403        if end.column > 0 {
404            end.column = buffer.line_len(end.row);
405        }
406
407        buffer.anchor_after(start)..buffer.anchor_before(end)
408    }
409}
410
411impl AssistantPatch {
412    pub(crate) async fn resolve(
413        &self,
414        project: Model<Project>,
415        cx: &mut AsyncAppContext,
416    ) -> ResolvedPatch {
417        let mut resolve_tasks = Vec::new();
418        for (ix, edit) in self.edits.iter().enumerate() {
419            if let Ok(edit) = edit.as_ref() {
420                resolve_tasks.push(
421                    edit.resolve(project.clone(), cx.clone())
422                        .map_err(move |error| (ix, error)),
423                );
424            }
425        }
426
427        let edits = future::join_all(resolve_tasks).await;
428        let mut errors = Vec::new();
429        let mut edits_by_buffer = HashMap::default();
430        for entry in edits {
431            match entry {
432                Ok((buffer, edit)) => {
433                    edits_by_buffer
434                        .entry(buffer)
435                        .or_insert_with(Vec::new)
436                        .push(edit);
437                }
438                Err((edit_ix, error)) => errors.push(AssistantPatchResolutionError {
439                    edit_ix,
440                    message: error.to_string(),
441                }),
442            }
443        }
444
445        // Expand the context ranges of each edit and group edits with overlapping context ranges.
446        let mut edit_groups_by_buffer = HashMap::default();
447        for (buffer, edits) in edits_by_buffer {
448            if let Ok(snapshot) = buffer.update(cx, |buffer, _| buffer.text_snapshot()) {
449                edit_groups_by_buffer.insert(buffer, Self::group_edits(edits, &snapshot));
450            }
451        }
452
453        ResolvedPatch {
454            edit_groups: edit_groups_by_buffer,
455            errors,
456        }
457    }
458
459    fn group_edits(
460        mut edits: Vec<ResolvedEdit>,
461        snapshot: &text::BufferSnapshot,
462    ) -> Vec<ResolvedEditGroup> {
463        let mut edit_groups = Vec::<ResolvedEditGroup>::new();
464        // Sort edits by their range so that earlier, larger ranges come first
465        edits.sort_by(|a, b| a.range.cmp(&b.range, &snapshot));
466
467        // Merge overlapping edits
468        edits.dedup_by(|a, b| b.try_merge(a, &snapshot));
469
470        // Create context ranges for each edit
471        for edit in edits {
472            let context_range = {
473                let edit_point_range = edit.range.to_point(&snapshot);
474                let start_row = edit_point_range.start.row.saturating_sub(5);
475                let end_row = cmp::min(edit_point_range.end.row + 5, snapshot.max_point().row);
476                let start = snapshot.anchor_before(Point::new(start_row, 0));
477                let end = snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
478                start..end
479            };
480
481            if let Some(last_group) = edit_groups.last_mut() {
482                if last_group
483                    .context_range
484                    .end
485                    .cmp(&context_range.start, &snapshot)
486                    .is_ge()
487                {
488                    // Merge with the previous group if context ranges overlap
489                    last_group.context_range.end = context_range.end;
490                    last_group.edits.push(edit);
491                } else {
492                    // Create a new group
493                    edit_groups.push(ResolvedEditGroup {
494                        context_range,
495                        edits: vec![edit],
496                    });
497                }
498            } else {
499                // Create the first group
500                edit_groups.push(ResolvedEditGroup {
501                    context_range,
502                    edits: vec![edit],
503                });
504            }
505        }
506
507        edit_groups
508    }
509
510    pub fn path_count(&self) -> usize {
511        self.paths().count()
512    }
513
514    pub fn paths(&self) -> impl '_ + Iterator<Item = &str> {
515        let mut prev_path = None;
516        self.edits.iter().filter_map(move |edit| {
517            if let Ok(edit) = edit {
518                let path = Some(edit.path.as_str());
519                if path != prev_path {
520                    prev_path = path;
521                    return path;
522                }
523            }
524            None
525        })
526    }
527}
528
529impl PartialEq for AssistantPatch {
530    fn eq(&self, other: &Self) -> bool {
531        self.range == other.range
532            && self.title == other.title
533            && Arc::ptr_eq(&self.edits, &other.edits)
534    }
535}
536
537impl Eq for AssistantPatch {}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use gpui::{AppContext, Context};
543    use language::{
544        language_settings::AllLanguageSettings, Language, LanguageConfig, LanguageMatcher,
545    };
546    use settings::SettingsStore;
547    use text::{OffsetRangeExt, Point};
548    use ui::BorrowAppContext;
549    use unindent::Unindent as _;
550
551    #[gpui::test]
552    fn test_resolve_location(cx: &mut AppContext) {
553        {
554            let buffer = cx.new_model(|cx| {
555                Buffer::local(
556                    concat!(
557                        "    Lorem\n",
558                        "    ipsum\n",
559                        "    dolor sit amet\n",
560                        "    consecteur",
561                    ),
562                    cx,
563                )
564            });
565            let snapshot = buffer.read(cx).snapshot();
566            assert_eq!(
567                AssistantEditKind::resolve_location(&snapshot, "ipsum\ndolor").to_point(&snapshot),
568                Point::new(1, 0)..Point::new(2, 18)
569            );
570        }
571
572        {
573            let buffer = cx.new_model(|cx| {
574                Buffer::local(
575                    concat!(
576                        "fn foo1(a: usize) -> usize {\n",
577                        "    40\n",
578                        "}\n",
579                        "\n",
580                        "fn foo2(b: usize) -> usize {\n",
581                        "    42\n",
582                        "}\n",
583                    ),
584                    cx,
585                )
586            });
587            let snapshot = buffer.read(cx).snapshot();
588            assert_eq!(
589                AssistantEditKind::resolve_location(&snapshot, "fn foo1(b: usize) {\n40\n}")
590                    .to_point(&snapshot),
591                Point::new(0, 0)..Point::new(2, 1)
592            );
593        }
594
595        {
596            let buffer = cx.new_model(|cx| {
597                Buffer::local(
598                    concat!(
599                        "fn main() {\n",
600                        "    Foo\n",
601                        "        .bar()\n",
602                        "        .baz()\n",
603                        "        .qux()\n",
604                        "}\n",
605                        "\n",
606                        "fn foo2(b: usize) -> usize {\n",
607                        "    42\n",
608                        "}\n",
609                    ),
610                    cx,
611                )
612            });
613            let snapshot = buffer.read(cx).snapshot();
614            assert_eq!(
615                AssistantEditKind::resolve_location(&snapshot, "Foo.bar.baz.qux()")
616                    .to_point(&snapshot),
617                Point::new(1, 0)..Point::new(4, 14)
618            );
619        }
620    }
621
622    #[gpui::test]
623    fn test_resolve_edits(cx: &mut AppContext) {
624        let settings_store = SettingsStore::test(cx);
625        cx.set_global(settings_store);
626        language::init(cx);
627        cx.update_global::<SettingsStore, _>(|settings, cx| {
628            settings.update_user_settings::<AllLanguageSettings>(cx, |_| {});
629        });
630
631        assert_edits(
632            "
633                /// A person
634                struct Person {
635                    name: String,
636                    age: usize,
637                }
638
639                /// A dog
640                struct Dog {
641                    weight: f32,
642                }
643
644                impl Person {
645                    fn name(&self) -> &str {
646                        &self.name
647                    }
648                }
649            "
650            .unindent(),
651            vec![
652                AssistantEditKind::Update {
653                    old_text: "
654                        name: String,
655                    "
656                    .unindent(),
657                    new_text: "
658                        first_name: String,
659                        last_name: String,
660                    "
661                    .unindent(),
662                    description: "".into(),
663                },
664                AssistantEditKind::Update {
665                    old_text: "
666                        fn name(&self) -> &str {
667                            &self.name
668                        }
669                    "
670                    .unindent(),
671                    new_text: "
672                        fn name(&self) -> String {
673                            format!(\"{} {}\", self.first_name, self.last_name)
674                        }
675                    "
676                    .unindent(),
677                    description: "".into(),
678                },
679            ],
680            "
681                /// A person
682                struct Person {
683                    first_name: String,
684                    last_name: String,
685                    age: usize,
686                }
687
688                /// A dog
689                struct Dog {
690                    weight: f32,
691                }
692
693                impl Person {
694                    fn name(&self) -> String {
695                        format!(\"{} {}\", self.first_name, self.last_name)
696                    }
697                }
698            "
699            .unindent(),
700            cx,
701        );
702    }
703
704    #[track_caller]
705    fn assert_edits(
706        old_text: String,
707        edits: Vec<AssistantEditKind>,
708        new_text: String,
709        cx: &mut AppContext,
710    ) {
711        let buffer =
712            cx.new_model(|cx| Buffer::local(old_text, cx).with_language(Arc::new(rust_lang()), cx));
713        let snapshot = buffer.read(cx).snapshot();
714        let resolved_edits = edits
715            .into_iter()
716            .map(|kind| kind.resolve(&snapshot))
717            .collect();
718        let edit_groups = AssistantPatch::group_edits(resolved_edits, &snapshot);
719        ResolvedPatch::apply_edit_groups(&edit_groups, &buffer, cx);
720        let actual_new_text = buffer.read(cx).text();
721        pretty_assertions::assert_eq!(actual_new_text, new_text);
722    }
723
724    fn rust_lang() -> Language {
725        Language::new(
726            LanguageConfig {
727                name: "Rust".into(),
728                matcher: LanguageMatcher {
729                    path_suffixes: vec!["rs".to_string()],
730                    ..Default::default()
731                },
732                ..Default::default()
733            },
734            Some(language::tree_sitter_rust::LANGUAGE.into()),
735        )
736        .with_indents_query(
737            r#"
738            (call_expression) @indent
739            (field_expression) @indent
740            (_ "(" ")" @end) @indent
741            (_ "{" "}" @end) @indent
742            "#,
743        )
744        .unwrap()
745    }
746}