1use gpui::{App, Context, Entity, EventEmitter};
  2use std::{cmp::Ordering, ops::Range, sync::Arc};
  3use text::{Anchor, BufferId, OffsetRangeExt as _};
  4
  5pub struct ConflictSet {
  6    pub has_conflict: bool,
  7    pub snapshot: ConflictSetSnapshot,
  8}
  9
 10#[derive(Clone, Debug, PartialEq, Eq)]
 11pub struct ConflictSetUpdate {
 12    pub buffer_range: Option<Range<Anchor>>,
 13    pub old_range: Range<usize>,
 14    pub new_range: Range<usize>,
 15}
 16
 17#[derive(Debug, Clone)]
 18pub struct ConflictSetSnapshot {
 19    pub buffer_id: BufferId,
 20    pub conflicts: Arc<[ConflictRegion]>,
 21}
 22
 23impl ConflictSetSnapshot {
 24    pub fn conflicts_in_range(
 25        &self,
 26        range: Range<Anchor>,
 27        buffer: &text::BufferSnapshot,
 28    ) -> &[ConflictRegion] {
 29        let start_ix = self
 30            .conflicts
 31            .binary_search_by(|conflict| {
 32                conflict
 33                    .range
 34                    .end
 35                    .cmp(&range.start, buffer)
 36                    .then(Ordering::Greater)
 37            })
 38            .unwrap_err();
 39        let end_ix = start_ix
 40            + self.conflicts[start_ix..]
 41                .binary_search_by(|conflict| {
 42                    conflict
 43                        .range
 44                        .start
 45                        .cmp(&range.end, buffer)
 46                        .then(Ordering::Less)
 47                })
 48                .unwrap_err();
 49        &self.conflicts[start_ix..end_ix]
 50    }
 51
 52    pub fn compare(&self, other: &Self, buffer: &text::BufferSnapshot) -> ConflictSetUpdate {
 53        let common_prefix_len = self
 54            .conflicts
 55            .iter()
 56            .zip(other.conflicts.iter())
 57            .take_while(|(old, new)| old == new)
 58            .count();
 59        let common_suffix_len = self.conflicts[common_prefix_len..]
 60            .iter()
 61            .rev()
 62            .zip(other.conflicts[common_prefix_len..].iter().rev())
 63            .take_while(|(old, new)| old == new)
 64            .count();
 65        let old_conflicts =
 66            &self.conflicts[common_prefix_len..(self.conflicts.len() - common_suffix_len)];
 67        let new_conflicts =
 68            &other.conflicts[common_prefix_len..(other.conflicts.len() - common_suffix_len)];
 69        let old_range = common_prefix_len..(common_prefix_len + old_conflicts.len());
 70        let new_range = common_prefix_len..(common_prefix_len + new_conflicts.len());
 71        let start = match (old_conflicts.first(), new_conflicts.first()) {
 72            (None, None) => None,
 73            (None, Some(conflict)) => Some(conflict.range.start),
 74            (Some(conflict), None) => Some(conflict.range.start),
 75            (Some(first), Some(second)) => {
 76                Some(*first.range.start.min(&second.range.start, buffer))
 77            }
 78        };
 79        let end = match (old_conflicts.last(), new_conflicts.last()) {
 80            (None, None) => None,
 81            (None, Some(conflict)) => Some(conflict.range.end),
 82            (Some(first), None) => Some(first.range.end),
 83            (Some(first), Some(second)) => Some(*first.range.end.max(&second.range.end, buffer)),
 84        };
 85        ConflictSetUpdate {
 86            buffer_range: start.zip(end).map(|(start, end)| start..end),
 87            old_range,
 88            new_range,
 89        }
 90    }
 91}
 92
 93#[derive(Debug, Clone, PartialEq, Eq)]
 94pub struct ConflictRegion {
 95    pub range: Range<Anchor>,
 96    pub ours: Range<Anchor>,
 97    pub theirs: Range<Anchor>,
 98    pub base: Option<Range<Anchor>>,
 99}
