keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::{Debug, Write},
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending_views: HashMap<usize, Context>,
 17    pending_keystrokes: Vec<Keystroke>,
 18    keymap: Keymap,
 19}
 20
 21#[derive(Default)]
 22pub struct Keymap {
 23    bindings: Vec<Binding>,
 24    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 25}
 26
 27pub struct Binding {
 28    keystrokes: SmallVec<[Keystroke; 2]>,
 29    action: Box<dyn Action>,
 30    context_predicate: Option<ContextPredicate>,
 31}
 32
 33#[derive(Clone, Debug, Eq, PartialEq)]
 34pub struct Keystroke {
 35    pub ctrl: bool,
 36    pub alt: bool,
 37    pub shift: bool,
 38    pub cmd: bool,
 39    pub function: bool,
 40    pub key: String,
 41}
 42
 43#[derive(Clone, Debug, Default, Eq, PartialEq)]
 44pub struct Context {
 45    pub set: HashSet<String>,
 46    pub map: HashMap<String, String>,
 47}
 48
 49#[derive(Debug, Eq, PartialEq)]
 50enum ContextPredicate {
 51    Identifier(String),
 52    Equal(String, String),
 53    NotEqual(String, String),
 54    Not(Box<ContextPredicate>),
 55    And(Box<ContextPredicate>, Box<ContextPredicate>),
 56    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 57}
 58
 59trait ActionArg {
 60    fn boxed_clone(&self) -> Box<dyn Any>;
 61}
 62
 63impl<T> ActionArg for T
 64where
 65    T: 'static + Any + Clone,
 66{
 67    fn boxed_clone(&self) -> Box<dyn Any> {
 68        Box::new(self.clone())
 69    }
 70}
 71
 72pub enum MatchResult {
 73    None,
 74    Pending,
 75    Matches(Vec<(usize, Box<dyn Action>)>),
 76}
 77
 78impl Debug for MatchResult {
 79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 80        match self {
 81            MatchResult::None => f.debug_struct("MatchResult::None").finish(),
 82            MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
 83            MatchResult::Matches(matches) => f
 84                .debug_list()
 85                .entries(
 86                    matches
 87                        .iter()
 88                        .map(|(view_id, action)| format!("{view_id}, {}", action.name())),
 89                )
 90                .finish(),
 91        }
 92    }
 93}
 94
 95impl PartialEq for MatchResult {
 96    fn eq(&self, other: &Self) -> bool {
 97        match (self, other) {
 98            (MatchResult::None, MatchResult::None) => true,
 99            (MatchResult::Pending, MatchResult::Pending) => true,
100            (MatchResult::Matches(matches), MatchResult::Matches(other_matches)) => {
101                matches.len() == other_matches.len()
102                    && matches.iter().zip(other_matches.iter()).all(
103                        |((view_id, action), (other_view_id, other_action))| {
104                            view_id == other_view_id && action.eq(other_action.as_ref())
105                        },
106                    )
107            }
108            _ => false,
109        }
110    }
111}
112
113impl Eq for MatchResult {}
114
115impl Matcher {
116    pub fn new(keymap: Keymap) -> Self {
117        Self {
118            pending_views: HashMap::new(),
119            pending_keystrokes: Vec::new(),
120            keymap,
121        }
122    }
123
124    pub fn set_keymap(&mut self, keymap: Keymap) {
125        self.clear_pending();
126        self.keymap = keymap;
127    }
128
129    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
130        self.clear_pending();
131        self.keymap.add_bindings(bindings);
132    }
133
134    pub fn clear_bindings(&mut self) {
135        self.clear_pending();
136        self.keymap.clear();
137    }
138
139    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
140        self.keymap.bindings_for_action_type(action_type)
141    }
142
143    pub fn clear_pending(&mut self) {
144        self.pending_keystrokes.clear();
145        self.pending_views.clear();
146    }
147
148    pub fn has_pending_keystrokes(&self) -> bool {
149        !self.pending_keystrokes.is_empty()
150    }
151
152    pub fn push_keystroke(
153        &mut self,
154        keystroke: Keystroke,
155        dispatch_path: Vec<(usize, Context)>,
156    ) -> MatchResult {
157        let mut any_pending = false;
158        let mut matched_bindings = Vec::new();
159
160        let first_keystroke = self.pending_keystrokes.is_empty();
161        self.pending_keystrokes.push(keystroke);
162
163        for (view_id, context) in dispatch_path {
164            // Don't require pending view entry if there are no pending keystrokes
165            if !first_keystroke && !self.pending_views.contains_key(&view_id) {
166                continue;
167            }
168
169            // If there is a previous view context, invalidate that view if it
170            // has changed
171            if let Some(previous_view_context) = self.pending_views.remove(&view_id) {
172                if previous_view_context != context {
173                    continue;
174                }
175            }
176
177            // Find the bindings which map the pending keystrokes and current context
178            for binding in self.keymap.bindings.iter().rev() {
179                if binding.keystrokes.starts_with(&self.pending_keystrokes)
180                    && binding
181                        .context_predicate
182                        .as_ref()
183                        .map(|c| c.eval(&context))
184                        .unwrap_or(true)
185                {
186                    // If the binding is completed, push it onto the matches list
187                    if binding.keystrokes.len() == self.pending_keystrokes.len() {
188                        matched_bindings.push((view_id, binding.action.boxed_clone()));
189                    } else {
190                        // Otherwise, the binding is still pending
191                        self.pending_views.insert(view_id, context.clone());
192                        any_pending = true;
193                    }
194                }
195            }
196        }
197
198        if !any_pending {
199            self.clear_pending();
200        }
201
202        if !matched_bindings.is_empty() {
203            MatchResult::Matches(matched_bindings)
204        } else if any_pending {
205            MatchResult::Pending
206        } else {
207            MatchResult::None
208        }
209    }
210
211    pub fn keystrokes_for_action(
212        &self,
213        action: &dyn Action,
214        cx: &Context,
215    ) -> Option<SmallVec<[Keystroke; 2]>> {
216        for binding in self.keymap.bindings.iter().rev() {
217            if binding.action.eq(action)
218                && binding
219                    .context_predicate
220                    .as_ref()
221                    .map_or(true, |predicate| predicate.eval(cx))
222            {
223                return Some(binding.keystrokes.clone());
224            }
225        }
226        None
227    }
228}
229
230impl Default for Matcher {
231    fn default() -> Self {
232        Self::new(Keymap::default())
233    }
234}
235
236impl Keymap {
237    pub fn new(bindings: Vec<Binding>) -> Self {
238        let mut binding_indices_by_action_type = HashMap::new();
239        for (ix, binding) in bindings.iter().enumerate() {
240            binding_indices_by_action_type
241                .entry(binding.action.as_any().type_id())
242                .or_insert_with(SmallVec::new)
243                .push(ix);
244        }
245        Self {
246            binding_indices_by_action_type,
247            bindings,
248        }
249    }
250
251    fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
252        self.binding_indices_by_action_type
253            .get(&action_type)
254            .map(SmallVec::as_slice)
255            .unwrap_or(&[])
256            .iter()
257            .map(|ix| &self.bindings[*ix])
258    }
259
260    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
261        for binding in bindings {
262            self.binding_indices_by_action_type
263                .entry(binding.action.as_any().type_id())
264                .or_default()
265                .push(self.bindings.len());
266            self.bindings.push(binding);
267        }
268    }
269
270    fn clear(&mut self) {
271        self.bindings.clear();
272        self.binding_indices_by_action_type.clear();
273    }
274}
275
276impl Binding {
277    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
278        Self::load(keystrokes, Box::new(action), context).unwrap()
279    }
280
281    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
282        let context = if let Some(context) = context {
283            Some(ContextPredicate::parse(context)?)
284        } else {
285            None
286        };
287
288        let keystrokes = keystrokes
289            .split_whitespace()
290            .map(Keystroke::parse)
291            .collect::<Result<_>>()?;
292
293        Ok(Self {
294            keystrokes,
295            action,
296            context_predicate: context,
297        })
298    }
299
300    pub fn keystrokes(&self) -> &[Keystroke] {
301        &self.keystrokes
302    }
303
304    pub fn action(&self) -> &dyn Action {
305        self.action.as_ref()
306    }
307}
308
309impl Keystroke {
310    pub fn parse(source: &str) -> anyhow::Result<Self> {
311        let mut ctrl = false;
312        let mut alt = false;
313        let mut shift = false;
314        let mut cmd = false;
315        let mut function = false;
316        let mut key = None;
317
318        let mut components = source.split('-').peekable();
319        while let Some(component) = components.next() {
320            match component {
321                "ctrl" => ctrl = true,
322                "alt" => alt = true,
323                "shift" => shift = true,
324                "cmd" => cmd = true,
325                "fn" => function = true,
326                _ => {
327                    if let Some(component) = components.peek() {
328                        if component.is_empty() && source.ends_with('-') {
329                            key = Some(String::from("-"));
330                            break;
331                        } else {
332                            return Err(anyhow!("Invalid keystroke `{}`", source));
333                        }
334                    } else {
335                        key = Some(String::from(component));
336                    }
337                }
338            }
339        }
340
341        let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
342
343        Ok(Keystroke {
344            ctrl,
345            alt,
346            shift,
347            cmd,
348            function,
349            key,
350        })
351    }
352
353    pub fn modified(&self) -> bool {
354        self.ctrl || self.alt || self.shift || self.cmd
355    }
356}
357
358impl std::fmt::Display for Keystroke {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        if self.ctrl {
361            f.write_char('^')?;
362        }
363        if self.alt {
364            f.write_char('⎇')?;
365        }
366        if self.cmd {
367            f.write_char('⌘')?;
368        }
369        if self.shift {
370            f.write_char('⇧')?;
371        }
372        let key = match self.key.as_str() {
373            "backspace" => '⌫',
374            "up" => '↑',
375            "down" => '↓',
376            "left" => '←',
377            "right" => '→',
378            "tab" => '⇥',
379            "escape" => '⎋',
380            key => {
381                if key.len() == 1 {
382                    key.chars().next().unwrap().to_ascii_uppercase()
383                } else {
384                    return f.write_str(key);
385                }
386            }
387        };
388        f.write_char(key)
389    }
390}
391
392impl Context {
393    pub fn extend(&mut self, other: &Context) {
394        for v in &other.set {
395            self.set.insert(v.clone());
396        }
397        for (k, v) in &other.map {
398            self.map.insert(k.clone(), v.clone());
399        }
400    }
401}
402
403impl ContextPredicate {
404    fn parse(source: &str) -> anyhow::Result<Self> {
405        let mut parser = Parser::new();
406        let language = unsafe { tree_sitter_context_predicate() };
407        parser.set_language(language).unwrap();
408        let source = source.as_bytes();
409        let tree = parser.parse(source, None).unwrap();
410        Self::from_node(tree.root_node(), source)
411    }
412
413    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
414        let parse_error = "error parsing context predicate";
415        let kind = node.kind();
416
417        match kind {
418            "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
419            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
420            "not" => {
421                let child = Self::from_node(
422                    node.child_by_field_name("expression")
423                        .ok_or_else(|| anyhow!(parse_error))?,
424                    source,
425                )?;
426                Ok(Self::Not(Box::new(child)))
427            }
428            "and" | "or" => {
429                let left = Box::new(Self::from_node(
430                    node.child_by_field_name("left")
431                        .ok_or_else(|| anyhow!(parse_error))?,
432                    source,
433                )?);
434                let right = Box::new(Self::from_node(
435                    node.child_by_field_name("right")
436                        .ok_or_else(|| anyhow!(parse_error))?,
437                    source,
438                )?);
439                if kind == "and" {
440                    Ok(Self::And(left, right))
441                } else {
442                    Ok(Self::Or(left, right))
443                }
444            }
445            "equal" | "not_equal" => {
446                let left = node
447                    .child_by_field_name("left")
448                    .ok_or_else(|| anyhow!(parse_error))?
449                    .utf8_text(source)?
450                    .into();
451                let right = node
452                    .child_by_field_name("right")
453                    .ok_or_else(|| anyhow!(parse_error))?
454                    .utf8_text(source)?
455                    .into();
456                if kind == "equal" {
457                    Ok(Self::Equal(left, right))
458                } else {
459                    Ok(Self::NotEqual(left, right))
460                }
461            }
462            "parenthesized" => Self::from_node(
463                node.child_by_field_name("expression")
464                    .ok_or_else(|| anyhow!(parse_error))?,
465                source,
466            ),
467            _ => Err(anyhow!(parse_error)),
468        }
469    }
470
471    fn eval(&self, cx: &Context) -> bool {
472        match self {
473            Self::Identifier(name) => cx.set.contains(name.as_str()),
474            Self::Equal(left, right) => cx
475                .map
476                .get(left)
477                .map(|value| value == right)
478                .unwrap_or(false),
479            Self::NotEqual(left, right) => {
480                cx.map.get(left).map(|value| value != right).unwrap_or(true)
481            }
482            Self::Not(pred) => !pred.eval(cx),
483            Self::And(left, right) => left.eval(cx) && right.eval(cx),
484            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
485        }
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use anyhow::Result;
492    use serde::Deserialize;
493
494    use crate::{actions, impl_actions};
495
496    use super::*;
497
498    #[test]
499    fn test_push_keystroke() -> Result<()> {
500        actions!(test, [B, AB, C, D, DA]);
501
502        let mut ctx1 = Context::default();
503        ctx1.set.insert("1".into());
504
505        let mut ctx2 = Context::default();
506        ctx2.set.insert("2".into());
507
508        let dispatch_path = vec![(2, ctx2), (1, ctx1)];
509
510        let keymap = Keymap::new(vec![
511            Binding::new("a b", AB, Some("1")),
512            Binding::new("b", B, Some("2")),
513            Binding::new("c", C, Some("2")),
514            Binding::new("d", D, Some("1")),
515            Binding::new("d", D, Some("2")),
516            Binding::new("d a", DA, Some("2")),
517        ]);
518
519        let mut matcher = Matcher::new(keymap);
520
521        // Binding with pending prefix always takes precedence
522        assert_eq!(
523            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
524            MatchResult::Pending,
525        );
526        // B alone doesn't match because a was pending, so AB is returned instead
527        assert_eq!(
528            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
529            MatchResult::Matches(vec![(1, Box::new(AB))]),
530        );
531        assert!(!matcher.has_pending_keystrokes());
532
533        // Without an a prefix, B is dispatched like expected
534        assert_eq!(
535            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
536            MatchResult::Matches(vec![(2, Box::new(B))]),
537        );
538        assert!(!matcher.has_pending_keystrokes());
539
540        // If a is prefixed, C will not be dispatched because there
541        // was a pending binding for it
542        assert_eq!(
543            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
544            MatchResult::Pending,
545        );
546        assert_eq!(
547            matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone()),
548            MatchResult::None,
549        );
550        assert!(!matcher.has_pending_keystrokes());
551
552        // If a single keystroke matches multiple bindings in the tree
553        // all of them are returned so that we can fallback if the action
554        // handler decides to propagate the action
555        assert_eq!(
556            matcher.push_keystroke(Keystroke::parse("d")?, dispatch_path.clone()),
557            MatchResult::Matches(vec![(2, Box::new(D)), (1, Box::new(D))]),
558        );
559        // If none of the d action handlers consume the binding, a pending
560        // binding may then be used
561        assert_eq!(
562            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
563            MatchResult::Matches(vec![(2, Box::new(DA))]),
564        );
565        assert!(!matcher.has_pending_keystrokes());
566
567        Ok(())
568    }
569
570    #[test]
571    fn test_keystroke_parsing() -> Result<()> {
572        assert_eq!(
573            Keystroke::parse("ctrl-p")?,
574            Keystroke {
575                key: "p".into(),
576                ctrl: true,
577                alt: false,
578                shift: false,
579                cmd: false,
580                function: false,
581            }
582        );
583
584        assert_eq!(
585            Keystroke::parse("alt-shift-down")?,
586            Keystroke {
587                key: "down".into(),
588                ctrl: false,
589                alt: true,
590                shift: true,
591                cmd: false,
592                function: false,
593            }
594        );
595
596        assert_eq!(
597            Keystroke::parse("shift-cmd--")?,
598            Keystroke {
599                key: "-".into(),
600                ctrl: false,
601                alt: false,
602                shift: true,
603                cmd: true,
604                function: false,
605            }
606        );
607
608        Ok(())
609    }
610
611    #[test]
612    fn test_context_predicate_parsing() -> Result<()> {
613        use ContextPredicate::*;
614
615        assert_eq!(
616            ContextPredicate::parse("a && (b == c || d != e)")?,
617            And(
618                Box::new(Identifier("a".into())),
619                Box::new(Or(
620                    Box::new(Equal("b".into(), "c".into())),
621                    Box::new(NotEqual("d".into(), "e".into())),
622                ))
623            )
624        );
625
626        assert_eq!(
627            ContextPredicate::parse("!a")?,
628            Not(Box::new(Identifier("a".into())),)
629        );
630
631        Ok(())
632    }
633
634    #[test]
635    fn test_context_predicate_eval() -> Result<()> {
636        let predicate = ContextPredicate::parse("a && b || c == d")?;
637
638        let mut context = Context::default();
639        context.set.insert("a".into());
640        assert!(!predicate.eval(&context));
641
642        context.set.insert("b".into());
643        assert!(predicate.eval(&context));
644
645        context.set.remove("b");
646        context.map.insert("c".into(), "x".into());
647        assert!(!predicate.eval(&context));
648
649        context.map.insert("c".into(), "d".into());
650        assert!(predicate.eval(&context));
651
652        let predicate = ContextPredicate::parse("!a")?;
653        assert!(predicate.eval(&Context::default()));
654
655        Ok(())
656    }
657
658    #[test]
659    fn test_matcher() -> Result<()> {
660        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
661        pub struct A(pub String);
662        impl_actions!(test, [A]);
663        actions!(test, [B, Ab]);
664
665        #[derive(Clone, Debug, Eq, PartialEq)]
666        struct ActionArg {
667            a: &'static str,
668        }
669
670        let keymap = Keymap::new(vec![
671            Binding::new("a", A("x".to_string()), Some("a")),
672            Binding::new("b", B, Some("a")),
673            Binding::new("a b", Ab, Some("a || b")),
674        ]);
675
676        let mut ctx_a = Context::default();
677        ctx_a.set.insert("a".into());
678
679        let mut ctx_b = Context::default();
680        ctx_b.set.insert("b".into());
681
682        let mut matcher = Matcher::new(keymap);
683
684        // Basic match
685        assert_eq!(
686            matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_a.clone())]),
687            MatchResult::Matches(vec![(1, Box::new(A("x".to_string())))])
688        );
689        matcher.clear_pending();
690
691        // Multi-keystroke match
692        assert_eq!(
693            matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_b.clone())]),
694            MatchResult::Pending
695        );
696        assert_eq!(
697            matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_b.clone())]),
698            MatchResult::Matches(vec![(1, Box::new(Ab))])
699        );
700        matcher.clear_pending();
701
702        // Failed matches don't interfere with matching subsequent keys
703        assert_eq!(
704            matcher.push_keystroke(Keystroke::parse("x")?, vec![(1, ctx_a.clone())]),
705            MatchResult::None
706        );
707        assert_eq!(
708            matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_a.clone())]),
709            MatchResult::Matches(vec![(1, Box::new(A("x".to_string())))])
710        );
711        matcher.clear_pending();
712
713        // Pending keystrokes are cleared when the context changes
714        assert_eq!(
715            matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_b.clone())]),
716            MatchResult::Pending
717        );
718        assert_eq!(
719            matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_a.clone())]),
720            MatchResult::None
721        );
722        matcher.clear_pending();
723
724        let mut ctx_c = Context::default();
725        ctx_c.set.insert("c".into());
726
727        // Pending keystrokes are maintained per-view
728        assert_eq!(
729            matcher.push_keystroke(
730                Keystroke::parse("a")?,
731                vec![(1, ctx_b.clone()), (2, ctx_c.clone())]
732            ),
733            MatchResult::Pending
734        );
735        assert_eq!(
736            matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_b.clone())]),
737            MatchResult::Matches(vec![(1, Box::new(Ab))])
738        );
739
740        Ok(())
741    }
742}