Restore focus to previously focused view when dismissing a modal (#2680)

Antonio Scandurra created

Fixes
https://linear.app/zed-industries/issue/Z-2500/focus-is-moved-from-the-assistant-panel-when-opening-and-closing

Release Notes:

- Fixed a bug that caused modals (such as the command palette) to not
restore focus when dismissing them.

Change summary

crates/context_menu/src/context_menu.rs |  3 -
crates/go_to_line/src/go_to_line.rs     | 11 +++++
crates/gpui/src/app.rs                  |  9 +---
crates/gpui/src/app/window.rs           |  4 ++
crates/picker/src/picker.rs             | 11 +++++
crates/workspace/src/workspace.rs       | 52 ++++++++++++++++++++++----
6 files changed, 73 insertions(+), 17 deletions(-)

Detailed changes

crates/context_menu/src/context_menu.rs 🔗

@@ -244,8 +244,7 @@ impl ContextMenu {
             let show_count = self.show_count;
             cx.defer(move |this, cx| {
                 if cx.handle().is_focused(cx) && this.show_count == show_count {
-                    let window_id = cx.window_id();
-                    (**cx).focus(window_id, this.previously_focused_view_id.take());
+                    (**cx).focus(this.previously_focused_view_id.take());
                 }
             });
         } else {

crates/go_to_line/src/go_to_line.rs 🔗

@@ -24,6 +24,7 @@ pub struct GoToLine {
     prev_scroll_position: Option<Vector2F>,
     cursor_point: Point,
     max_point: Point,
+    has_focus: bool,
 }
 
 pub enum Event {
@@ -57,6 +58,7 @@ impl GoToLine {
             prev_scroll_position: scroll_position,
             cursor_point,
             max_point,
+            has_focus: false,
         }
     }
 
@@ -178,11 +180,20 @@ impl View for GoToLine {
     }
 
     fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
+        self.has_focus = true;
         cx.focus(&self.line_editor);
     }
+
+    fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
+        self.has_focus = false;
+    }
 }
 
 impl Modal for GoToLine {
+    fn has_focus(&self) -> bool {
+        self.has_focus
+    }
+
     fn dismiss_on_event(event: &Self::Event) -> bool {
         matches!(event, Event::Dismissed)
     }

crates/gpui/src/app.rs 🔗

@@ -2971,14 +2971,12 @@ impl<'a, 'b, V: View> ViewContext<'a, 'b, V> {
     }
 
     pub fn focus(&mut self, handle: &AnyViewHandle) {
-        self.window_context
-            .focus(handle.window_id, Some(handle.view_id));
+        self.window_context.focus(Some(handle.view_id));
     }
 
     pub fn focus_self(&mut self) {
-        let window_id = self.window_id;
         let view_id = self.view_id;
-        self.window_context.focus(window_id, Some(view_id));
+        self.window_context.focus(Some(view_id));
     }
 
     pub fn is_self_focused(&self) -> bool {
@@ -2997,8 +2995,7 @@ impl<'a, 'b, V: View> ViewContext<'a, 'b, V> {
     }
 
     pub fn blur(&mut self) {
-        let window_id = self.window_id;
-        self.window_context.focus(window_id, None);
+        self.window_context.focus(None);
     }
 
     pub fn on_window_should_close<F>(&mut self, mut callback: F)

crates/gpui/src/app/window.rs 🔗

@@ -1096,6 +1096,10 @@ impl<'a> WindowContext<'a> {
         self.window.focused_view_id
     }
 
+    pub fn focus(&mut self, view_id: Option<usize>) {
+        self.app_context.focus(self.window_id, view_id);
+    }
+
     pub fn window_bounds(&self) -> WindowBounds {
         self.window.platform_window.bounds()
     }

crates/picker/src/picker.rs 🔗

@@ -25,6 +25,7 @@ pub struct Picker<D: PickerDelegate> {
     theme: Arc<Mutex<Box<dyn Fn(&theme::Theme) -> theme::Picker>>>,
     confirmed: bool,
     pending_update_matches: Task<Option<()>>,
+    has_focus: bool,
 }
 
 pub trait PickerDelegate: Sized + 'static {
@@ -140,13 +141,22 @@ impl<D: PickerDelegate> View for Picker<D> {
     }
 
     fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
+        self.has_focus = true;
         if cx.is_self_focused() {
             cx.focus(&self.query_editor);
         }
     }
+
+    fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
+        self.has_focus = false;
+    }
 }
 
 impl<D: PickerDelegate> Modal for Picker<D> {
+    fn has_focus(&self) -> bool {
+        self.has_focus
+    }
+
     fn dismiss_on_event(event: &Self::Event) -> bool {
         matches!(event, PickerEvent::Dismiss)
     }
@@ -191,6 +201,7 @@ impl<D: PickerDelegate> Picker<D> {
             theme,
             confirmed: false,
             pending_update_matches: Task::ready(None),
+            has_focus: false,
         };
         this.update_matches(String::new(), cx);
         this

crates/workspace/src/workspace.rs 🔗

@@ -97,9 +97,25 @@ lazy_static! {
 }
 
 pub trait Modal: View {
+    fn has_focus(&self) -> bool;
     fn dismiss_on_event(event: &Self::Event) -> bool;
 }
 
+trait ModalHandle {
+    fn as_any(&self) -> &AnyViewHandle;
+    fn has_focus(&self, cx: &WindowContext) -> bool;
+}
+
+impl<T: Modal> ModalHandle for ViewHandle<T> {
+    fn as_any(&self) -> &AnyViewHandle {
+        self
+    }
+
+    fn has_focus(&self, cx: &WindowContext) -> bool {
+        self.read(cx).has_focus()
+    }
+}
+
 #[derive(Clone, PartialEq)]
 pub struct RemoveWorktreeFromProject(pub WorktreeId);
 
@@ -466,7 +482,7 @@ pub enum Event {
 pub struct Workspace {
     weak_self: WeakViewHandle<Self>,
     remote_entity_subscription: Option<client::Subscription>,
-    modal: Option<AnyViewHandle>,
+    modal: Option<ActiveModal>,
     zoomed: Option<AnyWeakViewHandle>,
     zoomed_position: Option<DockPosition>,
     center: PaneGroup,
@@ -495,6 +511,11 @@ pub struct Workspace {
     pane_history_timestamp: Arc<AtomicUsize>,
 }
 
+struct ActiveModal {
+    view: Box<dyn ModalHandle>,
+    previously_focused_view_id: Option<usize>,
+}
+
 #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
 pub struct ViewId {
     pub creator: PeerId,
@@ -1482,8 +1503,10 @@ impl Workspace {
         cx.notify();
         // Whatever modal was visible is getting clobbered. If its the same type as V, then return
         // it. Otherwise, create a new modal and set it as active.
-        let already_open_modal = self.modal.take().and_then(|modal| modal.downcast::<V>());
-        if let Some(already_open_modal) = already_open_modal {
+        if let Some(already_open_modal) = self
+            .dismiss_modal(cx)
+            .and_then(|modal| modal.downcast::<V>())
+        {
             cx.focus_self();
             Some(already_open_modal)
         } else {
@@ -1494,8 +1517,12 @@ impl Workspace {
                 }
             })
             .detach();
+            let previously_focused_view_id = cx.focused_view_id();
             cx.focus(&modal);
-            self.modal = Some(modal.into_any());
+            self.modal = Some(ActiveModal {
+                view: Box::new(modal),
+                previously_focused_view_id,
+            });
             None
         }
     }
@@ -1503,13 +1530,20 @@ impl Workspace {
     pub fn modal<V: 'static + View>(&self) -> Option<ViewHandle<V>> {
         self.modal
             .as_ref()
-            .and_then(|modal| modal.clone().downcast::<V>())
+            .and_then(|modal| modal.view.as_any().clone().downcast::<V>())
     }
 
-    pub fn dismiss_modal(&mut self, cx: &mut ViewContext<Self>) {
-        if self.modal.take().is_some() {
-            cx.focus(&self.active_pane);
+    pub fn dismiss_modal(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyViewHandle> {
+        if let Some(modal) = self.modal.take() {
+            if let Some(previously_focused_view_id) = modal.previously_focused_view_id {
+                if modal.view.has_focus(cx) {
+                    cx.window_context().focus(Some(previously_focused_view_id));
+                }
+            }
             cx.notify();
+            Some(modal.view.as_any().clone())
+        } else {
+            None
         }
     }
 
@@ -3496,7 +3530,7 @@ impl View for Workspace {
                                         )
                                     }))
                                     .with_children(self.modal.as_ref().map(|modal| {
-                                        ChildView::new(modal, cx)
+                                        ChildView::new(modal.view.as_any(), cx)
                                             .contained()
                                             .with_style(theme.workspace.modal)
                                             .aligned()