Fix context_stack race in KeyContextView (#29324)

Conrad Irwin created

cc @notpeter

Before this change we used our own copy of `cx.key_context()` when
matching.
This led to races where the context queried could be either before (or
after) the
context used in dispatching.

To avoid the race, gpui now passes out the context stack actually used
instead.

Release Notes:

- Fixed a bug where the Key Context View could show the incorrect
context,
  causing confusing results.

Change summary

crates/gpui/src/app.rs                        | 15 ++++++++-----
crates/gpui/src/key_dispatch.rs               | 22 +++++++++++++------
crates/gpui/src/window.rs                     | 23 ++++++++++++++++----
crates/language_tools/src/key_context_view.rs |  7 +++--
crates/workspace/src/pane_group.rs            |  8 +++++-
5 files changed, 52 insertions(+), 23 deletions(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -33,12 +33,12 @@ use util::ResultExt;
 use crate::{
     Action, ActionBuildError, ActionRegistry, Any, AnyView, AnyWindowHandle, AppContext, Asset,
     AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
-    EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, Keymap, Keystroke,
-    LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform, PlatformDisplay,
-    PlatformKeyboardLayout, Point, PromptBuilder, PromptHandle, PromptLevel, Render, RenderImage,
-    RenderablePromptHandle, Reservation, ScreenCaptureSource, SharedString, SubscriberSet,
-    Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance, WindowHandle, WindowId,
-    WindowInvalidator, current_platform, hash, init_app_menus,
+    EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
+    Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
+    PlatformDisplay, PlatformKeyboardLayout, Point, PromptBuilder, PromptHandle, PromptLevel,
+    Render, RenderImage, RenderablePromptHandle, Reservation, ScreenCaptureSource, SharedString,
+    SubscriberSet, Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance,
+    WindowHandle, WindowId, WindowInvalidator, current_platform, hash, init_app_menus,
 };
 
 mod async_context;
@@ -1859,6 +1859,9 @@ pub struct KeystrokeEvent {
 
     /// The action that was resolved for the keystroke, if any
     pub action: Option<Box<dyn Action>>,
+
+    /// The context stack at the time
+    pub context_stack: Vec<KeyContext>,
 }
 
 struct NullHttpClient;

crates/gpui/src/key_dispatch.rs 🔗

@@ -121,6 +121,7 @@ pub(crate) struct DispatchResult {
     pub(crate) pending: SmallVec<[Keystroke; 1]>,
     pub(crate) bindings: SmallVec<[KeyBinding; 1]>,
     pub(crate) to_replay: SmallVec<[Replay; 1]>,
+    pub(crate) context_stack: Vec<KeyContext>,
 }
 
 type KeyListener = Rc<dyn Fn(&dyn Any, DispatchPhase, &mut Window, &mut App)>;
@@ -411,15 +412,17 @@ impl DispatchTree {
         &self,
         input: &[Keystroke],
         dispatch_path: &SmallVec<[DispatchNodeId; 32]>,
-    ) -> (SmallVec<[KeyBinding; 1]>, bool) {
-        let context_stack: SmallVec<[KeyContext; 4]> = dispatch_path
+    ) -> (SmallVec<[KeyBinding; 1]>, bool, Vec<KeyContext>) {
+        let context_stack: Vec<KeyContext> = dispatch_path
             .iter()
             .filter_map(|node_id| self.node(*node_id).context.clone())
             .collect();
 
-        self.keymap
+        let (bindings, partial) = self
+            .keymap
             .borrow()
-            .bindings_for_input(input, &context_stack)
+            .bindings_for_input(input, &context_stack);
+        return (bindings, partial, context_stack);
     }
 
     /// dispatch_key processes the keystroke
@@ -436,20 +439,25 @@ impl DispatchTree {
         dispatch_path: &SmallVec<[DispatchNodeId; 32]>,
     ) -> DispatchResult {
         input.push(keystroke.clone());
-        let (bindings, pending) = self.bindings_for_input(&input, dispatch_path);
+        let (bindings, pending, context_stack) = self.bindings_for_input(&input, dispatch_path);
 
         if pending {
             return DispatchResult {
                 pending: input,
+                context_stack,
                 ..Default::default()
             };
         } else if !bindings.is_empty() {
             return DispatchResult {
                 bindings,
+                context_stack,
                 ..Default::default()
             };
         } else if input.len() == 1 {
-            return DispatchResult::default();
+            return DispatchResult {
+                context_stack,
+                ..Default::default()
+            };
         }
         input.pop();
 
@@ -485,7 +493,7 @@ impl DispatchTree {
     ) -> (SmallVec<[Keystroke; 1]>, SmallVec<[Replay; 1]>) {
         let mut to_replay: SmallVec<[Replay; 1]> = Default::default();
         for last in (0..input.len()).rev() {
-            let (bindings, _) = self.bindings_for_input(&input[0..=last], dispatch_path);
+            let (bindings, _, _) = self.bindings_for_input(&input[0..=last], dispatch_path);
             if !bindings.is_empty() {
                 to_replay.push(Replay {
                     keystroke: input.drain(0..=last).next_back().unwrap(),

crates/gpui/src/window.rs 🔗

@@ -1154,6 +1154,7 @@ impl Window {
         &mut self,
         event: &dyn Any,
         action: Option<Box<dyn Action>>,
+        context_stack: Vec<KeyContext>,
         cx: &mut App,
     ) {
         let Some(key_down_event) = event.downcast_ref::<KeyDownEvent>() else {
@@ -1165,6 +1166,7 @@ impl Window {
                 &KeystrokeEvent {
                     keystroke: key_down_event.keystroke.clone(),
                     action: action.as_ref().map(|action| action.boxed_clone()),
+                    context_stack: context_stack.clone(),
                 },
                 self,
                 cx,
@@ -3275,7 +3277,7 @@ impl Window {
         }
 
         let Some(keystroke) = keystroke else {
-            self.finish_dispatch_key_event(event, dispatch_path, cx);
+            self.finish_dispatch_key_event(event, dispatch_path, self.context_stack(), cx);
             return;
         };
 
@@ -3329,13 +3331,18 @@ impl Window {
         for binding in match_result.bindings {
             self.dispatch_action_on_node(node_id, binding.action.as_ref(), cx);
             if !cx.propagate_event {
-                self.dispatch_keystroke_observers(event, Some(binding.action), cx);
+                self.dispatch_keystroke_observers(
+                    event,
+                    Some(binding.action),
+                    match_result.context_stack.clone(),
+                    cx,
+                );
                 self.pending_input_changed(cx);
                 return;
             }
         }
 
-        self.finish_dispatch_key_event(event, dispatch_path, cx);
+        self.finish_dispatch_key_event(event, dispatch_path, match_result.context_stack, cx);
         self.pending_input_changed(cx);
     }
 
@@ -3343,6 +3350,7 @@ impl Window {
         &mut self,
         event: &dyn Any,
         dispatch_path: SmallVec<[DispatchNodeId; 32]>,
+        context_stack: Vec<KeyContext>,
         cx: &mut App,
     ) {
         self.dispatch_key_down_up_event(event, &dispatch_path, cx);
@@ -3355,7 +3363,7 @@ impl Window {
             return;
         }
 
-        self.dispatch_keystroke_observers(event, None, cx);
+        self.dispatch_keystroke_observers(event, None, context_stack, cx);
     }
 
     fn pending_input_changed(&mut self, cx: &mut App) {
@@ -3453,7 +3461,12 @@ impl Window {
             for binding in replay.bindings {
                 self.dispatch_action_on_node(node_id, binding.action.as_ref(), cx);
                 if !cx.propagate_event {
-                    self.dispatch_keystroke_observers(&event, Some(binding.action), cx);
+                    self.dispatch_keystroke_observers(
+                        &event,
+                        Some(binding.action),
+                        Vec::default(),
+                        cx,
+                    );
                     continue 'replay;
                 }
             }

crates/language_tools/src/key_context_view.rs 🔗

@@ -41,17 +41,17 @@ struct KeyContextView {
 
 impl KeyContextView {
     pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
-        let sub1 = cx.observe_keystrokes(|this, e, window, cx| {
+        let sub1 = cx.observe_keystrokes(|this, e, _, cx| {
             let mut pending = this.pending_keystrokes.take().unwrap_or_default();
             pending.push(e.keystroke.clone());
             let mut possibilities = cx.all_bindings_for_input(&pending);
             possibilities.reverse();
-            this.context_stack = window.context_stack();
             this.last_keystrokes = Some(
                 json!(pending.iter().map(|p| p.unparse()).join(" "))
                     .to_string()
                     .into(),
             );
+            this.context_stack = e.context_stack.clone();
             this.last_possibilities = possibilities
                 .into_iter()
                 .map(|binding| {
@@ -89,6 +89,7 @@ impl KeyContextView {
                     )
                 })
                 .collect();
+            cx.notify();
         });
         let sub2 = cx.observe_pending_input(window, |this, window, cx| {
             this.pending_keystrokes = window
@@ -237,7 +238,7 @@ impl Render for KeyContextView {
                     .mt_8(),
             )
             .children({
-                window.context_stack().into_iter().enumerate().map(|(i, context)| {
+                self.context_stack.iter().enumerate().map(|(i, context)| {
                     let primary = context.primary().map(|e| e.key.clone()).unwrap_or_default();
                     let secondary = context
                         .secondary()

crates/workspace/src/pane_group.rs 🔗

@@ -437,8 +437,12 @@ impl PaneAxis {
     }
 
     pub fn load(axis: Axis, members: Vec<Member>, flexes: Option<Vec<f32>>) -> Self {
-        let flexes = flexes.unwrap_or_else(|| vec![1.; members.len()]);
-        // debug_assert!(members.len() == flexes.len());
+        let mut flexes = flexes.unwrap_or_else(|| vec![1.; members.len()]);
+        if flexes.len() != members.len()
+            || (flexes.iter().copied().sum::<f32>() - flexes.len() as f32).abs() >= 0.001
+        {
+            flexes = vec![1.; members.len()];
+        }
 
         let flexes = Arc::new(Mutex::new(flexes));
         let bounding_boxes = Arc::new(Mutex::new(vec![None; members.len()]));