key_dispatch.rs

  1use crate::{
  2    build_action_from_type, Action, Bounds, DispatchPhase, Element, FocusEvent, FocusHandle,
  3    FocusId, KeyContext, KeyMatch, Keymap, Keystroke, KeystrokeMatcher, MouseDownEvent, Pixels,
  4    Style, StyleRefinement, ViewContext, WindowContext,
  5};
  6use collections::HashMap;
  7use parking_lot::Mutex;
  8use refineable::Refineable;
  9use smallvec::SmallVec;
 10use std::{
 11    any::{Any, TypeId},
 12    rc::Rc,
 13    sync::Arc,
 14};
 15use util::ResultExt;
 16
 17pub type FocusListeners<V> = SmallVec<[FocusListener<V>; 2]>;
 18pub type FocusListener<V> =
 19    Box<dyn Fn(&mut V, &FocusHandle, &FocusEvent, &mut ViewContext<V>) + 'static>;
 20
 21#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
 22pub struct DispatchNodeId(usize);
 23
 24pub(crate) struct DispatchTree {
 25    node_stack: Vec<DispatchNodeId>,
 26    context_stack: Vec<KeyContext>,
 27    nodes: Vec<DispatchNode>,
 28    focusable_node_ids: HashMap<FocusId, DispatchNodeId>,
 29    keystroke_matchers: HashMap<SmallVec<[KeyContext; 4]>, KeystrokeMatcher>,
 30    keymap: Arc<Mutex<Keymap>>,
 31}
 32
 33#[derive(Default)]
 34pub(crate) struct DispatchNode {
 35    pub key_listeners: SmallVec<[KeyListener; 2]>,
 36    pub action_listeners: SmallVec<[ActionListener; 16]>,
 37    pub context: KeyContext,
 38    parent: Option<DispatchNodeId>,
 39}
 40
 41type KeyListener = Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>;
 42
 43#[derive(Clone)]
 44pub(crate) struct ActionListener {
 45    pub(crate) action_type: TypeId,
 46    pub(crate) listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
 47}
 48
 49impl DispatchTree {
 50    pub fn new(keymap: Arc<Mutex<Keymap>>) -> Self {
 51        Self {
 52            node_stack: Vec::new(),
 53            context_stack: Vec::new(),
 54            nodes: Vec::new(),
 55            focusable_node_ids: HashMap::default(),
 56            keystroke_matchers: HashMap::default(),
 57            keymap,
 58        }
 59    }
 60
 61    pub fn clear(&mut self) {
 62        self.node_stack.clear();
 63        self.nodes.clear();
 64        self.context_stack.clear();
 65        self.focusable_node_ids.clear();
 66        self.keystroke_matchers.clear();
 67    }
 68
 69    pub fn push_node(&mut self, context: KeyContext, old_dispatcher: &mut Self) {
 70        let parent = self.node_stack.last().copied();
 71        let node_id = DispatchNodeId(self.nodes.len());
 72        self.nodes.push(DispatchNode {
 73            parent,
 74            ..Default::default()
 75        });
 76        self.node_stack.push(node_id);
 77        if !context.is_empty() {
 78            self.active_node().context = context.clone();
 79            self.context_stack.push(context);
 80            if let Some((context_stack, matcher)) = old_dispatcher
 81                .keystroke_matchers
 82                .remove_entry(self.context_stack.as_slice())
 83            {
 84                self.keystroke_matchers.insert(context_stack, matcher);
 85            }
 86        }
 87    }
 88
 89    pub fn pop_node(&mut self) {
 90        let node_id = self.node_stack.pop().unwrap();
 91        if !self.nodes[node_id.0].context.is_empty() {
 92            self.context_stack.pop();
 93        }
 94    }
 95
 96    pub fn on_key_event(&mut self, listener: KeyListener) {
 97        self.active_node().key_listeners.push(listener);
 98    }
 99
100    pub fn on_action(
101        &mut self,
102        action_type: TypeId,
103        listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
104    ) {
105        self.active_node().action_listeners.push(ActionListener {
106            action_type,
107            listener,
108        });
109    }
110
111    pub fn make_focusable(&mut self, focus_id: FocusId) {
112        self.focusable_node_ids
113            .insert(focus_id, self.active_node_id());
114    }
115
116    pub fn focus_contains(&self, parent: FocusId, child: FocusId) -> bool {
117        if parent == child {
118            return true;
119        }
120
121        if let Some(parent_node_id) = self.focusable_node_ids.get(&parent) {
122            let mut current_node_id = self.focusable_node_ids.get(&child).copied();
123            while let Some(node_id) = current_node_id {
124                if node_id == *parent_node_id {
125                    return true;
126                }
127                current_node_id = self.nodes[node_id.0].parent;
128            }
129        }
130        false
131    }
132
133    pub fn available_actions(&self, target: FocusId) -> Vec<Box<dyn Action>> {
134        let mut actions = Vec::new();
135        if let Some(node) = self.focusable_node_ids.get(&target) {
136            for node_id in self.dispatch_path(*node) {
137                let node = &self.nodes[node_id.0];
138                for ActionListener { action_type, .. } in &node.action_listeners {
139                    actions.extend(build_action_from_type(action_type).log_err());
140                }
141            }
142        }
143        actions
144    }
145
146    pub fn dispatch_key(
147        &mut self,
148        keystroke: &Keystroke,
149        context: &[KeyContext],
150    ) -> Option<Box<dyn Action>> {
151        if !self
152            .keystroke_matchers
153            .contains_key(self.context_stack.as_slice())
154        {
155            let keystroke_contexts = self.context_stack.iter().cloned().collect();
156            self.keystroke_matchers.insert(
157                keystroke_contexts,
158                KeystrokeMatcher::new(self.keymap.clone()),
159            );
160        }
161
162        let keystroke_matcher = self
163            .keystroke_matchers
164            .get_mut(self.context_stack.as_slice())
165            .unwrap();
166        if let KeyMatch::Some(action) = keystroke_matcher.match_keystroke(keystroke, context) {
167            // Clear all pending keystrokes when an action has been found.
168            for keystroke_matcher in self.keystroke_matchers.values_mut() {
169                keystroke_matcher.clear_pending();
170            }
171
172            Some(action)
173        } else {
174            None
175        }
176    }
177
178    pub fn dispatch_path(&self, target: DispatchNodeId) -> SmallVec<[DispatchNodeId; 32]> {
179        let mut dispatch_path: SmallVec<[DispatchNodeId; 32]> = SmallVec::new();
180        let mut current_node_id = Some(target);
181        while let Some(node_id) = current_node_id {
182            dispatch_path.push(node_id);
183            current_node_id = self.nodes[node_id.0].parent;
184        }
185        dispatch_path.reverse(); // Reverse the path so it goes from the root to the focused node.
186        dispatch_path
187    }
188
189    pub fn node(&self, node_id: DispatchNodeId) -> &DispatchNode {
190        &self.nodes[node_id.0]
191    }
192
193    fn active_node(&mut self) -> &mut DispatchNode {
194        let active_node_id = self.active_node_id();
195        &mut self.nodes[active_node_id.0]
196    }
197
198    pub fn focusable_node_id(&self, target: FocusId) -> Option<DispatchNodeId> {
199        self.focusable_node_ids.get(&target).copied()
200    }
201
202    fn active_node_id(&self) -> DispatchNodeId {
203        *self.node_stack.last().unwrap()
204    }
205}
206
207pub trait KeyDispatch<V: 'static>: 'static {
208    fn as_focusable(&self) -> Option<&FocusableKeyDispatch<V>>;
209    fn as_focusable_mut(&mut self) -> Option<&mut FocusableKeyDispatch<V>>;
210    fn key_context(&self) -> &KeyContext;
211    fn key_context_mut(&mut self) -> &mut KeyContext;
212
213    fn initialize<R>(
214        &mut self,
215        focus_handle: Option<FocusHandle>,
216        cx: &mut ViewContext<V>,
217        f: impl FnOnce(Option<FocusHandle>, &mut ViewContext<V>) -> R,
218    ) -> R {
219        let focus_handle = if let Some(focusable) = self.as_focusable_mut() {
220            let focus_handle = focusable
221                .focus_handle
222                .get_or_insert_with(|| focus_handle.unwrap_or_else(|| cx.focus_handle()))
223                .clone();
224            for listener in focusable.focus_listeners.drain(..) {
225                let focus_handle = focus_handle.clone();
226                cx.on_focus_changed(move |view, event, cx| {
227                    listener(view, &focus_handle, event, cx)
228                });
229            }
230            Some(focus_handle)
231        } else {
232            None
233        };
234
235        cx.with_key_dispatch(self.key_context().clone(), focus_handle, f)
236    }
237
238    fn refine_style(&self, style: &mut Style, cx: &WindowContext) {
239        if let Some(focusable) = self.as_focusable() {
240            let focus_handle = focusable
241                .focus_handle
242                .as_ref()
243                .expect("must call initialize before refine_style");
244            if focus_handle.contains_focused(cx) {
245                style.refine(&focusable.focus_in_style);
246            }
247
248            if focus_handle.within_focused(cx) {
249                style.refine(&focusable.in_focus_style);
250            }
251
252            if focus_handle.is_focused(cx) {
253                style.refine(&focusable.focus_style);
254            }
255        }
256    }
257
258    fn paint(&self, bounds: Bounds<Pixels>, cx: &mut WindowContext) {
259        if let Some(focusable) = self.as_focusable() {
260            let focus_handle = focusable
261                .focus_handle
262                .clone()
263                .expect("must call initialize before paint");
264            cx.on_mouse_event(move |event: &MouseDownEvent, phase, cx| {
265                if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
266                    if !cx.default_prevented() {
267                        cx.focus(&focus_handle);
268                        cx.prevent_default();
269                    }
270                }
271            })
272        }
273    }
274}
275
276pub struct FocusableKeyDispatch<V> {
277    pub key_context: KeyContext,
278    pub focus_handle: Option<FocusHandle>,
279    pub focus_listeners: FocusListeners<V>,
280    pub focus_style: StyleRefinement,
281    pub focus_in_style: StyleRefinement,
282    pub in_focus_style: StyleRefinement,
283}
284
285impl<V> FocusableKeyDispatch<V> {
286    pub fn new() -> Self {
287        Self {
288            key_context: KeyContext::default(),
289            focus_handle: None,
290            focus_listeners: FocusListeners::default(),
291            focus_style: StyleRefinement::default(),
292            focus_in_style: StyleRefinement::default(),
293            in_focus_style: StyleRefinement::default(),
294        }
295    }
296
297    pub fn tracked(handle: &FocusHandle) -> Self {
298        Self {
299            key_context: KeyContext::default(),
300            focus_handle: Some(handle.clone()),
301            focus_listeners: FocusListeners::default(),
302            focus_style: StyleRefinement::default(),
303            focus_in_style: StyleRefinement::default(),
304            in_focus_style: StyleRefinement::default(),
305        }
306    }
307}
308
309impl<V: 'static> KeyDispatch<V> for FocusableKeyDispatch<V> {
310    fn as_focusable(&self) -> Option<&FocusableKeyDispatch<V>> {
311        Some(self)
312    }
313
314    fn as_focusable_mut(&mut self) -> Option<&mut FocusableKeyDispatch<V>> {
315        Some(self)
316    }
317
318    fn key_context(&self) -> &KeyContext {
319        &self.key_context
320    }
321
322    fn key_context_mut(&mut self) -> &mut KeyContext {
323        &mut self.key_context
324    }
325}
326
327impl<V> From<FocusHandle> for FocusableKeyDispatch<V> {
328    fn from(value: FocusHandle) -> Self {
329        Self {
330            key_context: KeyContext::default(),
331            focus_handle: Some(value),
332            focus_listeners: FocusListeners::default(),
333            focus_style: StyleRefinement::default(),
334            focus_in_style: StyleRefinement::default(),
335            in_focus_style: StyleRefinement::default(),
336        }
337    }
338}
339
340#[derive(Default)]
341pub struct NonFocusableKeyDispatch {
342    pub(crate) key_context: KeyContext,
343}
344
345impl<V: 'static> KeyDispatch<V> for NonFocusableKeyDispatch {
346    fn as_focusable(&self) -> Option<&FocusableKeyDispatch<V>> {
347        None
348    }
349
350    fn as_focusable_mut(&mut self) -> Option<&mut FocusableKeyDispatch<V>> {
351        None
352    }
353
354    fn key_context(&self) -> &KeyContext {
355        &self.key_context
356    }
357
358    fn key_context_mut(&mut self) -> &mut KeyContext {
359        &mut self.key_context
360    }
361}
362
363pub trait Focusable<V: 'static>: Element<V> {
364    fn focus_listeners(&mut self) -> &mut FocusListeners<V>;
365    fn set_focus_style(&mut self, style: StyleRefinement);
366    fn set_focus_in_style(&mut self, style: StyleRefinement);
367    fn set_in_focus_style(&mut self, style: StyleRefinement);
368
369    fn focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
370    where
371        Self: Sized,
372    {
373        self.set_focus_style(f(StyleRefinement::default()));
374        self
375    }
376
377    fn focus_in(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
378    where
379        Self: Sized,
380    {
381        self.set_focus_in_style(f(StyleRefinement::default()));
382        self
383    }
384
385    fn in_focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
386    where
387        Self: Sized,
388    {
389        self.set_in_focus_style(f(StyleRefinement::default()));
390        self
391    }
392
393    fn on_focus(
394        mut self,
395        listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + 'static,
396    ) -> Self
397    where
398        Self: Sized,
399    {
400        self.focus_listeners()
401            .push(Box::new(move |view, focus_handle, event, cx| {
402                if event.focused.as_ref() == Some(focus_handle) {
403                    listener(view, event, cx)
404                }
405            }));
406        self
407    }
408
409    fn on_blur(
410        mut self,
411        listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + 'static,
412    ) -> Self
413    where
414        Self: Sized,
415    {
416        self.focus_listeners()
417            .push(Box::new(move |view, focus_handle, event, cx| {
418                if event.blurred.as_ref() == Some(focus_handle) {
419                    listener(view, event, cx)
420                }
421            }));
422        self
423    }
424
425    fn on_focus_in(
426        mut self,
427        listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + 'static,
428    ) -> Self
429    where
430        Self: Sized,
431    {
432        self.focus_listeners()
433            .push(Box::new(move |view, focus_handle, event, cx| {
434                let descendant_blurred = event
435                    .blurred
436                    .as_ref()
437                    .map_or(false, |blurred| focus_handle.contains(blurred, cx));
438                let descendant_focused = event
439                    .focused
440                    .as_ref()
441                    .map_or(false, |focused| focus_handle.contains(focused, cx));
442
443                if !descendant_blurred && descendant_focused {
444                    listener(view, event, cx)
445                }
446            }));
447        self
448    }
449
450    fn on_focus_out(
451        mut self,
452        listener: impl Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + 'static,
453    ) -> Self
454    where
455        Self: Sized,
456    {
457        self.focus_listeners()
458            .push(Box::new(move |view, focus_handle, event, cx| {
459                let descendant_blurred = event
460                    .blurred
461                    .as_ref()
462                    .map_or(false, |blurred| focus_handle.contains(blurred, cx));
463                let descendant_focused = event
464                    .focused
465                    .as_ref()
466                    .map_or(false, |focused| focus_handle.contains(focused, cx));
467                if descendant_blurred && !descendant_focused {
468                    listener(view, event, cx)
469                }
470            }));
471        self
472    }
473}