1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending_views: HashMap<usize, Context>,
17 pending_keystrokes: Vec<Keystroke>,
18 keymap: Keymap,
19}
20
21#[derive(Default)]
22pub struct Keymap {
23 bindings: Vec<Binding>,
24 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
25}
26
27pub struct Binding {
28 keystrokes: SmallVec<[Keystroke; 2]>,
29 action: Box<dyn Action>,
30 context_predicate: Option<ContextPredicate>,
31}
32
33#[derive(Clone, Debug, Eq, PartialEq)]
34pub struct Keystroke {
35 pub ctrl: bool,
36 pub alt: bool,
37 pub shift: bool,
38 pub cmd: bool,
39 pub function: bool,
40 pub key: String,
41}
42
43#[derive(Clone, Debug, Default, Eq, PartialEq)]
44pub struct Context {
45 pub set: HashSet<String>,
46 pub map: HashMap<String, String>,
47}
48
49#[derive(Debug, Eq, PartialEq)]
50enum ContextPredicate {
51 Identifier(String),
52 Equal(String, String),
53 NotEqual(String, String),
54 Not(Box<ContextPredicate>),
55 And(Box<ContextPredicate>, Box<ContextPredicate>),
56 Or(Box<ContextPredicate>, Box<ContextPredicate>),
57}
58
59trait ActionArg {
60 fn boxed_clone(&self) -> Box<dyn Any>;
61}
62
63impl<T> ActionArg for T
64where
65 T: 'static + Any + Clone,
66{
67 fn boxed_clone(&self) -> Box<dyn Any> {
68 Box::new(self.clone())
69 }
70}
71
72pub enum MatchResult {
73 None,
74 Pending,
75 Matches(Vec<(usize, Box<dyn Action>)>),
76}
77
78impl Debug for MatchResult {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
82 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
83 MatchResult::Matches(matches) => f
84 .debug_list()
85 .entries(
86 matches
87 .iter()
88 .map(|(view_id, action)| format!("{view_id}, {}", action.name())),
89 )
90 .finish(),
91 }
92 }
93}
94
95impl PartialEq for MatchResult {
96 fn eq(&self, other: &Self) -> bool {
97 match (self, other) {
98 (MatchResult::None, MatchResult::None) => true,
99 (MatchResult::Pending, MatchResult::Pending) => true,
100 (MatchResult::Matches(matches), MatchResult::Matches(other_matches)) => {
101 matches.len() == other_matches.len()
102 && matches.iter().zip(other_matches.iter()).all(
103 |((view_id, action), (other_view_id, other_action))| {
104 view_id == other_view_id && action.eq(other_action.as_ref())
105 },
106 )
107 }
108 _ => false,
109 }
110 }
111}
112
113impl Eq for MatchResult {}
114
115impl Matcher {
116 pub fn new(keymap: Keymap) -> Self {
117 Self {
118 pending_views: HashMap::new(),
119 pending_keystrokes: Vec::new(),
120 keymap,
121 }
122 }
123
124 pub fn set_keymap(&mut self, keymap: Keymap) {
125 self.clear_pending();
126 self.keymap = keymap;
127 }
128
129 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
130 self.clear_pending();
131 self.keymap.add_bindings(bindings);
132 }
133
134 pub fn clear_bindings(&mut self) {
135 self.clear_pending();
136 self.keymap.clear();
137 }
138
139 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
140 self.keymap.bindings_for_action_type(action_type)
141 }
142
143 pub fn clear_pending(&mut self) {
144 self.pending_keystrokes.clear();
145 self.pending_views.clear();
146 }
147
148 pub fn has_pending_keystrokes(&self) -> bool {
149 !self.pending_keystrokes.is_empty()
150 }
151
152 pub fn push_keystroke(
153 &mut self,
154 keystroke: Keystroke,
155 dispatch_path: Vec<(usize, Context)>,
156 ) -> MatchResult {
157 let mut any_pending = false;
158 let mut matched_bindings = Vec::new();
159
160 let first_keystroke = self.pending_keystrokes.is_empty();
161 self.pending_keystrokes.push(keystroke);
162
163 for (view_id, context) in dispatch_path {
164 // Don't require pending view entry if there are no pending keystrokes
165 if !first_keystroke && !self.pending_views.contains_key(&view_id) {
166 continue;
167 }
168
169 // If there is a previous view context, invalidate that view if it
170 // has changed
171 if let Some(previous_view_context) = self.pending_views.remove(&view_id) {
172 if previous_view_context != context {
173 continue;
174 }
175 }
176
177 // Find the bindings which map the pending keystrokes and current context
178 for binding in self.keymap.bindings.iter().rev() {
179 if binding.keystrokes.starts_with(&self.pending_keystrokes)
180 && binding
181 .context_predicate
182 .as_ref()
183 .map(|c| c.eval(&context))
184 .unwrap_or(true)
185 {
186 // If the binding is completed, push it onto the matches list
187 if binding.keystrokes.len() == self.pending_keystrokes.len() {
188 matched_bindings.push((view_id, binding.action.boxed_clone()));
189 } else {
190 // Otherwise, the binding is still pending
191 self.pending_views.insert(view_id, context.clone());
192 any_pending = true;
193 }
194 }
195 }
196 }
197
198 if !any_pending {
199 self.clear_pending();
200 }
201
202 if !matched_bindings.is_empty() {
203 MatchResult::Matches(matched_bindings)
204 } else if any_pending {
205 MatchResult::Pending
206 } else {
207 MatchResult::None
208 }
209 }
210
211 pub fn keystrokes_for_action(
212 &self,
213 action: &dyn Action,
214 cx: &Context,
215 ) -> Option<SmallVec<[Keystroke; 2]>> {
216 for binding in self.keymap.bindings.iter().rev() {
217 if binding.action.eq(action)
218 && binding
219 .context_predicate
220 .as_ref()
221 .map_or(true, |predicate| predicate.eval(cx))
222 {
223 return Some(binding.keystrokes.clone());
224 }
225 }
226 None
227 }
228}
229
230impl Default for Matcher {
231 fn default() -> Self {
232 Self::new(Keymap::default())
233 }
234}
235
236impl Keymap {
237 pub fn new(bindings: Vec<Binding>) -> Self {
238 let mut binding_indices_by_action_type = HashMap::new();
239 for (ix, binding) in bindings.iter().enumerate() {
240 binding_indices_by_action_type
241 .entry(binding.action.as_any().type_id())
242 .or_insert_with(SmallVec::new)
243 .push(ix);
244 }
245 Self {
246 binding_indices_by_action_type,
247 bindings,
248 }
249 }
250
251 fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
252 self.binding_indices_by_action_type
253 .get(&action_type)
254 .map(SmallVec::as_slice)
255 .unwrap_or(&[])
256 .iter()
257 .map(|ix| &self.bindings[*ix])
258 }
259
260 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
261 for binding in bindings {
262 self.binding_indices_by_action_type
263 .entry(binding.action.as_any().type_id())
264 .or_default()
265 .push(self.bindings.len());
266 self.bindings.push(binding);
267 }
268 }
269
270 fn clear(&mut self) {
271 self.bindings.clear();
272 self.binding_indices_by_action_type.clear();
273 }
274}
275
276impl Binding {
277 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
278 Self::load(keystrokes, Box::new(action), context).unwrap()
279 }
280
281 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
282 let context = if let Some(context) = context {
283 Some(ContextPredicate::parse(context)?)
284 } else {
285 None
286 };
287
288 let keystrokes = keystrokes
289 .split_whitespace()
290 .map(Keystroke::parse)
291 .collect::<Result<_>>()?;
292
293 Ok(Self {
294 keystrokes,
295 action,
296 context_predicate: context,
297 })
298 }
299
300 pub fn keystrokes(&self) -> &[Keystroke] {
301 &self.keystrokes
302 }
303
304 pub fn action(&self) -> &dyn Action {
305 self.action.as_ref()
306 }
307}
308
309impl Keystroke {
310 pub fn parse(source: &str) -> anyhow::Result<Self> {
311 let mut ctrl = false;
312 let mut alt = false;
313 let mut shift = false;
314 let mut cmd = false;
315 let mut function = false;
316 let mut key = None;
317
318 let mut components = source.split('-').peekable();
319 while let Some(component) = components.next() {
320 match component {
321 "ctrl" => ctrl = true,
322 "alt" => alt = true,
323 "shift" => shift = true,
324 "cmd" => cmd = true,
325 "fn" => function = true,
326 _ => {
327 if let Some(component) = components.peek() {
328 if component.is_empty() && source.ends_with('-') {
329 key = Some(String::from("-"));
330 break;
331 } else {
332 return Err(anyhow!("Invalid keystroke `{}`", source));
333 }
334 } else {
335 key = Some(String::from(component));
336 }
337 }
338 }
339 }
340
341 let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
342
343 Ok(Keystroke {
344 ctrl,
345 alt,
346 shift,
347 cmd,
348 function,
349 key,
350 })
351 }
352
353 pub fn modified(&self) -> bool {
354 self.ctrl || self.alt || self.shift || self.cmd
355 }
356}
357
358impl std::fmt::Display for Keystroke {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 if self.ctrl {
361 f.write_char('^')?;
362 }
363 if self.alt {
364 f.write_char('⎇')?;
365 }
366 if self.cmd {
367 f.write_char('⌘')?;
368 }
369 if self.shift {
370 f.write_char('⇧')?;
371 }
372 let key = match self.key.as_str() {
373 "backspace" => '⌫',
374 "up" => '↑',
375 "down" => '↓',
376 "left" => '←',
377 "right" => '→',
378 "tab" => '⇥',
379 "escape" => '⎋',
380 key => {
381 if key.len() == 1 {
382 key.chars().next().unwrap().to_ascii_uppercase()
383 } else {
384 return f.write_str(key);
385 }
386 }
387 };
388 f.write_char(key)
389 }
390}
391
392impl Context {
393 pub fn extend(&mut self, other: &Context) {
394 for v in &other.set {
395 self.set.insert(v.clone());
396 }
397 for (k, v) in &other.map {
398 self.map.insert(k.clone(), v.clone());
399 }
400 }
401}
402
403impl ContextPredicate {
404 fn parse(source: &str) -> anyhow::Result<Self> {
405 let mut parser = Parser::new();
406 let language = unsafe { tree_sitter_context_predicate() };
407 parser.set_language(language).unwrap();
408 let source = source.as_bytes();
409 let tree = parser.parse(source, None).unwrap();
410 Self::from_node(tree.root_node(), source)
411 }
412
413 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
414 let parse_error = "error parsing context predicate";
415 let kind = node.kind();
416
417 match kind {
418 "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
419 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
420 "not" => {
421 let child = Self::from_node(
422 node.child_by_field_name("expression")
423 .ok_or_else(|| anyhow!(parse_error))?,
424 source,
425 )?;
426 Ok(Self::Not(Box::new(child)))
427 }
428 "and" | "or" => {
429 let left = Box::new(Self::from_node(
430 node.child_by_field_name("left")
431 .ok_or_else(|| anyhow!(parse_error))?,
432 source,
433 )?);
434 let right = Box::new(Self::from_node(
435 node.child_by_field_name("right")
436 .ok_or_else(|| anyhow!(parse_error))?,
437 source,
438 )?);
439 if kind == "and" {
440 Ok(Self::And(left, right))
441 } else {
442 Ok(Self::Or(left, right))
443 }
444 }
445 "equal" | "not_equal" => {
446 let left = node
447 .child_by_field_name("left")
448 .ok_or_else(|| anyhow!(parse_error))?
449 .utf8_text(source)?
450 .into();
451 let right = node
452 .child_by_field_name("right")
453 .ok_or_else(|| anyhow!(parse_error))?
454 .utf8_text(source)?
455 .into();
456 if kind == "equal" {
457 Ok(Self::Equal(left, right))
458 } else {
459 Ok(Self::NotEqual(left, right))
460 }
461 }
462 "parenthesized" => Self::from_node(
463 node.child_by_field_name("expression")
464 .ok_or_else(|| anyhow!(parse_error))?,
465 source,
466 ),
467 _ => Err(anyhow!(parse_error)),
468 }
469 }
470
471 fn eval(&self, cx: &Context) -> bool {
472 match self {
473 Self::Identifier(name) => cx.set.contains(name.as_str()),
474 Self::Equal(left, right) => cx
475 .map
476 .get(left)
477 .map(|value| value == right)
478 .unwrap_or(false),
479 Self::NotEqual(left, right) => {
480 cx.map.get(left).map(|value| value != right).unwrap_or(true)
481 }
482 Self::Not(pred) => !pred.eval(cx),
483 Self::And(left, right) => left.eval(cx) && right.eval(cx),
484 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
485 }
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use anyhow::Result;
492 use serde::Deserialize;
493
494 use crate::{actions, impl_actions};
495
496 use super::*;
497
498 #[test]
499 fn test_push_keystroke() -> Result<()> {
500 actions!(test, [B, AB, C, D, DA]);
501
502 let mut ctx1 = Context::default();
503 ctx1.set.insert("1".into());
504
505 let mut ctx2 = Context::default();
506 ctx2.set.insert("2".into());
507
508 let dispatch_path = vec![(2, ctx2), (1, ctx1)];
509
510 let keymap = Keymap::new(vec![
511 Binding::new("a b", AB, Some("1")),
512 Binding::new("b", B, Some("2")),
513 Binding::new("c", C, Some("2")),
514 Binding::new("d", D, Some("1")),
515 Binding::new("d", D, Some("2")),
516 Binding::new("d a", DA, Some("2")),
517 ]);
518
519 let mut matcher = Matcher::new(keymap);
520
521 // Binding with pending prefix always takes precedence
522 assert_eq!(
523 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
524 MatchResult::Pending,
525 );
526 // B alone doesn't match because a was pending, so AB is returned instead
527 assert_eq!(
528 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
529 MatchResult::Matches(vec![(1, Box::new(AB))]),
530 );
531 assert!(!matcher.has_pending_keystrokes());
532
533 // Without an a prefix, B is dispatched like expected
534 assert_eq!(
535 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
536 MatchResult::Matches(vec![(2, Box::new(B))]),
537 );
538 assert!(!matcher.has_pending_keystrokes());
539
540 // If a is prefixed, C will not be dispatched because there
541 // was a pending binding for it
542 assert_eq!(
543 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
544 MatchResult::Pending,
545 );
546 assert_eq!(
547 matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone()),
548 MatchResult::None,
549 );
550 assert!(!matcher.has_pending_keystrokes());
551
552 // If a single keystroke matches multiple bindings in the tree
553 // all of them are returned so that we can fallback if the action
554 // handler decides to propagate the action
555 assert_eq!(
556 matcher.push_keystroke(Keystroke::parse("d")?, dispatch_path.clone()),
557 MatchResult::Matches(vec![(2, Box::new(D)), (1, Box::new(D))]),
558 );
559 // If none of the d action handlers consume the binding, a pending
560 // binding may then be used
561 assert_eq!(
562 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
563 MatchResult::Matches(vec![(2, Box::new(DA))]),
564 );
565 assert!(!matcher.has_pending_keystrokes());
566
567 Ok(())
568 }
569
570 #[test]
571 fn test_keystroke_parsing() -> Result<()> {
572 assert_eq!(
573 Keystroke::parse("ctrl-p")?,
574 Keystroke {
575 key: "p".into(),
576 ctrl: true,
577 alt: false,
578 shift: false,
579 cmd: false,
580 function: false,
581 }
582 );
583
584 assert_eq!(
585 Keystroke::parse("alt-shift-down")?,
586 Keystroke {
587 key: "down".into(),
588 ctrl: false,
589 alt: true,
590 shift: true,
591 cmd: false,
592 function: false,
593 }
594 );
595
596 assert_eq!(
597 Keystroke::parse("shift-cmd--")?,
598 Keystroke {
599 key: "-".into(),
600 ctrl: false,
601 alt: false,
602 shift: true,
603 cmd: true,
604 function: false,
605 }
606 );
607
608 Ok(())
609 }
610
611 #[test]
612 fn test_context_predicate_parsing() -> Result<()> {
613 use ContextPredicate::*;
614
615 assert_eq!(
616 ContextPredicate::parse("a && (b == c || d != e)")?,
617 And(
618 Box::new(Identifier("a".into())),
619 Box::new(Or(
620 Box::new(Equal("b".into(), "c".into())),
621 Box::new(NotEqual("d".into(), "e".into())),
622 ))
623 )
624 );
625
626 assert_eq!(
627 ContextPredicate::parse("!a")?,
628 Not(Box::new(Identifier("a".into())),)
629 );
630
631 Ok(())
632 }
633
634 #[test]
635 fn test_context_predicate_eval() -> Result<()> {
636 let predicate = ContextPredicate::parse("a && b || c == d")?;
637
638 let mut context = Context::default();
639 context.set.insert("a".into());
640 assert!(!predicate.eval(&context));
641
642 context.set.insert("b".into());
643 assert!(predicate.eval(&context));
644
645 context.set.remove("b");
646 context.map.insert("c".into(), "x".into());
647 assert!(!predicate.eval(&context));
648
649 context.map.insert("c".into(), "d".into());
650 assert!(predicate.eval(&context));
651
652 let predicate = ContextPredicate::parse("!a")?;
653 assert!(predicate.eval(&Context::default()));
654
655 Ok(())
656 }
657
658 #[test]
659 fn test_matcher() -> Result<()> {
660 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
661 pub struct A(pub String);
662 impl_actions!(test, [A]);
663 actions!(test, [B, Ab]);
664
665 #[derive(Clone, Debug, Eq, PartialEq)]
666 struct ActionArg {
667 a: &'static str,
668 }
669
670 let keymap = Keymap::new(vec![
671 Binding::new("a", A("x".to_string()), Some("a")),
672 Binding::new("b", B, Some("a")),
673 Binding::new("a b", Ab, Some("a || b")),
674 ]);
675
676 let mut ctx_a = Context::default();
677 ctx_a.set.insert("a".into());
678
679 let mut ctx_b = Context::default();
680 ctx_b.set.insert("b".into());
681
682 let mut matcher = Matcher::new(keymap);
683
684 // Basic match
685 assert_eq!(
686 matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_a.clone())]),
687 MatchResult::Matches(vec![(1, Box::new(A("x".to_string())))])
688 );
689 matcher.clear_pending();
690
691 // Multi-keystroke match
692 assert_eq!(
693 matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_b.clone())]),
694 MatchResult::Pending
695 );
696 assert_eq!(
697 matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_b.clone())]),
698 MatchResult::Matches(vec![(1, Box::new(Ab))])
699 );
700 matcher.clear_pending();
701
702 // Failed matches don't interfere with matching subsequent keys
703 assert_eq!(
704 matcher.push_keystroke(Keystroke::parse("x")?, vec![(1, ctx_a.clone())]),
705 MatchResult::None
706 );
707 assert_eq!(
708 matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_a.clone())]),
709 MatchResult::Matches(vec![(1, Box::new(A("x".to_string())))])
710 );
711 matcher.clear_pending();
712
713 // Pending keystrokes are cleared when the context changes
714 assert_eq!(
715 matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_b.clone())]),
716 MatchResult::Pending
717 );
718 assert_eq!(
719 matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_a.clone())]),
720 MatchResult::None
721 );
722 matcher.clear_pending();
723
724 let mut ctx_c = Context::default();
725 ctx_c.set.insert("c".into());
726
727 // Pending keystrokes are maintained per-view
728 assert_eq!(
729 matcher.push_keystroke(
730 Keystroke::parse("a")?,
731 vec![(1, ctx_b.clone()), (2, ctx_c.clone())]
732 ),
733 MatchResult::Pending
734 );
735 assert_eq!(
736 matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_b.clone())]),
737 MatchResult::Matches(vec![(1, Box::new(Ab))])
738 );
739
740 Ok(())
741 }
742}