conflict_set.rs

  1use gpui::{App, Context, Entity, EventEmitter};
  2use std::{cmp::Ordering, ops::Range, sync::Arc};
  3use text::{Anchor, BufferId, OffsetRangeExt as _};
  4
  5pub struct ConflictSet {
  6    pub has_conflict: bool,
  7    pub snapshot: ConflictSetSnapshot,
  8}
  9
 10#[derive(Clone, Debug, PartialEq, Eq)]
 11pub struct ConflictSetUpdate {
 12    pub buffer_range: Option<Range<Anchor>>,
 13    pub old_range: Range<usize>,
 14    pub new_range: Range<usize>,
 15}
 16
 17#[derive(Debug, Clone)]
 18pub struct ConflictSetSnapshot {
 19    pub buffer_id: BufferId,
 20    pub conflicts: Arc<[ConflictRegion]>,
 21}
 22
 23impl ConflictSetSnapshot {
 24    pub fn conflicts_in_range(
 25        &self,
 26        range: Range<Anchor>,
 27        buffer: &text::BufferSnapshot,
 28    ) -> &[ConflictRegion] {
 29        let start_ix = self
 30            .conflicts
 31            .binary_search_by(|conflict| {
 32                conflict
 33                    .range
 34                    .end
 35                    .cmp(&range.start, buffer)
 36                    .then(Ordering::Greater)
 37            })
 38            .unwrap_err();
 39        let end_ix = start_ix
 40            + self.conflicts[start_ix..]
 41                .binary_search_by(|conflict| {
 42                    conflict
 43                        .range
 44                        .start
 45                        .cmp(&range.end, buffer)
 46                        .then(Ordering::Less)
 47                })
 48                .unwrap_err();
 49        &self.conflicts[start_ix..end_ix]
 50    }
 51
 52    pub fn compare(&self, other: &Self, buffer: &text::BufferSnapshot) -> ConflictSetUpdate {
 53        let common_prefix_len = self
 54            .conflicts
 55            .iter()
 56            .zip(other.conflicts.iter())
 57            .take_while(|(old, new)| old == new)
 58            .count();
 59        let common_suffix_len = self.conflicts[common_prefix_len..]
 60            .iter()
 61            .rev()
 62            .zip(other.conflicts[common_prefix_len..].iter().rev())
 63            .take_while(|(old, new)| old == new)
 64            .count();
 65        let old_conflicts =
 66            &self.conflicts[common_prefix_len..(self.conflicts.len() - common_suffix_len)];
 67        let new_conflicts =
 68            &other.conflicts[common_prefix_len..(other.conflicts.len() - common_suffix_len)];
 69        let old_range = common_prefix_len..(common_prefix_len + old_conflicts.len());
 70        let new_range = common_prefix_len..(common_prefix_len + new_conflicts.len());
 71        let start = match (old_conflicts.first(), new_conflicts.first()) {
 72            (None, None) => None,
 73            (None, Some(conflict)) => Some(conflict.range.start),
 74            (Some(conflict), None) => Some(conflict.range.start),
 75            (Some(first), Some(second)) => {
 76                Some(*first.range.start.min(&second.range.start, buffer))
 77            }
 78        };
 79        let end = match (old_conflicts.last(), new_conflicts.last()) {
 80            (None, None) => None,
 81            (None, Some(conflict)) => Some(conflict.range.end),
 82            (Some(first), None) => Some(first.range.end),
 83            (Some(first), Some(second)) => Some(*first.range.end.max(&second.range.end, buffer)),
 84        };
 85        ConflictSetUpdate {
 86            buffer_range: start.zip(end).map(|(start, end)| start..end),
 87            old_range,
 88            new_range,
 89        }
 90    }
 91}
 92
 93#[derive(Debug, Clone, PartialEq, Eq)]
 94pub struct ConflictRegion {
 95    pub range: Range<Anchor>,
 96    pub ours: Range<Anchor>,
 97    pub theirs: Range<Anchor>,
 98    pub base: Option<Range<Anchor>>,
 99}
