1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: SmallVec<[Keystroke; 2]>,
34 action: Box<dyn Action>,
35 context_predicate: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub function: bool,
45 pub key: String,
46}
47
48#[derive(Clone, Debug, Default, Eq, PartialEq)]
49pub struct Context {
50 pub set: HashSet<String>,
51 pub map: HashMap<String, String>,
52}
53
54#[derive(Debug, Eq, PartialEq)]
55enum ContextPredicate {
56 Identifier(String),
57 Equal(String, String),
58 NotEqual(String, String),
59 Not(Box<ContextPredicate>),
60 And(Box<ContextPredicate>, Box<ContextPredicate>),
61 Or(Box<ContextPredicate>, Box<ContextPredicate>),
62}
63
64trait ActionArg {
65 fn boxed_clone(&self) -> Box<dyn Any>;
66}
67
68impl<T> ActionArg for T
69where
70 T: 'static + Any + Clone,
71{
72 fn boxed_clone(&self) -> Box<dyn Any> {
73 Box::new(self.clone())
74 }
75}
76
77pub enum MatchResult {
78 None,
79 Pending,
80 Match(Vec<(usize, Box<dyn Action>)>),
81}
82
83impl Debug for MatchResult {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 MatchResult::None => f.debug_struct("MatchResult2::None").finish(),
87 MatchResult::Pending => f.debug_struct("MatchResult2::Pending").finish(),
88 MatchResult::Match { view_id, action } => f
89 .debug_struct("MatchResult::Match")
90 .field("view_id", view_id)
91 .field("action", &action.name())
92 .finish(),
93 }
94 }
95}
96
97impl PartialEq for MatchResult {
98 fn eq(&self, other: &Self) -> bool {
99 match (self, other) {
100 (MatchResult::None, MatchResult::None) => true,
101 (MatchResult::Pending, MatchResult::Pending) => true,
102 (
103 MatchResult::Match { view_id, action },
104 MatchResult::Match {
105 view_id: other_view_id,
106 action: other_action,
107 },
108 ) => view_id == other_view_id && action.eq(other_action.as_ref()),
109 _ => false,
110 }
111 }
112}
113
114impl Eq for MatchResult {}
115
116impl Matcher {
117 pub fn new(keymap: Keymap) -> Self {
118 Self {
119 pending: HashMap::new(),
120 keymap,
121 }
122 }
123
124 pub fn set_keymap(&mut self, keymap: Keymap) {
125 self.pending.clear();
126 self.keymap = keymap;
127 }
128
129 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
130 self.pending.clear();
131 self.keymap.add_bindings(bindings);
132 }
133
134 pub fn clear_bindings(&mut self) {
135 self.pending.clear();
136 self.keymap.clear();
137 }
138
139 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
140 self.keymap.bindings_for_action_type(action_type)
141 }
142
143 pub fn clear_pending(&mut self) {
144 self.pending.clear();
145 }
146
147 pub fn has_pending_keystrokes(&self) -> bool {
148 !self.pending.is_empty()
149 }
150
151 pub fn push_keystroke(
152 &mut self,
153 keystroke: Keystroke,
154 dispatch_path: Vec<(usize, Context)>,
155 ) -> MatchResult {
156 let mut any_pending = false;
157 let mut matched_bindings = Vec::new();
158
159 let first_keystroke = self.pending.is_empty();
160 dbg!(&dispatch_path);
161 for (view_id, context) in dispatch_path {
162 if !first_keystroke && !self.pending.contains_key(&view_id) {
163 continue;
164 }
165
166 let pending = self.pending.entry(view_id).or_default();
167
168 if let Some(pending_context) = pending.context.as_ref() {
169 if pending_context != &context {
170 pending.keystrokes.clear();
171 }
172 }
173
174 pending.keystrokes.push(keystroke.clone());
175
176 let mut retain_pending = false;
177 for binding in self.keymap.bindings.iter().rev() {
178 if binding.keystrokes.starts_with(&pending.keystrokes)
179 && binding
180 .context_predicate
181 .as_ref()
182 .map(|c| c.eval(&context))
183 .unwrap_or(true)
184 {
185 if binding.keystrokes.len() == pending.keystrokes.len() {
186 self.pending.remove(&view_id);
187 matched_bindings.push((view_id, binding.action.boxed_clone()));
188 } else {
189 retain_pending = true;
190 pending.context = Some(context.clone());
191 }
192 }
193 }
194
195 if retain_pending {
196 any_pending = true;
197 } else {
198 self.pending.remove(&view_id);
199 }
200 }
201
202 if !matched_bindings.is_empty() {
203 MatchResult::Match(matched_bindings)
204 } else if any_pending {
205 MatchResult::Pending
206 } else {
207 MatchResult::None
208 }
209 }
210
211 pub fn keystrokes_for_action(
212 &self,
213 action: &dyn Action,
214 cx: &Context,
215 ) -> Option<SmallVec<[Keystroke; 2]>> {
216 for binding in self.keymap.bindings.iter().rev() {
217 if binding.action.eq(action)
218 && binding
219 .context_predicate
220 .as_ref()
221 .map_or(true, |predicate| predicate.eval(cx))
222 {
223 return Some(binding.keystrokes.clone());
224 }
225 }
226 None
227 }
228}
229
230impl Default for Matcher {
231 fn default() -> Self {
232 Self::new(Keymap::default())
233 }
234}
235
236impl Keymap {
237 pub fn new(bindings: Vec<Binding>) -> Self {
238 let mut binding_indices_by_action_type = HashMap::new();
239 for (ix, binding) in bindings.iter().enumerate() {
240 binding_indices_by_action_type
241 .entry(binding.action.as_any().type_id())
242 .or_insert_with(SmallVec::new)
243 .push(ix);
244 }
245 Self {
246 binding_indices_by_action_type,
247 bindings,
248 }
249 }
250
251 fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
252 self.binding_indices_by_action_type
253 .get(&action_type)
254 .map(SmallVec::as_slice)
255 .unwrap_or(&[])
256 .iter()
257 .map(|ix| &self.bindings[*ix])
258 }
259
260 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
261 for binding in bindings {
262 self.binding_indices_by_action_type
263 .entry(binding.action.as_any().type_id())
264 .or_default()
265 .push(self.bindings.len());
266 self.bindings.push(binding);
267 }
268 }
269
270 fn clear(&mut self) {
271 self.bindings.clear();
272 self.binding_indices_by_action_type.clear();
273 }
274}
275
276impl Binding {
277 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
278 Self::load(keystrokes, Box::new(action), context).unwrap()
279 }
280
281 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
282 let context = if let Some(context) = context {
283 Some(ContextPredicate::parse(context)?)
284 } else {
285 None
286 };
287
288 let keystrokes = keystrokes
289 .split_whitespace()
290 .map(Keystroke::parse)
291 .collect::<Result<_>>()?;
292
293 Ok(Self {
294 keystrokes,
295 action,
296 context_predicate: context,
297 })
298 }
299
300 pub fn keystrokes(&self) -> &[Keystroke] {
301 &self.keystrokes
302 }
303
304 pub fn action(&self) -> &dyn Action {
305 self.action.as_ref()
306 }
307}
308
309impl Keystroke {
310 pub fn parse(source: &str) -> anyhow::Result<Self> {
311 let mut ctrl = false;
312 let mut alt = false;
313 let mut shift = false;
314 let mut cmd = false;
315 let mut function = false;
316 let mut key = None;
317
318 let mut components = source.split('-').peekable();
319 while let Some(component) = components.next() {
320 match component {
321 "ctrl" => ctrl = true,
322 "alt" => alt = true,
323 "shift" => shift = true,
324 "cmd" => cmd = true,
325 "fn" => function = true,
326 _ => {
327 if let Some(component) = components.peek() {
328 if component.is_empty() && source.ends_with('-') {
329 key = Some(String::from("-"));
330 break;
331 } else {
332 return Err(anyhow!("Invalid keystroke `{}`", source));
333 }
334 } else {
335 key = Some(String::from(component));
336 }
337 }
338 }
339 }
340
341 let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
342
343 Ok(Keystroke {
344 ctrl,
345 alt,
346 shift,
347 cmd,
348 function,
349 key,
350 })
351 }
352
353 pub fn modified(&self) -> bool {
354 self.ctrl || self.alt || self.shift || self.cmd
355 }
356}
357
358impl std::fmt::Display for Keystroke {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 if self.ctrl {
361 f.write_char('^')?;
362 }
363 if self.alt {
364 f.write_char('⎇')?;
365 }
366 if self.cmd {
367 f.write_char('⌘')?;
368 }
369 if self.shift {
370 f.write_char('⇧')?;
371 }
372 let key = match self.key.as_str() {
373 "backspace" => '⌫',
374 "up" => '↑',
375 "down" => '↓',
376 "left" => '←',
377 "right" => '→',
378 "tab" => '⇥',
379 "escape" => '⎋',
380 key => {
381 if key.len() == 1 {
382 key.chars().next().unwrap().to_ascii_uppercase()
383 } else {
384 return f.write_str(key);
385 }
386 }
387 };
388 f.write_char(key)
389 }
390}
391
392impl Context {
393 pub fn extend(&mut self, other: &Context) {
394 for v in &other.set {
395 self.set.insert(v.clone());
396 }
397 for (k, v) in &other.map {
398 self.map.insert(k.clone(), v.clone());
399 }
400 }
401}
402
403impl ContextPredicate {
404 fn parse(source: &str) -> anyhow::Result<Self> {
405 let mut parser = Parser::new();
406 let language = unsafe { tree_sitter_context_predicate() };
407 parser.set_language(language).unwrap();
408 let source = source.as_bytes();
409 let tree = parser.parse(source, None).unwrap();
410 Self::from_node(tree.root_node(), source)
411 }
412
413 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
414 let parse_error = "error parsing context predicate";
415 let kind = node.kind();
416
417 match kind {
418 "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
419 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
420 "not" => {
421 let child = Self::from_node(
422 node.child_by_field_name("expression")
423 .ok_or_else(|| anyhow!(parse_error))?,
424 source,
425 )?;
426 Ok(Self::Not(Box::new(child)))
427 }
428 "and" | "or" => {
429 let left = Box::new(Self::from_node(
430 node.child_by_field_name("left")
431 .ok_or_else(|| anyhow!(parse_error))?,
432 source,
433 )?);
434 let right = Box::new(Self::from_node(
435 node.child_by_field_name("right")
436 .ok_or_else(|| anyhow!(parse_error))?,
437 source,
438 )?);
439 if kind == "and" {
440 Ok(Self::And(left, right))
441 } else {
442 Ok(Self::Or(left, right))
443 }
444 }
445 "equal" | "not_equal" => {
446 let left = node
447 .child_by_field_name("left")
448 .ok_or_else(|| anyhow!(parse_error))?
449 .utf8_text(source)?
450 .into();
451 let right = node
452 .child_by_field_name("right")
453 .ok_or_else(|| anyhow!(parse_error))?
454 .utf8_text(source)?
455 .into();
456 if kind == "equal" {
457 Ok(Self::Equal(left, right))
458 } else {
459 Ok(Self::NotEqual(left, right))
460 }
461 }
462 "parenthesized" => Self::from_node(
463 node.child_by_field_name("expression")
464 .ok_or_else(|| anyhow!(parse_error))?,
465 source,
466 ),
467 _ => Err(anyhow!(parse_error)),
468 }
469 }
470
471 fn eval(&self, cx: &Context) -> bool {
472 match self {
473 Self::Identifier(name) => cx.set.contains(name.as_str()),
474 Self::Equal(left, right) => cx
475 .map
476 .get(left)
477 .map(|value| value == right)
478 .unwrap_or(false),
479 Self::NotEqual(left, right) => {
480 cx.map.get(left).map(|value| value != right).unwrap_or(true)
481 }
482 Self::Not(pred) => !pred.eval(cx),
483 Self::And(left, right) => left.eval(cx) && right.eval(cx),
484 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
485 }
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use anyhow::Result;
492 use serde::Deserialize;
493
494 use crate::{actions, impl_actions};
495
496 use super::*;
497
498 #[test]
499 fn test_push_keystroke() -> Result<()> {
500 actions!(test, [B, AB, C]);
501
502 let mut ctx1 = Context::default();
503 ctx1.set.insert("1".into());
504
505 let mut ctx2 = Context::default();
506 ctx2.set.insert("2".into());
507
508 let dispatch_path = vec![(2, ctx2), (1, ctx1)];
509
510 let keymap = Keymap::new(vec![
511 Binding::new("a b", AB, Some("1")),
512 Binding::new("b", B, Some("2")),
513 Binding::new("c", C, Some("2")),
514 ]);
515
516 let mut matcher = Matcher::new(keymap);
517
518 assert_eq!(
519 MatchResult::Pending,
520 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
521 );
522 assert_eq!(
523 MatchResult::Match {
524 view_id: 1,
525 action: Box::new(AB)
526 },
527 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
528 );
529 assert!(matcher.pending.is_empty());
530 assert_eq!(
531 MatchResult::Match {
532 view_id: 2,
533 action: Box::new(B)
534 },
535 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
536 );
537 assert!(matcher.pending.is_empty());
538 assert_eq!(
539 MatchResult::Pending,
540 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
541 );
542 assert_eq!(
543 MatchResult::None,
544 matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone())
545 );
546 assert!(matcher.pending.is_empty());
547
548 Ok(())
549 }
550
551 #[test]
552 fn test_keystroke_parsing() -> Result<()> {
553 assert_eq!(
554 Keystroke::parse("ctrl-p")?,
555 Keystroke {
556 key: "p".into(),
557 ctrl: true,
558 alt: false,
559 shift: false,
560 cmd: false,
561 function: false,
562 }
563 );
564
565 assert_eq!(
566 Keystroke::parse("alt-shift-down")?,
567 Keystroke {
568 key: "down".into(),
569 ctrl: false,
570 alt: true,
571 shift: true,
572 cmd: false,
573 function: false,
574 }
575 );
576
577 assert_eq!(
578 Keystroke::parse("shift-cmd--")?,
579 Keystroke {
580 key: "-".into(),
581 ctrl: false,
582 alt: false,
583 shift: true,
584 cmd: true,
585 function: false,
586 }
587 );
588
589 Ok(())
590 }
591
592 #[test]
593 fn test_context_predicate_parsing() -> Result<()> {
594 use ContextPredicate::*;
595
596 assert_eq!(
597 ContextPredicate::parse("a && (b == c || d != e)")?,
598 And(
599 Box::new(Identifier("a".into())),
600 Box::new(Or(
601 Box::new(Equal("b".into(), "c".into())),
602 Box::new(NotEqual("d".into(), "e".into())),
603 ))
604 )
605 );
606
607 assert_eq!(
608 ContextPredicate::parse("!a")?,
609 Not(Box::new(Identifier("a".into())),)
610 );
611
612 Ok(())
613 }
614
615 #[test]
616 fn test_context_predicate_eval() -> Result<()> {
617 let predicate = ContextPredicate::parse("a && b || c == d")?;
618
619 let mut context = Context::default();
620 context.set.insert("a".into());
621 assert!(!predicate.eval(&context));
622
623 context.set.insert("b".into());
624 assert!(predicate.eval(&context));
625
626 context.set.remove("b");
627 context.map.insert("c".into(), "x".into());
628 assert!(!predicate.eval(&context));
629
630 context.map.insert("c".into(), "d".into());
631 assert!(predicate.eval(&context));
632
633 let predicate = ContextPredicate::parse("!a")?;
634 assert!(predicate.eval(&Context::default()));
635
636 Ok(())
637 }
638
639 #[test]
640 fn test_matcher() -> Result<()> {
641 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
642 pub struct A(pub String);
643 impl_actions!(test, [A]);
644 actions!(test, [B, Ab]);
645
646 #[derive(Clone, Debug, Eq, PartialEq)]
647 struct ActionArg {
648 a: &'static str,
649 }
650
651 let keymap = Keymap::new(vec![
652 Binding::new("a", A("x".to_string()), Some("a")),
653 Binding::new("b", B, Some("a")),
654 Binding::new("a b", Ab, Some("a || b")),
655 ]);
656
657 let mut ctx_a = Context::default();
658 ctx_a.set.insert("a".into());
659
660 let mut ctx_b = Context::default();
661 ctx_b.set.insert("b".into());
662
663 let mut matcher = Matcher::new(keymap);
664
665 // Basic match
666 assert_eq!(
667 downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
668 Some(&A("x".to_string()))
669 );
670
671 // Multi-keystroke match
672 assert!(matcher
673 .test_keystroke("a", vec![(1, ctx_b.clone())])
674 .is_none());
675 assert_eq!(
676 downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
677 Some(&Ab)
678 );
679
680 // Failed matches don't interfere with matching subsequent keys
681 assert!(matcher
682 .test_keystroke("x", vec![(1, ctx_a.clone())])
683 .is_none());
684 assert_eq!(
685 downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
686 Some(&A("x".to_string()))
687 );
688
689 // Pending keystrokes are cleared when the context changes
690 assert!(&matcher
691 .test_keystroke("a", vec![(1, ctx_b.clone())])
692 .is_none());
693 assert_eq!(
694 downcast(&matcher.test_keystroke("b", vec![(1, ctx_a.clone())])),
695 Some(&B)
696 );
697
698 let mut ctx_c = Context::default();
699 ctx_c.set.insert("c".into());
700
701 // Pending keystrokes are maintained per-view
702 assert!(matcher
703 .test_keystroke("a", vec![(1, ctx_b.clone()), (2, ctx_c.clone())])
704 .is_none());
705 assert_eq!(
706 downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
707 Some(&Ab)
708 );
709
710 Ok(())
711 }
712
713 fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
714 action
715 .as_ref()
716 .and_then(|action| action.as_any().downcast_ref())
717 }
718
719 impl Matcher {
720 fn test_keystroke(
721 &mut self,
722 keystroke: &str,
723 dispatch_path: Vec<(usize, Context)>,
724 ) -> Option<Box<dyn Action>> {
725 if let MatchResult::Match { action, .. } =
726 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), dispatch_path)
727 {
728 Some(action.boxed_clone())
729 } else {
730 None
731 }
732 }
733 }
734}