keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::{Debug, Write},
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending_views: HashMap<usize, Context>,
 17    pending_keystrokes: Vec<Keystroke>,
 18    keymap: Keymap,
 19}
 20
 21#[derive(Default)]
 22pub struct Keymap {
 23    bindings: Vec<Binding>,
 24    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 25}
 26
 27pub struct Binding {
 28    keystrokes: SmallVec<[Keystroke; 2]>,
 29    action: Box<dyn Action>,
 30    context_predicate: Option<ContextPredicate>,
 31}
 32
 33#[derive(Clone, Debug, Eq, PartialEq)]
 34pub struct Keystroke {
 35    pub ctrl: bool,
 36    pub alt: bool,
 37    pub shift: bool,
 38    pub cmd: bool,
 39    pub function: bool,
 40    pub key: String,
 41}
 42
 43#[derive(Clone, Debug, Default, Eq, PartialEq)]
 44pub struct Context {
 45    pub set: HashSet<String>,
 46    pub map: HashMap<String, String>,
 47}
 48
 49#[derive(Debug, Eq, PartialEq)]
 50enum ContextPredicate {
 51    Identifier(String),
 52    Equal(String, String),
 53    NotEqual(String, String),
 54    Not(Box<ContextPredicate>),
 55    And(Box<ContextPredicate>, Box<ContextPredicate>),
 56    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 57}
 58
 59trait ActionArg {
 60    fn boxed_clone(&self) -> Box<dyn Any>;
 61}
 62
 63impl<T> ActionArg for T
 64where
 65    T: 'static + Any + Clone,
 66{
 67    fn boxed_clone(&self) -> Box<dyn Any> {
 68        Box::new(self.clone())
 69    }
 70}
 71
 72pub enum MatchResult {
 73    None,
 74    Pending,
 75    Matches(Vec<(usize, Box<dyn Action>)>),
 76}
 77
 78impl Debug for MatchResult {
 79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 80        match self {
 81            MatchResult::None => f.debug_struct("MatchResult2::None").finish(),
 82            MatchResult::Pending => f.debug_struct("MatchResult2::Pending").finish(),
 83            MatchResult::Matches(matches) => f
 84                .debug_list()
 85                .entries(
 86                    matches
 87                        .iter()
 88                        .map(|(view_id, action)| format!("{view_id}, {}", action.name())),
 89                )
 90                .finish(),
 91        }
 92    }
 93}
 94
 95impl PartialEq for MatchResult {
 96    fn eq(&self, other: &Self) -> bool {
 97        match (self, other) {
 98            (MatchResult::None, MatchResult::None) => true,
 99            (MatchResult::Pending, MatchResult::Pending) => true,
100            (MatchResult::Matches(matches), MatchResult::Matches(other_matches)) => {
101                matches.len() == other_matches.len()
102                    && matches.iter().zip(other_matches.iter()).all(
103                        |((view_id, action), (other_view_id, other_action))| {
104                            view_id == other_view_id && action.eq(other_action.as_ref())
105                        },
106                    )
107            }
108            _ => false,
109        }
110    }
111}
112
113impl Eq for MatchResult {}
114
115impl Matcher {
116    pub fn new(keymap: Keymap) -> Self {
117        Self {
118            pending_views: HashMap::new(),
119            pending_keystrokes: Vec::new(),
120            keymap,
121        }
122    }
123
124    pub fn set_keymap(&mut self, keymap: Keymap) {
125        self.clear_pending();
126        self.keymap = keymap;
127    }
128
129    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
130        self.clear_pending();
131        self.keymap.add_bindings(bindings);
132    }
133
134    pub fn clear_bindings(&mut self) {
135        self.clear_pending();
136        self.keymap.clear();
137    }
138
139    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
140        self.keymap.bindings_for_action_type(action_type)
141    }
142
143    pub fn clear_pending(&mut self) {
144        self.pending_keystrokes.clear();
145        self.pending_views.clear();
146    }
147
148    pub fn has_pending_keystrokes(&self) -> bool {
149        !self.pending_keystrokes.is_empty()
150    }
151
152    pub fn push_keystroke(
153        &mut self,
154        keystroke: Keystroke,
155        dispatch_path: Vec<(usize, Context)>,
156    ) -> MatchResult {
157        let mut any_pending = false;
158        let mut matched_bindings = Vec::new();
159
160        let first_keystroke = self.pending_keystrokes.is_empty();
161        self.pending_keystrokes.push(keystroke);
162
163        for (view_id, context) in dispatch_path {
164            // Don't require pending view entry if there are no pending keystrokes
165            if !first_keystroke && !self.pending_views.contains_key(&view_id) {
166                continue;
167            }
168
169            // If there is a previous view context, invalidate that view if it
170            // has changed
171            if let Some(previous_view_context) = self.pending_views.remove(&view_id) {
172                if previous_view_context != context {
173                    continue;
174                }
175            }
176
177            for binding in self.keymap.bindings.iter().rev() {
178                if binding.keystrokes.starts_with(&self.pending_keystrokes)
179                    && binding
180                        .context_predicate
181                        .as_ref()
182                        .map(|c| c.eval(&context))
183                        .unwrap_or(true)
184                {
185                    if binding.keystrokes.len() == self.pending_keystrokes.len() {
186                        matched_bindings.push((view_id, binding.action.boxed_clone()));
187                    } else {
188                        self.pending_views.insert(view_id, context.clone());
189                        any_pending = true;
190                    }
191                }
192            }
193        }
194
195        if !matched_bindings.is_empty() {
196            self.pending_views.clear();
197            self.pending_keystrokes.clear();
198            MatchResult::Matches(matched_bindings)
199        } else if any_pending {
200            MatchResult::Pending
201        } else {
202            self.pending_keystrokes.clear();
203            MatchResult::None
204        }
205    }
206
207    pub fn keystrokes_for_action(
208        &self,
209        action: &dyn Action,
210        cx: &Context,
211    ) -> Option<SmallVec<[Keystroke; 2]>> {
212        for binding in self.keymap.bindings.iter().rev() {
213            if binding.action.eq(action)
214                && binding
215                    .context_predicate
216                    .as_ref()
217                    .map_or(true, |predicate| predicate.eval(cx))
218            {
219                return Some(binding.keystrokes.clone());
220            }
221        }
222        None
223    }
224}
225
226impl Default for Matcher {
227    fn default() -> Self {
228        Self::new(Keymap::default())
229    }
230}
231
232impl Keymap {
233    pub fn new(bindings: Vec<Binding>) -> Self {
234        let mut binding_indices_by_action_type = HashMap::new();
235        for (ix, binding) in bindings.iter().enumerate() {
236            binding_indices_by_action_type
237                .entry(binding.action.as_any().type_id())
238                .or_insert_with(SmallVec::new)
239                .push(ix);
240        }
241        Self {
242            binding_indices_by_action_type,
243            bindings,
244        }
245    }
246
247    fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
248        self.binding_indices_by_action_type
249            .get(&action_type)
250            .map(SmallVec::as_slice)
251            .unwrap_or(&[])
252            .iter()
253            .map(|ix| &self.bindings[*ix])
254    }
255
256    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
257        for binding in bindings {
258            self.binding_indices_by_action_type
259                .entry(binding.action.as_any().type_id())
260                .or_default()
261                .push(self.bindings.len());
262            self.bindings.push(binding);
263        }
264    }
265
266    fn clear(&mut self) {
267        self.bindings.clear();
268        self.binding_indices_by_action_type.clear();
269    }
270}
271
272impl Binding {
273    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
274        Self::load(keystrokes, Box::new(action), context).unwrap()
275    }
276
277    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
278        let context = if let Some(context) = context {
279            Some(ContextPredicate::parse(context)?)
280        } else {
281            None
282        };
283
284        let keystrokes = keystrokes
285            .split_whitespace()
286            .map(Keystroke::parse)
287            .collect::<Result<_>>()?;
288
289        Ok(Self {
290            keystrokes,
291            action,
292            context_predicate: context,
293        })
294    }
295
296    pub fn keystrokes(&self) -> &[Keystroke] {
297        &self.keystrokes
298    }
299
300    pub fn action(&self) -> &dyn Action {
301        self.action.as_ref()
302    }
303}
304
305impl Keystroke {
306    pub fn parse(source: &str) -> anyhow::Result<Self> {
307        let mut ctrl = false;
308        let mut alt = false;
309        let mut shift = false;
310        let mut cmd = false;
311        let mut function = false;
312        let mut key = None;
313
314        let mut components = source.split('-').peekable();
315        while let Some(component) = components.next() {
316            match component {
317                "ctrl" => ctrl = true,
318                "alt" => alt = true,
319                "shift" => shift = true,
320                "cmd" => cmd = true,
321                "fn" => function = true,
322                _ => {
323                    if let Some(component) = components.peek() {
324                        if component.is_empty() && source.ends_with('-') {
325                            key = Some(String::from("-"));
326                            break;
327                        } else {
328                            return Err(anyhow!("Invalid keystroke `{}`", source));
329                        }
330                    } else {
331                        key = Some(String::from(component));
332                    }
333                }
334            }
335        }
336
337        let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
338
339        Ok(Keystroke {
340            ctrl,
341            alt,
342            shift,
343            cmd,
344            function,
345            key,
346        })
347    }
348
349    pub fn modified(&self) -> bool {
350        self.ctrl || self.alt || self.shift || self.cmd
351    }
352}
353
354impl std::fmt::Display for Keystroke {
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        if self.ctrl {
357            f.write_char('^')?;
358        }
359        if self.alt {
360            f.write_char('⎇')?;
361        }
362        if self.cmd {
363            f.write_char('⌘')?;
364        }
365        if self.shift {
366            f.write_char('⇧')?;
367        }
368        let key = match self.key.as_str() {
369            "backspace" => '⌫',
370            "up" => '↑',
371            "down" => '↓',
372            "left" => '←',
373            "right" => '→',
374            "tab" => '⇥',
375            "escape" => '⎋',
376            key => {
377                if key.len() == 1 {
378                    key.chars().next().unwrap().to_ascii_uppercase()
379                } else {
380                    return f.write_str(key);
381                }
382            }
383        };
384        f.write_char(key)
385    }
386}
387
388impl Context {
389    pub fn extend(&mut self, other: &Context) {
390        for v in &other.set {
391            self.set.insert(v.clone());
392        }
393        for (k, v) in &other.map {
394            self.map.insert(k.clone(), v.clone());
395        }
396    }
397}
398
399impl ContextPredicate {
400    fn parse(source: &str) -> anyhow::Result<Self> {
401        let mut parser = Parser::new();
402        let language = unsafe { tree_sitter_context_predicate() };
403        parser.set_language(language).unwrap();
404        let source = source.as_bytes();
405        let tree = parser.parse(source, None).unwrap();
406        Self::from_node(tree.root_node(), source)
407    }
408
409    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
410        let parse_error = "error parsing context predicate";
411        let kind = node.kind();
412
413        match kind {
414            "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
415            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
416            "not" => {
417                let child = Self::from_node(
418                    node.child_by_field_name("expression")
419                        .ok_or_else(|| anyhow!(parse_error))?,
420                    source,
421                )?;
422                Ok(Self::Not(Box::new(child)))
423            }
424            "and" | "or" => {
425                let left = Box::new(Self::from_node(
426                    node.child_by_field_name("left")
427                        .ok_or_else(|| anyhow!(parse_error))?,
428                    source,
429                )?);
430                let right = Box::new(Self::from_node(
431                    node.child_by_field_name("right")
432                        .ok_or_else(|| anyhow!(parse_error))?,
433                    source,
434                )?);
435                if kind == "and" {
436                    Ok(Self::And(left, right))
437                } else {
438                    Ok(Self::Or(left, right))
439                }
440            }
441            "equal" | "not_equal" => {
442                let left = node
443                    .child_by_field_name("left")
444                    .ok_or_else(|| anyhow!(parse_error))?
445                    .utf8_text(source)?
446                    .into();
447                let right = node
448                    .child_by_field_name("right")
449                    .ok_or_else(|| anyhow!(parse_error))?
450                    .utf8_text(source)?
451                    .into();
452                if kind == "equal" {
453                    Ok(Self::Equal(left, right))
454                } else {
455                    Ok(Self::NotEqual(left, right))
456                }
457            }
458            "parenthesized" => Self::from_node(
459                node.child_by_field_name("expression")
460                    .ok_or_else(|| anyhow!(parse_error))?,
461                source,
462            ),
463            _ => Err(anyhow!(parse_error)),
464        }
465    }
466
467    fn eval(&self, cx: &Context) -> bool {
468        match self {
469            Self::Identifier(name) => cx.set.contains(name.as_str()),
470            Self::Equal(left, right) => cx
471                .map
472                .get(left)
473                .map(|value| value == right)
474                .unwrap_or(false),
475            Self::NotEqual(left, right) => {
476                cx.map.get(left).map(|value| value != right).unwrap_or(true)
477            }
478            Self::Not(pred) => !pred.eval(cx),
479            Self::And(left, right) => left.eval(cx) && right.eval(cx),
480            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
481        }
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use anyhow::Result;
488    use serde::Deserialize;
489
490    use crate::{actions, impl_actions};
491
492    use super::*;
493
494    #[test]
495    fn test_push_keystroke() -> Result<()> {
496        actions!(test, [B, AB, C]);
497
498        let mut ctx1 = Context::default();
499        ctx1.set.insert("1".into());
500
501        let mut ctx2 = Context::default();
502        ctx2.set.insert("2".into());
503
504        let dispatch_path = vec![(2, ctx2), (1, ctx1)];
505
506        let keymap = Keymap::new(vec![
507            Binding::new("a b", AB, Some("1")),
508            Binding::new("b", B, Some("2")),
509            Binding::new("c", C, Some("2")),
510        ]);
511
512        let mut matcher = Matcher::new(keymap);
513
514        assert_eq!(
515            MatchResult::Pending,
516            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
517        );
518        assert_eq!(
519            MatchResult::Matches(vec![(1, Box::new(AB))]),
520            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
521        );
522        assert!(!matcher.has_pending_keystrokes());
523        assert_eq!(
524            MatchResult::Matches(vec![(2, Box::new(B))]),
525            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
526        );
527        assert!(!matcher.has_pending_keystrokes());
528        assert_eq!(
529            MatchResult::Pending,
530            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
531        );
532        assert_eq!(
533            MatchResult::None,
534            matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone())
535        );
536        assert!(!matcher.has_pending_keystrokes());
537
538        Ok(())
539    }
540
541    #[test]
542    fn test_keystroke_parsing() -> Result<()> {
543        assert_eq!(
544            Keystroke::parse("ctrl-p")?,
545            Keystroke {
546                key: "p".into(),
547                ctrl: true,
548                alt: false,
549                shift: false,
550                cmd: false,
551                function: false,
552            }
553        );
554
555        assert_eq!(
556            Keystroke::parse("alt-shift-down")?,
557            Keystroke {
558                key: "down".into(),
559                ctrl: false,
560                alt: true,
561                shift: true,
562                cmd: false,
563                function: false,
564            }
565        );
566
567        assert_eq!(
568            Keystroke::parse("shift-cmd--")?,
569            Keystroke {
570                key: "-".into(),
571                ctrl: false,
572                alt: false,
573                shift: true,
574                cmd: true,
575                function: false,
576            }
577        );
578
579        Ok(())
580    }
581
582    #[test]
583    fn test_context_predicate_parsing() -> Result<()> {
584        use ContextPredicate::*;
585
586        assert_eq!(
587            ContextPredicate::parse("a && (b == c || d != e)")?,
588            And(
589                Box::new(Identifier("a".into())),
590                Box::new(Or(
591                    Box::new(Equal("b".into(), "c".into())),
592                    Box::new(NotEqual("d".into(), "e".into())),
593                ))
594            )
595        );
596
597        assert_eq!(
598            ContextPredicate::parse("!a")?,
599            Not(Box::new(Identifier("a".into())),)
600        );
601
602        Ok(())
603    }
604
605    #[test]
606    fn test_context_predicate_eval() -> Result<()> {
607        let predicate = ContextPredicate::parse("a && b || c == d")?;
608
609        let mut context = Context::default();
610        context.set.insert("a".into());
611        assert!(!predicate.eval(&context));
612
613        context.set.insert("b".into());
614        assert!(predicate.eval(&context));
615
616        context.set.remove("b");
617        context.map.insert("c".into(), "x".into());
618        assert!(!predicate.eval(&context));
619
620        context.map.insert("c".into(), "d".into());
621        assert!(predicate.eval(&context));
622
623        let predicate = ContextPredicate::parse("!a")?;
624        assert!(predicate.eval(&Context::default()));
625
626        Ok(())
627    }
628
629    #[test]
630    fn test_matcher() -> Result<()> {
631        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
632        pub struct A(pub String);
633        impl_actions!(test, [A]);
634        actions!(test, [B, Ab]);
635
636        #[derive(Clone, Debug, Eq, PartialEq)]
637        struct ActionArg {
638            a: &'static str,
639        }
640
641        let keymap = Keymap::new(vec![
642            Binding::new("a", A("x".to_string()), Some("a")),
643            Binding::new("b", B, Some("a")),
644            Binding::new("a b", Ab, Some("a || b")),
645        ]);
646
647        let mut ctx_a = Context::default();
648        ctx_a.set.insert("a".into());
649
650        let mut ctx_b = Context::default();
651        ctx_b.set.insert("b".into());
652
653        let mut matcher = Matcher::new(keymap);
654
655        // Basic match
656        assert_eq!(
657            downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
658            Some(&A("x".to_string()))
659        );
660
661        // Multi-keystroke match
662        assert!(matcher
663            .test_keystroke("a", vec![(1, ctx_b.clone())])
664            .is_none());
665        assert_eq!(
666            downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
667            Some(&Ab)
668        );
669
670        // Failed matches don't interfere with matching subsequent keys
671        assert!(matcher
672            .test_keystroke("x", vec![(1, ctx_a.clone())])
673            .is_none());
674        assert_eq!(
675            downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
676            Some(&A("x".to_string()))
677        );
678
679        // Pending keystrokes are cleared when the context changes
680        assert!(&matcher
681            .test_keystroke("a", vec![(1, ctx_b.clone())])
682            .is_none());
683        assert_eq!(
684            downcast(&matcher.test_keystroke("b", vec![(1, ctx_a.clone())])),
685            Some(&B)
686        );
687
688        let mut ctx_c = Context::default();
689        ctx_c.set.insert("c".into());
690
691        // Pending keystrokes are maintained per-view
692        assert!(matcher
693            .test_keystroke("a", vec![(1, ctx_b.clone()), (2, ctx_c.clone())])
694            .is_none());
695        assert_eq!(
696            downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
697            Some(&Ab)
698        );
699
700        Ok(())
701    }
702
703    fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
704        action
705            .as_ref()
706            .and_then(|action| action.as_any().downcast_ref())
707    }
708
709    impl Matcher {
710        fn test_keystroke<'a>(
711            &'a mut self,
712            keystroke: &str,
713            dispatch_path: Vec<(usize, Context)>,
714        ) -> Option<Box<dyn Action>> {
715            if let MatchResult::Matches(matches) =
716                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), dispatch_path)
717            {
718                Some(matches[0].1.boxed_clone())
719            } else {
720                None
721            }
722        }
723    }
724}