keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::{Debug, Write},
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending_views: HashMap<usize, Context>,
 17    pending_keystrokes: Vec<Keystroke>,
 18    keymap: Keymap,
 19}
 20
 21#[derive(Default)]
 22pub struct Keymap {
 23    bindings: Vec<Binding>,
 24    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 25}
 26
 27pub struct Binding {
 28    keystrokes: SmallVec<[Keystroke; 2]>,
 29    action: Box<dyn Action>,
 30    context_predicate: Option<ContextPredicate>,
 31}
 32
 33#[derive(Clone, Debug, Eq, PartialEq)]
 34pub struct Keystroke {
 35    pub ctrl: bool,
 36    pub alt: bool,
 37    pub shift: bool,
 38    pub cmd: bool,
 39    pub function: bool,
 40    pub key: String,
 41}
 42
 43#[derive(Clone, Debug, Default, Eq, PartialEq)]
 44pub struct Context {
 45    pub set: HashSet<String>,
 46    pub map: HashMap<String, String>,
 47}
 48
 49#[derive(Debug, Eq, PartialEq)]
 50enum ContextPredicate {
 51    Identifier(String),
 52    Equal(String, String),
 53    NotEqual(String, String),
 54    Not(Box<ContextPredicate>),
 55    And(Box<ContextPredicate>, Box<ContextPredicate>),
 56    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 57}
 58
 59trait ActionArg {
 60    fn boxed_clone(&self) -> Box<dyn Any>;
 61}
 62
 63impl<T> ActionArg for T
 64where
 65    T: 'static + Any + Clone,
 66{
 67    fn boxed_clone(&self) -> Box<dyn Any> {
 68        Box::new(self.clone())
 69    }
 70}
 71
 72pub enum MatchResult {
 73    None,
 74    Pending,
 75    Matches(Vec<(usize, Box<dyn Action>)>),
 76}
 77
 78impl Debug for MatchResult {
 79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 80        match self {
 81            MatchResult::None => f.debug_struct("MatchResult::None").finish(),
 82            MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
 83            MatchResult::Matches(matches) => f
 84                .debug_list()
 85                .entries(
 86                    matches
 87                        .iter()
 88                        .map(|(view_id, action)| format!("{view_id}, {}", action.name())),
 89                )
 90                .finish(),
 91        }
 92    }
 93}
 94
 95impl PartialEq for MatchResult {
 96    fn eq(&self, other: &Self) -> bool {
 97        match (self, other) {
 98            (MatchResult::None, MatchResult::None) => true,
 99            (MatchResult::Pending, MatchResult::Pending) => true,
100            (MatchResult::Matches(matches), MatchResult::Matches(other_matches)) => {
101                matches.len() == other_matches.len()
102                    && matches.iter().zip(other_matches.iter()).all(
103                        |((view_id, action), (other_view_id, other_action))| {
104                            view_id == other_view_id && action.eq(other_action.as_ref())
105                        },
106                    )
107            }
108            _ => false,
109        }
110    }
111}
112
113impl Eq for MatchResult {}
114
115impl Clone for MatchResult {
116    fn clone(&self) -> Self {
117        match self {
118            MatchResult::None => MatchResult::None,
119            MatchResult::Pending => MatchResult::Pending,
120            MatchResult::Matches(matches) => MatchResult::Matches(
121                matches
122                    .iter()
123                    .map(|(view_id, action)| (*view_id, Action::boxed_clone(action.as_ref())))
124                    .collect(),
125            ),
126        }
127    }
128}
129
130impl Matcher {
131    pub fn new(keymap: Keymap) -> Self {
132        Self {
133            pending_views: HashMap::new(),
134            pending_keystrokes: Vec::new(),
135            keymap,
136        }
137    }
138
139    pub fn set_keymap(&mut self, keymap: Keymap) {
140        self.clear_pending();
141        self.keymap = keymap;
142    }
143
144    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
145        self.clear_pending();
146        self.keymap.add_bindings(bindings);
147    }
148
149    pub fn clear_bindings(&mut self) {
150        self.clear_pending();
151        self.keymap.clear();
152    }
153
154    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
155        self.keymap.bindings_for_action_type(action_type)
156    }
157
158    pub fn clear_pending(&mut self) {
159        self.pending_keystrokes.clear();
160        self.pending_views.clear();
161    }
162
163    pub fn has_pending_keystrokes(&self) -> bool {
164        !self.pending_keystrokes.is_empty()
165    }
166
167    pub fn push_keystroke(
168        &mut self,
169        keystroke: Keystroke,
170        dispatch_path: Vec<(usize, Context)>,
171    ) -> MatchResult {
172        let mut any_pending = false;
173        let mut matched_bindings = Vec::new();
174
175        let first_keystroke = self.pending_keystrokes.is_empty();
176        self.pending_keystrokes.push(keystroke);
177
178        for (view_id, context) in dispatch_path {
179            // Don't require pending view entry if there are no pending keystrokes
180            if !first_keystroke && !self.pending_views.contains_key(&view_id) {
181                continue;
182            }
183
184            // If there is a previous view context, invalidate that view if it
185            // has changed
186            if let Some(previous_view_context) = self.pending_views.remove(&view_id) {
187                if previous_view_context != context {
188                    continue;
189                }
190            }
191
192            // Find the bindings which map the pending keystrokes and current context
193            for binding in self.keymap.bindings.iter().rev() {
194                if binding.keystrokes.starts_with(&self.pending_keystrokes)
195                    && binding
196                        .context_predicate
197                        .as_ref()
198                        .map(|c| c.eval(&context))
199                        .unwrap_or(true)
200                {
201                    // If the binding is completed, push it onto the matches list
202                    if binding.keystrokes.len() == self.pending_keystrokes.len() {
203                        matched_bindings.push((view_id, binding.action.boxed_clone()));
204                    } else {
205                        // Otherwise, the binding is still pending
206                        self.pending_views.insert(view_id, context.clone());
207                        any_pending = true;
208                    }
209                }
210            }
211        }
212
213        if !any_pending {
214            self.clear_pending();
215        }
216
217        if !matched_bindings.is_empty() {
218            MatchResult::Matches(matched_bindings)
219        } else if any_pending {
220            MatchResult::Pending
221        } else {
222            MatchResult::None
223        }
224    }
225
226    pub fn keystrokes_for_action(
227        &self,
228        action: &dyn Action,
229        cx: &Context,
230    ) -> Option<SmallVec<[Keystroke; 2]>> {
231        for binding in self.keymap.bindings.iter().rev() {
232            if binding.action.eq(action)
233                && binding
234                    .context_predicate
235                    .as_ref()
236                    .map_or(true, |predicate| predicate.eval(cx))
237            {
238                return Some(binding.keystrokes.clone());
239            }
240        }
241        None
242    }
243}
244
245impl Default for Matcher {
246    fn default() -> Self {
247        Self::new(Keymap::default())
248    }
249}
250
251impl Keymap {
252    pub fn new(bindings: Vec<Binding>) -> Self {
253        let mut binding_indices_by_action_type = HashMap::new();
254        for (ix, binding) in bindings.iter().enumerate() {
255            binding_indices_by_action_type
256                .entry(binding.action.as_any().type_id())
257                .or_insert_with(SmallVec::new)
258                .push(ix);
259        }
260        Self {
261            binding_indices_by_action_type,
262            bindings,
263        }
264    }
265
266    fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
267        self.binding_indices_by_action_type
268            .get(&action_type)
269            .map(SmallVec::as_slice)
270            .unwrap_or(&[])
271            .iter()
272            .map(|ix| &self.bindings[*ix])
273    }
274
275    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
276        for binding in bindings {
277            self.binding_indices_by_action_type
278                .entry(binding.action.as_any().type_id())
279                .or_default()
280                .push(self.bindings.len());
281            self.bindings.push(binding);
282        }
283    }
284
285    fn clear(&mut self) {
286        self.bindings.clear();
287        self.binding_indices_by_action_type.clear();
288    }
289}
290
291impl Binding {
292    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
293        Self::load(keystrokes, Box::new(action), context).unwrap()
294    }
295
296    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
297        let context = if let Some(context) = context {
298            Some(ContextPredicate::parse(context)?)
299        } else {
300            None
301        };
302
303        let keystrokes = keystrokes
304            .split_whitespace()
305            .map(Keystroke::parse)
306            .collect::<Result<_>>()?;
307
308        Ok(Self {
309            keystrokes,
310            action,
311            context_predicate: context,
312        })
313    }
314
315    pub fn keystrokes(&self) -> &[Keystroke] {
316        &self.keystrokes
317    }
318
319    pub fn action(&self) -> &dyn Action {
320        self.action.as_ref()
321    }
322}
323
324impl Keystroke {
325    pub fn parse(source: &str) -> anyhow::Result<Self> {
326        let mut ctrl = false;
327        let mut alt = false;
328        let mut shift = false;
329        let mut cmd = false;
330        let mut function = false;
331        let mut key = None;
332
333        let mut components = source.split('-').peekable();
334        while let Some(component) = components.next() {
335            match component {
336                "ctrl" => ctrl = true,
337                "alt" => alt = true,
338                "shift" => shift = true,
339                "cmd" => cmd = true,
340                "fn" => function = true,
341                _ => {
342                    if let Some(component) = components.peek() {
343                        if component.is_empty() && source.ends_with('-') {
344                            key = Some(String::from("-"));
345                            break;
346                        } else {
347                            return Err(anyhow!("Invalid keystroke `{}`", source));
348                        }
349                    } else {
350                        key = Some(String::from(component));
351                    }
352                }
353            }
354        }
355
356        let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
357
358        Ok(Keystroke {
359            ctrl,
360            alt,
361            shift,
362            cmd,
363            function,
364            key,
365        })
366    }
367
368    pub fn modified(&self) -> bool {
369        self.ctrl || self.alt || self.shift || self.cmd
370    }
371}
372
373impl std::fmt::Display for Keystroke {
374    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375        if self.ctrl {
376            f.write_char('^')?;
377        }
378        if self.alt {
379            f.write_char('⎇')?;
380        }
381        if self.cmd {
382            f.write_char('⌘')?;
383        }
384        if self.shift {
385            f.write_char('⇧')?;
386        }
387        let key = match self.key.as_str() {
388            "backspace" => '⌫',
389            "up" => '↑',
390            "down" => '↓',
391            "left" => '←',
392            "right" => '→',
393            "tab" => '⇥',
394            "escape" => '⎋',
395            key => {
396                if key.len() == 1 {
397                    key.chars().next().unwrap().to_ascii_uppercase()
398                } else {
399                    return f.write_str(key);
400                }
401            }
402        };
403        f.write_char(key)
404    }
405}
406
407impl Context {
408    pub fn extend(&mut self, other: &Context) {
409        for v in &other.set {
410            self.set.insert(v.clone());
411        }
412        for (k, v) in &other.map {
413            self.map.insert(k.clone(), v.clone());
414        }
415    }
416}
417
418impl ContextPredicate {
419    fn parse(source: &str) -> anyhow::Result<Self> {
420        let mut parser = Parser::new();
421        let language = unsafe { tree_sitter_context_predicate() };
422        parser.set_language(language).unwrap();
423        let source = source.as_bytes();
424        let tree = parser.parse(source, None).unwrap();
425        Self::from_node(tree.root_node(), source)
426    }
427
428    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
429        let parse_error = "error parsing context predicate";
430        let kind = node.kind();
431
432        match kind {
433            "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
434            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
435            "not" => {
436                let child = Self::from_node(
437                    node.child_by_field_name("expression")
438                        .ok_or_else(|| anyhow!(parse_error))?,
439                    source,
440                )?;
441                Ok(Self::Not(Box::new(child)))
442            }
443            "and" | "or" => {
444                let left = Box::new(Self::from_node(
445                    node.child_by_field_name("left")
446                        .ok_or_else(|| anyhow!(parse_error))?,
447                    source,
448                )?);
449                let right = Box::new(Self::from_node(
450                    node.child_by_field_name("right")
451                        .ok_or_else(|| anyhow!(parse_error))?,
452                    source,
453                )?);
454                if kind == "and" {
455                    Ok(Self::And(left, right))
456                } else {
457                    Ok(Self::Or(left, right))
458                }
459            }
460            "equal" | "not_equal" => {
461                let left = node
462                    .child_by_field_name("left")
463                    .ok_or_else(|| anyhow!(parse_error))?
464                    .utf8_text(source)?
465                    .into();
466                let right = node
467                    .child_by_field_name("right")
468                    .ok_or_else(|| anyhow!(parse_error))?
469                    .utf8_text(source)?
470                    .into();
471                if kind == "equal" {
472                    Ok(Self::Equal(left, right))
473                } else {
474                    Ok(Self::NotEqual(left, right))
475                }
476            }
477            "parenthesized" => Self::from_node(
478                node.child_by_field_name("expression")
479                    .ok_or_else(|| anyhow!(parse_error))?,
480                source,
481            ),
482            _ => Err(anyhow!(parse_error)),
483        }
484    }
485
486    fn eval(&self, cx: &Context) -> bool {
487        match self {
488            Self::Identifier(name) => cx.set.contains(name.as_str()),
489            Self::Equal(left, right) => cx
490                .map
491                .get(left)
492                .map(|value| value == right)
493                .unwrap_or(false),
494            Self::NotEqual(left, right) => {
495                cx.map.get(left).map(|value| value != right).unwrap_or(true)
496            }
497            Self::Not(pred) => !pred.eval(cx),
498            Self::And(left, right) => left.eval(cx) && right.eval(cx),
499            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
500        }
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use anyhow::Result;
507    use serde::Deserialize;
508
509    use crate::{actions, impl_actions};
510
511    use super::*;
512
513    #[test]
514    fn test_push_keystroke() -> Result<()> {
515        actions!(test, [B, AB, C, D, DA]);
516
517        let mut ctx1 = Context::default();
518        ctx1.set.insert("1".into());
519
520        let mut ctx2 = Context::default();
521        ctx2.set.insert("2".into());
522
523        let dispatch_path = vec![(2, ctx2), (1, ctx1)];
524
525        let keymap = Keymap::new(vec![
526            Binding::new("a b", AB, Some("1")),
527            Binding::new("b", B, Some("2")),
528            Binding::new("c", C, Some("2")),
529            Binding::new("d", D, Some("1")),
530            Binding::new("d", D, Some("2")),
531            Binding::new("d a", DA, Some("2")),
532        ]);
533
534        let mut matcher = Matcher::new(keymap);
535
536        // Binding with pending prefix always takes precedence
537        assert_eq!(
538            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
539            MatchResult::Pending,
540        );
541        // B alone doesn't match because a was pending, so AB is returned instead
542        assert_eq!(
543            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
544            MatchResult::Matches(vec![(1, Box::new(AB))]),
545        );
546        assert!(!matcher.has_pending_keystrokes());
547
548        // Without an a prefix, B is dispatched like expected
549        assert_eq!(
550            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone()),
551            MatchResult::Matches(vec![(2, Box::new(B))]),
552        );
553        assert!(!matcher.has_pending_keystrokes());
554
555        // If a is prefixed, C will not be dispatched because there
556        // was a pending binding for it
557        assert_eq!(
558            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
559            MatchResult::Pending,
560        );
561        assert_eq!(
562            matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone()),
563            MatchResult::None,
564        );
565        assert!(!matcher.has_pending_keystrokes());
566
567        // If a single keystroke matches multiple bindings in the tree
568        // all of them are returned so that we can fallback if the action
569        // handler decides to propagate the action
570        assert_eq!(
571            matcher.push_keystroke(Keystroke::parse("d")?, dispatch_path.clone()),
572            MatchResult::Matches(vec![(2, Box::new(D)), (1, Box::new(D))]),
573        );
574        // If none of the d action handlers consume the binding, a pending
575        // binding may then be used
576        assert_eq!(
577            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone()),
578            MatchResult::Matches(vec![(2, Box::new(DA))]),
579        );
580        assert!(!matcher.has_pending_keystrokes());
581
582        Ok(())
583    }
584
585    #[test]
586    fn test_keystroke_parsing() -> Result<()> {
587        assert_eq!(
588            Keystroke::parse("ctrl-p")?,
589            Keystroke {
590                key: "p".into(),
591                ctrl: true,
592                alt: false,
593                shift: false,
594                cmd: false,
595                function: false,
596            }
597        );
598
599        assert_eq!(
600            Keystroke::parse("alt-shift-down")?,
601            Keystroke {
602                key: "down".into(),
603                ctrl: false,
604                alt: true,
605                shift: true,
606                cmd: false,
607                function: false,
608            }
609        );
610
611        assert_eq!(
612            Keystroke::parse("shift-cmd--")?,
613            Keystroke {
614                key: "-".into(),
615                ctrl: false,
616                alt: false,
617                shift: true,
618                cmd: true,
619                function: false,
620            }
621        );
622
623        Ok(())
624    }
625
626    #[test]
627    fn test_context_predicate_parsing() -> Result<()> {
628        use ContextPredicate::*;
629
630        assert_eq!(
631            ContextPredicate::parse("a && (b == c || d != e)")?,
632            And(
633                Box::new(Identifier("a".into())),
634                Box::new(Or(
635                    Box::new(Equal("b".into(), "c".into())),
636                    Box::new(NotEqual("d".into(), "e".into())),
637                ))
638            )
639        );
640
641        assert_eq!(
642            ContextPredicate::parse("!a")?,
643            Not(Box::new(Identifier("a".into())),)
644        );
645
646        Ok(())
647    }
648
649    #[test]
650    fn test_context_predicate_eval() -> Result<()> {
651        let predicate = ContextPredicate::parse("a && b || c == d")?;
652
653        let mut context = Context::default();
654        context.set.insert("a".into());
655        assert!(!predicate.eval(&context));
656
657        context.set.insert("b".into());
658        assert!(predicate.eval(&context));
659
660        context.set.remove("b");
661        context.map.insert("c".into(), "x".into());
662        assert!(!predicate.eval(&context));
663
664        context.map.insert("c".into(), "d".into());
665        assert!(predicate.eval(&context));
666
667        let predicate = ContextPredicate::parse("!a")?;
668        assert!(predicate.eval(&Context::default()));
669
670        Ok(())
671    }
672
673    #[test]
674    fn test_matcher() -> Result<()> {
675        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
676        pub struct A(pub String);
677        impl_actions!(test, [A]);
678        actions!(test, [B, Ab]);
679
680        #[derive(Clone, Debug, Eq, PartialEq)]
681        struct ActionArg {
682            a: &'static str,
683        }
684
685        let keymap = Keymap::new(vec![
686            Binding::new("a", A("x".to_string()), Some("a")),
687            Binding::new("b", B, Some("a")),
688            Binding::new("a b", Ab, Some("a || b")),
689        ]);
690
691        let mut ctx_a = Context::default();
692        ctx_a.set.insert("a".into());
693
694        let mut ctx_b = Context::default();
695        ctx_b.set.insert("b".into());
696
697        let mut matcher = Matcher::new(keymap);
698
699        // Basic match
700        assert_eq!(
701            matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_a.clone())]),
702            MatchResult::Matches(vec![(1, Box::new(A("x".to_string())))])
703        );
704        matcher.clear_pending();
705
706        // Multi-keystroke match
707        assert_eq!(
708            matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_b.clone())]),
709            MatchResult::Pending
710        );
711        assert_eq!(
712            matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_b.clone())]),
713            MatchResult::Matches(vec![(1, Box::new(Ab))])
714        );
715        matcher.clear_pending();
716
717        // Failed matches don't interfere with matching subsequent keys
718        assert_eq!(
719            matcher.push_keystroke(Keystroke::parse("x")?, vec![(1, ctx_a.clone())]),
720            MatchResult::None
721        );
722        assert_eq!(
723            matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_a.clone())]),
724            MatchResult::Matches(vec![(1, Box::new(A("x".to_string())))])
725        );
726        matcher.clear_pending();
727
728        // Pending keystrokes are cleared when the context changes
729        assert_eq!(
730            matcher.push_keystroke(Keystroke::parse("a")?, vec![(1, ctx_b.clone())]),
731            MatchResult::Pending
732        );
733        assert_eq!(
734            matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_a.clone())]),
735            MatchResult::None
736        );
737        matcher.clear_pending();
738
739        let mut ctx_c = Context::default();
740        ctx_c.set.insert("c".into());
741
742        // Pending keystrokes are maintained per-view
743        assert_eq!(
744            matcher.push_keystroke(
745                Keystroke::parse("a")?,
746                vec![(1, ctx_b.clone()), (2, ctx_c.clone())]
747            ),
748            MatchResult::Pending
749        );
750        assert_eq!(
751            matcher.push_keystroke(Keystroke::parse("b")?, vec![(1, ctx_b.clone())]),
752            MatchResult::Matches(vec![(1, Box::new(Ab))])
753        );
754
755        Ok(())
756    }
757}