Update branch picker to use repo snapshot to fetch branches

Anthony Eid created

Change summary

crates/git_ui/src/branch_picker.rs | 178 +++++++++++++++++++------------
1 file changed, 107 insertions(+), 71 deletions(-)

Detailed changes

crates/git_ui/src/branch_picker.rs 🔗

@@ -11,7 +11,7 @@ use gpui::{
     SharedString, Styled, Subscription, Task, WeakEntity, Window, actions, rems,
 };
 use picker::{Picker, PickerDelegate, PickerEditorPosition};
-use project::git_store::Repository;
+use project::git_store::{Repository, RepositoryEvent};
 use project::project_settings::ProjectSettings;
 use settings::Settings;
 use std::sync::Arc;
@@ -113,7 +113,7 @@ pub struct BranchList {
     width: Rems,
     pub picker: Entity<Picker<BranchListDelegate>>,
     picker_focus_handle: FocusHandle,
-    _subscription: Option<Subscription>,
+    _subscriptions: Vec<Subscription>,
     embedded: bool,
 }
 
@@ -127,9 +127,10 @@ impl BranchList {
         cx: &mut Context<Self>,
     ) -> Self {
         let mut this = Self::new_inner(workspace, repository, style, width, false, window, cx);
-        this._subscription = Some(cx.subscribe(&this.picker, |_, _, _, cx| {
-            cx.emit(DismissEvent);
-        }));
+        this._subscriptions
+            .push(cx.subscribe(&this.picker, |_, _, _, cx| {
+                cx.emit(DismissEvent);
+            }));
         this
     }
 
