key_dispatch.rs

  1/// KeyDispatch is where GPUI deals with binding actions to key events.
  2///
  3/// The key pieces to making a key binding work are to define an action,
  4/// implement a method that takes that action as a type parameter,
  5/// and then to register the action during render on a focused node
  6/// with a keymap context:
  7///
  8/// ```rust
  9/// actions!(editor,[Undo, Redo]);;
 10///
 11/// impl Editor {
 12///   fn undo(&mut self, _: &Undo, _cx: &mut ViewContext<Self>) { ... }
 13///   fn redo(&mut self, _: &Redo, _cx: &mut ViewContext<Self>) { ... }
 14/// }
 15///
 16/// impl Render for Editor {
 17///   fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
 18///     div()
 19///       .track_focus(&self.focus_handle)
 20///       .keymap_context("Editor")
 21///       .on_action(cx.listener(Editor::undo))
 22///       .on_action(cx.listener(Editor::redo))
 23///     ...
 24///    }
 25/// }
 26///```
 27///
 28/// The keybindings themselves are managed independently by calling cx.bind_keys().
 29/// (Though mostly when developing Zed itself, you just need to add a new line to
 30///  assets/keymaps/default.json).
 31///
 32/// ```rust
 33/// cx.bind_keys([
 34///   KeyBinding::new("cmd-z", Editor::undo, Some("Editor")),
 35///   KeyBinding::new("cmd-shift-z", Editor::redo, Some("Editor")),
 36/// ])
 37/// ```
 38///
 39/// With all of this in place, GPUI will ensure that if you have an Editor that contains
 40/// the focus, hitting cmd-z will Undo.
 41///
 42/// In real apps, it is a little more complicated than this, because typically you have
 43/// several nested views that each register keyboard handlers. In this case action matching
 44/// bubbles up from the bottom. For example in Zed, the Workspace is the top-level view, which contains Pane's, which contain Editors. If there are conflicting keybindings defined
 45/// then the Editor's bindings take precedence over the Pane's bindings, which take precedence over the Workspace.
 46///
 47/// In GPUI, keybindings are not limited to just single keystrokes, you can define
 48/// sequences by separating the keys with a space:
 49///
 50///  KeyBinding::new("cmd-k left", pane::SplitLeft, Some("Pane"))
 51///
 52use crate::{
 53    Action, ActionRegistry, DispatchPhase, EntityId, FocusId, KeyBinding, KeyContext, Keymap,
 54    KeymatchResult, Keystroke, KeystrokeMatcher, ModifiersChangedEvent, WindowContext,
 55};
 56use collections::FxHashMap;
 57use smallvec::SmallVec;
 58use std::{
 59    any::{Any, TypeId},
 60    cell::RefCell,
 61    mem,
 62    ops::Range,
 63    rc::Rc,
 64};
 65
 66#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
 67pub(crate) struct DispatchNodeId(usize);
 68
 69pub(crate) struct DispatchTree {
 70    node_stack: Vec<DispatchNodeId>,
 71    pub(crate) context_stack: Vec<KeyContext>,
 72    view_stack: Vec<EntityId>,
 73    nodes: Vec<DispatchNode>,
 74    focusable_node_ids: FxHashMap<FocusId, DispatchNodeId>,
 75    view_node_ids: FxHashMap<EntityId, DispatchNodeId>,
 76    keystroke_matchers: FxHashMap<SmallVec<[KeyContext; 4]>, KeystrokeMatcher>,
 77    keymap: Rc<RefCell<Keymap>>,
 78    action_registry: Rc<ActionRegistry>,
 79}
 80
 81#[derive(Default)]
 82pub(crate) struct DispatchNode {
 83    pub key_listeners: Vec<KeyListener>,
 84    pub action_listeners: Vec<DispatchActionListener>,
 85    pub modifiers_changed_listeners: Vec<ModifiersChangedListener>,
 86    pub context: Option<KeyContext>,
 87    pub focus_id: Option<FocusId>,
 88    view_id: Option<EntityId>,
 89    parent: Option<DispatchNodeId>,
 90}
 91
 92pub(crate) struct ReusedSubtree {
 93    old_range: Range<usize>,
 94    new_range: Range<usize>,
 95    contains_focus: bool,
 96}
 97
 98impl ReusedSubtree {
 99    pub fn refresh_node_id(&self, node_id: DispatchNodeId) -> DispatchNodeId {
100        debug_assert!(
101            self.old_range.contains(&node_id.0),
102            "node {} was not part of the reused subtree {:?}",
103            node_id.0,
104            self.old_range
105        );
106        DispatchNodeId((node_id.0 - self.old_range.start) + self.new_range.start)
107    }
108
109    pub fn contains_focus(&self) -> bool {
110        self.contains_focus
111    }
112}
113
114type KeyListener = Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>;
115type ModifiersChangedListener = Rc<dyn Fn(&ModifiersChangedEvent, &mut WindowContext)>;
116
117#[derive(Clone)]
118pub(crate) struct DispatchActionListener {
119    pub(crate) action_type: TypeId,
120    pub(crate) listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
121}
122
123impl DispatchTree {
124    pub fn new(keymap: Rc<RefCell<Keymap>>, action_registry: Rc<ActionRegistry>) -> Self {
125        Self {
126            node_stack: Vec::new(),
127            context_stack: Vec::new(),
128            view_stack: Vec::new(),
129            nodes: Vec::new(),
130            focusable_node_ids: FxHashMap::default(),
131            view_node_ids: FxHashMap::default(),
132            keystroke_matchers: FxHashMap::default(),
133            keymap,
134            action_registry,
135        }
136    }
137
138    pub fn clear(&mut self) {
139        self.node_stack.clear();
140        self.context_stack.clear();
141        self.view_stack.clear();
142        self.nodes.clear();
143        self.focusable_node_ids.clear();
144        self.view_node_ids.clear();
145        self.keystroke_matchers.clear();
146    }
147
148    pub fn len(&self) -> usize {
149        self.nodes.len()
150    }
151
152    pub fn push_node(&mut self) -> DispatchNodeId {
153        let parent = self.node_stack.last().copied();
154        let node_id = DispatchNodeId(self.nodes.len());
155
156        self.nodes.push(DispatchNode {
157            parent,
158            ..Default::default()
159        });
160        self.node_stack.push(node_id);
161        node_id
162    }
163
164    pub fn set_active_node(&mut self, node_id: DispatchNodeId) {
165        let next_node_parent = self.nodes[node_id.0].parent;
166        while self.node_stack.last().copied() != next_node_parent && !self.node_stack.is_empty() {
167            self.pop_node();
168        }
169
170        if self.node_stack.last().copied() == next_node_parent {
171            self.node_stack.push(node_id);
172            let active_node = &self.nodes[node_id.0];
173            if let Some(view_id) = active_node.view_id {
174                self.view_stack.push(view_id)
175            }
176            if let Some(context) = active_node.context.clone() {
177                self.context_stack.push(context);
178            }
179        } else {
180            debug_assert_eq!(self.node_stack.len(), 0);
181
182            let mut current_node_id = Some(node_id);
183            while let Some(node_id) = current_node_id {
184                let node = &self.nodes[node_id.0];
185                if let Some(context) = node.context.clone() {
186                    self.context_stack.push(context);
187                }
188                if node.view_id.is_some() {
189                    self.view_stack.push(node.view_id.unwrap());
190                }
191                self.node_stack.push(node_id);
192                current_node_id = node.parent;
193            }
194
195            self.context_stack.reverse();
196            self.view_stack.reverse();
197            self.node_stack.reverse();
198        }
199    }
200
201    pub fn set_key_context(&mut self, context: KeyContext) {
202        self.active_node().context = Some(context.clone());
203        self.context_stack.push(context);
204    }
205
206    pub fn set_focus_id(&mut self, focus_id: FocusId) {
207        let node_id = *self.node_stack.last().unwrap();
208        self.nodes[node_id.0].focus_id = Some(focus_id);
209        self.focusable_node_ids.insert(focus_id, node_id);
210    }
211
212    pub fn parent_view_id(&mut self) -> Option<EntityId> {
213        self.view_stack.last().copied()
214    }
215
216    pub fn set_view_id(&mut self, view_id: EntityId) {
217        if self.view_stack.last().copied() != Some(view_id) {
218            let node_id = *self.node_stack.last().unwrap();
219            self.nodes[node_id.0].view_id = Some(view_id);
220            self.view_node_ids.insert(view_id, node_id);
221            self.view_stack.push(view_id);
222        }
223    }
224
225    pub fn pop_node(&mut self) {
226        let node = &self.nodes[self.active_node_id().unwrap().0];
227        if node.context.is_some() {
228            self.context_stack.pop();
229        }
230        if node.view_id.is_some() {
231            self.view_stack.pop();
232        }
233        self.node_stack.pop();
234    }
235
236    fn move_node(&mut self, source: &mut DispatchNode) {
237        self.push_node();
238        if let Some(context) = source.context.clone() {
239            self.set_key_context(context);
240        }
241        if let Some(focus_id) = source.focus_id {
242            self.set_focus_id(focus_id);
243        }
244        if let Some(view_id) = source.view_id {
245            self.set_view_id(view_id);
246        }
247
248        let target = self.active_node();
249        target.key_listeners = mem::take(&mut source.key_listeners);
250        target.action_listeners = mem::take(&mut source.action_listeners);
251        target.modifiers_changed_listeners = mem::take(&mut source.modifiers_changed_listeners);
252    }
253
254    pub fn reuse_subtree(
255        &mut self,
256        old_range: Range<usize>,
257        source: &mut Self,
258        focus: Option<FocusId>,
259    ) -> ReusedSubtree {
260        let new_range = self.nodes.len()..self.nodes.len() + old_range.len();
261
262        let mut contains_focus = false;
263        let mut source_stack = vec![];
264        for (source_node_id, source_node) in source
265            .nodes
266            .iter_mut()
267            .enumerate()
268            .skip(old_range.start)
269            .take(old_range.len())
270        {
271            let source_node_id = DispatchNodeId(source_node_id);
272            while let Some(source_ancestor) = source_stack.last() {
273                if source_node.parent == Some(*source_ancestor) {
274                    break;
275                } else {
276                    source_stack.pop();
277                    self.pop_node();
278                }
279            }
280
281            source_stack.push(source_node_id);
282            if source_node.focus_id.is_some() && source_node.focus_id == focus {
283                contains_focus = true;
284            }
285            self.move_node(source_node);
286        }
287
288        while !source_stack.is_empty() {
289            source_stack.pop();
290            self.pop_node();
291        }
292
293        ReusedSubtree {
294            old_range,
295            new_range,
296            contains_focus,
297        }
298    }
299
300    pub fn truncate(&mut self, index: usize) {
301        for node in &self.nodes[index..] {
302            if let Some(focus_id) = node.focus_id {
303                self.focusable_node_ids.remove(&focus_id);
304            }
305
306            if let Some(view_id) = node.view_id {
307                self.view_node_ids.remove(&view_id);
308            }
309        }
310        self.nodes.truncate(index);
311    }
312
313    pub fn clear_pending_keystrokes(&mut self) {
314        self.keystroke_matchers.clear();
315    }
316
317    /// Preserve keystroke matchers from previous frames to support multi-stroke
318    /// bindings across multiple frames.
319    pub fn preserve_pending_keystrokes(&mut self, old_tree: &mut Self, focus_id: Option<FocusId>) {
320        if let Some(node_id) = focus_id.and_then(|focus_id| self.focusable_node_id(focus_id)) {
321            let dispatch_path = self.dispatch_path(node_id);
322
323            self.context_stack.clear();
324            for node_id in dispatch_path {
325                let node = self.node(node_id);
326                if let Some(context) = node.context.clone() {
327                    self.context_stack.push(context);
328                }
329
330                if let Some((context_stack, matcher)) = old_tree
331                    .keystroke_matchers
332                    .remove_entry(self.context_stack.as_slice())
333                {
334                    self.keystroke_matchers.insert(context_stack, matcher);
335                }
336            }
337        }
338    }
339
340    pub fn on_key_event(&mut self, listener: KeyListener) {
341        self.active_node().key_listeners.push(listener);
342    }
343
344    pub fn on_modifiers_changed(&mut self, listener: ModifiersChangedListener) {
345        self.active_node()
346            .modifiers_changed_listeners
347            .push(listener);
348    }
349
350    pub fn on_action(
351        &mut self,
352        action_type: TypeId,
353        listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
354    ) {
355        self.active_node()
356            .action_listeners
357            .push(DispatchActionListener {
358                action_type,
359                listener,
360            });
361    }
362
363    pub fn focus_contains(&self, parent: FocusId, child: FocusId) -> bool {
364        if parent == child {
365            return true;
366        }
367
368        if let Some(parent_node_id) = self.focusable_node_ids.get(&parent) {
369            let mut current_node_id = self.focusable_node_ids.get(&child).copied();
370            while let Some(node_id) = current_node_id {
371                if node_id == *parent_node_id {
372                    return true;
373                }
374                current_node_id = self.nodes[node_id.0].parent;
375            }
376        }
377        false
378    }
379
380    pub fn available_actions(&self, target: DispatchNodeId) -> Vec<Box<dyn Action>> {
381        let mut actions = Vec::<Box<dyn Action>>::new();
382        for node_id in self.dispatch_path(target) {
383            let node = &self.nodes[node_id.0];
384            for DispatchActionListener { action_type, .. } in &node.action_listeners {
385                if let Err(ix) = actions.binary_search_by_key(action_type, |a| a.as_any().type_id())
386                {
387                    // Intentionally silence these errors without logging.
388                    // If an action cannot be built by default, it's not available.
389                    let action = self.action_registry.build_action_type(action_type).ok();
390                    if let Some(action) = action {
391                        actions.insert(ix, action);
392                    }
393                }
394            }
395        }
396        actions
397    }
398
399    pub fn is_action_available(&self, action: &dyn Action, target: DispatchNodeId) -> bool {
400        for node_id in self.dispatch_path(target) {
401            let node = &self.nodes[node_id.0];
402            if node
403                .action_listeners
404                .iter()
405                .any(|listener| listener.action_type == action.as_any().type_id())
406            {
407                return true;
408            }
409        }
410        false
411    }
412
413    pub fn bindings_for_action(
414        &self,
415        action: &dyn Action,
416        context_stack: &[KeyContext],
417    ) -> Vec<KeyBinding> {
418        let keymap = self.keymap.borrow();
419        keymap
420            .bindings_for_action(action)
421            .filter(|binding| {
422                for i in 0..context_stack.len() {
423                    let context = &context_stack[0..=i];
424                    if keymap.binding_enabled(binding, context) {
425                        return true;
426                    }
427                }
428                false
429            })
430            .cloned()
431            .collect()
432    }
433
434    // dispatch_key pushes the next keystroke into any key binding matchers.
435    // any matching bindings are returned in the order that they should be dispatched:
436    // * First by length of binding (so if you have a binding for "b" and "ab", the "ab" binding fires first)
437    // * Secondly by depth in the tree (so if Editor has a binding for "b" and workspace a
438    // binding for "b", the Editor action fires first).
439    pub fn dispatch_key(
440        &mut self,
441        keystroke: &Keystroke,
442        dispatch_path: &SmallVec<[DispatchNodeId; 32]>,
443    ) -> KeymatchResult {
444        let mut bindings = SmallVec::<[KeyBinding; 1]>::new();
445        let mut pending = false;
446
447        let mut context_stack: SmallVec<[KeyContext; 4]> = SmallVec::new();
448        for node_id in dispatch_path {
449            let node = self.node(*node_id);
450
451            if let Some(context) = node.context.clone() {
452                context_stack.push(context);
453            }
454        }
455
456        while !context_stack.is_empty() {
457            let keystroke_matcher = self
458                .keystroke_matchers
459                .entry(context_stack.clone())
460                .or_insert_with(|| KeystrokeMatcher::new(self.keymap.clone()));
461
462            let result = keystroke_matcher.match_keystroke(keystroke, &context_stack);
463            if result.pending && !pending && !bindings.is_empty() {
464                context_stack.pop();
465                continue;
466            }
467
468            pending = result.pending || pending;
469            for new_binding in result.bindings {
470                match bindings
471                    .iter()
472                    .position(|el| el.keystrokes.len() < new_binding.keystrokes.len())
473                {
474                    Some(idx) => {
475                        bindings.insert(idx, new_binding);
476                    }
477                    None => bindings.push(new_binding),
478                }
479            }
480            context_stack.pop();
481        }
482
483        KeymatchResult { bindings, pending }
484    }
485
486    pub fn has_pending_keystrokes(&self) -> bool {
487        self.keystroke_matchers
488            .iter()
489            .any(|(_, matcher)| matcher.has_pending_keystrokes())
490    }
491
492    pub fn dispatch_path(&self, target: DispatchNodeId) -> SmallVec<[DispatchNodeId; 32]> {
493        let mut dispatch_path: SmallVec<[DispatchNodeId; 32]> = SmallVec::new();
494        let mut current_node_id = Some(target);
495        while let Some(node_id) = current_node_id {
496            dispatch_path.push(node_id);
497            current_node_id = self.nodes[node_id.0].parent;
498        }
499        dispatch_path.reverse(); // Reverse the path so it goes from the root to the focused node.
500        dispatch_path
501    }
502
503    pub fn focus_path(&self, focus_id: FocusId) -> SmallVec<[FocusId; 8]> {
504        let mut focus_path: SmallVec<[FocusId; 8]> = SmallVec::new();
505        let mut current_node_id = self.focusable_node_ids.get(&focus_id).copied();
506        while let Some(node_id) = current_node_id {
507            let node = self.node(node_id);
508            if let Some(focus_id) = node.focus_id {
509                focus_path.push(focus_id);
510            }
511            current_node_id = node.parent;
512        }
513        focus_path.reverse(); // Reverse the path so it goes from the root to the focused node.
514        focus_path
515    }
516
517    pub fn view_path(&self, view_id: EntityId) -> SmallVec<[EntityId; 8]> {
518        let mut view_path: SmallVec<[EntityId; 8]> = SmallVec::new();
519        let mut current_node_id = self.view_node_ids.get(&view_id).copied();
520        while let Some(node_id) = current_node_id {
521            let node = self.node(node_id);
522            if let Some(view_id) = node.view_id {
523                view_path.push(view_id);
524            }
525            current_node_id = node.parent;
526        }
527        view_path.reverse(); // Reverse the path so it goes from the root to the view node.
528        view_path
529    }
530
531    pub fn node(&self, node_id: DispatchNodeId) -> &DispatchNode {
532        &self.nodes[node_id.0]
533    }
534
535    fn active_node(&mut self) -> &mut DispatchNode {
536        let active_node_id = self.active_node_id().unwrap();
537        &mut self.nodes[active_node_id.0]
538    }
539
540    pub fn focusable_node_id(&self, target: FocusId) -> Option<DispatchNodeId> {
541        self.focusable_node_ids.get(&target).copied()
542    }
543
544    pub fn root_node_id(&self) -> DispatchNodeId {
545        debug_assert!(!self.nodes.is_empty());
546        DispatchNodeId(0)
547    }
548
549    pub fn active_node_id(&self) -> Option<DispatchNodeId> {
550        self.node_stack.last().copied()
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use std::{cell::RefCell, rc::Rc};
557
558    use crate::{Action, ActionRegistry, DispatchTree, KeyBinding, KeyContext, Keymap};
559
560    #[derive(PartialEq, Eq)]
561    struct TestAction;
562
563    impl Action for TestAction {
564        fn name(&self) -> &'static str {
565            "test::TestAction"
566        }
567
568        fn debug_name() -> &'static str
569        where
570            Self: ::std::marker::Sized,
571        {
572            "test::TestAction"
573        }
574
575        fn partial_eq(&self, action: &dyn Action) -> bool {
576            action
577                .as_any()
578                .downcast_ref::<Self>()
579                .map_or(false, |a| self == a)
580        }
581
582        fn boxed_clone(&self) -> std::boxed::Box<dyn Action> {
583            Box::new(TestAction)
584        }
585
586        fn as_any(&self) -> &dyn ::std::any::Any {
587            self
588        }
589
590        fn build(_value: serde_json::Value) -> anyhow::Result<Box<dyn Action>>
591        where
592            Self: Sized,
593        {
594            Ok(Box::new(TestAction))
595        }
596    }
597
598    #[test]
599    fn test_keybinding_for_action_bounds() {
600        let keymap = Keymap::new(vec![KeyBinding::new(
601            "cmd-n",
602            TestAction,
603            Some("ProjectPanel"),
604        )]);
605
606        let mut registry = ActionRegistry::default();
607
608        registry.load_action::<TestAction>();
609
610        let keymap = Rc::new(RefCell::new(keymap));
611
612        let tree = DispatchTree::new(keymap, Rc::new(registry));
613
614        let contexts = vec![
615            KeyContext::parse("Workspace").unwrap(),
616            KeyContext::parse("ProjectPanel").unwrap(),
617        ];
618
619        let keybinding = tree.bindings_for_action(&TestAction, &contexts);
620
621        assert!(keybinding[0].action.partial_eq(&TestAction))
622    }
623}