100
101impl ConflictRegion {
102    pub fn resolve(
103        &self,
104        buffer: Entity<language::Buffer>,
105        ranges: &[Range<Anchor>],
106        cx: &mut App,
107    ) {
108        let buffer_snapshot = buffer.read(cx).snapshot();
109        let mut deletions = Vec::new();
110        let empty = "";
111        let outer_range = self.range.to_offset(&buffer_snapshot);
112        let mut offset = outer_range.start;
113        for kept_range in ranges {
114            let kept_range = kept_range.to_offset(&buffer_snapshot);
115            if kept_range.start > offset {
116                deletions.push((offset..kept_range.start, empty));
117            }
118            offset = kept_range.end;
119        }
120        if outer_range.end > offset {
121            deletions.push((offset..outer_range.end, empty));
122        }
123
124        buffer.update(cx, |buffer, cx| {
125            buffer.edit(deletions, None, cx);
126        });
127    }
128}
129
130impl ConflictSet {
131    pub fn new(buffer_id: BufferId, has_conflict: bool, _: &mut Context<Self>) -> Self {
132        Self {
133            has_conflict,
134            snapshot: ConflictSetSnapshot {
135                buffer_id,
136                conflicts: Default::default(),
137            },
138        }
139    }
140
141    pub fn set_has_conflict(&mut self, has_conflict: bool, cx: &mut Context<Self>) -> bool {
142        if has_conflict != self.has_conflict {
143            self.has_conflict = has_conflict;
144            if !self.has_conflict {
145                cx.emit(ConflictSetUpdate {
146                    buffer_range: None,
147                    old_range: 0..self.snapshot.conflicts.len(),
148                    new_range: 0..0,
149                });
150                self.snapshot.conflicts = Default::default();
151            }
152            true
153        } else {
154            false
155        }
156    }
157
158    pub fn snapshot(&self) -> ConflictSetSnapshot {
159        self.snapshot.clone()
160    }
161
162    pub fn set_snapshot(
163        &mut self,
164        snapshot: ConflictSetSnapshot,
165        update: ConflictSetUpdate,
166        cx: &mut Context<Self>,
167    ) {
168        self.snapshot = snapshot;
169        cx.emit(update);
170    }
171
172    pub fn parse(buffer: &text::BufferSnapshot) -> ConflictSetSnapshot {
173        let mut conflicts = Vec::new();
174
175        let mut line_pos = 0;
176        let buffer_len = buffer.len();
177        let mut lines = buffer.text_for_range(0..buffer_len).lines();
178
179        let mut conflict_start: Option<usize> = None;
180        let mut ours_start: Option<usize> = None;
181        let mut ours_end: Option<usize> = None;
182        let mut base_start: Option<usize> = None;
183        let mut base_end: Option<usize> = None;
184        let mut theirs_start: Option<usize> = None;
185
186        while let Some(line) = lines.next() {
187            let line_end = line_pos + line.len();
188
189            if line.starts_with("<<<<<<< ") {
190                // If we see a new conflict marker while already parsing one,
191                // abandon the previous one and start a new one
192                conflict_start = Some(line_pos);
193                ours_start = Some(line_end + 1);
194            } else if line.starts_with("||||||| ")
195                && conflict_start.is_some()
196                && ours_start.is_some()
197            {
198                ours_end = Some(line_pos);
199                base_start = Some(line_end + 1);
200            } else if line.starts_with("=======")
201                && conflict_start.is_some()
202                && ours_start.is_some()
203            {
204                // Set ours_end if not already set (would be set if we have base markers)
205                if ours_end.is_none() {
206                    ours_end = Some(line_pos);
207                } else if base_start.is_some() {
208                    base_end = Some(line_pos);
209                }
210                theirs_start = Some(line_end + 1);
211            } else if line.starts_with(">>>>>>> ")
212                && conflict_start.is_some()
213                && ours_start.is_some()
214                && ours_end.is_some()
215                && theirs_start.is_some()
216            {
217                let theirs_end = line_pos;
218                let conflict_end = (line_end + 1).min(buffer_len);
219
220                let range = buffer.anchor_after(conflict_start.unwrap())
221                    ..buffer.anchor_before(conflict_end);
222                let ours = buffer.anchor_after(ours_start.unwrap())
223                    ..buffer.anchor_before(ours_end.unwrap());
224                let theirs =
225                    buffer.anchor_after(theirs_start.unwrap())..buffer.anchor_before(theirs_end);
226
227                let base = base_start
228                    .zip(base_end)
229                    .map(|(start, end)| buffer.anchor_after(start)..buffer.anchor_before(end));
230
231                conflicts.push(ConflictRegion {
232                    range,
233                    ours,
234                    theirs,
235                    base,
236                });
237
238                conflict_start = None;
239                ours_start = None;
240                ours_end = None;
241                base_start = None;
242                base_end = None;
243                theirs_start = None;
244            }
245
246            line_pos = line_end + 1;
247        }
248
249        ConflictSetSnapshot {
250            conflicts: conflicts.into(),
251            buffer_id: buffer.remote_id(),
252        }
253    }
254}
255
256impl EventEmitter<ConflictSetUpdate> for ConflictSet {}
257
258#[cfg(test)]
259mod tests {
260    use std::sync::mpsc;
261
262    use crate::Project;
263
264    use super::*;
265    use fs::FakeFs;
266    use git::{
267        repository::repo_path,
268        status::{UnmergedStatus, UnmergedStatusCode},
269    };
270    use gpui::{BackgroundExecutor, TestAppContext};
271    use language::language_settings::AllLanguageSettings;
272    use serde_json::json;
273    use settings::Settings as _;
274    use text::{Buffer, BufferId, Point, ReplicaId, ToOffset as _};
275    use unindent::Unindent as _;
276    use util::{path, rel_path::rel_path};
277    use worktree::WorktreeSettings;
278
279    #[test]
280    fn test_parse_conflicts_in_buffer() {
281        // Create a buffer with conflict markers
282        let test_content = r#"
283            This is some text before the conflict.
284            <<<<<<< HEAD
285            This is our version
286            =======
287            This is their version
288            >>>>>>> branch-name
289
290            Another conflict:
291            <<<<<<< HEAD
292            Our second change
293            ||||||| merged common ancestors
294            Original content
295            =======
296            Their second change
297            >>>>>>> branch-name
298        "#
299        .unindent();
300
301        let buffer_id = BufferId::new(1).unwrap();
302        let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, test_content);
303        let snapshot = buffer.snapshot();
304
305        let conflict_snapshot = ConflictSet::parse(&snapshot);
306        assert_eq!(conflict_snapshot.conflicts.len(), 2);
307
308        let first = &conflict_snapshot.conflicts[0];
309        assert!(first.base.is_none());
310        let our_text = snapshot
311            .text_for_range(first.ours.clone())
312            .collect::<String>();
313        let their_text = snapshot
314            .text_for_range(first.theirs.clone())
315            .collect::<String>();
316        assert_eq!(our_text, "This is our version\n");
317        assert_eq!(their_text, "This is their version\n");
318
319        let second = &conflict_snapshot.conflicts[1];
320        assert!(second.base.is_some());
321        let our_text = snapshot
322            .text_for_range(second.ours.clone())
323            .collect::<String>();
324        let their_text = snapshot
325            .text_for_range(second.theirs.clone())
326            .collect::<String>();
327        let base_text = snapshot
328            .text_for_range(second.base.as_ref().unwrap().clone())
329            .collect::<String>();
330        assert_eq!(our_text, "Our second change\n");
331        assert_eq!(their_text, "Their second change\n");
332        assert_eq!(base_text, "Original content\n");
333
334        // Test conflicts_in_range
335        let range = snapshot.anchor_before(0)..snapshot.anchor_before(snapshot.len());
336        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
337        assert_eq!(conflicts_in_range.len(), 2);
338
339        // Test with a range that includes only the first conflict
340        let first_conflict_end = conflict_snapshot.conflicts[0].range.end;
341        let range = snapshot.anchor_before(0)..first_conflict_end;
342        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
343        assert_eq!(conflicts_in_range.len(), 1);
344
345        // Test with a range that includes only the second conflict
346        let second_conflict_start = conflict_snapshot.conflicts[1].range.start;
347        let range = second_conflict_start..snapshot.anchor_before(snapshot.len());
348        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
349        assert_eq!(conflicts_in_range.len(), 1);
350
351        // Test with a range that doesn't include any conflicts
352        let range = buffer.anchor_after(first_conflict_end.to_next_offset(&buffer))
353            ..buffer.anchor_before(second_conflict_start.to_previous_offset(&buffer));
354        let conflicts_in_range = conflict_snapshot.conflicts_in_range(range, &snapshot);
355        assert_eq!(conflicts_in_range.len(), 0);
356    }
357
358    #[test]
359    fn test_nested_conflict_markers() {
360        // Create a buffer with nested conflict markers
361        let test_content = r#"
362            This is some text before the conflict.
363            <<<<<<< HEAD
364            This is our version
365            <<<<<<< HEAD
366            This is a nested conflict marker
367            =======
368            This is their version in a nested conflict
369            >>>>>>> branch-nested
370            =======
371            This is their version
372            >>>>>>> branch-name
373        "#
374        .unindent();
375
376        let buffer_id = BufferId::new(1).unwrap();
377        let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, test_content);
378        let snapshot = buffer.snapshot();
379
380        let conflict_snapshot = ConflictSet::parse(&snapshot);
381
382        assert_eq!(conflict_snapshot.conflicts.len(), 1);
383
384        // The conflict should have our version, their version, but no base
385        let conflict = &conflict_snapshot.conflicts[0];
386        assert!(conflict.base.is_none());
387
388        // Check that the nested conflict was detected correctly
389        let our_text = snapshot
390            .text_for_range(conflict.ours.clone())
391            .collect::<String>();
392        assert_eq!(our_text, "This is a nested conflict marker\n");
393        let their_text = snapshot
394            .text_for_range(conflict.theirs.clone())
395            .collect::<String>();
396        assert_eq!(their_text, "This is their version in a nested conflict\n");
397    }
398
399    #[test]
400    fn test_conflict_markers_at_eof() {
401        let test_content = r#"
402            <<<<<<< ours
403            =======
404            This is their version
405            >>>>>>> "#
406            .unindent();
407        let buffer_id = BufferId::new(1).unwrap();
408        let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, test_content);
409        let snapshot = buffer.snapshot();
410
411        let conflict_snapshot = ConflictSet::parse(&snapshot);
412        assert_eq!(conflict_snapshot.conflicts.len(), 1);
413    }
414
415    #[test]
416    fn test_conflicts_in_range() {
417        // Create a buffer with conflict markers
418        let test_content = r#"
419            one
420            <<<<<<< HEAD1
421            two
422            =======
423            three
424            >>>>>>> branch1
425            four
426            five
427            <<<<<<< HEAD2
428            six
429            =======
430            seven
431            >>>>>>> branch2
432            eight
433            nine
434            <<<<<<< HEAD3
435            ten
436            =======
437            eleven
438            >>>>>>> branch3
439            twelve
440            <<<<<<< HEAD4
441            thirteen
442            =======
443            fourteen
444            >>>>>>> branch4
445            fifteen
446        "#
447        .unindent();
448
449        let buffer_id = BufferId::new(1).unwrap();
450        let buffer = Buffer::new(ReplicaId::LOCAL, buffer_id, test_content.clone());
451        let snapshot = buffer.snapshot();
452
453        let conflict_snapshot = ConflictSet::parse(&snapshot);
454        assert_eq!(conflict_snapshot.conflicts.len(), 4);
455
456        let range = test_content.find("seven").unwrap()..test_content.find("eleven").unwrap();
457        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
458        assert_eq!(
459            conflict_snapshot.conflicts_in_range(range, &snapshot),
460            &conflict_snapshot.conflicts[1..=2]
461        );
462
463        let range = test_content.find("one").unwrap()..test_content.find("<<<<<<< HEAD2").unwrap();
464        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
465        assert_eq!(
466            conflict_snapshot.conflicts_in_range(range, &snapshot),
467            &conflict_snapshot.conflicts[0..=1]
468        );
469
470        let range =
471            test_content.find("eight").unwrap() - 1..test_content.find(">>>>>>> branch3").unwrap();
472        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
473        assert_eq!(
474            conflict_snapshot.conflicts_in_range(range, &snapshot),
475            &conflict_snapshot.conflicts[1..=2]
476        );
477
478        let range = test_content.find("thirteen").unwrap() - 1..test_content.len();
479        let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
480        assert_eq!(
481            conflict_snapshot.conflicts_in_range(range, &snapshot),
482            &conflict_snapshot.conflicts[3..=3]
483        );
484    }
485
486    #[gpui::test]
487    async fn test_conflict_updates(executor: BackgroundExecutor, cx: &mut TestAppContext) {
488        zlog::init_test();
489        cx.update(|cx| {
490            settings::init(cx);
491            WorktreeSettings::register(cx);
492            Project::init_settings(cx);
493            AllLanguageSettings::register(cx);
494        });
495        let initial_text = "
496            one
497            two
498            three
499            four
500            five
501        "
502        .unindent();
503        let fs = FakeFs::new(executor);
504        fs.insert_tree(
505            path!("/project"),
506            json!({
507                ".git": {},
508                "a.txt": initial_text,
509            }),
510        )
511        .await;
512        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
513        let (git_store, buffer) = project.update(cx, |project, cx| {
514            (
515                project.git_store().clone(),
516                project.open_local_buffer(path!("/project/a.txt"), cx),
517            )
518        });
519        let buffer = buffer.await.unwrap();
520        let conflict_set = git_store.update(cx, |git_store, cx| {
521            git_store.open_conflict_set(buffer.clone(), cx)
522        });
523        let (events_tx, events_rx) = mpsc::channel::<ConflictSetUpdate>();
524        let _conflict_set_subscription = cx.update(|cx| {
525            cx.subscribe(&conflict_set, move |_, event, _| {
526                events_tx.send(event.clone()).ok();
527            })
528        });
529        let conflicts_snapshot =
530            conflict_set.read_with(cx, |conflict_set, _| conflict_set.snapshot());
531        assert!(conflicts_snapshot.conflicts.is_empty());
532
533        buffer.update(cx, |buffer, cx| {
534            buffer.edit(
535                [
536                    (4..4, "<<<<<<< HEAD\n"),
537                    (14..14, "=======\nTWO\n>>>>>>> branch\n"),
538                ],
539                None,
540                cx,
541            );
542        });
543
544        cx.run_until_parked();
545        events_rx.try_recv().expect_err(
546            "no conflicts should be registered as long as the file's status is unchanged",
547        );
548
549        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
550            state.unmerged_paths.insert(
551                repo_path("a.txt"),
552                UnmergedStatus {
553                    first_head: UnmergedStatusCode::Updated,
554                    second_head: UnmergedStatusCode::Updated,
555                },
556            );
557            // Cause the repository to emit MergeHeadsChanged.
558            state.refs.insert("MERGE_HEAD".into(), "123".into())
559        })
560        .unwrap();
561
562        cx.run_until_parked();
563        let update = events_rx
564            .try_recv()
565            .expect("status change should trigger conflict parsing");
566        assert_eq!(update.old_range, 0..0);
567        assert_eq!(update.new_range, 0..1);
568
569        let conflict = conflict_set.read_with(cx, |conflict_set, _| {
570            conflict_set.snapshot().conflicts[0].clone()
571        });
572        cx.update(|cx| {
573            conflict.resolve(buffer.clone(), std::slice::from_ref(&conflict.theirs), cx);
574        });
575
576        cx.run_until_parked();
577        let update = events_rx
578            .try_recv()
579            .expect("conflicts should be removed after resolution");
580        assert_eq!(update.old_range, 0..1);
581        assert_eq!(update.new_range, 0..0);
582    }
583
584    #[gpui::test]
585    async fn test_conflict_updates_without_merge_head(
586        executor: BackgroundExecutor,
587        cx: &mut TestAppContext,
588    ) {
589        zlog::init_test();
590        cx.update(|cx| {
591            settings::init(cx);
592            WorktreeSettings::register(cx);
593            Project::init_settings(cx);
594            AllLanguageSettings::register(cx);
595        });
596
597        let initial_text = "
598            zero
599            <<<<<<< HEAD
600            one
601            =======
602            two
603            >>>>>>> Stashed Changes
604            three
605        "
606        .unindent();
607
608        let fs = FakeFs::new(executor);
609        fs.insert_tree(
610            path!("/project"),
611            json!({
612                ".git": {},
613                "a.txt": initial_text,
614            }),
615        )
616        .await;
617
618        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
619        let (git_store, buffer) = project.update(cx, |project, cx| {
620            (
621                project.git_store().clone(),
622                project.open_local_buffer(path!("/project/a.txt"), cx),
623            )
624        });
625
626        cx.run_until_parked();
627        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
628            state.unmerged_paths.insert(
629                rel_path("a.txt").into(),
630                UnmergedStatus {
631                    first_head: UnmergedStatusCode::Updated,
632                    second_head: UnmergedStatusCode::Updated,
633                },
634            )
635        })
636        .unwrap();
637
638        let buffer = buffer.await.unwrap();
639
640        // Open the conflict set for a file that currently has conflicts.
641        let conflict_set = git_store.update(cx, |git_store, cx| {
642            git_store.open_conflict_set(buffer.clone(), cx)
643        });
644
645        cx.run_until_parked();
646        conflict_set.update(cx, |conflict_set, cx| {
647            let conflict_range = conflict_set.snapshot().conflicts[0]
648                .range
649                .to_point(buffer.read(cx));
650            assert_eq!(conflict_range, Point::new(1, 0)..Point::new(6, 0));
651        });
652
653        // Simulate the conflict being removed by e.g. staging the file.
654        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
655            state.unmerged_paths.remove(&repo_path("a.txt"))
656        })
657        .unwrap();
658
659        cx.run_until_parked();
660        conflict_set.update(cx, |conflict_set, _| {
661            assert!(!conflict_set.has_conflict);
662            assert_eq!(conflict_set.snapshot.conflicts.len(), 0);
663        });
664
665        // Simulate the conflict being re-added.
666        fs.with_git_state(path!("/project/.git").as_ref(), true, |state| {
667            state.unmerged_paths.insert(
668                repo_path("a.txt"),
669                UnmergedStatus {
670                    first_head: UnmergedStatusCode::Updated,
671                    second_head: UnmergedStatusCode::Updated,
672                },
673            )
674        })
675        .unwrap();
676
677        cx.run_until_parked();
678        conflict_set.update(cx, |conflict_set, cx| {
679            let conflict_range = conflict_set.snapshot().conflicts[0]
680                .range
681                .to_point(buffer.read(cx));
682            assert_eq!(conflict_range, Point::new(1, 0)..Point::new(6, 0));
683        });
684    }
685}