conflict_set.rs

  1use gpui::{App, Context, Entity, EventEmitter};
  2use std::{cmp::Ordering, ops::Range, sync::Arc};
  3use text::{Anchor, BufferId, OffsetRangeExt as _};
  4
  5pub struct ConflictSet {
  6    pub has_conflict: bool,
  7    pub snapshot: ConflictSetSnapshot,
  8}
  9
 10#[derive(Clone, Debug, PartialEq, Eq)]
 11pub struct ConflictSetUpdate {
 12    pub buffer_range: Option<Range<Anchor>>,
 13    pub old_range: Range<usize>,
 14    pub new_range: Range<usize>,
 15}
 16
 17#[derive(Debug, Clone)]
 18pub struct ConflictSetSnapshot {
 19    pub buffer_id: BufferId,
 20    pub conflicts: Arc<[ConflictRegion]>,
 21}
 22
 23impl ConflictSetSnapshot {
 24    pub fn conflicts_in_range(
 25        &self,
 26        range: Range<Anchor>,
 27        buffer: &text::BufferSnapshot,
 28    ) -> &[ConflictRegion] {
 29        let start_ix = self
 30            .conflicts
 31            .binary_search_by(|conflict| {
 32                conflict
 33                    .range
 34                    .end
 35                    .cmp(&range.start, buffer)
 36                    .then(Ordering::Greater)
 37            })
 38            .unwrap_err();
 39        let end_ix = start_ix
 40            + self.conflicts[start_ix..]
 41                .binary_search_by(|conflict| {
 42                    conflict
 43                        .range
 44                        .start
 45                        .cmp(&range.end, buffer)
 46                        .then(Ordering::Less)
 47                })
 48                .unwrap_err();
 49        &self.conflicts[start_ix..end_ix]
 50    }
 51
 52    pub fn compare(&self, other: &Self, buffer: &text::BufferSnapshot) -> ConflictSetUpdate {
 53        let common_prefix_len = self
 54            .conflicts
 55            .iter()
 56            .zip(other.conflicts.iter())
 57            .take_while(|(old, new)| old == new)
 58            .count();
 59        let common_suffix_len = self.conflicts[common_prefix_len..]
 60            .iter()
 61            .rev()
 62            .zip(other.conflicts[common_prefix_len..].iter().rev())
 63            .take_while(|(old, new)| old == new)
 64            .count();
 65        let old_conflicts =
 66            &self.conflicts[common_prefix_len..(self.conflicts.len() - common_suffix_len)];
 67        let new_conflicts =
 68            &other.conflicts[common_prefix_len..(other.conflicts.len() - common_suffix_len)];
 69        let old_range = common_prefix_len..(common_prefix_len + old_conflicts.len());
 70        let new_range = common_prefix_len..(common_prefix_len + new_conflicts.len());
 71        let start = match (old_conflicts.first(), new_conflicts.first()) {
 72            (None, None) => None,
 73            (None, Some(conflict)) => Some(conflict.range.start),
 74            (Some(conflict), None) => Some(conflict.range.start),
 75            (Some(first), Some(second)) => Some(first.range.start.min(&second.range.start, buffer)),
 76        };
 77        let end = match (old_conflicts.last(), new_conflicts.last()) {
 78            (None, None) => None,
 79            (None, Some(conflict)) => Some(conflict.range.end),
 80            (Some(first), None) => Some(first.range.end),
 81            (Some(first), Some(second)) => Some(first.range.end.max(&second.range.end, buffer)),
 82        };
 83        ConflictSetUpdate {
 84            buffer_range: start.zip(end).map(|(start, end)| start..end),
 85            old_range,
 86            new_range,
 87        }
 88    }
 89}
 90
 91#[derive(Debug, Clone, PartialEq, Eq)]
 92pub struct ConflictRegion {
 93    pub range: Range<Anchor>,
 94    pub ours: Range<Anchor>,
 95    pub theirs: Range<Anchor>,
 96    pub base: Option<Range<Anchor>>,
 97}
 98
 99impl ConflictRegion {
100    pub fn resolve(
101        &self,
102        buffer: Entity<language::Buffer>,
103        ranges: &[Range<Anchor>],
104        cx: &mut App,
105    ) {
106        let buffer_snapshot = buffer.read(cx).snapshot();
107        let mut deletions = Vec::new();
108        let empty = "";
109        let outer_range = self.range.to_offset(&buffer_snapshot);
110        let mut offset = outer_range.start;
111        for kept_range in ranges {
112            let kept_range = kept_range.to_offset(&buffer_snapshot);
113            if kept_range.start > offset {
114                deletions.push((offset..kept_range.start, empty));
115            }
116            offset = kept_range.end;
117        }
118        if outer_range.end > offset {
119            deletions.push((offset..outer_range.end, empty));
120        }
121
122        buffer.update(cx, |buffer, cx| {
123            buffer.edit(deletions, None, cx);
124        });
125    }
126}
127
128impl ConflictSet {
129    pub fn new(buffer_id: BufferId, has_conflict: bool, _: &mut Context<Self>) -> Self {
130        Self {
131            has_conflict,
132            snapshot: ConflictSetSnapshot {
133                buffer_id,
134                conflicts: Default::default(),
135            },
136        }
137    }
138
139    pub fn set_has_conflict(&mut self, has_conflict: bool, cx: &mut Context<Self>) -> bool {
140        if has_conflict != self.has_conflict {
141            self.has_conflict = has_conflict;
142            if !self.has_conflict {
143                cx.emit(ConflictSetUpdate {
144                    buffer_range: None,
145                    old_range: 0..self.snapshot.conflicts.len(),
146                    new_range: 0..0,
147                });
148                self.snapshot.conflicts = Default::default();
149            }
150            true
151        } else {
152            false
153        }
154    }
155
156    pub fn snapshot(&self) -> ConflictSetSnapshot {
157        self.snapshot.clone()
158    }
159
160    pub fn set_snapshot(
161        &mut self,
162        snapshot: ConflictSetSnapshot,
163        update: ConflictSetUpdate,
164        cx: &mut Context<Self>,
165    ) {
166        self.snapshot = snapshot;
167        cx.emit(update);
168    }
169
170    pub fn parse(buffer: &text::BufferSnapshot) -> ConflictSetSnapshot {
171        let mut conflicts = Vec::new();
172
173        let mut line_pos = 0;
174        let buffer_len = buffer.len();
175        let mut lines = buffer.text_for_range(0..buffer_len).lines();
176
177        let mut conflict_start: Option<usize> = None;
178        let mut ours_start: Option<usize> = None;
179        let mut ours_end: Option<usize> = None;
180        let mut base_start: Option<usize> = None;
181        let mut base_end: Option<usize> = None;
182        let mut theirs_start: Option<usize> = None;
183
184        while let Some(line) = lines.next() {
185            let line_end = line_pos + line.len();
186
187            if line.starts_with("<<<<<<< ") {
188                // If we see a new conflict marker while already parsing one,
189                // abandon the previous one and start a new one
190                conflict_start = Some(line_pos);
191                ours_start = Some(line_end + 1);
192            } else if line.starts_with("||||||| ")
193                && conflict_start.is_some()
194                && ours_start.is_some()
195            {
196                ours_end = Some(line_pos);
197                base_start = Some(line_end + 1);
198            } else if line.starts_with("=======")
199                && conflict_start.is_some()
200                && ours_start.is_some()
201            {
202                // Set ours_end if not already set (would be set if we have base markers)
203                if ours_end.is_none() {
204                    ours_end = Some(line_pos);
205                } else if base_start.is_some() {
206                    base_end = Some(line_pos);
207                }
208                theirs_start = Some(line_end + 1);
209            } else if line.starts_with(">>>>>>> ")
210                && conflict_start.is_some()
211                && ours_start.is_some()
212                && ours_end.is_some()
213                && theirs_start.is_some()
214            {
215                let theirs_end = line_pos;
216                let conflict_end = (line_end + 1).min(buffer_len);
217
218                let range = buffer.anchor_after(conflict_start.unwrap())
219                    ..buffer.anchor_before(conflict_end);
220                let ours = buffer.anchor_after(ours_start.unwrap())
221                    ..buffer.anchor_before(ours_end.unwrap());
222                let theirs =
223                    buffer.anchor_after(theirs_start.unwrap())..buffer.anchor_before(theirs_end);
224
225                let base = base_start
226                    .zip(base_end)
227                    .map(|(start, end)| buffer.anchor_after(start)..buffer.anchor_before(end));
228
229                conflicts.push(ConflictRegion {
230                    range,
231                    ours,
232                    theirs,
233                    base,
234                });
235
236                conflict_start = None;
237                ours_start = None;
238                ours_end = None;
239                base_start = None;
240                base_end = None;
241                theirs_start = None;
242            }
243
244            line_pos = line_end + 1;
245        }
246
247        ConflictSetSnapshot {
248            conflicts: conflicts.into(),
249            buffer_id: buffer.remote_id(),
250        }
251    }
252}
253
254impl EventEmitter<ConflictSetUpdate> for ConflictSet {}
255
256#[cfg(test)]
257mod tests {
258    use std::sync::mpsc;
259
260    use crate::Project;
261
262    use super::*;
263    use fs::FakeFs;
264    use git::{
265        repository::repo_path,
266        status::{UnmergedStatus, UnmergedStatusCode},
267    };
268    use gpui::{BackgroundExecutor, TestAppContext};
269    use language::language_settings::AllLanguageSettings;
270    use serde_json::json;
271    use settings::Settings as _;
272    use text::{Buffer, BufferId, Point, ToOffset as _};
273    use unindent::Unindent as _;
274    use util::{path, rel_path::rel_path};
275    use worktree::WorktreeSettings;
276
277    #[test]
278    fn test_parse_conflicts_in_buffer() {
279        // Create a buffer with conflict markers
280        let test_content = r#"
281            This is some text before the conflict.
282            <<<<<<< HEAD
283            This is our version
284            =======
285            This is their version
286            >>>>>>> branch-name
287
288            Another conflict:
289            <<<<<<< HEAD
290            Our second change
291            ||||||| merged common ancestors
292            Original content
293            =======
294            Their second change
295            >>>>>>> branch-name
296        "#
297        .unindent();
298
299        let buffer_id = BufferId::new(1).unwrap();
300        let buffer = Buffer::new(0, buffer_id, test_content);
301        let snapshot = buffer.snapshot();
302
303        let conflict_snapshot = ConflictSet::parse(&snapshot);
304        assert_eq!(conflict_snapshot.conflicts.len(), 2);
305
306        let first = &conflict_snapshot.conflicts[0];
307        assert!(first.base.is_none());
308        let our_text = snapshot
309            .text_for_range(first.ours.clone())
310            .collect::<String>();
311        let their_text = snapshot
312            .text_for_range(first.theirs.clone())
313            .collect::<String>();
314        assert_eq!(our_text, "This is our version\n");
315        assert_eq!(their_text, "This is their version\n");
316
317        let second = &conflict_snapshot.conflicts[1];
318        assert!(second.base.is_some());
319        let our_text = snapshot
320            .text_for_range(second.ours.clone())
321            .collect::<String>();
322        let their_text = snapshot
323            .text_for_range(second.theirs.clone())
324            .collect::<String>();
325        let base_text = snapshot
326            .text_for_range(second.base.as_ref().unwrap().clone())
327            .collect::<String>();
328        assert_eq!(our_text, "Our second change\n");
329        assert_eq!(their_text, "Their second change\n");
330        assert_eq!(base_text, "Original content\n");
331
332        // Test conflicts_in_range
333        let range = snapshot.anchor_before(0)..snapshot.anchor_before(snapshot.len());
334        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
335        assert_eq!(conflicts_in_range.len(), 2);
336
337        // Test with a range that includes only the first conflict
338        let first_conflict_end = conflict_snapshot.conflicts[0].range.end;
339        let range = snapshot.anchor_before(0)..first_conflict_end;
340        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
341        assert_eq!(conflicts_in_range.len(), 1);
342
343        // Test with a range that includes only the second conflict
344        let second_conflict_start = conflict_snapshot.conflicts[1].range.start;
345        let range = second_conflict_start..snapshot.anchor_before(snapshot.len());
346        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
347        assert_eq!(conflicts_in_range.len(), 1);
348
349        // Test with a range that doesn't include any conflicts
350        let range = buffer.anchor_after(first_conflict_end.to_next_offset(&buffer))
351            ..buffer.anchor_before(second_conflict_start.to_previous_offset(&buffer));
352        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
353        assert_eq!(conflicts_in_range.len(), 0);
354    }
355
356    #[test]
357    fn test_nested_conflict_markers() {
358        // Create a buffer with nested conflict markers
359        let test_content = r#"
360            This is some text before the conflict.
361            <<<<<<< HEAD
362            This is our version
363            <<<<<<< HEAD
364            This is a nested conflict marker
365            =======
366            This is their version in a nested conflict
367            >>>>>>> branch-nested
368            =======
369            This is their version
370            >>>>>>> branch-name
371        "#
372        .unindent();
373
374        let buffer_id = BufferId::new(1).unwrap();
375        let buffer = Buffer::new(0, buffer_id, test_content);
376        let snapshot = buffer.snapshot();
377
378        let conflict_snapshot = ConflictSet::parse(&snapshot);
379
380        assert_eq!(conflict_snapshot.conflicts.len(), 1);
381
382        // The conflict should have our version, their version, but no base
383        let conflict = &conflict_snapshot.conflicts[0];
384        assert!(conflict.base.is_none());
385
386        // Check that the nested conflict was detected correctly
387        let our_text = snapshot
388            .text_for_range(conflict.ours.clone())
389            .collect::<String>();
390        assert_eq!(our_text, "This is a nested conflict marker\n");
391        let their_text = snapshot
392            .text_for_range(conflict.theirs.clone())
393            .collect::<String>();
394        assert_eq!(their_text, "This is their version in a nested conflict\n");
395    }
396
397    #[test]
398    fn test_conflict_markers_at_eof() {
399        let test_content = r#"
400            <<<<<<< ours
401            =======
402            This is their version
403            >>>>>>> "#
404            .unindent();
405        let buffer_id = BufferId::new(1).unwrap();
406        let buffer = Buffer::new(0, buffer_id, test_content);
407        let snapshot = buffer.snapshot();
408
409        let conflict_snapshot = ConflictSet::parse(&snapshot);
410        assert_eq!(conflict_snapshot.conflicts.len(), 1);
411    }
412
413    #[test]
414    fn test_conflicts_in_range() {
415        // Create a buffer with conflict markers
416        let test_content = r#"
417            one
418            <<<<<<< HEAD1
419            two
420            =======
421            three
422            >>>>>>> branch1
423            four
424            five
425            <<<<<<< HEAD2
426            six
427            =======
428            seven
429            >>>>>>> branch2
430            eight
431            nine
432            <<<<<<< HEAD3
433            ten
434            =======
435            eleven
436            >>>>>>> branch3
437            twelve
438            <<<<<<< HEAD4
439            thirteen
440            =======
441            fourteen
442            >>>>>>> branch4
443            fifteen
444        "#
445        .unindent();
446
447        let buffer_id = BufferId::new(1).unwrap();
448        let buffer = Buffer::new(0, buffer_id, test_content.clone());
449        let snapshot = buffer.snapshot();
450
451        let conflict_snapshot = ConflictSet::parse(&snapshot);
452        assert_eq!(conflict_snapshot.conflicts.len(), 4);
453
454        let range = test_content.find("seven").unwrap()..test_content.find("eleven").unwrap();
455        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
456        assert_eq!(
457            conflict_snapshot.conflicts_in_range(range, &snapshot),
458            &conflict_snapshot.conflicts[1..=2]
459        );
460
461        let range = test_content.find("one").unwrap()..test_content.find("<<<<<<< HEAD2").unwrap();
462        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
463        assert_eq!(
464            conflict_snapshot.conflicts_in_range(range, &snapshot),
465            &conflict_snapshot.conflicts[0..=1]
466        );
467
468        let range =
469            test_content.find("eight").unwrap() - 1..test_content.find(">>>>>>> branch3").unwrap();
470        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
471        assert_eq!(
472            conflict_snapshot.conflicts_in_range(range, &snapshot),
473            &conflict_snapshot.conflicts[1..=2]
474        );
475
476        let range = test_content.find("thirteen").unwrap() - 1..test_content.len();
477        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
478        assert_eq!(
479            conflict_snapshot.conflicts_in_range(range, &snapshot),
480            &conflict_snapshot.conflicts[3..=3]
481        );
482    }
483
484    #[gpui::test]
485    async fn test_conflict_updates(executor: BackgroundExecutor, cx: &mut TestAppContext) {
486        zlog::init_test();
487        cx.update(|cx| {
488            settings::init(cx);
489            WorktreeSettings::register(cx);
490            Project::init_settings(cx);
491            AllLanguageSettings::register(cx);
492        });
493        let initial_text = "
494            one
495            two
496            three
497            four
498            five
499        "
500        .unindent();
501        let fs = FakeFs::new(executor);
502        fs.insert_tree(
503            path!("/project"),
504            json!({
505                ".git": {},
506                "a.txt": initial_text,
507            }),
508        )
509        .await;
510        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
511        let (git_store, buffer) = project.update(cx, |project, cx| {
512            (
513                project.git_store().clone(),
514                project.open_local_buffer(path!("/project/a.txt"), cx),
515            )
516        });
517        let buffer = buffer.await.unwrap();
518        let conflict_set = git_store.update(cx, |git_store, cx| {
519            git_store.open_conflict_set(buffer.clone(), cx)
520        });
521        let (events_tx, events_rx) = mpsc::channel::<ConflictSetUpdate>();
522        let _conflict_set_subscription = cx.update(|cx| {
523            cx.subscribe(&conflict_set, move |_, event, _| {
524                events_tx.send(event.clone()).ok();
525            })
526        });
527        let conflicts_snapshot =
528            conflict_set.read_with(cx, |conflict_set, _| conflict_set.snapshot());
529        assert!(conflicts_snapshot.conflicts.is_empty());
530
531        buffer.update(cx, |buffer, cx| {
532            buffer.edit(
533                [
534                    (4..4, "<<<<<<< HEAD\n"),
535                    (14..14, "=======\nTWO\n>>>>>>> branch\n"),
536                ],
537                None,
538                cx,
539            );
540        });
541
542        cx.run_until_parked();
543        events_rx.try_recv().expect_err(
544            "no conflicts should be registered as long as the file's status is unchanged",
545        );
546
547        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
548            state.unmerged_paths.insert(
549                repo_path("a.txt"),
550                UnmergedStatus {
551                    first_head: UnmergedStatusCode::Updated,
552                    second_head: UnmergedStatusCode::Updated,
553                },
554            );
555            // Cause the repository to emit MergeHeadsChanged.
556            state.refs.insert("MERGE_HEAD".into(), "123".into())
557        })
558        .unwrap();
559
560        cx.run_until_parked();
561        let update = events_rx
562            .try_recv()
563            .expect("status change should trigger conflict parsing");
564        assert_eq!(update.old_range, 0..0);
565        assert_eq!(update.new_range, 0..1);
566
567        let conflict = conflict_set.read_with(cx, |conflict_set, _| {
568            conflict_set.snapshot().conflicts[0].clone()
569        });
570        cx.update(|cx| {
571            conflict.resolve(buffer.clone(), std::slice::from_ref(&conflict.theirs), cx);
572        });
573
574        cx.run_until_parked();
575        let update = events_rx
576            .try_recv()
577            .expect("conflicts should be removed after resolution");
578        assert_eq!(update.old_range, 0..1);
579        assert_eq!(update.new_range, 0..0);
580    }
581
582    #[gpui::test]
583    async fn test_conflict_updates_without_merge_head(
584        executor: BackgroundExecutor,
585        cx: &mut TestAppContext,
586    ) {
587        zlog::init_test();
588        cx.update(|cx| {
589            settings::init(cx);
590            WorktreeSettings::register(cx);
591            Project::init_settings(cx);
592            AllLanguageSettings::register(cx);
593        });
594
595        let initial_text = "
596            zero
597            <<<<<<< HEAD
598            one
599            =======
600            two
601            >>>>>>> Stashed Changes
602            three
603        "
604        .unindent();
605
606        let fs = FakeFs::new(executor);
607        fs.insert_tree(
608            path!("/project"),
609            json!({
610                ".git": {},
611                "a.txt": initial_text,
612            }),
613        )
614        .await;
615
616        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
617        let (git_store, buffer) = project.update(cx, |project, cx| {
618            (
619                project.git_store().clone(),
620                project.open_local_buffer(path!("/project/a.txt"), cx),
621            )
622        });
623
624        cx.run_until_parked();
625        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
626            state.unmerged_paths.insert(
627                rel_path("a.txt").into(),
628                UnmergedStatus {
629                    first_head: UnmergedStatusCode::Updated,
630                    second_head: UnmergedStatusCode::Updated,
631                },
632            )
633        })
634        .unwrap();
635
636        let buffer = buffer.await.unwrap();
637
638        // Open the conflict set for a file that currently has conflicts.
639        let conflict_set = git_store.update(cx, |git_store, cx| {
640            git_store.open_conflict_set(buffer.clone(), cx)
641        });
642
643        cx.run_until_parked();
644        conflict_set.update(cx, |conflict_set, cx| {
645            let conflict_range = conflict_set.snapshot().conflicts[0]
646                .range
647                .to_point(buffer.read(cx));
648            assert_eq!(conflict_range, Point::new(1, 0)..Point::new(6, 0));
649        });
650
651        // Simulate the conflict being removed by e.g. staging the file.
652        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
653            state.unmerged_paths.remove(&repo_path("a.txt"))
654        })
655        .unwrap();
656
657        cx.run_until_parked();
658        conflict_set.update(cx, |conflict_set, _| {
659            assert!(!conflict_set.has_conflict);
660            assert_eq!(conflict_set.snapshot.conflicts.len(), 0);
661        });
662
663        // Simulate the conflict being re-added.
664        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
665            state.unmerged_paths.insert(
666                repo_path("a.txt"),
667                UnmergedStatus {
668                    first_head: UnmergedStatusCode::Updated,
669                    second_head: UnmergedStatusCode::Updated,
670                },
671            )
672        })
673        .unwrap();
674
675        cx.run_until_parked();
676        conflict_set.update(cx, |conflict_set, cx| {
677            let conflict_range = conflict_set.snapshot().conflicts[0]
678                .range
679                .to_point(buffer.read(cx));
680            assert_eq!(conflict_range, Point::new(1, 0)..Point::new(6, 0));
681        });
682    }
683}