1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending_views: HashMap<usize, Context>,
17 pending_keystrokes: Vec<Keystroke>,
18 keymap: Keymap,
19}
20
21#[derive(Default)]
22pub struct Keymap {
23 bindings: Vec<Binding>,
24 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
25}
26
27pub struct Binding {
28 keystrokes: SmallVec<[Keystroke; 2]>,
29 action: Box<dyn Action>,
30 context_predicate: Option<ContextPredicate>,
31}
32
33#[derive(Clone, Debug, Eq, PartialEq)]
34pub struct Keystroke {
35 pub ctrl: bool,
36 pub alt: bool,
37 pub shift: bool,
38 pub cmd: bool,
39 pub function: bool,
40 pub key: String,
41}
42
43#[derive(Clone, Debug, Default, Eq, PartialEq)]
44pub struct Context {
45 pub set: HashSet<String>,
46 pub map: HashMap<String, String>,
47}
48
49#[derive(Debug, Eq, PartialEq)]
50enum ContextPredicate {
51 Identifier(String),
52 Equal(String, String),
53 NotEqual(String, String),
54 Not(Box<ContextPredicate>),
55 And(Box<ContextPredicate>, Box<ContextPredicate>),
56 Or(Box<ContextPredicate>, Box<ContextPredicate>),
57}
58
59trait ActionArg {
60 fn boxed_clone(&self) -> Box<dyn Any>;
61}
62
63impl<T> ActionArg for T
64where
65 T: 'static + Any + Clone,
66{
67 fn boxed_clone(&self) -> Box<dyn Any> {
68 Box::new(self.clone())
69 }
70}
71
72pub enum MatchResult {
73 None,
74 Pending,
75 Matches(Vec<(usize, Box<dyn Action>)>),
76}
77
78impl Debug for MatchResult {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 MatchResult::None => f.debug_struct("MatchResult2::None").finish(),
82 MatchResult::Pending => f.debug_struct("MatchResult2::Pending").finish(),
83 MatchResult::Matches(matches) => f
84 .debug_list()
85 .entries(
86 matches
87 .iter()
88 .map(|(view_id, action)| format!("{view_id}, {}", action.name())),
89 )
90 .finish(),
91 }
92 }
93}
94
95impl PartialEq for MatchResult {
96 fn eq(&self, other: &Self) -> bool {
97 match (self, other) {
98 (MatchResult::None, MatchResult::None) => true,
99 (MatchResult::Pending, MatchResult::Pending) => true,
100 (MatchResult::Matches(matches), MatchResult::Matches(other_matches)) => {
101 matches.len() == other_matches.len()
102 && matches.iter().zip(other_matches.iter()).all(
103 |((view_id, action), (other_view_id, other_action))| {
104 view_id == other_view_id && action.eq(other_action.as_ref())
105 },
106 )
107 }
108 _ => false,
109 }
110 }
111}
112
113impl Eq for MatchResult {}
114
115impl Matcher {
116 pub fn new(keymap: Keymap) -> Self {
117 Self {
118 pending_views: HashMap::new(),
119 pending_keystrokes: Vec::new(),
120 keymap,
121 }
122 }
123
124 pub fn set_keymap(&mut self, keymap: Keymap) {
125 self.clear_pending();
126 self.keymap = keymap;
127 }
128
129 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
130 self.clear_pending();
131 self.keymap.add_bindings(bindings);
132 }
133
134 pub fn clear_bindings(&mut self) {
135 self.clear_pending();
136 self.keymap.clear();
137 }
138
139 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
140 self.keymap.bindings_for_action_type(action_type)
141 }
142
143 pub fn clear_pending(&mut self) {
144 self.pending_keystrokes.clear();
145 self.pending_views.clear();
146 }
147
148 pub fn has_pending_keystrokes(&self) -> bool {
149 !self.pending_keystrokes.is_empty()
150 }
151
152 pub fn push_keystroke(
153 &mut self,
154 keystroke: Keystroke,
155 dispatch_path: Vec<(usize, Context)>,
156 ) -> MatchResult {
157 let mut any_pending = false;
158 let mut matched_bindings = Vec::new();
159
160 let first_keystroke = self.pending_keystrokes.is_empty();
161 self.pending_keystrokes.push(keystroke);
162
163 for (view_id, context) in dispatch_path {
164 // Don't require pending view entry if there are no pending keystrokes
165 if !first_keystroke && !self.pending_views.contains_key(&view_id) {
166 continue;
167 }
168
169 // If there is a previous view context, invalidate that view if it
170 // has changed
171 if let Some(previous_view_context) = self.pending_views.remove(&view_id) {
172 if previous_view_context != context {
173 continue;
174 }
175 }
176
177 for binding in self.keymap.bindings.iter().rev() {
178 if binding.keystrokes.starts_with(&self.pending_keystrokes)
179 && binding
180 .context_predicate
181 .as_ref()
182 .map(|c| c.eval(&context))
183 .unwrap_or(true)
184 {
185 if binding.keystrokes.len() == self.pending_keystrokes.len() {
186 matched_bindings.push((view_id, binding.action.boxed_clone()));
187 } else {
188 self.pending_views.insert(view_id, context.clone());
189 any_pending = true;
190 }
191 }
192 }
193 }
194
195 if !matched_bindings.is_empty() {
196 self.pending_views.clear();
197 self.pending_keystrokes.clear();
198 MatchResult::Matches(matched_bindings)
199 } else if any_pending {
200 MatchResult::Pending
201 } else {
202 self.pending_keystrokes.clear();
203 MatchResult::None
204 }
205 }
206
207 pub fn keystrokes_for_action(
208 &self,
209 action: &dyn Action,
210 cx: &Context,
211 ) -> Option<SmallVec<[Keystroke; 2]>> {
212 for binding in self.keymap.bindings.iter().rev() {
213 if binding.action.eq(action)
214 && binding
215 .context_predicate
216 .as_ref()
217 .map_or(true, |predicate| predicate.eval(cx))
218 {
219 return Some(binding.keystrokes.clone());
220 }
221 }
222 None
223 }
224}
225
226impl Default for Matcher {
227 fn default() -> Self {
228 Self::new(Keymap::default())
229 }
230}
231
232impl Keymap {
233 pub fn new(bindings: Vec<Binding>) -> Self {
234 let mut binding_indices_by_action_type = HashMap::new();
235 for (ix, binding) in bindings.iter().enumerate() {
236 binding_indices_by_action_type
237 .entry(binding.action.as_any().type_id())
238 .or_insert_with(SmallVec::new)
239 .push(ix);
240 }
241 Self {
242 binding_indices_by_action_type,
243 bindings,
244 }
245 }
246
247 fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
248 self.binding_indices_by_action_type
249 .get(&action_type)
250 .map(SmallVec::as_slice)
251 .unwrap_or(&[])
252 .iter()
253 .map(|ix| &self.bindings[*ix])
254 }
255
256 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
257 for binding in bindings {
258 self.binding_indices_by_action_type
259 .entry(binding.action.as_any().type_id())
260 .or_default()
261 .push(self.bindings.len());
262 self.bindings.push(binding);
263 }
264 }
265
266 fn clear(&mut self) {
267 self.bindings.clear();
268 self.binding_indices_by_action_type.clear();
269 }
270}
271
272impl Binding {
273 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
274 Self::load(keystrokes, Box::new(action), context).unwrap()
275 }
276
277 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
278 let context = if let Some(context) = context {
279 Some(ContextPredicate::parse(context)?)
280 } else {
281 None
282 };
283
284 let keystrokes = keystrokes
285 .split_whitespace()
286 .map(Keystroke::parse)
287 .collect::<Result<_>>()?;
288
289 Ok(Self {
290 keystrokes,
291 action,
292 context_predicate: context,
293 })
294 }
295
296 pub fn keystrokes(&self) -> &[Keystroke] {
297 &self.keystrokes
298 }
299
300 pub fn action(&self) -> &dyn Action {
301 self.action.as_ref()
302 }
303}
304
305impl Keystroke {
306 pub fn parse(source: &str) -> anyhow::Result<Self> {
307 let mut ctrl = false;
308 let mut alt = false;
309 let mut shift = false;
310 let mut cmd = false;
311 let mut function = false;
312 let mut key = None;
313
314 let mut components = source.split('-').peekable();
315 while let Some(component) = components.next() {
316 match component {
317 "ctrl" => ctrl = true,
318 "alt" => alt = true,
319 "shift" => shift = true,
320 "cmd" => cmd = true,
321 "fn" => function = true,
322 _ => {
323 if let Some(component) = components.peek() {
324 if component.is_empty() && source.ends_with('-') {
325 key = Some(String::from("-"));
326 break;
327 } else {
328 return Err(anyhow!("Invalid keystroke `{}`", source));
329 }
330 } else {
331 key = Some(String::from(component));
332 }
333 }
334 }
335 }
336
337 let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
338
339 Ok(Keystroke {
340 ctrl,
341 alt,
342 shift,
343 cmd,
344 function,
345 key,
346 })
347 }
348
349 pub fn modified(&self) -> bool {
350 self.ctrl || self.alt || self.shift || self.cmd
351 }
352}
353
354impl std::fmt::Display for Keystroke {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 if self.ctrl {
357 f.write_char('^')?;
358 }
359 if self.alt {
360 f.write_char('⎇')?;
361 }
362 if self.cmd {
363 f.write_char('⌘')?;
364 }
365 if self.shift {
366 f.write_char('⇧')?;
367 }
368 let key = match self.key.as_str() {
369 "backspace" => '⌫',
370 "up" => '↑',
371 "down" => '↓',
372 "left" => '←',
373 "right" => '→',
374 "tab" => '⇥',
375 "escape" => '⎋',
376 key => {
377 if key.len() == 1 {
378 key.chars().next().unwrap().to_ascii_uppercase()
379 } else {
380 return f.write_str(key);
381 }
382 }
383 };
384 f.write_char(key)
385 }
386}
387
388impl Context {
389 pub fn extend(&mut self, other: &Context) {
390 for v in &other.set {
391 self.set.insert(v.clone());
392 }
393 for (k, v) in &other.map {
394 self.map.insert(k.clone(), v.clone());
395 }
396 }
397}
398
399impl ContextPredicate {
400 fn parse(source: &str) -> anyhow::Result<Self> {
401 let mut parser = Parser::new();
402 let language = unsafe { tree_sitter_context_predicate() };
403 parser.set_language(language).unwrap();
404 let source = source.as_bytes();
405 let tree = parser.parse(source, None).unwrap();
406 Self::from_node(tree.root_node(), source)
407 }
408
409 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
410 let parse_error = "error parsing context predicate";
411 let kind = node.kind();
412
413 match kind {
414 "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
415 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
416 "not" => {
417 let child = Self::from_node(
418 node.child_by_field_name("expression")
419 .ok_or_else(|| anyhow!(parse_error))?,
420 source,
421 )?;
422 Ok(Self::Not(Box::new(child)))
423 }
424 "and" | "or" => {
425 let left = Box::new(Self::from_node(
426 node.child_by_field_name("left")
427 .ok_or_else(|| anyhow!(parse_error))?,
428 source,
429 )?);
430 let right = Box::new(Self::from_node(
431 node.child_by_field_name("right")
432 .ok_or_else(|| anyhow!(parse_error))?,
433 source,
434 )?);
435 if kind == "and" {
436 Ok(Self::And(left, right))
437 } else {
438 Ok(Self::Or(left, right))
439 }
440 }
441 "equal" | "not_equal" => {
442 let left = node
443 .child_by_field_name("left")
444 .ok_or_else(|| anyhow!(parse_error))?
445 .utf8_text(source)?
446 .into();
447 let right = node
448 .child_by_field_name("right")
449 .ok_or_else(|| anyhow!(parse_error))?
450 .utf8_text(source)?
451 .into();
452 if kind == "equal" {
453 Ok(Self::Equal(left, right))
454 } else {
455 Ok(Self::NotEqual(left, right))
456 }
457 }
458 "parenthesized" => Self::from_node(
459 node.child_by_field_name("expression")
460 .ok_or_else(|| anyhow!(parse_error))?,
461 source,
462 ),
463 _ => Err(anyhow!(parse_error)),
464 }
465 }
466
467 fn eval(&self, cx: &Context) -> bool {
468 match self {
469 Self::Identifier(name) => cx.set.contains(name.as_str()),
470 Self::Equal(left, right) => cx
471 .map
472 .get(left)
473 .map(|value| value == right)
474 .unwrap_or(false),
475 Self::NotEqual(left, right) => {
476 cx.map.get(left).map(|value| value != right).unwrap_or(true)
477 }
478 Self::Not(pred) => !pred.eval(cx),
479 Self::And(left, right) => left.eval(cx) && right.eval(cx),
480 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
481 }
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use anyhow::Result;
488 use serde::Deserialize;
489
490 use crate::{actions, impl_actions};
491
492 use super::*;
493
494 #[test]
495 fn test_push_keystroke() -> Result<()> {
496 actions!(test, [B, AB, C]);
497
498 let mut ctx1 = Context::default();
499 ctx1.set.insert("1".into());
500
501 let mut ctx2 = Context::default();
502 ctx2.set.insert("2".into());
503
504 let dispatch_path = vec![(2, ctx2), (1, ctx1)];
505
506 let keymap = Keymap::new(vec![
507 Binding::new("a b", AB, Some("1")),
508 Binding::new("b", B, Some("2")),
509 Binding::new("c", C, Some("2")),
510 ]);
511
512 let mut matcher = Matcher::new(keymap);
513
514 assert_eq!(
515 MatchResult::Pending,
516 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
517 );
518 assert_eq!(
519 MatchResult::Matches(vec![(1, Box::new(AB))]),
520 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
521 );
522 assert!(!matcher.has_pending_keystrokes());
523 assert_eq!(
524 MatchResult::Matches(vec![(2, Box::new(B))]),
525 matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
526 );
527 assert!(!matcher.has_pending_keystrokes());
528 assert_eq!(
529 MatchResult::Pending,
530 matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
531 );
532 assert_eq!(
533 MatchResult::None,
534 matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone())
535 );
536 assert!(!matcher.has_pending_keystrokes());
537
538 Ok(())
539 }
540
541 #[test]
542 fn test_keystroke_parsing() -> Result<()> {
543 assert_eq!(
544 Keystroke::parse("ctrl-p")?,
545 Keystroke {
546 key: "p".into(),
547 ctrl: true,
548 alt: false,
549 shift: false,
550 cmd: false,
551 function: false,
552 }
553 );
554
555 assert_eq!(
556 Keystroke::parse("alt-shift-down")?,
557 Keystroke {
558 key: "down".into(),
559 ctrl: false,
560 alt: true,
561 shift: true,
562 cmd: false,
563 function: false,
564 }
565 );
566
567 assert_eq!(
568 Keystroke::parse("shift-cmd--")?,
569 Keystroke {
570 key: "-".into(),
571 ctrl: false,
572 alt: false,
573 shift: true,
574 cmd: true,
575 function: false,
576 }
577 );
578
579 Ok(())
580 }
581
582 #[test]
583 fn test_context_predicate_parsing() -> Result<()> {
584 use ContextPredicate::*;
585
586 assert_eq!(
587 ContextPredicate::parse("a && (b == c || d != e)")?,
588 And(
589 Box::new(Identifier("a".into())),
590 Box::new(Or(
591 Box::new(Equal("b".into(), "c".into())),
592 Box::new(NotEqual("d".into(), "e".into())),
593 ))
594 )
595 );
596
597 assert_eq!(
598 ContextPredicate::parse("!a")?,
599 Not(Box::new(Identifier("a".into())),)
600 );
601
602 Ok(())
603 }
604
605 #[test]
606 fn test_context_predicate_eval() -> Result<()> {
607 let predicate = ContextPredicate::parse("a && b || c == d")?;
608
609 let mut context = Context::default();
610 context.set.insert("a".into());
611 assert!(!predicate.eval(&context));
612
613 context.set.insert("b".into());
614 assert!(predicate.eval(&context));
615
616 context.set.remove("b");
617 context.map.insert("c".into(), "x".into());
618 assert!(!predicate.eval(&context));
619
620 context.map.insert("c".into(), "d".into());
621 assert!(predicate.eval(&context));
622
623 let predicate = ContextPredicate::parse("!a")?;
624 assert!(predicate.eval(&Context::default()));
625
626 Ok(())
627 }
628
629 #[test]
630 fn test_matcher() -> Result<()> {
631 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
632 pub struct A(pub String);
633 impl_actions!(test, [A]);
634 actions!(test, [B, Ab]);
635
636 #[derive(Clone, Debug, Eq, PartialEq)]
637 struct ActionArg {
638 a: &'static str,
639 }
640
641 let keymap = Keymap::new(vec![
642 Binding::new("a", A("x".to_string()), Some("a")),
643 Binding::new("b", B, Some("a")),
644 Binding::new("a b", Ab, Some("a || b")),
645 ]);
646
647 let mut ctx_a = Context::default();
648 ctx_a.set.insert("a".into());
649
650 let mut ctx_b = Context::default();
651 ctx_b.set.insert("b".into());
652
653 let mut matcher = Matcher::new(keymap);
654
655 // Basic match
656 assert_eq!(
657 downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
658 Some(&A("x".to_string()))
659 );
660
661 // Multi-keystroke match
662 assert!(matcher
663 .test_keystroke("a", vec![(1, ctx_b.clone())])
664 .is_none());
665 assert_eq!(
666 downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
667 Some(&Ab)
668 );
669
670 // Failed matches don't interfere with matching subsequent keys
671 assert!(matcher
672 .test_keystroke("x", vec![(1, ctx_a.clone())])
673 .is_none());
674 assert_eq!(
675 downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
676 Some(&A("x".to_string()))
677 );
678
679 // Pending keystrokes are cleared when the context changes
680 assert!(&matcher
681 .test_keystroke("a", vec![(1, ctx_b.clone())])
682 .is_none());
683 assert_eq!(
684 downcast(&matcher.test_keystroke("b", vec![(1, ctx_a.clone())])),
685 Some(&B)
686 );
687
688 let mut ctx_c = Context::default();
689 ctx_c.set.insert("c".into());
690
691 // Pending keystrokes are maintained per-view
692 assert!(matcher
693 .test_keystroke("a", vec![(1, ctx_b.clone()), (2, ctx_c.clone())])
694 .is_none());
695 assert_eq!(
696 downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
697 Some(&Ab)
698 );
699
700 Ok(())
701 }
702
703 fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
704 action
705 .as_ref()
706 .and_then(|action| action.as_any().downcast_ref())
707 }
708
709 impl Matcher {
710 fn test_keystroke<'a>(
711 &'a mut self,
712 keystroke: &str,
713 dispatch_path: Vec<(usize, Context)>,
714 ) -> Option<Box<dyn Action>> {
715 if let MatchResult::Matches(matches) =
716 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), dispatch_path)
717 {
718 Some(matches[0].1.boxed_clone())
719 } else {
720 None
721 }
722 }
723 }
724}