100
101impl ConflictRegion {
102    pub fn resolve(
103        &self,
104        buffer: Entity<language::Buffer>,
105        ranges: &[Range<Anchor>],
106        cx: &mut App,
107    ) {
108        let buffer_snapshot = buffer.read(cx).snapshot();
109        let mut deletions = Vec::new();
110        let empty = "";
111        let outer_range = self.range.to_offset(&buffer_snapshot);
112        let mut offset = outer_range.start;
113        for kept_range in ranges {
114            let kept_range = kept_range.to_offset(&buffer_snapshot);
115            if kept_range.start > offset {
116                deletions.push((offset..kept_range.start, empty));
117            }
118            offset = kept_range.end;
119        }
120        if outer_range.end > offset {
121            deletions.push((offset..outer_range.end, empty));
122        }
123
124        buffer.update(cx, |buffer, cx| {
125            buffer.edit(deletions, None, cx);
126        });
127    }
128}
129
130impl ConflictSet {
131    pub fn new(buffer_id: BufferId, has_conflict: bool, _: &mut Context<Self>) -> Self {
132        Self {
133            has_conflict,
134            snapshot: ConflictSetSnapshot {
135                buffer_id,
136                conflicts: Default::default(),
137            },
138        }
139    }
140
141    pub fn set_has_conflict(&mut self, has_conflict: bool, cx: &mut Context<Self>) -> bool {
142        if has_conflict != self.has_conflict {
143            self.has_conflict = has_conflict;
144            if !self.has_conflict {
145                cx.emit(ConflictSetUpdate {
146                    buffer_range: None,
147                    old_range: 0..self.snapshot.conflicts.len(),
148                    new_range: 0..0,
149                });
150                self.snapshot.conflicts = Default::default();
151            }
152            true
153        } else {
154            false
155        }
156    }
157
158    pub fn snapshot(&self) -> ConflictSetSnapshot {
159        self.snapshot.clone()
160    }
161
162    pub fn set_snapshot(
163        &mut self,
164        snapshot: ConflictSetSnapshot,
165        update: ConflictSetUpdate,
166        cx: &mut Context<Self>,
167    ) {
168        self.snapshot = snapshot;
169        cx.emit(update);
170    }
171
172    // Vec<(Range<usize>)>
173    // Vec<(Range<usize>, &str)>
174    //
175    // [(1..2, ""), (6..7, "")]
176    // {"hello": "world"}
177    // {hello: "world"}
178    //
179    // foo(bar);
180    // }
181
182    pub fn parse(buffer: &text::BufferSnapshot) -> ConflictSetSnapshot {
183        let mut conflicts = Vec::new();
184
185        let mut line_pos = 0;
186        let buffer_len = buffer.len();
187        let mut lines = buffer.text_for_range(0..buffer_len).lines();
188
189        let mut conflict_start: Option<usize> = None;
190        let mut ours_start: Option<usize> = None;
191        let mut ours_end: Option<usize> = None;
192        let mut base_start: Option<usize> = None;
193        let mut base_end: Option<usize> = None;
194        let mut theirs_start: Option<usize> = None;
195
196        while let Some(line) = lines.next() {
197            let line_end = line_pos + line.len();
198
199            if line.starts_with("<<<<<<< ") {
200                // If we see a new conflict marker while already parsing one,
201                // abandon the previous one and start a new one
202                conflict_start = Some(line_pos);
203                ours_start = Some(line_end + 1);
204            } else if line.starts_with("||||||| ")
205                && conflict_start.is_some()
206                && ours_start.is_some()
207            {
208                ours_end = Some(line_pos);
209                base_start = Some(line_end + 1);
210            } else if line.starts_with("=======")
211                && conflict_start.is_some()
212                && ours_start.is_some()
213            {
214                // Set ours_end if not already set (would be set if we have base markers)
215                if ours_end.is_none() {
216                    ours_end = Some(line_pos);
217                } else if base_start.is_some() {
218                    base_end = Some(line_pos);
219                }
220                theirs_start = Some(line_end + 1);
221            } else if line.starts_with(">>>>>>> ")
222                && conflict_start.is_some()
223                && ours_start.is_some()
224                && ours_end.is_some()
225                && theirs_start.is_some()
226            {
227                let theirs_end = line_pos;
228                let conflict_end = (line_end + 1).min(buffer_len);
229
230                let range = buffer.anchor_after(conflict_start.unwrap())
231                    ..buffer.anchor_before(conflict_end);
232                let ours = buffer.anchor_after(ours_start.unwrap())
233                    ..buffer.anchor_before(ours_end.unwrap());
234                let theirs =
235                    buffer.anchor_after(theirs_start.unwrap())..buffer.anchor_before(theirs_end);
236
237                let base = base_start
238                    .zip(base_end)
239                    .map(|(start, end)| buffer.anchor_after(start)..buffer.anchor_before(end));
240
241                conflicts.push(ConflictRegion {
242                    range,
243                    ours,
244                    theirs,
245                    base,
246                });
247
248                conflict_start = None;
249                ours_start = None;
250                ours_end = None;
251                base_start = None;
252                base_end = None;
253                theirs_start = None;
254            }
255
256            line_pos = line_end + 1;
257        }
258
259        ConflictSetSnapshot {
260            conflicts: conflicts.into(),
261            buffer_id: buffer.remote_id(),
262        }
263    }
264}
265
266impl EventEmitter<ConflictSetUpdate> for ConflictSet {}
267
268#[cfg(test)]
269mod tests {
270    use std::sync::mpsc;
271
272    use crate::Project;
273
274    use super::*;
275    use fs::FakeFs;
276    use git::{
277        repository::repo_path,
278        status::{UnmergedStatus, UnmergedStatusCode},
279    };
280    use gpui::{BackgroundExecutor, TestAppContext};
281    use language::language_settings::AllLanguageSettings;
282    use serde_json::json;
283    use settings::Settings as _;
284    use text::{Buffer, BufferId, Point, ReplicaId, ToOffset as _};
285    use unindent::Unindent as _;
286    use util::{path, rel_path::rel_path};
287    use worktree::WorktreeSettings;
288
289    #[test]
290    fn test_parse_conflicts_in_buffer() {
291        // Create a buffer with conflict markers
292        let test_content = r#"
293            This is some text before the conflict.
294            <<<<<<< HEAD
295            This is our version
296            =======
297            This is their version
298            >>>>>>> branch-name
299
300            Another conflict:
301            <<<<<<< HEAD
302            Our second change
303            ||||||| merged common ancestors
304            Original content
305            =======
306            Their second change
307            >>>>>>> branch-name
308        "#
309        .unindent();
310
311        let buffer_id = BufferId::new(1).unwrap();
312        let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, test_content);
313        let snapshot = buffer.snapshot();
314
315        let conflict_snapshot = ConflictSet::parse(&snapshot);
316        assert_eq!(conflict_snapshot.conflicts.len(), 2);
317
318        let first = &conflict_snapshot.conflicts[0];
319        assert!(first.base.is_none());
320        let our_text = snapshot
321            .text_for_range(first.ours.clone())
322            .collect::<String>();
323        let their_text = snapshot
324            .text_for_range(first.theirs.clone())
325            .collect::<String>();
326        assert_eq!(our_text, "This is our version\n");
327        assert_eq!(their_text, "This is their version\n");
328
329        let second = &conflict_snapshot.conflicts[1];
330        assert!(second.base.is_some());
331        let our_text = snapshot
332            .text_for_range(second.ours.clone())
333            .collect::<String>();
334        let their_text = snapshot
335            .text_for_range(second.theirs.clone())
336            .collect::<String>();
337        let base_text = snapshot
338            .text_for_range(second.base.as_ref().unwrap().clone())
339            .collect::<String>();
340        assert_eq!(our_text, "Our second change\n");
341        assert_eq!(their_text, "Their second change\n");
342        assert_eq!(base_text, "Original content\n");
343
344        // Test conflicts_in_range
345        let range = snapshot.anchor_before(0)..snapshot.anchor_before(snapshot.len());
346        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
347        assert_eq!(conflicts_in_range.len(), 2);
348
349        // Test with a range that includes only the first conflict
350        let first_conflict_end = conflict_snapshot.conflicts[0].range.end;
351        let range = snapshot.anchor_before(0)..first_conflict_end;
352        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
353        assert_eq!(conflicts_in_range.len(), 1);
354
355        // Test with a range that includes only the second conflict
356        let second_conflict_start = conflict_snapshot.conflicts[1].range.start;
357        let range = second_conflict_start..snapshot.anchor_before(snapshot.len());
358        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
359        assert_eq!(conflicts_in_range.len(), 1);
360
361        // Test with a range that doesn't include any conflicts
362        let range = buffer.anchor_after(first_conflict_end.to_next_offset(&buffer))
363            ..buffer.anchor_before(second_conflict_start.to_previous_offset(&buffer));
364        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
365        assert_eq!(conflicts_in_range.len(), 0);
366    }
367
368    #[test]
369    fn test_nested_conflict_markers() {
370        // Create a buffer with nested conflict markers
371        let test_content = r#"
372            This is some text before the conflict.
373            <<<<<<< HEAD
374            This is our version
375            <<<<<<< HEAD
376            This is a nested conflict marker
377            =======
378            This is their version in a nested conflict
379            >>>>>>> branch-nested
380            =======
381            This is their version
382            >>>>>>> branch-name
383        "#
384        .unindent();
385
386        let buffer_id = BufferId::new(1).unwrap();
387        let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, test_content);
388        let snapshot = buffer.snapshot();
389
390        let conflict_snapshot = ConflictSet::parse(&snapshot);
391
392        assert_eq!(conflict_snapshot.conflicts.len(), 1);
393
394        // The conflict should have our version, their version, but no base
395        let conflict = &conflict_snapshot.conflicts[0];
396        assert!(conflict.base.is_none());
397
398        // Check that the nested conflict was detected correctly
399        let our_text = snapshot
400            .text_for_range(conflict.ours.clone())
401            .collect::<String>();
402        assert_eq!(our_text, "This is a nested conflict marker\n");
403        let their_text = snapshot
404            .text_for_range(conflict.theirs.clone())
405            .collect::<String>();
406        assert_eq!(their_text, "This is their version in a nested conflict\n");
407    }
408
409    #[test]
410    fn test_conflict_markers_at_eof() {
411        let test_content = r#"
412            <<<<<<< ours
413            =======
414            This is their version
415            >>>>>>> "#
416            .unindent();
417        let buffer_id = BufferId::new(1).unwrap();
418        let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, test_content);
419        let snapshot = buffer.snapshot();
420
421        let conflict_snapshot = ConflictSet::parse(&snapshot);
422        assert_eq!(conflict_snapshot.conflicts.len(), 1);
423    }
424
425    #[test]
426    fn test_conflicts_in_range() {
427        // Create a buffer with conflict markers
428        let test_content = r#"
429            one
430            <<<<<<< HEAD1
431            two
432            =======
433            three
434            >>>>>>> branch1
435            four
436            five
437            <<<<<<< HEAD2
438            six
439            =======
440            seven
441            >>>>>>> branch2
442            eight
443            nine
444            <<<<<<< HEAD3
445            ten
446            =======
447            eleven
448            >>>>>>> branch3
449            twelve
450            <<<<<<< HEAD4
451            thirteen
452            =======
453            fourteen
454            >>>>>>> branch4
455            fifteen
456        "#
457        .unindent();
458
459        let buffer_id = BufferId::new(1).unwrap();
460        let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, test_content.clone());
461        let snapshot = buffer.snapshot();
462
463        let conflict_snapshot = ConflictSet::parse(&snapshot);
464        assert_eq!(conflict_snapshot.conflicts.len(), 4);
465
466        let range = test_content.find("seven").unwrap()..test_content.find("eleven").unwrap();
467        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
468        assert_eq!(
469            conflict_snapshot.conflicts_in_range(range, &snapshot),
470            &conflict_snapshot.conflicts[1..=2]
471        );
472
473        let range = test_content.find("one").unwrap()..test_content.find("<<<<<<< HEAD2").unwrap();
474        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
475        assert_eq!(
476            conflict_snapshot.conflicts_in_range(range, &snapshot),
477            &conflict_snapshot.conflicts[0..=1]
478        );
479
480        let range =
481            test_content.find("eight").unwrap() - 1..test_content.find(">>>>>>> branch3").unwrap();
482        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
483        assert_eq!(
484            conflict_snapshot.conflicts_in_range(range, &snapshot),
485            &conflict_snapshot.conflicts[1..=2]
486        );
487
488        let range = test_content.find("thirteen").unwrap() - 1..test_content.len();
489        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
490        assert_eq!(
491            conflict_snapshot.conflicts_in_range(range, &snapshot),
492            &conflict_snapshot.conflicts[3..=3]
493        );
494    }
495
496    #[gpui::test]
497    async fn test_conflict_updates(executor: BackgroundExecutor, cx: &mut TestAppContext) {
498        zlog::init_test();
499        cx.update(|cx| {
500            settings::init(cx);
501            WorktreeSettings::register(cx);
502            Project::init_settings(cx);
503            AllLanguageSettings::register(cx);
504        });
505        let initial_text = "
506            one
507            two
508            three
509            four
510            five
511        "
512        .unindent();
513        let fs = FakeFs::new(executor);
514        fs.insert_tree(
515            path!("/project"),
516            json!({
517                ".git": {},
518                "a.txt": initial_text,
519            }),
520        )
521        .await;
522        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
523        let (git_store, buffer) = project.update(cx, |project, cx| {
524            (
525                project.git_store().clone(),
526                project.open_local_buffer(path!("/project/a.txt"), cx),
527            )
528        });
529        let buffer = buffer.await.unwrap();
530        let conflict_set = git_store.update(cx, |git_store, cx| {
531            git_store.open_conflict_set(buffer.clone(), cx)
532        });
533        let (events_tx, events_rx) = mpsc::channel::<ConflictSetUpdate>();
534        let _conflict_set_subscription = cx.update(|cx| {
535            cx.subscribe(&conflict_set, move |_, event, _| {
536                events_tx.send(event.clone()).ok();
537            })
538        });
539        let conflicts_snapshot =
540            conflict_set.read_with(cx, |conflict_set, _| conflict_set.snapshot());
541        assert!(conflicts_snapshot.conflicts.is_empty());
542
543        buffer.update(cx, |buffer, cx| {
544            buffer.edit(
545                [
546                    (4..4, "<<<<<<< HEAD\n"),
547                    (14..14, "=======\nTWO\n>>>>>>> branch\n"),
548                ],
549                None,
550                cx,
551            );
552        });
553
554        cx.run_until_parked();
555        events_rx.try_recv().expect_err(
556            "no conflicts should be registered as long as the file's status is unchanged",
557        );
558
559        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
560            state.unmerged_paths.insert(
561                repo_path("a.txt"),
562                UnmergedStatus {
563                    first_head: UnmergedStatusCode::Updated,
564                    second_head: UnmergedStatusCode::Updated,
565                },
566            );
567            // Cause the repository to emit MergeHeadsChanged.
568            state.refs.insert("MERGE_HEAD".into(), "123".into())
569        })
570        .unwrap();
571
572        cx.run_until_parked();
573        let update = events_rx
574            .try_recv()
575            .expect("status change should trigger conflict parsing");
576        assert_eq!(update.old_range, 0..0);
577        assert_eq!(update.new_range, 0..1);
578
579        let conflict = conflict_set.read_with(cx, |conflict_set, _| {
580            conflict_set.snapshot().conflicts[0].clone()
581        });
582        cx.update(|cx| {
583            conflict.resolve(buffer.clone(), std::slice::from_ref(&conflict.theirs), cx);
584        });
585
586        cx.run_until_parked();
587        let update = events_rx
588            .try_recv()
589            .expect("conflicts should be removed after resolution");
590        assert_eq!(update.old_range, 0..1);
591        assert_eq!(update.new_range, 0..0);
592    }
593
594    #[gpui::test]
595    async fn test_conflict_updates_without_merge_head(
596        executor: BackgroundExecutor,
597        cx: &mut TestAppContext,
598    ) {
599        zlog::init_test();
600        cx.update(|cx| {
601            settings::init(cx);
602            WorktreeSettings::register(cx);
603            Project::init_settings(cx);
604            AllLanguageSettings::register(cx);
605        });
606
607        let initial_text = "
608            zero
609            <<<<<<< HEAD
610            one
611            =======
612            two
613            >>>>>>> Stashed Changes
614            three
615        "
616        .unindent();
617
618        let fs = FakeFs::new(executor);
619        fs.insert_tree(
620            path!("/project"),
621            json!({
622                ".git": {},
623                "a.txt": initial_text,
624            }),
625        )
626        .await;
627
628        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
629        let (git_store, buffer) = project.update(cx, |project, cx| {
630            (
631                project.git_store().clone(),
632                project.open_local_buffer(path!("/project/a.txt"), cx),
633            )
634        });
635
636        cx.run_until_parked();
637        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
638            state.unmerged_paths.insert(
639                rel_path("a.txt").into(),
640                UnmergedStatus {
641                    first_head: UnmergedStatusCode::Updated,
642                    second_head: UnmergedStatusCode::Updated,
643                },
644            )
645        })
646        .unwrap();
647
648        let buffer = buffer.await.unwrap();
649
650        // Open the conflict set for a file that currently has conflicts.
651        let conflict_set = git_store.update(cx, |git_store, cx| {
652            git_store.open_conflict_set(buffer.clone(), cx)
653        });
654
655        cx.run_until_parked();
656        conflict_set.update(cx, |conflict_set, cx| {
657            let conflict_range = conflict_set.snapshot().conflicts[0]
658                .range
659                .to_point(buffer.read(cx));
660            assert_eq!(conflict_range, Point::new(1, 0)..Point::new(6, 0));
661        });
662
663        // Simulate the conflict being removed by e.g. staging the file.
664        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
665            state.unmerged_paths.remove(&repo_path("a.txt"))
666        })
667        .unwrap();
668
669        cx.run_until_parked();
670        conflict_set.update(cx, |conflict_set, _| {
671            assert!(!conflict_set.has_conflict);
672            assert_eq!(conflict_set.snapshot.conflicts.len(), 0);
673        });
674
675        // Simulate the conflict being re-added.
676        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
677            state.unmerged_paths.insert(
678                repo_path("a.txt"),
679                UnmergedStatus {
680                    first_head: UnmergedStatusCode::Updated,
681                    second_head: UnmergedStatusCode::Updated,
682                },
683            )
684        })
685        .unwrap();
686
687        cx.run_until_parked();
688        conflict_set.update(cx, |conflict_set, cx| {
689            let conflict_range = conflict_set.snapshot().conflicts[0]
690                .range
691                .to_point(buffer.read(cx));
692            assert_eq!(conflict_range, Point::new(1, 0)..Point::new(6, 0));
693        });
694    }
695}