1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: SmallVec<[Keystroke; 2]>,
34 action: Box<dyn Action>,
35 context_predicate: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub function: bool,
45 pub key: String,
46}
47
48#[derive(Clone, Debug, Default, Eq, PartialEq)]
49pub struct Context {
50 pub set: HashSet<String>,
51 pub map: HashMap<String, String>,
52}
53
54#[derive(Debug, Eq, PartialEq)]
55enum ContextPredicate {
56 Identifier(String),
57 Equal(String, String),
58 NotEqual(String, String),
59 Not(Box<ContextPredicate>),
60 And(Box<ContextPredicate>, Box<ContextPredicate>),
61 Or(Box<ContextPredicate>, Box<ContextPredicate>),
62}
63
64trait ActionArg {
65 fn boxed_clone(&self) -> Box<dyn Any>;
66}
67
68impl<T> ActionArg for T
69where
70 T: 'static + Any + Clone,
71{
72 fn boxed_clone(&self) -> Box<dyn Any> {
73 Box::new(self.clone())
74 }
75}
76
77pub enum MatchResult {
78 None,
79 Pending,
80 Match {
81 view_id: usize,
82 action: Box<dyn Action>,
83 },
84}
85
86impl Debug for MatchResult {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 match self {
89 MatchResult::None => f.debug_struct("MatchResult2::None").finish(),
90 MatchResult::Pending => f.debug_struct("MatchResult2::Pending").finish(),
91 MatchResult::Match { view_id, action } => f
92 .debug_struct("MatchResult::Match")
93 .field("view_id", view_id)
94 .field("action", &action.name())
95 .finish(),
96 }
97 }
98}
99
100impl PartialEq for MatchResult {
101 fn eq(&self, other: &Self) -> bool {
102 match (self, other) {
103 (MatchResult::None, MatchResult::None) => true,
104 (MatchResult::Pending, MatchResult::Pending) => true,
105 (
106 MatchResult::Match { view_id, action },
107 MatchResult::Match {
108 view_id: other_view_id,
109 action: other_action,
110 },
111 ) => view_id == other_view_id && action.eq(other_action.as_ref()),
112 _ => false,
113 }
114 }
115}
116
117impl Eq for MatchResult {}
118
119impl Matcher {
120 pub fn new(keymap: Keymap) -> Self {
121 Self {
122 pending: HashMap::new(),
123 keymap,
124 }
125 }
126
127 pub fn set_keymap(&mut self, keymap: Keymap) {
128 self.pending.clear();
129 self.keymap = keymap;
130 }
131
132 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
133 self.pending.clear();
134 self.keymap.add_bindings(bindings);
135 }
136
137 pub fn clear_bindings(&mut self) {
138 self.pending.clear();
139 self.keymap.clear();
140 }
141
142 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
143 self.keymap.bindings_for_action_type(action_type)
144 }
145
146 pub fn clear_pending(&mut self) {
147 self.pending.clear();
148 }
149
150 pub fn has_pending_keystrokes(&self) -> bool {
151 !self.pending.is_empty()
152 }
153
154 pub fn push_keystroke(
155 &mut self,
156 keystroke: Keystroke,
157 dispatch_path: Vec<(usize, Context)>,
158 ) -> MatchResult {
159 let mut any_pending = false;
160
161 let first_keystroke = self.pending.is_empty();
162 for (view_id, context) in dispatch_path {
163 if !first_keystroke && !self.pending.contains_key(&view_id) {
164 continue;
165 }
166
167 let pending = self.pending.entry(view_id).or_default();
168
169 if let Some(pending_context) = pending.context.as_ref() {
170 if pending_context != &context {
171 pending.keystrokes.clear();
172 }
173 }
174
175 pending.keystrokes.push(keystroke.clone());
176
177 let mut retain_pending = false;
178 for binding in self.keymap.bindings.iter().rev() {
179 if binding.keystrokes.starts_with(&pending.keystrokes)
180 && binding
181 .context_predicate
182 .as_ref()
183 .map(|c| c.eval(&context))
184 .unwrap_or(true)
185 {
186 if binding.keystrokes.len() == pending.keystrokes.len() {
187 self.pending.remove(&view_id);
188 return MatchResult::Match {
189 view_id,
190 action: binding.action.boxed_clone(),
191 };
192 } else {
193 retain_pending = true;
194 pending.context = Some(context.clone());
195 }
196 }
197 }
198
199 if retain_pending {
200 any_pending = true;
201 } else {
202 self.pending.remove(&view_id);
203 }
204 }
205
206 if any_pending {
207 MatchResult::Pending
208 } else {
209 MatchResult::None
210 }
211 }
212
213 pub fn keystrokes_for_action(
214 &self,
215 action: &dyn Action,
216 cx: &Context,
217 ) -> Option<SmallVec<[Keystroke; 2]>> {
218 for binding in self.keymap.bindings.iter().rev() {
219 if binding.action.eq(action)
220 && binding
221 .context_predicate
222 .as_ref()
223 .map_or(true, |predicate| predicate.eval(cx))
224 {
225 return Some(binding.keystrokes.clone());
226 }
227 }
228 None
229 }
230}
231
232impl Default for Matcher {
233 fn default() -> Self {
234 Self::new(Keymap::default())
235 }
236}
237
238impl Keymap {
239 pub fn new(bindings: Vec<Binding>) -> Self {
240 let mut binding_indices_by_action_type = HashMap::new();
241 for (ix, binding) in bindings.iter().enumerate() {
242 binding_indices_by_action_type
243 .entry(binding.action.as_any().type_id())
244 .or_insert_with(SmallVec::new)
245 .push(ix);
246 }
247 Self {
248 binding_indices_by_action_type,
249 bindings,
250 }
251 }
252
253 fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
254 self.binding_indices_by_action_type
255 .get(&action_type)
256 .map(SmallVec::as_slice)
257 .unwrap_or(&[])
258 .iter()
259 .map(|ix| &self.bindings[*ix])
260 }
261
262 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
263 for binding in bindings {
264 self.binding_indices_by_action_type
265 .entry(binding.action.as_any().type_id())
266 .or_default()
267 .push(self.bindings.len());
268 self.bindings.push(binding);
269 }
270 }
271
272 fn clear(&mut self) {
273 self.bindings.clear();
274 self.binding_indices_by_action_type.clear();
275 }
276}
277
278impl Binding {
279 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
280 Self::load(keystrokes, Box::new(action), context).unwrap()
281 }
282
283 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
284 let context = if let Some(context) = context {
285 Some(ContextPredicate::parse(context)?)
286 } else {
287 None
288 };
289
290 let keystrokes = keystrokes
291 .split_whitespace()
292 .map(Keystroke::parse)
293 .collect::<Result<_>>()?;
294
295 Ok(Self {
296 keystrokes,
297 action,
298 context_predicate: context,
299 })
300 }
301
302 pub fn keystrokes(&self) -> &[Keystroke] {
303 &self.keystrokes
304 }
305
306 pub fn action(&self) -> &dyn Action {
307 self.action.as_ref()
308 }
309}
310
311impl Keystroke {
312 pub fn parse(source: &str) -> anyhow::Result<Self> {
313 let mut ctrl = false;
314 let mut alt = false;
315 let mut shift = false;
316 let mut cmd = false;
317 let mut function = false;
318 let mut key = None;
319
320 let mut components = source.split('-').peekable();
321 while let Some(component) = components.next() {
322 match component {
323 "ctrl" => ctrl = true,
324 "alt" => alt = true,
325 "shift" => shift = true,
326 "cmd" => cmd = true,
327 "fn" => function = true,
328 _ => {
329 if let Some(component) = components.peek() {
330 if component.is_empty() && source.ends_with('-') {
331 key = Some(String::from("-"));
332 break;
333 } else {
334 return Err(anyhow!("Invalid keystroke `{}`", source));
335 }
336 } else {
337 key = Some(String::from(component));
338 }
339 }
340 }
341 }
342
343 let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
344
345 Ok(Keystroke {
346 ctrl,
347 alt,
348 shift,
349 cmd,
350 function,
351 key,
352 })
353 }
354
355 pub fn modified(&self) -> bool {
356 self.ctrl || self.alt || self.shift || self.cmd
357 }
358}
359
360impl std::fmt::Display for Keystroke {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 if self.ctrl {
363 f.write_char('^')?;
364 }
365 if self.alt {
366 f.write_char('⎇')?;
367 }
368 if self.cmd {
369 f.write_char('⌘')?;
370 }
371 if self.shift {
372 f.write_char('⇧')?;
373 }
374 let key = match self.key.as_str() {
375 "backspace" => '⌫',
376 "up" => '↑',
377 "down" => '↓',
378 "left" => '←',
379 "right" => '→',
380 "tab" => '⇥',
381 "escape" => '⎋',
382 key => {
383 if key.len() == 1 {
384 key.chars().next().unwrap().to_ascii_uppercase()
385 } else {
386 return f.write_str(key);
387 }
388 }
389 };
390 f.write_char(key)
391 }
392}
393
394impl Context {
395 pub fn extend(&mut self, other: &Context) {
396 for v in &other.set {
397 self.set.insert(v.clone());
398 }
399 for (k, v) in &other.map {
400 self.map.insert(k.clone(), v.clone());
401 }
402 }
403}
404
405impl ContextPredicate {
406 fn parse(source: &str) -> anyhow::Result<Self> {
407 let mut parser = Parser::new();
408 let language = unsafe { tree_sitter_context_predicate() };
409 parser.set_language(language).unwrap();
410 let source = source.as_bytes();
411 let tree = parser.parse(source, None).unwrap();
412 Self::from_node(tree.root_node(), source)
413 }
414
415 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
416 let parse_error = "error parsing context predicate";
417 let kind = node.kind();
418
419 match kind {
420 "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
421 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
422 "not" => {
423 let child = Self::from_node(
424 node.child_by_field_name("expression")
425 .ok_or_else(|| anyhow!(parse_error))?,
426 source,
427 )?;
428 Ok(Self::Not(Box::new(child)))
429 }
430 "and" | "or" => {
431 let left = Box::new(Self::from_node(
432 node.child_by_field_name("left")
433 .ok_or_else(|| anyhow!(parse_error))?,
434 source,
435 )?);
436 let right = Box::new(Self::from_node(
437 node.child_by_field_name("right")
438 .ok_or_else(|| anyhow!(parse_error))?,
439 source,
440 )?);
441 if kind == "and" {
442 Ok(Self::And(left, right))
443 } else {
444 Ok(Self::Or(left, right))
445 }
446 }
447 "equal" | "not_equal" => {
448 let left = node
449 .child_by_field_name("left")
450 .ok_or_else(|| anyhow!(parse_error))?
451 .utf8_text(source)?
452 .into();
453 let right = node
454 .child_by_field_name("right")
455 .ok_or_else(|| anyhow!(parse_error))?
456 .utf8_text(source)?
457 .into();
458 if kind == "equal" {
459 Ok(Self::Equal(left, right))
460 } else {
461 Ok(Self::NotEqual(left, right))
462 }
463 }
464 "parenthesized" => Self::from_node(
465 node.child_by_field_name("expression")
466 .ok_or_else(|| anyhow!(parse_error))?,
467 source,
468 ),
469 _ => Err(anyhow!(parse_error)),
470 }
471 }
472
473 fn eval(&self, cx: &Context) -> bool {
474 match self {
475 Self::Identifier(name) => cx.set.contains(name.as_str()),
476 Self::Equal(left, right) => cx
477 .map
478 .get(left)
479 .map(|value| value == right)
480 .unwrap_or(false),
481 Self::NotEqual(left, right) => {
482 cx.map.get(left).map(|value| value != right).unwrap_or(true)
483 }
484 Self::Not(pred) => !pred.eval(cx),
485 Self::And(left, right) => left.eval(cx) && right.eval(cx),
486 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
487 }
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use anyhow::Result;
494 use serde::Deserialize;
495
496 use crate::{actions, impl_actions};
497
498 use super::*;
499
500 #[test]
501 fn test_push_keystroke() -> Result<()> {
502 actions!(test, [B, AB, C]);
503
504 let mut ctx1 = Context::default();
505 ctx1.set.insert("1".into());
506
507 let mut ctx2 = Context::default();
508 ctx2.set.insert("2".into());
509
510 let dispatch_path = vec![(2, ctx2), (1, ctx1)];
511
512 let keymap = Keymap::new(vec![
513 Binding::new("a b", AB, Some("1")),
514 Binding::new("b", B, Some("2")),
515 Binding::new("c", C, Some("2")),
516 ]);
517
518 let mut matcher = Matcher::new(keymap);
519
520 assert_eq!(
521 MatchResult::Pending,
522 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
523 );
524 assert_eq!(
525 MatchResult::Match {
526 view_id: 1,
527 action: Box::new(AB)
528 },
529 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
530 );
531 assert!(matcher.pending.is_empty());
532 assert_eq!(
533 MatchResult::Match {
534 view_id: 2,
535 action: Box::new(B)
536 },
537 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
538 );
539 assert!(matcher.pending.is_empty());
540 assert_eq!(
541 MatchResult::Pending,
542 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
543 );
544 assert_eq!(
545 MatchResult::None,
546 matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone())
547 );
548 assert!(matcher.pending.is_empty());
549
550 Ok(())
551 }
552
553 #[test]
554 fn test_keystroke_parsing() -> Result<()> {
555 assert_eq!(
556 Keystroke::parse("ctrl-p")?,
557 Keystroke {
558 key: "p".into(),
559 ctrl: true,
560 alt: false,
561 shift: false,
562 cmd: false,
563 function: false,
564 }
565 );
566
567 assert_eq!(
568 Keystroke::parse("alt-shift-down")?,
569 Keystroke {
570 key: "down".into(),
571 ctrl: false,
572 alt: true,
573 shift: true,
574 cmd: false,
575 function: false,
576 }
577 );
578
579 assert_eq!(
580 Keystroke::parse("shift-cmd--")?,
581 Keystroke {
582 key: "-".into(),
583 ctrl: false,
584 alt: false,
585 shift: true,
586 cmd: true,
587 function: false,
588 }
589 );
590
591 Ok(())
592 }
593
594 #[test]
595 fn test_context_predicate_parsing() -> Result<()> {
596 use ContextPredicate::*;
597
598 assert_eq!(
599 ContextPredicate::parse("a && (b == c || d != e)")?,
600 And(
601 Box::new(Identifier("a".into())),
602 Box::new(Or(
603 Box::new(Equal("b".into(), "c".into())),
604 Box::new(NotEqual("d".into(), "e".into())),
605 ))
606 )
607 );
608
609 assert_eq!(
610 ContextPredicate::parse("!a")?,
611 Not(Box::new(Identifier("a".into())),)
612 );
613
614 Ok(())
615 }
616
617 #[test]
618 fn test_context_predicate_eval() -> Result<()> {
619 let predicate = ContextPredicate::parse("a && b || c == d")?;
620
621 let mut context = Context::default();
622 context.set.insert("a".into());
623 assert!(!predicate.eval(&context));
624
625 context.set.insert("b".into());
626 assert!(predicate.eval(&context));
627
628 context.set.remove("b");
629 context.map.insert("c".into(), "x".into());
630 assert!(!predicate.eval(&context));
631
632 context.map.insert("c".into(), "d".into());
633 assert!(predicate.eval(&context));
634
635 let predicate = ContextPredicate::parse("!a")?;
636 assert!(predicate.eval(&Context::default()));
637
638 Ok(())
639 }
640
641 #[test]
642 fn test_matcher() -> Result<()> {
643 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
644 pub struct A(pub String);
645 impl_actions!(test, [A]);
646 actions!(test, [B, Ab]);
647
648 #[derive(Clone, Debug, Eq, PartialEq)]
649 struct ActionArg {
650 a: &'static str,
651 }
652
653 let keymap = Keymap::new(vec![
654 Binding::new("a", A("x".to_string()), Some("a")),
655 Binding::new("b", B, Some("a")),
656 Binding::new("a b", Ab, Some("a || b")),
657 ]);
658
659 let mut ctx_a = Context::default();
660 ctx_a.set.insert("a".into());
661
662 let mut ctx_b = Context::default();
663 ctx_b.set.insert("b".into());
664
665 let mut matcher = Matcher::new(keymap);
666
667 // Basic match
668 assert_eq!(
669 downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
670 Some(&A("x".to_string()))
671 );
672
673 // Multi-keystroke match
674 assert!(matcher
675 .test_keystroke("a", vec![(1, ctx_b.clone())])
676 .is_none());
677 assert_eq!(
678 downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
679 Some(&Ab)
680 );
681
682 // Failed matches don't interfere with matching subsequent keys
683 assert!(matcher
684 .test_keystroke("x", vec![(1, ctx_a.clone())])
685 .is_none());
686 assert_eq!(
687 downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
688 Some(&A("x".to_string()))
689 );
690
691 // Pending keystrokes are cleared when the context changes
692 assert!(&matcher
693 .test_keystroke("a", vec![(1, ctx_b.clone())])
694 .is_none());
695 assert_eq!(
696 downcast(&matcher.test_keystroke("b", vec![(1, ctx_a.clone())])),
697 Some(&B)
698 );
699
700 let mut ctx_c = Context::default();
701 ctx_c.set.insert("c".into());
702
703 // Pending keystrokes are maintained per-view
704 assert!(matcher
705 .test_keystroke("a", vec![(1, ctx_b.clone()), (2, ctx_c.clone())])
706 .is_none());
707 assert_eq!(
708 downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
709 Some(&Ab)
710 );
711
712 Ok(())
713 }
714
715 fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
716 action
717 .as_ref()
718 .and_then(|action| action.as_any().downcast_ref())
719 }
720
721 impl Matcher {
722 fn test_keystroke(
723 &mut self,
724 keystroke: &str,
725 dispatch_path: Vec<(usize, Context)>,
726 ) -> Option<Box<dyn Action>> {
727 if let MatchResult::Match { action, .. } =
728 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), dispatch_path)
729 {
730 Some(action.boxed_clone())
731 } else {
732 None
733 }
734 }
735 }
736}