1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending_views: HashMap<usize, Context>,
17 pending_keystrokes: Vec<Keystroke>,
18 keymap: Keymap,
19}
20
21#[derive(Default)]
22pub struct Keymap {
23 bindings: Vec<Binding>,
24 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
25}
26
27pub struct Binding {
28 keystrokes: SmallVec<[Keystroke; 2]>,
29 action: Box<dyn Action>,
30 context_predicate: Option<ContextPredicate>,
31}
32
33#[derive(Clone, Debug, Eq, PartialEq)]
34pub struct Keystroke {
35 pub ctrl: bool,
36 pub alt: bool,
37 pub shift: bool,
38 pub cmd: bool,
39 pub function: bool,
40 pub key: String,
41}
42
43#[derive(Clone, Debug, Default, Eq, PartialEq)]
44pub struct Context {
45 pub set: HashSet<String>,
46 pub map: HashMap<String, String>,
47}
48
49#[derive(Debug, Eq, PartialEq)]
50enum ContextPredicate {
51 Identifier(String),
52 Equal(String, String),
53 NotEqual(String, String),
54 Not(Box<ContextPredicate>),
55 And(Box<ContextPredicate>, Box<ContextPredicate>),
56 Or(Box<ContextPredicate>, Box<ContextPredicate>),
57}
58
59trait ActionArg {
60 fn boxed_clone(&self) -> Box<dyn Any>;
61}
62
63impl<T> ActionArg for T
64where
65 T: 'static + Any + Clone,
66{
67 fn boxed_clone(&self) -> Box<dyn Any> {
68 Box::new(self.clone())
69 }
70}
71
72pub enum MatchResult {
73 None,
74 Pending,
75 Matches(Vec<(usize, Box<dyn Action>)>),
76}
77
78impl Debug for MatchResult {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
82 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
83 MatchResult::Matches(matches) => f
84 .debug_list()
85 .entries(
86 matches
87 .iter()
88 .map(|(view_id, action)| format!("{view_id}, {}", action.name())),
89 )
90 .finish(),
91 }
92 }
93}
94
95impl PartialEq for MatchResult {
96 fn eq(&self, other: &Self) -> bool {
97 match (self, other) {
98 (MatchResult::None, MatchResult::None) => true,
99 (MatchResult::Pending, MatchResult::Pending) => true,
100 (MatchResult::Matches(matches), MatchResult::Matches(other_matches)) => {
101 matches.len() == other_matches.len()
102 && matches.iter().zip(other_matches.iter()).all(
103 |((view_id, action), (other_view_id, other_action))| {
104 view_id == other_view_id && action.eq(other_action.as_ref())
105 },
106 )
107 }
108 _ => false,
109 }
110 }
111}
112
113impl Eq for MatchResult {}
114
115impl Clone for MatchResult {
116 fn clone(&self) -> Self {
117 match self {
118 MatchResult::None => MatchResult::None,
119 MatchResult::Pending => MatchResult::Pending,
120 MatchResult::Matches(matches) => MatchResult::Matches(
121 matches
122 .iter()
123 .map(|(view_id, action)| (*view_id, Action::boxed_clone(action.as_ref())))
124 .collect(),
125 ),
126 }
127 }
128}
129
130impl Matcher {
131 pub fn new(keymap: Keymap) -> Self {
132 Self {
133 pending_views: HashMap::new(),
134 pending_keystrokes: Vec::new(),
135 keymap,
136 }
137 }
138
139 pub fn set_keymap(&mut self, keymap: Keymap) {
140 self.clear_pending();
141 self.keymap = keymap;
142 }
143
144 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
145 self.clear_pending();
146 self.keymap.add_bindings(bindings);
147 }
148
149 pub fn clear_bindings(&mut self) {
150 self.clear_pending();
151 self.keymap.clear();
152 }
153
154 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
155 self.keymap.bindings_for_action_type(action_type)
156 }
157
158 pub fn clear_pending(&mut self) {
159 self.pending_keystrokes.clear();
160 self.pending_views.clear();
161 }
162
163 pub fn has_pending_keystrokes(&self) -> bool {
164 !self.pending_keystrokes.is_empty()
165 }
166
167 pub fn push_keystroke(
168 &mut self,
169 keystroke: Keystroke,
170 dispatch_path: Vec<(usize, Context)>,
171 ) -> MatchResult {
172 let mut any_pending = false;
173 let mut matched_bindings = Vec::new();
174
175 let first_keystroke = self.pending_keystrokes.is_empty();
176 self.pending_keystrokes.push(keystroke);
177
178 for (view_id, context) in dispatch_path {
179 // Don't require pending view entry if there are no pending keystrokes
180 if !first_keystroke && !self.pending_views.contains_key(&view_id) {
181 continue;
182 }
183
184 // If there is a previous view context, invalidate that view if it
185 // has changed
186 if let Some(previous_view_context) = self.pending_views.remove(&view_id) {
187 if previous_view_context != context {
188 continue;
189 }
190 }
191
192 // Find the bindings which map the pending keystrokes and current context
193 for binding in self.keymap.bindings.iter().rev() {
194 if binding.keystrokes.starts_with(&self.pending_keystrokes)
195 && binding
196 .context_predicate
197 .as_ref()
198 .map(|c| c.eval(&context))
199 .unwrap_or(true)
200 {
201 // If the binding is completed, push it onto the matches list
202 if binding.keystrokes.len() == self.pending_keystrokes.len() {
203 matched_bindings.push((view_id, binding.action.boxed_clone()));
204 } else {
205 // Otherwise, the binding is still pending
206 self.pending_views.insert(view_id, context.clone());
207 any_pending = true;
208 }
209 }
210 }
211 }
212
213 if !any_pending {
214 self.clear_pending();
215 }
216
217 if !matched_bindings.is_empty() {
218 MatchResult::Matches(matched_bindings)
219 } else if any_pending {
220 MatchResult::Pending
221 } else {
222 MatchResult::None
223 }
224 }
225
226 pub fn keystrokes_for_action(
227 &self,
228 action: &dyn Action,
229 cx: &Context,
230 ) -> Option<SmallVec<[Keystroke; 2]>> {
231 for binding in self.keymap.bindings.iter().rev() {
232 if binding.action.eq(action)
233 && binding
234 .context_predicate
235 .as_ref()
236 .map_or(true, |predicate| predicate.eval(cx))
237 {
238 return Some(binding.keystrokes.clone());
239 }
240 }
241 None
242 }
243}
244
245impl Default for Matcher {
246 fn default() -> Self {
247 Self::new(Keymap::default())
248 }
249}
250
251impl Keymap {
252 pub fn new(bindings: Vec<Binding>) -> Self {
253 let mut binding_indices_by_action_type = HashMap::new();
254 for (ix, binding) in bindings.iter().enumerate() {
255 binding_indices_by_action_type
256 .entry(binding.action.as_any().type_id())
257 .or_insert_with(SmallVec::new)
258 .push(ix);
259 }
260 Self {
261 binding_indices_by_action_type,
262 bindings,
263 }
264 }
265
266 fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
267 self.binding_indices_by_action_type
268 .get(&action_type)
269 .map(SmallVec::as_slice)
270 .unwrap_or(&[])
271 .iter()
272 .map(|ix| &self.bindings[*ix])
273 }
274
275 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
276 for binding in bindings {
277 self.binding_indices_by_action_type
278 .entry(binding.action.as_any().type_id())
279 .or_default()
280 .push(self.bindings.len());
281 self.bindings.push(binding);
282 }
283 }
284
285 fn clear(&mut self) {
286 self.bindings.clear();
287 self.binding_indices_by_action_type.clear();
288 }
289}
290
291impl Binding {
292 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
293 Self::load(keystrokes, Box::new(action), context).unwrap()
294 }
295
296 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
297 let context = if let Some(context) = context {
298 Some(ContextPredicate::parse(context)?)
299 } else {
300 None
301 };
302
303 let keystrokes = keystrokes
304 .split_whitespace()
305 .map(Keystroke::parse)
306 .collect::<Result<_>>()?;
307
308 Ok(Self {
309 keystrokes,
310 action,
311 context_predicate: context,
312 })
313 }
314
315 pub fn keystrokes(&self) -> &[Keystroke] {
316 &self.keystrokes
317 }
318
319 pub fn action(&self) -> &dyn Action {
320 self.action.as_ref()
321 }
322}
323
324impl Keystroke {
325 pub fn parse(source: &str) -> anyhow::Result<Self> {
326 let mut ctrl = false;
327 let mut alt = false;
328 let mut shift = false;
329 let mut cmd = false;
330 let mut function = false;
331 let mut key = None;
332
333 let mut components = source.split('-').peekable();
334 while let Some(component) = components.next() {
335 match component {
336 "ctrl" => ctrl = true,
337 "alt" => alt = true,
338 "shift" => shift = true,
339 "cmd" => cmd = true,
340 "fn" => function = true,
341 _ => {
342 if let Some(component) = components.peek() {
343 if component.is_empty() && source.ends_with('-') {
344 key = Some(String::from("-"));
345 break;
346 } else {
347 return Err(anyhow!("Invalid keystroke `{}`", source));
348 }
349 } else {
350 key = Some(String::from(component));
351 }
352 }
353 }
354 }
355
356 let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
357
358 Ok(Keystroke {
359 ctrl,
360 alt,
361 shift,
362 cmd,
363 function,
364 key,
365 })
366 }
367
368 pub fn modified(&self) -> bool {
369 self.ctrl || self.alt || self.shift || self.cmd
370 }
371}
372
373impl std::fmt::Display for Keystroke {
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 if self.ctrl {
376 f.write_char('^')?;
377 }
378 if self.alt {
379 f.write_char('⎇')?;
380 }
381 if self.cmd {
382 f.write_char('⌘')?;
383 }
384 if self.shift {
385 f.write_char('⇧')?;
386 }
387 let key = match self.key.as_str() {
388 "backspace" => '⌫',
389 "up" => '↑',
390 "down" => '↓',
391 "left" => '←',
392 "right" => '→',
393 "tab" => '⇥',
394 "escape" => '⎋',
395 key => {
396 if key.len() == 1 {
397 key.chars().next().unwrap().to_ascii_uppercase()
398 } else {
399 return f.write_str(key);
400 }
401 }
402 };
403 f.write_char(key)
404 }
405}
406
407impl Context {
408 pub fn extend(&mut self, other: &Context) {
409 for v in &other.set {
410 self.set.insert(v.clone());
411 }
412 for (k, v) in &other.map {
413 self.map.insert(k.clone(), v.clone());
414 }
415 }
416}
417
418impl ContextPredicate {
419 fn parse(source: &str) -> anyhow::Result<Self> {
420 let mut parser = Parser::new();
421 let language = unsafe { tree_sitter_context_predicate() };
422 parser.set_language(language).unwrap();
423 let source = source.as_bytes();
424 let tree = parser.parse(source, None).unwrap();
425 Self::from_node(tree.root_node(), source)
426 }
427
428 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
429 let parse_error = "error parsing context predicate";
430 let kind = node.kind();
431
432 match kind {
433 "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
434 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
435 "not" => {
436 let child = Self::from_node(
437 node.child_by_field_name("expression")
438 .ok_or_else(|| anyhow!(parse_error))?,
439 source,
440 )?;
441 Ok(Self::Not(Box::new(child)))
442 }
443 "and" | "or" => {
444 let left = Box::new(Self::from_node(
445 node.child_by_field_name("left")
446 .ok_or_else(|| anyhow!(parse_error))?,
447 source,
448 )?);
449 let right = Box::new(Self::from_node(
450 node.child_by_field_name("right")
451 .ok_or_else(|| anyhow!(parse_error))?,
452 source,
453 )?);
454 if kind == "and" {
455 Ok(Self::And(left, right))
456 } else {
457 Ok(Self::Or(left, right))
458 }
459 }
460 "equal" | "not_equal" => {
461 let left = node
462 .child_by_field_name("left")
463 .ok_or_else(|| anyhow!(parse_error))?
464 .utf8_text(source)?
465 .into();
466 let right = node
467 .child_by_field_name("right")
468 .ok_or_else(|| anyhow!(parse_error))?
469 .utf8_text(source)?
470 .into();
471 if kind == "equal" {
472 Ok(Self::Equal(left, right))
473 } else {
474 Ok(Self::NotEqual(left, right))
475 }
476 }
477 "parenthesized" => Self::from_node(
478 node.child_by_field_name("expression")
479 .ok_or_else(|| anyhow!(parse_error))?,
480 source,
481 ),
482 _ => Err(anyhow!(parse_error)),
483 }
484 }
485
486 fn eval(&self, cx: &Context) -> bool {
487 match self {
488 Self::Identifier(name) => cx.set.contains(name.as_str()),
489 Self::Equal(left, right) => cx
490 .map
491 .get(left)
492 .map(|value| value == right)
493 .unwrap_or(false),
494 Self::NotEqual(left, right) => {
495 cx.map.get(left).map(|value| value != right).unwrap_or(true)
496 }
497 Self::Not(pred) => !pred.eval(cx),
498 Self::And(left, right) => left.eval(cx) && right.eval(cx),
499 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
500 }
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use anyhow::Result;
507 use serde::Deserialize;
508
509 use crate::{actions, impl_actions};
510
511 use super::*;
512
513 #[test]
514 fn test_push_keystroke() -> Result<()> {
515 actions!(test, [B, AB, C, D, DA]);
516
517 let mut ctx1 = Context::default();
518 ctx1.set.insert("1".into());
519
520 let mut ctx2 = Context::default();
521 ctx2.set.insert("2".into());
522
523 let dispatch_path = vec![(2, ctx2), (1, ctx1)];
524
525 let keymap = Keymap::new(vec![
526 Binding::new("a b", AB, Some("1")),
527 Binding::new("b", B, Some("2")),
528 Binding::new("c", C, Some("2")),
529 Binding::new("d", D, Some("1")),
530 Binding::new("d", D, Some("2")),
531 Binding::new("d a", DA, Some("2")),
532 ]);
533
534 let mut matcher = Matcher::new(keymap);
535
536 // Binding with pending prefix always takes precedence
537 assert_eq!(
538 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
539 MatchResult::Pending,
540 );
541 // B alone doesn't match because a was pending, so AB is returned instead
542 assert_eq!(
543 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
544 MatchResult::Matches(vec![(1, Box::new(AB))]),
545 );
546 assert!(!matcher.has_pending_keystrokes());
547
548 // Without an a prefix, B is dispatched like expected
549 assert_eq!(
550 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
551 MatchResult::Matches(vec![(2, Box::new(B))]),
552 );
553 assert!(!matcher.has_pending_keystrokes());
554
555 // If a is prefixed, C will not be dispatched because there
556 // was a pending binding for it
557 assert_eq!(
558 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
559 MatchResult::Pending,
560 );
561 assert_eq!(
562 matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone()),
563 MatchResult::None,
564 );
565 assert!(!matcher.has_pending_keystrokes());
566
567 // If a single keystroke matches multiple bindings in the tree
568 // all of them are returned so that we can fallback if the action
569 // handler decides to propagate the action
570 assert_eq!(
571 matcher.push_keystroke(Keystroke::parse("d")?, dispatch_path.clone()),
572 MatchResult::Matches(vec![(2, Box::new(D)), (1, Box::new(D))]),
573 );
574 // If none of the d action handlers consume the binding, a pending
575 // binding may then be used
576 assert_eq!(
577 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
578 MatchResult::Matches(vec![(2, Box::new(DA))]),
579 );
580 assert!(!matcher.has_pending_keystrokes());
581
582 Ok(())
583 }
584
585 #[test]
586 fn test_keystroke_parsing() -> Result<()> {
587 assert_eq!(
588 Keystroke::parse("ctrl-p")?,
589 Keystroke {
590 key: "p".into(),
591 ctrl: true,
592 alt: false,
593 shift: false,
594 cmd: false,
595 function: false,
596 }
597 );
598
599 assert_eq!(
600 Keystroke::parse("alt-shift-down")?,
601 Keystroke {
602 key: "down".into(),
603 ctrl: false,
604 alt: true,
605 shift: true,
606 cmd: false,
607 function: false,
608 }
609 );
610
611 assert_eq!(
612 Keystroke::parse("shift-cmd--")?,
613 Keystroke {
614 key: "-".into(),
615 ctrl: false,
616 alt: false,
617 shift: true,
618 cmd: true,
619 function: false,
620 }
621 );
622
623 Ok(())
624 }
625
626 #[test]
627 fn test_context_predicate_parsing() -> Result<()> {
628 use ContextPredicate::*;
629
630 assert_eq!(
631 ContextPredicate::parse("a && (b == c || d != e)")?,
632 And(
633 Box::new(Identifier("a".into())),
634 Box::new(Or(
635 Box::new(Equal("b".into(), "c".into())),
636 Box::new(NotEqual("d".into(), "e".into())),
637 ))
638 )
639 );
640
641 assert_eq!(
642 ContextPredicate::parse("!a")?,
643 Not(Box::new(Identifier("a".into())),)
644 );
645
646 Ok(())
647 }
648
649 #[test]
650 fn test_context_predicate_eval() -> Result<()> {
651 let predicate = ContextPredicate::parse("a && b || c == d")?;
652
653 let mut context = Context::default();
654 context.set.insert("a".into());
655 assert!(!predicate.eval(&context));
656
657 context.set.insert("b".into());
658 assert!(predicate.eval(&context));
659
660 context.set.remove("b");
661 context.map.insert("c".into(), "x".into());
662 assert!(!predicate.eval(&context));
663
664 context.map.insert("c".into(), "d".into());
665 assert!(predicate.eval(&context));
666
667 let predicate = ContextPredicate::parse("!a")?;
668 assert!(predicate.eval(&Context::default()));
669
670 Ok(())
671 }
672
673 #[test]
674 fn test_matcher() -> Result<()> {
675 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
676 pub struct A(pub String);
677 impl_actions!(test, [A]);
678 actions!(test, [B, Ab]);
679
680 #[derive(Clone, Debug, Eq, PartialEq)]
681 struct ActionArg {
682 a: &'static str,
683 }
684
685 let keymap = Keymap::new(vec![
686 Binding::new("a", A("x".to_string()), Some("a")),
687 Binding::new("b", B, Some("a")),
688 Binding::new("a b", Ab, Some("a || b")),
689 ]);
690
691 let mut ctx_a = Context::default();
692 ctx_a.set.insert("a".into());
693
694 let mut ctx_b = Context::default();
695 ctx_b.set.insert("b".into());
696
697 let mut matcher = Matcher::new(keymap);
698
699 // Basic match
700 assert_eq!(
701 matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_a.clone())]),
702 MatchResult::Matches(vec![(1, Box::new(A("x".to_string())))])
703 );
704 matcher.clear_pending();
705
706 // Multi-keystroke match
707 assert_eq!(
708 matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_b.clone())]),
709 MatchResult::Pending
710 );
711 assert_eq!(
712 matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_b.clone())]),
713 MatchResult::Matches(vec![(1, Box::new(Ab))])
714 );
715 matcher.clear_pending();
716
717 // Failed matches don't interfere with matching subsequent keys
718 assert_eq!(
719 matcher.push_keystroke(Keystroke::parse("x")?, vec![(1, ctx_a.clone())]),
720 MatchResult::None
721 );
722 assert_eq!(
723 matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_a.clone())]),
724 MatchResult::Matches(vec![(1, Box::new(A("x".to_string())))])
725 );
726 matcher.clear_pending();
727
728 // Pending keystrokes are cleared when the context changes
729 assert_eq!(
730 matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_b.clone())]),
731 MatchResult::Pending
732 );
733 assert_eq!(
734 matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_a.clone())]),
735 MatchResult::None
736 );
737 matcher.clear_pending();
738
739 let mut ctx_c = Context::default();
740 ctx_c.set.insert("c".into());
741
742 // Pending keystrokes are maintained per-view
743 assert_eq!(
744 matcher.push_keystroke(
745 Keystroke::parse("a")?,
746 vec![(1, ctx_b.clone()), (2, ctx_c.clone())]
747 ),
748 MatchResult::Pending
749 );
750 assert_eq!(
751 matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_b.clone())]),
752 MatchResult::Matches(vec![(1, Box::new(Ab))])
753 );
754
755 Ok(())
756 }
757}