@@ -142,18 +143,55 @@ impl BranchList {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
-        let all_branches_request = repository
-            .clone()
-            .map(|repository| repository.update(cx, |repository, _| repository.branches()));
+        let all_branches = repository
+            .as_ref()
+            .map(|repo| process_branches(&repo.read(cx).branch_list))
+            .unwrap_or_default();
 
         let default_branch_request = repository.clone().map(|repository| {
             repository.update(cx, |repository, _| repository.default_branch(false))
         });
 
+        let mut delegate = BranchListDelegate::new(workspace, repository.clone(), style, cx);
+        delegate.all_branches = all_branches;
+
+        let picker = cx.new(|cx| {
+            Picker::uniform_list(delegate, window, cx)
+                .show_scrollbar(true)
+                .modal(!embedded)
+        });
+        let picker_focus_handle = picker.focus_handle(cx);
+
+        picker.update(cx, |picker, _| {
+            picker.delegate.focus_handle = picker_focus_handle.clone();
+        });
+
+        let mut subscriptions = Vec::new();
+
+        if let Some(repo) = &repository {
+            subscriptions.push(cx.subscribe_in(
+                repo,
+                window,
+                move |this, repo, event, window, cx| {
+                    if matches!(event, RepositoryEvent::BranchListChanged) {
+                        let branch_list = repo.read(cx).branch_list.clone();
+                        let all_branches = process_branches(&branch_list);
+                        this.picker.update(cx, |picker, cx| {
+                            picker.delegate.restore_selected_branch = picker
+                                .delegate
+                                .matches
+                                .get(picker.delegate.selected_index)
+                                .and_then(|entry| entry.as_branch().map(|b| b.ref_name.clone()));
+                            picker.delegate.all_branches = all_branches;
+                            picker.refresh(window, cx);
+                        });
+                    }
+                },
+            ));
+        }
+
+        // Fetch default branch asynchronously since it requires a git operation
         cx.spawn_in(window, async move |this, cx| {
-            let mut all_branches = all_branches_request
-                .context("No active repository")?
-                .await??;
             let default_branch = default_branch_request
                 .context("No active repository")?
                 .await
@@ -162,64 +200,21 @@ impl BranchList {
                 .flatten()
                 .flatten();
 
-            let all_branches = cx
-                .background_spawn(async move {
-                    let remote_upstreams: HashSet<_> = all_branches
-                        .iter()
-                        .filter_map(|branch| {
-                            branch
-                                .upstream
-                                .as_ref()
-                                .filter(|upstream| upstream.is_remote())
-                                .map(|upstream| upstream.ref_name.clone())
-                        })
-                        .collect();
-
-                    all_branches.retain(|branch| !remote_upstreams.contains(&branch.ref_name));
-
-                    all_branches.sort_by_key(|branch| {
-                        (
-                            !branch.is_head, // Current branch (is_head=true) comes first
-                            branch
-                                .most_recent_commit
-                                .as_ref()
-                                .map(|commit| 0 - commit.commit_timestamp),
-                        )
-                    });
-
-                    all_branches
-                })
-                .await;
-
-            let _ = this.update_in(cx, |this, window, cx| {
-                this.picker.update(cx, |picker, cx| {
+            let _ = this.update_in(cx, |this, _window, cx| {
+                this.picker.update(cx, |picker, _cx| {
                     picker.delegate.default_branch = default_branch;
-                    picker.delegate.all_branches = Some(all_branches);
-                    picker.refresh(window, cx);
-                })
+                });
             });
 
             anyhow::Ok(())
         })
         .detach_and_log_err(cx);
 
-        let delegate = BranchListDelegate::new(workspace, repository, style, cx);
-        let picker = cx.new(|cx| {
-            Picker::uniform_list(delegate, window, cx)
-                .show_scrollbar(true)
-                .modal(!embedded)
-        });
-        let picker_focus_handle = picker.focus_handle(cx);
-
-        picker.update(cx, |picker, _| {
-            picker.delegate.focus_handle = picker_focus_handle.clone();
-        });
-
         Self {
             picker,
             picker_focus_handle,
             width,
-            _subscription: None,
+            _subscriptions: subscriptions,
             embedded,
         }
     }
@@ -240,9 +235,10 @@ impl BranchList {
             window,
             cx,
         );
-        this._subscription = Some(cx.subscribe(&this.picker, |_, _, _, cx| {
-            cx.emit(DismissEvent);
-        }));
+        this._subscriptions
+            .push(cx.subscribe(&this.picker, |_, _, _, cx| {
+                cx.emit(DismissEvent);
+            }));
         this
     }
 
@@ -379,7 +375,7 @@ impl BranchFilter {
 pub struct BranchListDelegate {
     workspace: WeakEntity<Workspace>,
     matches: Vec<Entry>,
-    all_branches: Option<Vec<Branch>>,
+    all_branches: Vec<Branch>,
     default_branch: Option<SharedString>,
     repo: Option<Entity<Repository>>,
     style: BranchListStyle,
@@ -389,6 +385,7 @@ pub struct BranchListDelegate {
     branch_filter: BranchFilter,
     state: PickerState,
     focus_handle: FocusHandle,
+    restore_selected_branch: Option<SharedString>,
 }
 
 #[derive(Debug)]
@@ -403,6 +400,37 @@ enum PickerState {
     NewBranch,
 }
 
+fn process_branches(branches: &Arc<[Branch]>) -> Vec<Branch> {
+    let remote_upstreams: HashSet<_> = branches
+        .iter()
+        .filter_map(|branch| {
+            branch
+                .upstream
+                .as_ref()
+                .filter(|upstream| upstream.is_remote())
+                .map(|upstream| upstream.ref_name.clone())
+        })
+        .collect();
+
+    let mut result: Vec<Branch> = branches
+        .iter()
+        .filter(|branch| !remote_upstreams.contains(&branch.ref_name))
+        .cloned()
+        .collect();
+
+    result.sort_by_key(|branch| {
+        (
+            !branch.is_head,
+            branch
+                .most_recent_commit
+                .as_ref()
+                .map(|commit| 0 - commit.commit_timestamp),
+        )
+    });
+
+    result
+}
+
 impl BranchListDelegate {
     fn new(
         workspace: WeakEntity<Workspace>,
@@ -415,7 +443,7 @@ impl BranchListDelegate {
             matches: vec![],
             repo,
             style,
-            all_branches: None,
+            all_branches: Vec::new(),
             default_branch: None,
             selected_index: 0,
             last_query: Default::default(),
@@ -423,6 +451,7 @@ impl BranchListDelegate {
             branch_filter: BranchFilter::All,
             state: PickerState::List,
             focus_handle: cx.focus_handle(),
+            restore_selected_branch: None,
         }
     }
 
@@ -536,9 +565,10 @@ impl BranchListDelegate {
                 picker.delegate.matches.retain(|e| e != &entry);
 
                 if let Entry::Branch { branch, .. } = &entry {
-                    if let Some(all_branches) = &mut picker.delegate.all_branches {
-                        all_branches.retain(|e| e.ref_name != branch.ref_name);
-                    }
+                    picker
+                        .delegate
+                        .all_branches
+                        .retain(|e| e.ref_name != branch.ref_name);
                 }
 
                 if picker.delegate.matches.is_empty() {
@@ -666,9 +696,7 @@ impl PickerDelegate for BranchListDelegate {
         window: &mut Window,
         cx: &mut Context<Picker<Self>>,
     ) -> Task<()> {
-        let Some(all_branches) = self.all_branches.clone() else {
-            return Task::ready(());
-        };
+        let all_branches = self.all_branches.clone();
 
         let branch_filter = self.branch_filter;
         cx.spawn_in(window, async move |picker, cx| {
@@ -770,6 +798,14 @@ impl PickerDelegate for BranchListDelegate {
                     delegate.matches = matches;
                     if delegate.matches.is_empty() {
                         delegate.selected_index = 0;
+                    } else if let Some(ref_name) = delegate.restore_selected_branch.take() {
+                        delegate.selected_index = delegate
+                            .matches
+                            .iter()
+                            .position(|entry| {
+                                entry.as_branch().is_some_and(|b| b.ref_name == ref_name)
+                            })
+                            .unwrap_or(0);
                     } else {
                         delegate.selected_index =
                             core::cmp::min(delegate.selected_index, delegate.matches.len() - 1);
@@ -1385,7 +1421,7 @@ mod tests {
                         BranchListStyle::Modal,
                         cx,
                     );
-                    delegate.all_branches = Some(branches);
+                    delegate.all_branches = branches;
                     let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
                     let picker_focus_handle = picker.focus_handle(cx);
                     picker.update(cx, |picker, _| {
@@ -1400,7 +1436,7 @@ mod tests {
                         picker,
                         picker_focus_handle,
                         width: rems(34.),
-                        _subscription: Some(_subscription),
+                        _subscriptions: vec![_subscription],
                         embedded: false,
                     }
                 })