key_dispatch.rs

  1use crate::{
  2    Action, ActionRegistry, DispatchPhase, EntityId, FocusId, KeyBinding, KeyContext, KeyMatch,
  3    Keymap, Keystroke, KeystrokeMatcher, WindowContext,
  4};
  5use collections::FxHashMap;
  6use parking_lot::Mutex;
  7use smallvec::{smallvec, SmallVec};
  8use std::{
  9    any::{Any, TypeId},
 10    mem,
 11    rc::Rc,
 12    sync::Arc,
 13};
 14
 15#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
 16pub struct DispatchNodeId(usize);
 17
 18pub(crate) struct DispatchTree {
 19    node_stack: Vec<DispatchNodeId>,
 20    pub(crate) context_stack: Vec<KeyContext>,
 21    nodes: Vec<DispatchNode>,
 22    focusable_node_ids: FxHashMap<FocusId, DispatchNodeId>,
 23    view_node_ids: FxHashMap<EntityId, DispatchNodeId>,
 24    keystroke_matchers: FxHashMap<SmallVec<[KeyContext; 4]>, KeystrokeMatcher>,
 25    keymap: Arc<Mutex<Keymap>>,
 26    action_registry: Rc<ActionRegistry>,
 27}
 28
 29#[derive(Default)]
 30pub(crate) struct DispatchNode {
 31    pub key_listeners: Vec<KeyListener>,
 32    pub action_listeners: Vec<DispatchActionListener>,
 33    pub context: Option<KeyContext>,
 34    focus_id: Option<FocusId>,
 35    view_id: Option<EntityId>,
 36    parent: Option<DispatchNodeId>,
 37}
 38
 39type KeyListener = Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>;
 40
 41#[derive(Clone)]
 42pub(crate) struct DispatchActionListener {
 43    pub(crate) action_type: TypeId,
 44    pub(crate) listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
 45}
 46
 47impl DispatchTree {
 48    pub fn new(keymap: Arc<Mutex<Keymap>>, action_registry: Rc<ActionRegistry>) -> Self {
 49        Self {
 50            node_stack: Vec::new(),
 51            context_stack: Vec::new(),
 52            nodes: Vec::new(),
 53            focusable_node_ids: FxHashMap::default(),
 54            view_node_ids: FxHashMap::default(),
 55            keystroke_matchers: FxHashMap::default(),
 56            keymap,
 57            action_registry,
 58        }
 59    }
 60
 61    pub fn clear(&mut self) {
 62        self.node_stack.clear();
 63        self.context_stack.clear();
 64        self.nodes.clear();
 65        self.focusable_node_ids.clear();
 66        self.view_node_ids.clear();
 67        self.keystroke_matchers.clear();
 68    }
 69
 70    pub fn push_node(
 71        &mut self,
 72        context: Option<KeyContext>,
 73        focus_id: Option<FocusId>,
 74        view_id: Option<EntityId>,
 75    ) {
 76        // Associate a view id to this only if it is the root node for the view.
 77        let view_id = view_id.and_then(|view_id| {
 78            if self.view_node_ids.contains_key(&view_id) {
 79                None
 80            } else {
 81                Some(view_id)
 82            }
 83        });
 84
 85        let parent = self.node_stack.last().copied();
 86        let node_id = DispatchNodeId(self.nodes.len());
 87        self.nodes.push(DispatchNode {
 88            parent,
 89            focus_id,
 90            view_id,
 91            ..Default::default()
 92        });
 93        self.node_stack.push(node_id);
 94
 95        if let Some(context) = context {
 96            self.active_node().context = Some(context.clone());
 97            self.context_stack.push(context);
 98        }
 99
100        if let Some(focus_id) = focus_id {
101            self.focusable_node_ids.insert(focus_id, node_id);
102        }
103
104        if let Some(view_id) = view_id {
105            self.view_node_ids.insert(view_id, node_id);
106        }
107    }
108
109    pub fn pop_node(&mut self) {
110        let node = &self.nodes[self.active_node_id().0];
111        if node.context.is_some() {
112            self.context_stack.pop();
113        }
114        self.node_stack.pop();
115    }
116
117    fn move_node(&mut self, source: &mut DispatchNode) {
118        self.push_node(source.context.take(), source.focus_id, source.view_id);
119        let target = self.active_node();
120        target.key_listeners = mem::take(&mut source.key_listeners);
121        target.action_listeners = mem::take(&mut source.action_listeners);
122    }
123
124    pub fn graft(&mut self, view_id: EntityId, source: &mut Self) -> SmallVec<[EntityId; 8]> {
125        let view_source_node_id = source
126            .view_node_ids
127            .get(&view_id)
128            .expect("view should exist in previous dispatch tree");
129        let view_source_node = &mut source.nodes[view_source_node_id.0];
130        self.move_node(view_source_node);
131
132        let mut grafted_view_ids = smallvec![view_id];
133        let mut source_stack = vec![*view_source_node_id];
134        for (source_node_id, source_node) in source
135            .nodes
136            .iter_mut()
137            .enumerate()
138            .skip(view_source_node_id.0 + 1)
139        {
140            let source_node_id = DispatchNodeId(source_node_id);
141            while let Some(source_ancestor) = source_stack.last() {
142                if source_node.parent != Some(*source_ancestor) {
143                    source_stack.pop();
144                    self.pop_node();
145                } else {
146                    break;
147                }
148            }
149
150            if source_stack.is_empty() {
151                break;
152            } else {
153                source_stack.push(source_node_id);
154                self.move_node(source_node);
155                if let Some(view_id) = source_node.view_id {
156                    grafted_view_ids.push(view_id);
157                }
158            }
159        }
160
161        while !source_stack.is_empty() {
162            source_stack.pop();
163            self.pop_node();
164        }
165
166        grafted_view_ids
167    }
168
169    pub fn clear_pending_keystrokes(&mut self) {
170        self.keystroke_matchers.clear();
171    }
172
173    /// Preserve keystroke matchers from previous frames to support multi-stroke
174    /// bindings across multiple frames.
175    pub fn preserve_pending_keystrokes(&mut self, old_tree: &mut Self, focus_id: Option<FocusId>) {
176        if let Some(node_id) = focus_id.and_then(|focus_id| self.focusable_node_id(focus_id)) {
177            let dispatch_path = self.dispatch_path(node_id);
178
179            self.context_stack.clear();
180            for node_id in dispatch_path {
181                let node = self.node(node_id);
182                if let Some(context) = node.context.clone() {
183                    self.context_stack.push(context);
184                }
185
186                if let Some((context_stack, matcher)) = old_tree
187                    .keystroke_matchers
188                    .remove_entry(self.context_stack.as_slice())
189                {
190                    self.keystroke_matchers.insert(context_stack, matcher);
191                }
192            }
193        }
194    }
195
196    pub fn on_key_event(&mut self, listener: KeyListener) {
197        self.active_node().key_listeners.push(listener);
198    }
199
200    pub fn on_action(
201        &mut self,
202        action_type: TypeId,
203        listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
204    ) {
205        self.active_node()
206            .action_listeners
207            .push(DispatchActionListener {
208                action_type,
209                listener,
210            });
211    }
212
213    pub fn focus_contains(&self, parent: FocusId, child: FocusId) -> bool {
214        if parent == child {
215            return true;
216        }
217
218        if let Some(parent_node_id) = self.focusable_node_ids.get(&parent) {
219            let mut current_node_id = self.focusable_node_ids.get(&child).copied();
220            while let Some(node_id) = current_node_id {
221                if node_id == *parent_node_id {
222                    return true;
223                }
224                current_node_id = self.nodes[node_id.0].parent;
225            }
226        }
227        false
228    }
229
230    pub fn available_actions(&self, target: DispatchNodeId) -> Vec<Box<dyn Action>> {
231        let mut actions = Vec::<Box<dyn Action>>::new();
232        for node_id in self.dispatch_path(target) {
233            let node = &self.nodes[node_id.0];
234            for DispatchActionListener { action_type, .. } in &node.action_listeners {
235                if let Err(ix) = actions.binary_search_by_key(action_type, |a| a.as_any().type_id())
236                {
237                    // Intentionally silence these errors without logging.
238                    // If an action cannot be built by default, it's not available.
239                    let action = self.action_registry.build_action_type(action_type).ok();
240                    if let Some(action) = action {
241                        actions.insert(ix, action);
242                    }
243                }
244            }
245        }
246        actions
247    }
248
249    pub fn is_action_available(&self, action: &dyn Action, target: DispatchNodeId) -> bool {
250        for node_id in self.dispatch_path(target) {
251            let node = &self.nodes[node_id.0];
252            if node
253                .action_listeners
254                .iter()
255                .any(|listener| listener.action_type == action.as_any().type_id())
256            {
257                return true;
258            }
259        }
260        false
261    }
262
263    pub fn bindings_for_action(
264        &self,
265        action: &dyn Action,
266        context_stack: &Vec<KeyContext>,
267    ) -> Vec<KeyBinding> {
268        let keymap = self.keymap.lock();
269        keymap
270            .bindings_for_action(action)
271            .filter(|binding| {
272                for i in 0..context_stack.len() {
273                    let context = &context_stack[0..=i];
274                    if keymap.binding_enabled(binding, context) {
275                        return true;
276                    }
277                }
278                false
279            })
280            .cloned()
281            .collect()
282    }
283
284    pub fn dispatch_key(
285        &mut self,
286        keystroke: &Keystroke,
287        context: &[KeyContext],
288    ) -> Vec<Box<dyn Action>> {
289        if !self.keystroke_matchers.contains_key(context) {
290            let keystroke_contexts = context.iter().cloned().collect();
291            self.keystroke_matchers.insert(
292                keystroke_contexts,
293                KeystrokeMatcher::new(self.keymap.clone()),
294            );
295        }
296
297        let keystroke_matcher = self.keystroke_matchers.get_mut(context).unwrap();
298        if let KeyMatch::Some(actions) = keystroke_matcher.match_keystroke(keystroke, context) {
299            // Clear all pending keystrokes when an action has been found.
300            for keystroke_matcher in self.keystroke_matchers.values_mut() {
301                keystroke_matcher.clear_pending();
302            }
303
304            actions
305        } else {
306            vec![]
307        }
308    }
309
310    pub fn has_pending_keystrokes(&self) -> bool {
311        self.keystroke_matchers
312            .iter()
313            .any(|(_, matcher)| matcher.has_pending_keystrokes())
314    }
315
316    pub fn dispatch_path(&self, target: DispatchNodeId) -> SmallVec<[DispatchNodeId; 32]> {
317        let mut dispatch_path: SmallVec<[DispatchNodeId; 32]> = SmallVec::new();
318        let mut current_node_id = Some(target);
319        while let Some(node_id) = current_node_id {
320            dispatch_path.push(node_id);
321            current_node_id = self.nodes[node_id.0].parent;
322        }
323        dispatch_path.reverse(); // Reverse the path so it goes from the root to the focused node.
324        dispatch_path
325    }
326
327    pub fn focus_path(&self, focus_id: FocusId) -> SmallVec<[FocusId; 8]> {
328        let mut focus_path: SmallVec<[FocusId; 8]> = SmallVec::new();
329        let mut current_node_id = self.focusable_node_ids.get(&focus_id).copied();
330        while let Some(node_id) = current_node_id {
331            let node = self.node(node_id);
332            if let Some(focus_id) = node.focus_id {
333                focus_path.push(focus_id);
334            }
335            current_node_id = node.parent;
336        }
337        focus_path.reverse(); // Reverse the path so it goes from the root to the focused node.
338        focus_path
339    }
340
341    pub fn view_path(&self, view_id: EntityId) -> SmallVec<[EntityId; 8]> {
342        let mut view_path: SmallVec<[EntityId; 8]> = SmallVec::new();
343        let mut current_node_id = self.view_node_ids.get(&view_id).copied();
344        while let Some(node_id) = current_node_id {
345            let node = self.node(node_id);
346            if let Some(view_id) = node.view_id {
347                view_path.push(view_id);
348            }
349            current_node_id = node.parent;
350        }
351        view_path.reverse(); // Reverse the path so it goes from the root to the view node.
352        view_path
353    }
354
355    pub fn node(&self, node_id: DispatchNodeId) -> &DispatchNode {
356        &self.nodes[node_id.0]
357    }
358
359    fn active_node(&mut self) -> &mut DispatchNode {
360        let active_node_id = self.active_node_id();
361        &mut self.nodes[active_node_id.0]
362    }
363
364    pub fn focusable_node_id(&self, target: FocusId) -> Option<DispatchNodeId> {
365        self.focusable_node_ids.get(&target).copied()
366    }
367
368    pub fn root_node_id(&self) -> DispatchNodeId {
369        debug_assert!(!self.nodes.is_empty());
370        DispatchNodeId(0)
371    }
372
373    fn active_node_id(&self) -> DispatchNodeId {
374        *self.node_stack.last().unwrap()
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use std::{rc::Rc, sync::Arc};
381
382    use parking_lot::Mutex;
383
384    use crate::{Action, ActionRegistry, DispatchTree, KeyBinding, KeyContext, Keymap};
385
386    #[derive(PartialEq, Eq)]
387    struct TestAction;
388
389    impl Action for TestAction {
390        fn name(&self) -> &'static str {
391            "test::TestAction"
392        }
393
394        fn debug_name() -> &'static str
395        where
396            Self: ::std::marker::Sized,
397        {
398            "test::TestAction"
399        }
400
401        fn partial_eq(&self, action: &dyn Action) -> bool {
402            action
403                .as_any()
404                .downcast_ref::<Self>()
405                .map_or(false, |a| self == a)
406        }
407
408        fn boxed_clone(&self) -> std::boxed::Box<dyn Action> {
409            Box::new(TestAction)
410        }
411
412        fn as_any(&self) -> &dyn ::std::any::Any {
413            self
414        }
415
416        fn build(_value: serde_json::Value) -> anyhow::Result<Box<dyn Action>>
417        where
418            Self: Sized,
419        {
420            Ok(Box::new(TestAction))
421        }
422    }
423
424    #[test]
425    fn test_keybinding_for_action_bounds() {
426        let keymap = Keymap::new(vec![KeyBinding::new(
427            "cmd-n",
428            TestAction,
429            Some("ProjectPanel"),
430        )]);
431
432        let mut registry = ActionRegistry::default();
433
434        registry.load_action::<TestAction>();
435
436        let keymap = Arc::new(Mutex::new(keymap));
437
438        let tree = DispatchTree::new(keymap, Rc::new(registry));
439
440        let contexts = vec![
441            KeyContext::parse("Workspace").unwrap(),
442            KeyContext::parse("ProjectPanel").unwrap(),
443        ];
444
445        let keybinding = tree.bindings_for_action(&TestAction, &contexts);
446
447        assert!(keybinding[0].action.partial_eq(&TestAction))
448    }
449}