keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::{Debug, Write},
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending: HashMap<usize, Pending>,
 17    keymap: Keymap,
 18}
 19
 20#[derive(Default)]
 21struct Pending {
 22    keystrokes: Vec<Keystroke>,
 23    context: Option<Context>,
 24}
 25
 26#[derive(Default)]
 27pub struct Keymap {
 28    bindings: Vec<Binding>,
 29    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 30}
 31
 32pub struct Binding {
 33    keystrokes: SmallVec<[Keystroke; 2]>,
 34    action: Box<dyn Action>,
 35    context_predicate: Option<ContextPredicate>,
 36}
 37
 38#[derive(Clone, Debug, Eq, PartialEq)]
 39pub struct Keystroke {
 40    pub ctrl: bool,
 41    pub alt: bool,
 42    pub shift: bool,
 43    pub cmd: bool,
 44    pub function: bool,
 45    pub key: String,
 46}
 47
 48#[derive(Clone, Debug, Default, Eq, PartialEq)]
 49pub struct Context {
 50    pub set: HashSet<String>,
 51    pub map: HashMap<String, String>,
 52}
 53
 54#[derive(Debug, Eq, PartialEq)]
 55enum ContextPredicate {
 56    Identifier(String),
 57    Equal(String, String),
 58    NotEqual(String, String),
 59    Not(Box<ContextPredicate>),
 60    And(Box<ContextPredicate>, Box<ContextPredicate>),
 61    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 62}
 63
 64trait ActionArg {
 65    fn boxed_clone(&self) -> Box<dyn Any>;
 66}
 67
 68impl<T> ActionArg for T
 69where
 70    T: 'static + Any + Clone,
 71{
 72    fn boxed_clone(&self) -> Box<dyn Any> {
 73        Box::new(self.clone())
 74    }
 75}
 76
 77pub enum MatchResult {
 78    None,
 79    Pending,
 80    Match {
 81        view_id: usize,
 82        action: Box<dyn Action>,
 83    },
 84}
 85
 86impl Debug for MatchResult {
 87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 88        match self {
 89            MatchResult::None => f.debug_struct("MatchResult2::None").finish(),
 90            MatchResult::Pending => f.debug_struct("MatchResult2::Pending").finish(),
 91            MatchResult::Match { view_id, action } => f
 92                .debug_struct("MatchResult::Match")
 93                .field("view_id", view_id)
 94                .field("action", &action.name())
 95                .finish(),
 96        }
 97    }
 98}
 99
