Make each branch have it's own unique color

Anthony Eid created

Change summary

crates/git_graph/src/git_graph.rs | 273 +++++++++++++++++++++++++++-----
1 file changed, 228 insertions(+), 45 deletions(-)

Detailed changes

crates/git_graph/src/git_graph.rs 🔗

@@ -274,8 +274,8 @@ fn accent_colors_count(accents: &AccentColors) -> usize {
     accents.0.len()
 }
 
-#[derive(Copy, Clone, Debug)]
-struct BranchColor(u8);
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+struct BranchId(usize);
 
 #[derive(Debug)]
 enum LaneState {
@@ -283,7 +283,7 @@ enum LaneState {
     Active {
         child: Oid,
         parent: Oid,
-        color: Option<BranchColor>,
+        branch_id: BranchId,
         starting_row: usize,
         starting_col: usize,
         destination_column: Option<usize>,
@@ -297,7 +297,6 @@ impl LaneState {
         ending_row: usize,
         lane_column: usize,
         parent_column: usize,
-        parent_color: BranchColor,
     ) -> Option<CommitLine> {
         let state = std::mem::replace(self, LaneState::Empty);
 
@@ -307,14 +306,13 @@ impl LaneState {
                 parent,
                 #[cfg_attr(not(test), allow(unused_variables))]
                 child,
-                color,
+                branch_id,
                 starting_row,
                 starting_col,
                 destination_column,
                 mut segments,
             } => {
                 let final_destination = destination_column.unwrap_or(parent_column);
-                let final_color = color.unwrap_or(parent_color);
 
                 Some(CommitLine {
                     #[cfg(test)]
@@ -323,7 +321,7 @@ impl LaneState {
                     parent,
                     child_column: starting_col,
                     full_interval: starting_row..ending_row,
-                    color_idx: final_color.0 as usize,
+                    branch_id,
                     segments: {
                         match segments.last_mut() {
                             Some(CommitLineSegment::Straight { to_row })
@@ -431,6 +429,13 @@ impl LaneState {
         }
     }
 
+    fn branch_id(&self) -> Option<BranchId> {
+        match self {
+            LaneState::Active { branch_id, .. } => Some(*branch_id),
+            LaneState::Empty => None,
+        }
+    }
+
     fn is_empty(&self) -> bool {
         match self {
             LaneState::Empty => true,
@@ -442,7 +447,7 @@ impl LaneState {
 struct CommitEntry {
     data: Arc<InitialGraphCommitData>,
     lane: usize,
-    color_idx: usize,
+    branch_id: BranchId,
 }
 
 type ActiveLaneIdx = usize;
@@ -478,7 +483,7 @@ struct CommitLine {
     parent: Oid,
     child_column: usize,
     full_interval: Range<usize>,
-    color_idx: usize,
+    branch_id: BranchId,
     segments: SmallVec<[CommitLineSegment; 1]>,
 }
 
@@ -522,10 +527,8 @@ struct CommitLineKey {
 
 struct GraphData {
     lane_states: SmallVec<[LaneState; 8]>,
-    lane_colors: HashMap<ActiveLaneIdx, BranchColor>,
     parent_to_lanes: HashMap<Oid, SmallVec<[usize; 1]>>,
-    next_color: BranchColor,
-    accent_colors_count: usize,
+    next_branch_id: BranchId,
     commits: Vec<Rc<CommitEntry>>,
     max_commit_count: AllCommitCount,
     max_lanes: usize,
@@ -535,13 +538,11 @@ struct GraphData {
 }
 
 impl GraphData {
-    fn new(accent_colors_count: usize) -> Self {
+    fn new(_accent_colors_count: usize) -> Self {
         GraphData {
             lane_states: SmallVec::default(),
-            lane_colors: HashMap::default(),
             parent_to_lanes: HashMap::default(),
-            next_color: BranchColor(0),
-            accent_colors_count,
+            next_branch_id: BranchId(0),
             commits: Vec::default(),
             max_commit_count: AllCommitCount::NotLoaded,
             max_lanes: 0,
@@ -553,13 +554,12 @@ impl GraphData {
 
     fn clear(&mut self) {
         self.lane_states.clear();
-        self.lane_colors.clear();
         self.parent_to_lanes.clear();
         self.commits.clear();
         self.lines.clear();
         self.active_commit_lines.clear();
         self.active_commit_lines_by_parent.clear();
-        self.next_color = BranchColor(0);
+        self.next_branch_id = BranchId(0);
         self.max_commit_count = AllCommitCount::NotLoaded;
         self.max_lanes = 0;
     }
@@ -574,13 +574,23 @@ impl GraphData {
             })
     }
 
-    fn get_lane_color(&mut self, lane_idx: ActiveLaneIdx) -> BranchColor {
-        let accent_colors_count = self.accent_colors_count;
-        *self.lane_colors.entry(lane_idx).or_insert_with(|| {
-            let color_idx = self.next_color;
-            self.next_color = BranchColor((self.next_color.0 + 1) % accent_colors_count as u8);
-            color_idx
-        })
+    fn allocate_branch_id(&mut self) -> BranchId {
+        let branch_id = self.next_branch_id;
+        self.next_branch_id = BranchId(self.next_branch_id.0 + 1);
+        branch_id
+    }
+
+    fn branch_id_for_lane(&self, lane_idx: ActiveLaneIdx) -> Option<BranchId> {
+        self.lane_states
+            .get(lane_idx)
+            .and_then(LaneState::branch_id)
+    }
+
+    fn origin_branch_id_for_parent(&self, parent: Oid) -> Option<BranchId> {
+        self.parent_to_lanes
+            .get(&parent)
+            .and_then(|lane_indices| lane_indices.first().copied())
+            .and_then(|lane_idx| self.branch_id_for_lane(lane_idx))
     }
 
     fn add_commits(&mut self, commits: &[Arc<InitialGraphCommitData>]) {
@@ -596,8 +606,9 @@ impl GraphData {
                 .and_then(|lanes| lanes.first().copied());
 
             let commit_lane = commit_lane.unwrap_or_else(|| self.first_empty_lane_idx());
-
-            let commit_color = self.get_lane_color(commit_lane);
+            let commit_branch_id = self
+                .branch_id_for_lane(commit_lane)
+                .unwrap_or_else(|| self.allocate_branch_id());
 
             if let Some(lanes) = self.parent_to_lanes.remove(&commit.sha) {
                 for lane_column in lanes {
@@ -632,7 +643,7 @@ impl GraphData {
                     }
 
                     if let Some(commit_line) =
-                        state.to_commit_lines(commit_row, lane_column, commit_lane, commit_color)
+                        state.to_commit_lines(commit_row, lane_column, commit_lane)
                     {
                         self.lines.push(Rc::new(commit_line));
                     }
@@ -648,7 +659,7 @@ impl GraphData {
                         self.lane_states[commit_lane] = LaneState::Active {
                             parent: *parent,
                             child: commit.sha,
-                            color: Some(commit_color),
+                            branch_id: commit_branch_id,
                             starting_col: commit_lane,
                             starting_row: commit_row,
                             destination_column: None,
@@ -661,11 +672,14 @@ impl GraphData {
                             .push(commit_lane);
                     } else {
                         let new_lane = self.first_empty_lane_idx();
+                        let branch_id = self
+                            .origin_branch_id_for_parent(*parent)
+                            .unwrap_or_else(|| self.allocate_branch_id());
 
                         self.lane_states[new_lane] = LaneState::Active {
                             parent: *parent,
                             child: commit.sha,
-                            color: None,
+                            branch_id,
                             starting_col: commit_lane,
                             starting_row: commit_row,
                             destination_column: None,
@@ -688,7 +702,7 @@ impl GraphData {
             self.commits.push(Rc::new(CommitEntry {
                 data: commit.clone(),
                 lane: commit_lane,
-                color_idx: commit_color.0 as usize,
+                branch_id: commit_branch_id,
             }));
         }
 
@@ -1121,11 +1135,7 @@ impl GitGraph {
                 }
 
                 let accent_colors = cx.theme().accents();
-                let accent_color = accent_colors
-                    .0
-                    .get(commit.color_idx)
-                    .copied()
-                    .unwrap_or_else(|| accent_colors.0.first().copied().unwrap_or_default());
+                let accent_color = accent_colors.color_for_index(commit.branch_id.0 as u32);
 
                 let is_selected = self.selected_entry_idx == Some(idx);
                 let column_label = |label: SharedString| {
@@ -1341,11 +1351,7 @@ impl GitGraph {
         let ref_names = commit_entry.data.ref_names.clone();
 
         let accent_colors = cx.theme().accents();
-        let accent_color = accent_colors
-            .0
-            .get(commit_entry.color_idx)
-            .copied()
-            .unwrap_or_else(|| accent_colors.0.first().copied().unwrap_or_default());
+        let accent_color = accent_colors.color_for_index(commit_entry.branch_id.0 as u32);
 
         let (author_name, author_email, commit_timestamp, subject) = match &data {
             CommitDataState::Loaded(data) => (
@@ -1761,7 +1767,7 @@ impl GitGraph {
                     }
 
                     for (row_idx, row) in rows.into_iter().enumerate() {
-                        let row_color = accent_colors.color_for_index(row.color_idx as u32);
+                        let row_color = accent_colors.color_for_index(row.branch_id.0 as u32);
                         let row_y_center =
                             bounds.origin.y + row_idx as f32 * row_height + row_height / 2.0
                                 - vertical_scroll_offset;
@@ -1899,11 +1905,11 @@ impl GitGraph {
                         }
 
                         builder.close();
-                        lines.entry(line.color_idx).or_default().push(builder);
+                        lines.entry(line.branch_id.0).or_default().push(builder);
                     }
 
-                    for (color_idx, builders) in lines {
-                        let line_color = accent_colors.color_for_index(color_idx as u32);
+                    for (branch_id, builders) in lines {
+                        let line_color = accent_colors.color_for_index(branch_id as u32);
 
                         for builder in builders {
                             if let Ok(path) = builder.build() {
@@ -2828,6 +2834,157 @@ mod tests {
         Ok(())
     }
 
+    fn expected_branch_assignments(
+        commits: &[Arc<InitialGraphCommitData>],
+    ) -> (HashMap<Oid, BranchId>, HashMap<(Oid, Oid), BranchId>) {
+        let mut parent_to_branch_ids: HashMap<Oid, SmallVec<[BranchId; 1]>> = HashMap::default();
+        let mut commit_branch_ids = HashMap::default();
+        let mut line_branch_ids = HashMap::default();
+        let mut next_branch_id = 0usize;
+
+        let mut allocate_branch_id = || {
+            let branch_id = BranchId(next_branch_id);
+            next_branch_id += 1;
+            branch_id
+        };
+
+        for commit in commits {
+            let commit_branch_id = parent_to_branch_ids
+                .remove(&commit.sha)
+                .and_then(|branch_ids| branch_ids.first().copied())
+                .unwrap_or_else(|| allocate_branch_id());
+
+            commit_branch_ids.insert(commit.sha, commit_branch_id);
+
+            for (parent_idx, parent) in commit.parents.iter().enumerate() {
+                let branch_id = if parent_idx == 0 {
+                    commit_branch_id
+                } else {
+                    parent_to_branch_ids
+                        .get(parent)
+                        .and_then(|branch_ids| branch_ids.first().copied())
+                        .unwrap_or_else(|| allocate_branch_id())
+                };
+
+                line_branch_ids.insert((commit.sha, *parent), branch_id);
+                parent_to_branch_ids
+                    .entry(*parent)
+                    .or_default()
+                    .push(branch_id);
+            }
+        }
+
+        (commit_branch_ids, line_branch_ids)
+    }
+
+    fn verify_branch_ids(graph: &GraphData, commits: &[Arc<InitialGraphCommitData>]) -> Result<()> {
+        let (expected_commit_branch_ids, expected_line_branch_ids) =
+            expected_branch_assignments(commits);
+        let mut seen_branch_ids = HashSet::default();
+        let actual_commit_branch_ids: HashMap<Oid, BranchId> = graph
+            .commits
+            .iter()
+            .map(|commit| (commit.data.sha, commit.branch_id))
+            .collect();
+        let actual_line_branch_ids: HashMap<(Oid, Oid), BranchId> = graph
+            .lines
+            .iter()
+            .map(|line| ((line.child, line.parent), line.branch_id))
+            .collect();
+        let mut parent_to_seen_branch_ids: HashMap<Oid, SmallVec<[BranchId; 1]>> =
+            HashMap::default();
+
+        for commit in &graph.commits {
+            let expected_branch_id = expected_commit_branch_ids
+                .get(&commit.data.sha)
+                .context("Commit is missing an expected branch id")?;
+
+            if &commit.branch_id != expected_branch_id {
+                bail!(
+                    "Commit {:?} has branch_id {:?}, expected {:?}",
+                    commit.data.sha,
+                    commit.branch_id,
+                    expected_branch_id
+                );
+            }
+
+            seen_branch_ids.insert(commit.branch_id.0);
+        }
+
+        for line in &graph.lines {
+            let expected_branch_id = expected_line_branch_ids
+                .get(&(line.child, line.parent))
+                .context("Line is missing an expected branch id")?;
+
+            if &line.branch_id != expected_branch_id {
+                bail!(
+                    "Line {:?} -> {:?} has branch_id {:?}, expected {:?}",
+                    line.child,
+                    line.parent,
+                    line.branch_id,
+                    expected_branch_id
+                );
+            }
+
+            seen_branch_ids.insert(line.branch_id.0);
+        }
+
+        for commit in commits {
+            let commit_branch_id = *actual_commit_branch_ids
+                .get(&commit.sha)
+                .context("Commit is missing an actual branch id")?;
+
+            for (parent_idx, parent) in commit.parents.iter().enumerate() {
+                let line_branch_id = *actual_line_branch_ids
+                    .get(&(commit.sha, *parent))
+                    .context("Line is missing an actual branch id")?;
+
+                if parent_idx > 0
+                    && let Some(origin_branch_id) = parent_to_seen_branch_ids
+                        .get(parent)
+                        .and_then(|branch_ids| branch_ids.first().copied())
+                {
+                    if line_branch_id != origin_branch_id {
+                        bail!(
+                            "Line {:?} -> {:?} has branch_id {:?}, expected origin branch id {:?}",
+                            commit.sha,
+                            parent,
+                            line_branch_id,
+                            origin_branch_id
+                        );
+                    }
+
+                    if line_branch_id == commit_branch_id {
+                        bail!(
+                            "Line {:?} -> {:?} reused merged-into branch id {:?} instead of origin branch id {:?}",
+                            commit.sha,
+                            parent,
+                            commit_branch_id,
+                            origin_branch_id
+                        );
+                    }
+                }
+
+                parent_to_seen_branch_ids
+                    .entry(*parent)
+                    .or_default()
+                    .push(line_branch_id);
+            }
+        }
+
+        let Some(max_branch_id) = seen_branch_ids.iter().max().copied() else {
+            return Ok(());
+        };
+
+        for expected_branch_id in 0..=max_branch_id {
+            if !seen_branch_ids.contains(&expected_branch_id) {
+                bail!("Missing branch id {}", expected_branch_id);
+            }
+        }
+
+        Ok(())
+    }
+
     fn verify_merge_line_optimality(
         graph: &GraphData,
         oid_to_row: &HashMap<Oid, usize>,
@@ -2928,6 +3085,7 @@ mod tests {
         let oid_to_row = build_oid_to_row_map(graph);
 
         verify_commit_order(graph, commits).context("commit order")?;
+        verify_branch_ids(graph, commits).context("branch ids")?;
         verify_line_endpoints(graph, &oid_to_row).context("line endpoints")?;
         verify_column_correctness(graph, &oid_to_row).context("column correctness")?;
         verify_segment_continuity(graph).context("segment continuity")?;
@@ -3044,6 +3202,31 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_git_graph_random_branch_ids() {
+        for seed in 0..100 {
+            let mut rng = StdRng::seed_from_u64(seed);
+
+            let adversarial = rng.random_bool(0.2);
+            let num_commits = if adversarial {
+                rng.random_range(10..100)
+            } else {
+                rng.random_range(5..50)
+            };
+
+            let commits = generate_random_commit_dag(&mut rng, num_commits, adversarial);
+            let mut graph_data = GraphData::new(8);
+            graph_data.add_commits(&commits);
+
+            if let Err(error) = verify_branch_ids(&graph_data, &commits) {
+                panic!(
+                    "Branch id invariant violation (seed={}, adversarial={}, num_commits={}):\n{:#}",
+                    seed, adversarial, num_commits, error
+                );
+            }
+        }
+    }
+
     // The full integration test has less iterations because it's significantly slower
     // than the random commit test
     #[gpui::test(iterations = 10)]