key_dispatch.rs

  1use crate::{
  2    build_action_from_type, Action, Bounds, DispatchPhase, Element, FocusEvent, FocusHandle,
  3    FocusId, KeyContext, KeyDownEvent, KeyMatch, Keymap, KeystrokeMatcher, MouseDownEvent, Pixels,
  4    Style, StyleRefinement, ViewContext, WindowContext,
  5};
  6use collections::HashMap;
  7use parking_lot::Mutex;
  8use refineable::Refineable;
  9use smallvec::SmallVec;
 10use std::{
 11    any::{Any, TypeId},
 12    sync::Arc,
 13};
 14use util::ResultExt;
 15
 16type KeyListener = Box<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>;
 17pub type FocusListeners<V> = SmallVec<[FocusListener<V>; 2]>;
 18pub type FocusListener<V> =
 19    Box<dyn Fn(&mut V, &FocusHandle, &FocusEvent, &mut ViewContext<V>) + 'static>;
 20
 21#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
 22pub struct DispatchNodeId(usize);
 23
 24pub struct KeyDispatcher {
 25    node_stack: Vec<DispatchNodeId>,
 26    context_stack: Vec<KeyContext>,
 27    nodes: Vec<DispatchNode>,
 28    focusable_node_ids: HashMap<FocusId, DispatchNodeId>,
 29    keystroke_matchers: HashMap<SmallVec<[KeyContext; 4]>, KeystrokeMatcher>,
 30    keymap: Arc<Mutex<Keymap>>,
 31}
 32
 33#[derive(Default)]
 34pub struct DispatchNode {
 35    key_listeners: SmallVec<[KeyListener; 2]>,
 36    action_listeners: SmallVec<[ActionListener; 16]>,
 37    context: KeyContext,
 38    parent: Option<DispatchNodeId>,
 39}
 40
 41struct ActionListener {
 42    action_type: TypeId,
 43    listener: Box<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
 44}
 45
 46impl KeyDispatcher {
 47    pub fn new(keymap: Arc<Mutex<Keymap>>) -> Self {
 48        Self {
 49            node_stack: Vec::new(),
 50            context_stack: Vec::new(),
 51            nodes: Vec::new(),
 52            focusable_node_ids: HashMap::default(),
 53            keystroke_matchers: HashMap::default(),
 54            keymap,
 55        }
 56    }
 57
 58    pub fn clear(&mut self) {
 59        self.node_stack.clear();
 60        self.nodes.clear();
 61        self.context_stack.clear();
 62        self.focusable_node_ids.clear();
 63        self.keystroke_matchers.clear();
 64    }
 65
 66    pub fn push_node(&mut self, context: KeyContext, old_dispatcher: &mut Self) {
 67        let parent = self.node_stack.last().copied();
 68        let node_id = DispatchNodeId(self.nodes.len());
 69        self.nodes.push(DispatchNode {
 70            parent,
 71            ..Default::default()
 72        });
 73        self.node_stack.push(node_id);
 74        if !context.is_empty() {
 75            self.active_node().context = context.clone();
 76            self.context_stack.push(context);
 77            if let Some((context_stack, matcher)) = old_dispatcher
 78                .keystroke_matchers
 79                .remove_entry(self.context_stack.as_slice())
 80            {
 81                self.keystroke_matchers.insert(context_stack, matcher);
 82            }
 83        }
 84    }
 85
 86    pub fn pop_node(&mut self) {
 87        let node_id = self.node_stack.pop().unwrap();
 88        if !self.nodes[node_id.0].context.is_empty() {
 89            self.context_stack.pop();
 90        }
 91    }
 92
 93    pub fn on_key_event(&mut self, listener: KeyListener) {
 94        self.active_node().key_listeners.push(listener);
 95    }
 96
 97    pub fn on_action(
 98        &mut self,
 99        action_type: TypeId,
100        listener: Box<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
101    ) {
102        self.active_node().action_listeners.push(ActionListener {
103            action_type,
104            listener,
105        });
106    }
107
108    pub fn make_focusable(&mut self, focus_id: FocusId) {
109        self.focusable_node_ids
110            .insert(focus_id, self.active_node_id());
111    }
112
113    pub fn focus_contains(&self, parent: FocusId, child: FocusId) -> bool {
114        if parent == child {
115            return true;
116        }
117
118        if let Some(parent_node_id) = self.focusable_node_ids.get(&parent) {
119            let mut current_node_id = self.focusable_node_ids.get(&child).copied();
120            while let Some(node_id) = current_node_id {
121                if node_id == *parent_node_id {
122                    return true;
123                }
124                current_node_id = self.nodes[node_id.0].parent;
125            }
126        }
127        false
128    }
129
130    pub fn available_actions(&self, target: FocusId) -> Vec<Box<dyn Action>> {
131        let mut actions = Vec::new();
132        if let Some(node) = self.focusable_node_ids.get(&target) {
133            for node_id in self.dispatch_path(*node) {
134                let node = &self.nodes[node_id.0];
135                for ActionListener { action_type, .. } in &node.action_listeners {
136                    actions.extend(build_action_from_type(action_type).log_err());
137                }
138            }
139        }
140        actions
141    }
142
143    pub fn dispatch_key(&mut self, target: FocusId, event: &dyn Any, cx: &mut WindowContext) {
144        if let Some(target_node_id) = self.focusable_node_ids.get(&target).copied() {
145            self.dispatch_key_on_node(target_node_id, event, cx);
146        }
147    }
148
149    fn dispatch_key_on_node(
150        &mut self,
151        node_id: DispatchNodeId,
152        event: &dyn Any,
153        cx: &mut WindowContext,
154    ) {
155        let dispatch_path = self.dispatch_path(node_id);
156
157        // Capture phase
158        self.context_stack.clear();
159        cx.propagate_event = true;
160
161        for node_id in &dispatch_path {
162            let node = &self.nodes[node_id.0];
163            if !node.context.is_empty() {
164                self.context_stack.push(node.context.clone());
165            }
166
167            for key_listener in &node.key_listeners {
168                key_listener(event, DispatchPhase::Capture, cx);
169                if !cx.propagate_event {
170                    return;
171                }
172            }
173        }
174
175        // Bubble phase
176        for node_id in dispatch_path.iter().rev() {
177            let node = &self.nodes[node_id.0];
178
179            // Handle low level key events
180            for key_listener in &node.key_listeners {
181                key_listener(event, DispatchPhase::Bubble, cx);
182                if !cx.propagate_event {
183                    return;
184                }
185            }
186
187            // Match keystrokes
188            if !node.context.is_empty() {
189                if let Some(key_down_event) = event.downcast_ref::<KeyDownEvent>() {
190                    if !self
191                        .keystroke_matchers
192                        .contains_key(self.context_stack.as_slice())
193                    {
194                        let keystroke_contexts = self.context_stack.iter().cloned().collect();
195                        self.keystroke_matchers.insert(
196                            keystroke_contexts,
197                            KeystrokeMatcher::new(self.keymap.clone()),
198                        );
199                    }
200
201                    let keystroke_matcher = self
202                        .keystroke_matchers
203                        .get_mut(self.context_stack.as_slice())
204                        .unwrap();
205                    if let KeyMatch::Some(action) = keystroke_matcher
206                        .match_keystroke(&key_down_event.keystroke, self.context_stack.as_slice())
207                    {
208                        // Clear all pending keystrokes when an action has been found.
209                        for keystroke_matcher in self.keystroke_matchers.values_mut() {
210                            keystroke_matcher.clear_pending();
211                        }
212
213                        self.dispatch_action_on_node(*node_id, action, cx);
214                        if !cx.propagate_event {
215                            return;
216                        }
217                    }
218                }
219
220                self.context_stack.pop();
221            }
222        }
223    }
224
225    pub fn dispatch_action(
226        &self,
227        target: FocusId,
228        action: Box<dyn Action>,
229        cx: &mut WindowContext,
230    ) {
231        if let Some(target_node_id) = self.focusable_node_ids.get(&target).copied() {
232            self.dispatch_action_on_node(target_node_id, action, cx);
233        }
234    }
235
236    fn dispatch_action_on_node(
237        &self,
238        node_id: DispatchNodeId,
239        action: Box<dyn Action>,
240        cx: &mut WindowContext,
241    ) {
242        let dispatch_path = self.dispatch_path(node_id);
243
244        // Capture phase
245        for node_id in &dispatch_path {
246            let node = &self.nodes[node_id.0];
247            for ActionListener {
248                action_type,
249                listener,
250            } in &node.action_listeners
251            {
252                let any_action = action.as_any();
253                if *action_type == any_action.type_id() {
254                    listener(any_action, DispatchPhase::Capture, cx);
255                    if !cx.propagate_event {
256                        return;
257                    }
258                }
259            }
260        }
261
262        // Bubble phase
263        for node_id in dispatch_path.iter().rev() {
264            let node = &self.nodes[node_id.0];
265            for ActionListener {
266                action_type,
267                listener,
268            } in &node.action_listeners
269            {
270                let any_action = action.as_any();
271                if *action_type == any_action.type_id() {
272                    cx.propagate_event = false; // Actions stop propagation by default during the bubble phase
273                    listener(any_action, DispatchPhase::Bubble, cx);
274                    if !cx.propagate_event {
275                        return;
276                    }
277                }
278            }
279        }
280    }
281
282    fn active_node(&mut self) -> &mut DispatchNode {
283        let active_node_id = self.active_node_id();
284        &mut self.nodes[active_node_id.0]
285    }
286
287    fn active_node_id(&self) -> DispatchNodeId {
288        *self.node_stack.last().unwrap()
289    }
290
291    /// Returns the DispatchNodeIds from the root of the tree to the given target node id.
292    fn dispatch_path(&self, target: DispatchNodeId) -> SmallVec<[DispatchNodeId; 32]> {
293        let mut dispatch_path: SmallVec<[DispatchNodeId; 32]> = SmallVec::new();
294        let mut current_node_id = Some(target);
295        while let Some(node_id) = current_node_id {
296            dispatch_path.push(node_id);
297            current_node_id = self.nodes[node_id.0].parent;
298        }
299        dispatch_path.reverse(); // Reverse the path so it goes from the root to the focused node.
300        dispatch_path
301    }
302}
303
304pub trait KeyDispatch<V: 'static>: 'static {
305    fn as_focusable(&self) -> Option<&FocusableKeyDispatch<V>>;
306    fn as_focusable_mut(&mut self) -> Option<&mut FocusableKeyDispatch<V>>;
307    fn key_context(&self) -> &KeyContext;
308    fn key_context_mut(&mut self) -> &mut KeyContext;
309
310    fn initialize<R>(
311        &mut self,
312        focus_handle: Option<FocusHandle>,
313        cx: &mut ViewContext<V>,
314        f: impl FnOnce(Option<FocusHandle>, &mut ViewContext<V>) -> R,
315    ) -> R {
316        let focus_handle = if let Some(focusable) = self.as_focusable_mut() {
317            let focus_handle = focusable
318                .focus_handle
319                .get_or_insert_with(|| focus_handle.unwrap_or_else(|| cx.focus_handle()))
320                .clone();
321            for listener in focusable.focus_listeners.drain(..) {
322                let focus_handle = focus_handle.clone();
323                cx.on_focus_changed(move |view, event, cx| {
324                    listener(view, &focus_handle, event, cx)
325                });
326            }
327            Some(focus_handle)
328        } else {
329            None
330        };
331
332        cx.with_key_dispatch(self.key_context().clone(), focus_handle, f)
333    }
334
335    fn refine_style(&self, style: &mut Style, cx: &WindowContext) {
336        if let Some(focusable) = self.as_focusable() {
337            let focus_handle = focusable
338                .focus_handle
339                .as_ref()
340                .expect("must call initialize before refine_style");
341            if focus_handle.contains_focused(cx) {
342                style.refine(&focusable.focus_in_style);
343            }
344
345            if focus_handle.within_focused(cx) {
346                style.refine(&focusable.in_focus_style);
347            }
348
349            if focus_handle.is_focused(cx) {
350                style.refine(&focusable.focus_style);
351            }
352        }
353    }
354
355    fn paint(&self, bounds: Bounds<Pixels>, cx: &mut WindowContext) {
356        if let Some(focusable) = self.as_focusable() {
357            let focus_handle = focusable
358                .focus_handle
359                .clone()
360                .expect("must call initialize before paint");
361            cx.on_mouse_event(move |event: &MouseDownEvent, phase, cx| {
362                if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
363                    if !cx.default_prevented() {
364                        cx.focus(&focus_handle);
365                        cx.prevent_default();
366                    }
367                }
368            })
369        }
370    }
371}
372
373pub struct FocusableKeyDispatch<V> {
374    pub key_context: KeyContext,
375    pub focus_handle: Option<FocusHandle>,
376    pub focus_listeners: FocusListeners<V>,
377    pub focus_style: StyleRefinement,
378    pub focus_in_style: StyleRefinement,
379    pub in_focus_style: StyleRefinement,
380}
381
382impl<V> FocusableKeyDispatch<V> {
383    pub fn new() -> Self {
384        Self {
385            key_context: KeyContext::default(),
386            focus_handle: None,
387            focus_listeners: FocusListeners::default(),
388            focus_style: StyleRefinement::default(),
389            focus_in_style: StyleRefinement::default(),
390            in_focus_style: StyleRefinement::default(),
391        }
392    }
393
394    pub fn tracked(handle: &FocusHandle) -> Self {
395        Self {
396            key_context: KeyContext::default(),
397            focus_handle: Some(handle.clone()),
398            focus_listeners: FocusListeners::default(),
399            focus_style: StyleRefinement::default(),
400            focus_in_style: StyleRefinement::default(),
401            in_focus_style: StyleRefinement::default(),
402        }
403    }
404}
405
406impl<V: 'static> KeyDispatch<V> for FocusableKeyDispatch<V> {
407    fn as_focusable(&self) -> Option<&FocusableKeyDispatch<V>> {
408        Some(self)
409    }
410
411    fn as_focusable_mut(&mut self) -> Option<&mut FocusableKeyDispatch<V>> {
412        Some(self)
413    }
414
415    fn key_context(&self) -> &KeyContext {
416        &self.key_context
417    }
418
419    fn key_context_mut(&mut self) -> &mut KeyContext {
420        &mut self.key_context
421    }
422}
423
424impl<V> From<FocusHandle> for FocusableKeyDispatch<V> {
425    fn from(value: FocusHandle) -> Self {
426        Self {
427            key_context: KeyContext::default(),
428            focus_handle: Some(value),
429            focus_listeners: FocusListeners::default(),
430            focus_style: StyleRefinement::default(),
431            focus_in_style: StyleRefinement::default(),
432            in_focus_style: StyleRefinement::default(),
433        }
434    }
435}
436
437#[derive(Default)]
438pub struct NonFocusableKeyDispatch {
439    pub(crate) key_context: KeyContext,
440}
441
442impl<V: 'static> KeyDispatch<V> for NonFocusableKeyDispatch {
443    fn as_focusable(&self) -> Option<&FocusableKeyDispatch<V>> {
444        None
445    }
446
447    fn as_focusable_mut(&mut self) -> Option<&mut FocusableKeyDispatch<V>> {
448        None
449    }
450
451    fn key_context(&self) -> &KeyContext {
452        &self.key_context
453    }
454
455    fn key_context_mut(&mut self) -> &mut KeyContext {
456        &mut self.key_context
457    }
458}
459
460pub trait Focusable<V: 'static>: Element<V> {
461    fn focus_listeners(&mut self) -> &mut FocusListeners<V>;
462    fn set_focus_style(&mut self, style: StyleRefinement);
463    fn set_focus_in_style(&mut self, style: StyleRefinement);
464    fn set_in_focus_style(&mut self, style: StyleRefinement);
465
466    fn focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
467    where
468        Self: Sized,
469    {
470        self.set_focus_style(f(StyleRefinement::default()));
471        self
472    }
473
474    fn focus_in(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
475    where
476        Self: Sized,
477    {
478        self.set_focus_in_style(f(StyleRefinement::default()));
479        self
480    }
481
482    fn in_focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
483    where
484        Self: Sized,
485    {
486        self.set_in_focus_style(f(StyleRefinement::default()));
487        self
488    }
489
490    fn on_focus(
491        mut self,
492        listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + 'static,
493    ) -> Self
494    where
495        Self: Sized,
496    {
497        self.focus_listeners()
498            .push(Box::new(move |view, focus_handle, event, cx| {
499                if event.focused.as_ref() == Some(focus_handle) {
500                    listener(view, event, cx)
501                }
502            }));
503        self
504    }
505
506    fn on_blur(
507        mut self,
508        listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + 'static,
509    ) -> Self
510    where
511        Self: Sized,
512    {
513        self.focus_listeners()
514            .push(Box::new(move |view, focus_handle, event, cx| {
515                if event.blurred.as_ref() == Some(focus_handle) {
516                    listener(view, event, cx)
517                }
518            }));
519        self
520    }
521
522    fn on_focus_in(
523        mut self,
524        listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + 'static,
525    ) -> Self
526    where
527        Self: Sized,
528    {
529        self.focus_listeners()
530            .push(Box::new(move |view, focus_handle, event, cx| {
531                let descendant_blurred = event
532                    .blurred
533                    .as_ref()
534                    .map_or(false, |blurred| focus_handle.contains(blurred, cx));
535                let descendant_focused = event
536                    .focused
537                    .as_ref()
538                    .map_or(false, |focused| focus_handle.contains(focused, cx));
539
540                if !descendant_blurred && descendant_focused {
541                    listener(view, event, cx)
542                }
543            }));
544        self
545    }
546
547    fn on_focus_out(
548        mut self,
549        listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + 'static,
550    ) -> Self
551    where
552        Self: Sized,
553    {
554        self.focus_listeners()
555            .push(Box::new(move |view, focus_handle, event, cx| {
556                let descendant_blurred = event
557                    .blurred
558                    .as_ref()
559                    .map_or(false, |blurred| focus_handle.contains(blurred, cx));
560                let descendant_focused = event
561                    .focused
562                    .as_ref()
563                    .map_or(false, |focused| focus_handle.contains(focused, cx));
564                if descendant_blurred && !descendant_focused {
565                    listener(view, event, cx)
566                }
567            }));
568        self
569    }
570}