100impl PartialEq for MatchResult {
101    fn eq(&self, other: &Self) -> bool {
102        match (self, other) {
103            (MatchResult::None, MatchResult::None) => true,
104            (MatchResult::Pending, MatchResult::Pending) => true,
105            (
106                MatchResult::Match { view_id, action },
107                MatchResult::Match {
108                    view_id: other_view_id,
109                    action: other_action,
110                },
111            ) => view_id == other_view_id && action.eq(other_action.as_ref()),
112            _ => false,
113        }
114    }
115}
116
117impl Eq for MatchResult {}
118
119impl Matcher {
120    pub fn new(keymap: Keymap) -> Self {
121        Self {
122            pending: HashMap::new(),
123            keymap,
124        }
125    }
126
127    pub fn set_keymap(&mut self, keymap: Keymap) {
128        self.pending.clear();
129        self.keymap = keymap;
130    }
131
132    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
133        self.pending.clear();
134        self.keymap.add_bindings(bindings);
135    }
136
137    pub fn clear_bindings(&mut self) {
138        self.pending.clear();
139        self.keymap.clear();
140    }
141
142    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
143        self.keymap.bindings_for_action_type(action_type)
144    }
145
146    pub fn clear_pending(&mut self) {
147        self.pending.clear();
148    }
149
150    pub fn has_pending_keystrokes(&self) -> bool {
151        !self.pending.is_empty()
152    }
153
154    pub fn push_keystroke(
155        &mut self,
156        keystroke: Keystroke,
157        dispatch_path: Vec<(usize, Context)>,
158    ) -> MatchResult {
159        let mut any_pending = false;
160
161        let first_keystroke = self.pending.is_empty();
162        for (view_id, context) in dispatch_path {
163            if !first_keystroke && !self.pending.contains_key(&view_id) {
164                continue;
165            }
166
167            let pending = self.pending.entry(view_id).or_default();
168
169            if let Some(pending_context) = pending.context.as_ref() {
170                if pending_context != &context {
171                    pending.keystrokes.clear();
172                }
173            }
174
175            pending.keystrokes.push(keystroke.clone());
176
177            let mut retain_pending = false;
178            for binding in self.keymap.bindings.iter().rev() {
179                if binding.keystrokes.starts_with(&pending.keystrokes)
180                    && binding
181                        .context_predicate
182                        .as_ref()
183                        .map(|c| c.eval(&context))
184                        .unwrap_or(true)
185                {
186                    if binding.keystrokes.len() == pending.keystrokes.len() {
187                        self.pending.remove(&view_id);
188                        return MatchResult::Match {
189                            view_id,
190                            action: binding.action.boxed_clone(),
191                        };
192                    } else {
193                        retain_pending = true;
194                        pending.context = Some(context.clone());
195                    }
196                }
197            }
198
199            if retain_pending {
200                any_pending = true;
201            } else {
202                self.pending.remove(&view_id);
203            }
204        }
205
206        if any_pending {
207            MatchResult::Pending
208        } else {
209            MatchResult::None
210        }
211    }
212
213    pub fn keystrokes_for_action(
214        &self,
215        action: &dyn Action,
216        cx: &Context,
217    ) -> Option<SmallVec<[Keystroke; 2]>> {
218        for binding in self.keymap.bindings.iter().rev() {
219            if binding.action.eq(action)
220                && binding
221                    .context_predicate
222                    .as_ref()
223                    .map_or(true, |predicate| predicate.eval(cx))
224            {
225                return Some(binding.keystrokes.clone());
226            }
227        }
228        None
229    }
230}
231
232impl Default for Matcher {
233    fn default() -> Self {
234        Self::new(Keymap::default())
235    }
236}
237
238impl Keymap {
239    pub fn new(bindings: Vec<Binding>) -> Self {
240        let mut binding_indices_by_action_type = HashMap::new();
241        for (ix, binding) in bindings.iter().enumerate() {
242            binding_indices_by_action_type
243                .entry(binding.action.as_any().type_id())
244                .or_insert_with(SmallVec::new)
245                .push(ix);
246        }
247        Self {
248            binding_indices_by_action_type,
249            bindings,
250        }
251    }
252
253    fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
254        self.binding_indices_by_action_type
255            .get(&action_type)
256            .map(SmallVec::as_slice)
257            .unwrap_or(&[])
258            .iter()
259            .map(|ix| &self.bindings[*ix])
260    }
261
262    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
263        for binding in bindings {
264            self.binding_indices_by_action_type
265                .entry(binding.action.as_any().type_id())
266                .or_default()
267                .push(self.bindings.len());
268            self.bindings.push(binding);
269        }
270    }
271
272    fn clear(&mut self) {
273        self.bindings.clear();
274        self.binding_indices_by_action_type.clear();
275    }
276}
277
278impl Binding {
279    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
280        Self::load(keystrokes, Box::new(action), context).unwrap()
281    }
282
283    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
284        let context = if let Some(context) = context {
285            Some(ContextPredicate::parse(context)?)
286        } else {
287            None
288        };
289
290        let keystrokes = keystrokes
291            .split_whitespace()
292            .map(Keystroke::parse)
293            .collect::<Result<_>>()?;
294
295        Ok(Self {
296            keystrokes,
297            action,
298            context_predicate: context,
299        })
300    }
301
302    pub fn keystrokes(&self) -> &[Keystroke] {
303        &self.keystrokes
304    }
305
306    pub fn action(&self) -> &dyn Action {
307        self.action.as_ref()
308    }
309}
310
311impl Keystroke {
312    pub fn parse(source: &str) -> anyhow::Result<Self> {
313        let mut ctrl = false;
314        let mut alt = false;
315        let mut shift = false;
316        let mut cmd = false;
317        let mut function = false;
318        let mut key = None;
319
320        let mut components = source.split('-').peekable();
321        while let Some(component) = components.next() {
322            match component {
323                "ctrl" => ctrl = true,
324                "alt" => alt = true,
325                "shift" => shift = true,
326                "cmd" => cmd = true,
327                "fn" => function = true,
328                _ => {
329                    if let Some(component) = components.peek() {
330                        if component.is_empty() && source.ends_with('-') {
331                            key = Some(String::from("-"));
332                            break;
333                        } else {
334                            return Err(anyhow!("Invalid keystroke `{}`", source));
335                        }
336                    } else {
337                        key = Some(String::from(component));
338                    }
339                }
340            }
341        }
342
343        let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
344
345        Ok(Keystroke {
346            ctrl,
347            alt,
348            shift,
349            cmd,
350            function,
351            key,
352        })
353    }
354
355    pub fn modified(&self) -> bool {
356        self.ctrl || self.alt || self.shift || self.cmd
357    }
358}
359
360impl std::fmt::Display for Keystroke {
361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        if self.ctrl {
363            f.write_char('^')?;
364        }
365        if self.alt {
366            f.write_char('⎇')?;
367        }
368        if self.cmd {
369            f.write_char('⌘')?;
370        }
371        if self.shift {
372            f.write_char('⇧')?;
373        }
374        let key = match self.key.as_str() {
375            "backspace" => '⌫',
376            "up" => '↑',
377            "down" => '↓',
378            "left" => '←',
379            "right" => '→',
380            "tab" => '⇥',
381            "escape" => '⎋',
382            key => {
383                if key.len() == 1 {
384                    key.chars().next().unwrap().to_ascii_uppercase()
385                } else {
386                    return f.write_str(key);
387                }
388            }
389        };
390        f.write_char(key)
391    }
392}
393
394impl Context {
395    pub fn extend(&mut self, other: &Context) {
396        for v in &other.set {
397            self.set.insert(v.clone());
398        }
399        for (k, v) in &other.map {
400            self.map.insert(k.clone(), v.clone());
401        }
402    }
403}
404
405impl ContextPredicate {
406    fn parse(source: &str) -> anyhow::Result<Self> {
407        let mut parser = Parser::new();
408        let language = unsafe { tree_sitter_context_predicate() };
409        parser.set_language(language).unwrap();
410        let source = source.as_bytes();
411        let tree = parser.parse(source, None).unwrap();
412        Self::from_node(tree.root_node(), source)
413    }
414
415    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
416        let parse_error = "error parsing context predicate";
417        let kind = node.kind();
418
419        match kind {
420            "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
421            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
422            "not" => {
423                let child = Self::from_node(
424                    node.child_by_field_name("expression")
425                        .ok_or_else(|| anyhow!(parse_error))?,
426                    source,
427                )?;
428                Ok(Self::Not(Box::new(child)))
429            }
430            "and" | "or" => {
431                let left = Box::new(Self::from_node(
432                    node.child_by_field_name("left")
433                        .ok_or_else(|| anyhow!(parse_error))?,
434                    source,
435                )?);
436                let right = Box::new(Self::from_node(
437                    node.child_by_field_name("right")
438                        .ok_or_else(|| anyhow!(parse_error))?,
439                    source,
440                )?);
441                if kind == "and" {
442                    Ok(Self::And(left, right))
443                } else {
444                    Ok(Self::Or(left, right))
445                }
446            }
447            "equal" | "not_equal" => {
448                let left = node
449                    .child_by_field_name("left")
450                    .ok_or_else(|| anyhow!(parse_error))?
451                    .utf8_text(source)?
452                    .into();
453                let right = node
454                    .child_by_field_name("right")
455                    .ok_or_else(|| anyhow!(parse_error))?
456                    .utf8_text(source)?
457                    .into();
458                if kind == "equal" {
459                    Ok(Self::Equal(left, right))
460                } else {
461                    Ok(Self::NotEqual(left, right))
462                }
463            }
464            "parenthesized" => Self::from_node(
465                node.child_by_field_name("expression")
466                    .ok_or_else(|| anyhow!(parse_error))?,
467                source,
468            ),
469            _ => Err(anyhow!(parse_error)),
470        }
471    }
472
473    fn eval(&self, cx: &Context) -> bool {
474        match self {
475            Self::Identifier(name) => cx.set.contains(name.as_str()),
476            Self::Equal(left, right) => cx
477                .map
478                .get(left)
479                .map(|value| value == right)
480                .unwrap_or(false),
481            Self::NotEqual(left, right) => {
482                cx.map.get(left).map(|value| value != right).unwrap_or(true)
483            }
484            Self::Not(pred) => !pred.eval(cx),
485            Self::And(left, right) => left.eval(cx) && right.eval(cx),
486            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
487        }
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use anyhow::Result;
494    use serde::Deserialize;
495
496    use crate::{actions, impl_actions};
497
498    use super::*;
499
500    #[test]
501    fn test_push_keystroke() -> Result<()> {
502        actions!(test, [B, AB, C]);
503
504        let mut ctx1 = Context::default();
505        ctx1.set.insert("1".into());
506
507        let mut ctx2 = Context::default();
508        ctx2.set.insert("2".into());
509
510        let dispatch_path = vec![(2, ctx2), (1, ctx1)];
511
512        let keymap = Keymap::new(vec![
513            Binding::new("a b", AB, Some("1")),
514            Binding::new("b", B, Some("2")),
515            Binding::new("c", C, Some("2")),
516        ]);
517
518        let mut matcher = Matcher::new(keymap);
519
520        assert_eq!(
521            MatchResult::Pending,
522            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
523        );
524        assert_eq!(
525            MatchResult::Match {
526                view_id: 1,
527                action: Box::new(AB)
528            },
529            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
530        );
531        assert!(matcher.pending.is_empty());
532        assert_eq!(
533            MatchResult::Match {
534                view_id: 2,
535                action: Box::new(B)
536            },
537            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
538        );
539        assert!(matcher.pending.is_empty());
540        assert_eq!(
541            MatchResult::Pending,
542            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
543        );
544        assert_eq!(
545            MatchResult::None,
546            matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone())
547        );
548        assert!(matcher.pending.is_empty());
549
550        Ok(())
551    }
552
553    #[test]
554    fn test_keystroke_parsing() -> Result<()> {
555        assert_eq!(
556            Keystroke::parse("ctrl-p")?,
557            Keystroke {
558                key: "p".into(),
559                ctrl: true,
560                alt: false,
561                shift: false,
562                cmd: false,
563                function: false,
564            }
565        );
566
567        assert_eq!(
568            Keystroke::parse("alt-shift-down")?,
569            Keystroke {
570                key: "down".into(),
571                ctrl: false,
572                alt: true,
573                shift: true,
574                cmd: false,
575                function: false,
576            }
577        );
578
579        assert_eq!(
580            Keystroke::parse("shift-cmd--")?,
581            Keystroke {
582                key: "-".into(),
583                ctrl: false,
584                alt: false,
585                shift: true,
586                cmd: true,
587                function: false,
588            }
589        );
590
591        Ok(())
592    }
593
594    #[test]
595    fn test_context_predicate_parsing() -> Result<()> {
596        use ContextPredicate::*;
597
598        assert_eq!(
599            ContextPredicate::parse("a && (b == c || d != e)")?,
600            And(
601                Box::new(Identifier("a".into())),
602                Box::new(Or(
603                    Box::new(Equal("b".into(), "c".into())),
604                    Box::new(NotEqual("d".into(), "e".into())),
605                ))
606            )
607        );
608
609        assert_eq!(
610            ContextPredicate::parse("!a")?,
611            Not(Box::new(Identifier("a".into())),)
612        );
613
614        Ok(())
615    }
616
617    #[test]
618    fn test_context_predicate_eval() -> Result<()> {
619        let predicate = ContextPredicate::parse("a && b || c == d")?;
620
621        let mut context = Context::default();
622        context.set.insert("a".into());
623        assert!(!predicate.eval(&context));
624
625        context.set.insert("b".into());
626        assert!(predicate.eval(&context));
627
628        context.set.remove("b");
629        context.map.insert("c".into(), "x".into());
630        assert!(!predicate.eval(&context));
631
632        context.map.insert("c".into(), "d".into());
633        assert!(predicate.eval(&context));
634
635        let predicate = ContextPredicate::parse("!a")?;
636        assert!(predicate.eval(&Context::default()));
637
638        Ok(())
639    }
640
641    #[test]
642    fn test_matcher() -> Result<()> {
643        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
644        pub struct A(pub String);
645        impl_actions!(test, [A]);
646        actions!(test, [B, Ab]);
647
648        #[derive(Clone, Debug, Eq, PartialEq)]
649        struct ActionArg {
650            a: &'static str,
651        }
652
653        let keymap = Keymap::new(vec![
654            Binding::new("a", A("x".to_string()), Some("a")),
655            Binding::new("b", B, Some("a")),
656            Binding::new("a b", Ab, Some("a || b")),
657        ]);
658
659        let mut ctx_a = Context::default();
660        ctx_a.set.insert("a".into());
661
662        let mut ctx_b = Context::default();
663        ctx_b.set.insert("b".into());
664
665        let mut matcher = Matcher::new(keymap);
666
667        // Basic match
668        assert_eq!(
669            downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
670            Some(&A("x".to_string()))
671        );
672
673        // Multi-keystroke match
674        assert!(matcher
675            .test_keystroke("a", vec![(1, ctx_b.clone())])
676            .is_none());
677        assert_eq!(
678            downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
679            Some(&Ab)
680        );
681
682        // Failed matches don't interfere with matching subsequent keys
683        assert!(matcher
684            .test_keystroke("x", vec![(1, ctx_a.clone())])
685            .is_none());
686        assert_eq!(
687            downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
688            Some(&A("x".to_string()))
689        );
690
691        // Pending keystrokes are cleared when the context changes
692        assert!(&matcher
693            .test_keystroke("a", vec![(1, ctx_b.clone())])
694            .is_none());
695        assert_eq!(
696            downcast(&matcher.test_keystroke("b", vec![(1, ctx_a.clone())])),
697            Some(&B)
698        );
699
700        let mut ctx_c = Context::default();
701        ctx_c.set.insert("c".into());
702
703        // Pending keystrokes are maintained per-view
704        assert!(matcher
705            .test_keystroke("a", vec![(1, ctx_b.clone()), (2, ctx_c.clone())])
706            .is_none());
707        assert_eq!(
708            downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
709            Some(&Ab)
710        );
711
712        Ok(())
713    }
714
715    fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
716        action
717            .as_ref()
718            .and_then(|action| action.as_any().downcast_ref())
719    }
720
721    impl Matcher {
722        fn test_keystroke(
723            &mut self,
724            keystroke: &str,
725            dispatch_path: Vec<(usize, Context)>,
726        ) -> Option<Box<dyn Action>> {
727            if let MatchResult::Match { action, .. } =
728                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), dispatch_path)
729            {
730                Some(action.boxed_clone())
731            } else {
732                None
733            }
734        }
735    }
736}