1use crate::{
2 Action, ActionRegistry, DispatchPhase, EntityId, FocusId, KeyBinding, KeyContext, KeyMatch,
3 Keymap, Keystroke, KeystrokeMatcher, WindowContext,
4};
5use collections::FxHashMap;
6use parking_lot::Mutex;
7use smallvec::SmallVec;
8use std::{
9 any::{Any, TypeId},
10 mem,
11 rc::Rc,
12 sync::Arc,
13};
14
15#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
16pub struct DispatchNodeId(usize);
17
18pub(crate) struct DispatchTree {
19 node_stack: Vec<DispatchNodeId>,
20 pub(crate) context_stack: Vec<KeyContext>,
21 nodes: Vec<DispatchNode>,
22 focusable_node_ids: FxHashMap<FocusId, DispatchNodeId>,
23 view_node_ids: FxHashMap<EntityId, DispatchNodeId>,
24 keystroke_matchers: FxHashMap<SmallVec<[KeyContext; 4]>, KeystrokeMatcher>,
25 keymap: Arc<Mutex<Keymap>>,
26 action_registry: Rc<ActionRegistry>,
27}
28
29#[derive(Default)]
30pub(crate) struct DispatchNode {
31 pub key_listeners: Vec<KeyListener>,
32 pub action_listeners: Vec<DispatchActionListener>,
33 pub context: Option<KeyContext>,
34 focus_id: Option<FocusId>,
35 view_id: Option<EntityId>,
36 parent: Option<DispatchNodeId>,
37}
38
39type KeyListener = Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>;
40
41#[derive(Clone)]
42pub(crate) struct DispatchActionListener {
43 pub(crate) action_type: TypeId,
44 pub(crate) listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
45}
46
47impl DispatchTree {
48 pub fn new(keymap: Arc<Mutex<Keymap>>, action_registry: Rc<ActionRegistry>) -> Self {
49 Self {
50 node_stack: Vec::new(),
51 context_stack: Vec::new(),
52 nodes: Vec::new(),
53 focusable_node_ids: FxHashMap::default(),
54 view_node_ids: FxHashMap::default(),
55 keystroke_matchers: FxHashMap::default(),
56 keymap,
57 action_registry,
58 }
59 }
60
61 pub fn clear(&mut self) {
62 self.node_stack.clear();
63 self.nodes.clear();
64 self.context_stack.clear();
65 self.focusable_node_ids.clear();
66 self.view_node_ids.clear();
67 self.keystroke_matchers.clear();
68 }
69
70 pub fn push_node(&mut self, context: Option<KeyContext>) {
71 let parent = self.node_stack.last().copied();
72 let node_id = DispatchNodeId(self.nodes.len());
73 self.nodes.push(DispatchNode {
74 parent,
75 ..Default::default()
76 });
77 self.node_stack.push(node_id);
78 if let Some(context) = context {
79 self.active_node().context = Some(context.clone());
80 self.context_stack.push(context);
81 }
82 }
83
84 pub fn pop_node(&mut self) {
85 let node_id = self.node_stack.pop().unwrap();
86 if self.nodes[node_id.0].context.is_some() {
87 self.context_stack.pop();
88 }
89 }
90
91 fn move_node(&mut self, source_node: &mut DispatchNode) {
92 self.push_node(source_node.context.take());
93 if let Some(focus_id) = source_node.focus_id {
94 self.make_focusable(focus_id);
95 }
96 if let Some(view_id) = source_node.view_id {
97 self.associate_view(view_id);
98 }
99
100 let target_node = self.active_node();
101 target_node.key_listeners = mem::take(&mut source_node.key_listeners);
102 target_node.action_listeners = mem::take(&mut source_node.action_listeners);
103 }
104
105 pub fn graft(&mut self, view_id: EntityId, source: &mut Self) {
106 let view_source_node_id = source
107 .view_node_ids
108 .get(&view_id)
109 .expect("view should exist in previous dispatch tree");
110 let view_source_node = &mut source.nodes[view_source_node_id.0];
111 self.move_node(view_source_node);
112
113 let mut source_stack = vec![*view_source_node_id];
114 for (source_node_id, source_node) in source
115 .nodes
116 .iter_mut()
117 .enumerate()
118 .skip(view_source_node_id.0 + 1)
119 {
120 let source_node_id = DispatchNodeId(source_node_id);
121 while let Some(source_ancestor) = source_stack.last() {
122 if source_node.parent != Some(*source_ancestor) {
123 source_stack.pop();
124 self.pop_node();
125 }
126 }
127
128 if source_stack.is_empty() {
129 break;
130 } else {
131 source_stack.push(source_node_id);
132 self.move_node(source_node);
133 }
134 }
135
136 while !source_stack.is_empty() {
137 self.pop_node();
138 }
139 }
140
141 pub fn clear_pending_keystrokes(&mut self) {
142 self.keystroke_matchers.clear();
143 }
144
145 /// Preserve keystroke matchers from previous frames to support multi-stroke
146 /// bindings across multiple frames.
147 pub fn preserve_pending_keystrokes(&mut self, old_tree: &mut Self, focus_id: Option<FocusId>) {
148 if let Some(node_id) = focus_id.and_then(|focus_id| self.focusable_node_id(focus_id)) {
149 let dispatch_path = self.dispatch_path(node_id);
150
151 self.context_stack.clear();
152 for node_id in dispatch_path {
153 let node = self.node(node_id);
154 if let Some(context) = node.context.clone() {
155 self.context_stack.push(context);
156 }
157
158 if let Some((context_stack, matcher)) = old_tree
159 .keystroke_matchers
160 .remove_entry(self.context_stack.as_slice())
161 {
162 self.keystroke_matchers.insert(context_stack, matcher);
163 }
164 }
165 }
166 }
167
168 pub fn on_key_event(&mut self, listener: KeyListener) {
169 self.active_node().key_listeners.push(listener);
170 }
171
172 pub fn on_action(
173 &mut self,
174 action_type: TypeId,
175 listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
176 ) {
177 self.active_node()
178 .action_listeners
179 .push(DispatchActionListener {
180 action_type,
181 listener,
182 });
183 }
184
185 pub fn make_focusable(&mut self, focus_id: FocusId) {
186 let node_id = self.active_node_id();
187 self.active_node().focus_id = Some(focus_id);
188 self.focusable_node_ids.insert(focus_id, node_id);
189 }
190
191 pub fn associate_view(&mut self, view_id: EntityId) {
192 let node_id = self.active_node_id();
193 self.active_node().view_id = Some(view_id);
194 self.view_node_ids.insert(view_id, node_id);
195 }
196
197 pub fn focus_contains(&self, parent: FocusId, child: FocusId) -> bool {
198 if parent == child {
199 return true;
200 }
201
202 if let Some(parent_node_id) = self.focusable_node_ids.get(&parent) {
203 let mut current_node_id = self.focusable_node_ids.get(&child).copied();
204 while let Some(node_id) = current_node_id {
205 if node_id == *parent_node_id {
206 return true;
207 }
208 current_node_id = self.nodes[node_id.0].parent;
209 }
210 }
211 false
212 }
213
214 pub fn available_actions(&self, target: DispatchNodeId) -> Vec<Box<dyn Action>> {
215 let mut actions = Vec::<Box<dyn Action>>::new();
216 for node_id in self.dispatch_path(target) {
217 let node = &self.nodes[node_id.0];
218 for DispatchActionListener { action_type, .. } in &node.action_listeners {
219 if let Err(ix) = actions.binary_search_by_key(action_type, |a| a.as_any().type_id())
220 {
221 // Intentionally silence these errors without logging.
222 // If an action cannot be built by default, it's not available.
223 let action = self.action_registry.build_action_type(action_type).ok();
224 if let Some(action) = action {
225 actions.insert(ix, action);
226 }
227 }
228 }
229 }
230 actions
231 }
232
233 pub fn is_action_available(&self, action: &dyn Action, target: DispatchNodeId) -> bool {
234 for node_id in self.dispatch_path(target) {
235 let node = &self.nodes[node_id.0];
236 if node
237 .action_listeners
238 .iter()
239 .any(|listener| listener.action_type == action.as_any().type_id())
240 {
241 return true;
242 }
243 }
244 false
245 }
246
247 pub fn bindings_for_action(
248 &self,
249 action: &dyn Action,
250 context_stack: &Vec<KeyContext>,
251 ) -> Vec<KeyBinding> {
252 let keymap = self.keymap.lock();
253 keymap
254 .bindings_for_action(action)
255 .filter(|binding| {
256 for i in 1..context_stack.len() {
257 let context = &context_stack[0..i];
258 if keymap.binding_enabled(binding, context) {
259 return true;
260 }
261 }
262 false
263 })
264 .cloned()
265 .collect()
266 }
267
268 pub fn dispatch_key(
269 &mut self,
270 keystroke: &Keystroke,
271 context: &[KeyContext],
272 ) -> Vec<Box<dyn Action>> {
273 if !self.keystroke_matchers.contains_key(context) {
274 let keystroke_contexts = context.iter().cloned().collect();
275 self.keystroke_matchers.insert(
276 keystroke_contexts,
277 KeystrokeMatcher::new(self.keymap.clone()),
278 );
279 }
280
281 let keystroke_matcher = self.keystroke_matchers.get_mut(context).unwrap();
282 if let KeyMatch::Some(actions) = keystroke_matcher.match_keystroke(keystroke, context) {
283 // Clear all pending keystrokes when an action has been found.
284 for keystroke_matcher in self.keystroke_matchers.values_mut() {
285 keystroke_matcher.clear_pending();
286 }
287
288 actions
289 } else {
290 vec![]
291 }
292 }
293
294 pub fn has_pending_keystrokes(&self) -> bool {
295 self.keystroke_matchers
296 .iter()
297 .any(|(_, matcher)| matcher.has_pending_keystrokes())
298 }
299
300 pub fn dispatch_path(&self, target: DispatchNodeId) -> SmallVec<[DispatchNodeId; 32]> {
301 let mut dispatch_path: SmallVec<[DispatchNodeId; 32]> = SmallVec::new();
302 let mut current_node_id = Some(target);
303 while let Some(node_id) = current_node_id {
304 dispatch_path.push(node_id);
305 current_node_id = self.nodes[node_id.0].parent;
306 }
307 dispatch_path.reverse(); // Reverse the path so it goes from the root to the focused node.
308 dispatch_path
309 }
310
311 pub fn focus_path(&self, focus_id: FocusId) -> SmallVec<[FocusId; 8]> {
312 let mut focus_path: SmallVec<[FocusId; 8]> = SmallVec::new();
313 let mut current_node_id = self.focusable_node_ids.get(&focus_id).copied();
314 while let Some(node_id) = current_node_id {
315 let node = self.node(node_id);
316 if let Some(focus_id) = node.focus_id {
317 focus_path.push(focus_id);
318 }
319 current_node_id = node.parent;
320 }
321 focus_path.reverse(); // Reverse the path so it goes from the root to the focused node.
322 focus_path
323 }
324
325 pub fn node(&self, node_id: DispatchNodeId) -> &DispatchNode {
326 &self.nodes[node_id.0]
327 }
328
329 fn active_node(&mut self) -> &mut DispatchNode {
330 let active_node_id = self.active_node_id();
331 &mut self.nodes[active_node_id.0]
332 }
333
334 pub fn focusable_node_id(&self, target: FocusId) -> Option<DispatchNodeId> {
335 self.focusable_node_ids.get(&target).copied()
336 }
337
338 pub fn root_node_id(&self) -> DispatchNodeId {
339 debug_assert!(!self.nodes.is_empty());
340 DispatchNodeId(0)
341 }
342
343 fn active_node_id(&self) -> DispatchNodeId {
344 *self.node_stack.last().unwrap()
345 }
346}