keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::{Debug, Write},
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending: HashMap<usize, Pending>,
 17    keymap: Keymap,
 18}
 19
 20#[derive(Default)]
 21struct Pending {
 22    keystrokes: Vec<Keystroke>,
 23    context: Option<Context>,
 24}
 25
 26#[derive(Default)]
 27pub struct Keymap {
 28    bindings: Vec<Binding>,
 29    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 30}
 31
 32pub struct Binding {
 33    keystrokes: SmallVec<[Keystroke; 2]>,
 34    action: Box<dyn Action>,
 35    context_predicate: Option<ContextPredicate>,
 36}
 37
 38#[derive(Clone, Debug, Eq, PartialEq)]
 39pub struct Keystroke {
 40    pub ctrl: bool,
 41    pub alt: bool,
 42    pub shift: bool,
 43    pub cmd: bool,
 44    pub function: bool,
 45    pub key: String,
 46}
 47
 48#[derive(Clone, Debug, Default, Eq, PartialEq)]
 49pub struct Context {
 50    pub set: HashSet<String>,
 51    pub map: HashMap<String, String>,
 52}
 53
 54#[derive(Debug, Eq, PartialEq)]
 55enum ContextPredicate {
 56    Identifier(String),
 57    Equal(String, String),
 58    NotEqual(String, String),
 59    Not(Box<ContextPredicate>),
 60    And(Box<ContextPredicate>, Box<ContextPredicate>),
 61    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 62}
 63
 64trait ActionArg {
 65    fn boxed_clone(&self) -> Box<dyn Any>;
 66}
 67
 68impl<T> ActionArg for T
 69where
 70    T: 'static + Any + Clone,
 71{
 72    fn boxed_clone(&self) -> Box<dyn Any> {
 73        Box::new(self.clone())
 74    }
 75}
 76
 77pub enum MatchResult {
 78    None,
 79    Pending,
 80    Match(Vec<(usize, Box<dyn Action>)>),
 81}
 82
 83impl Debug for MatchResult {
 84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 85        match self {
 86            MatchResult::None => f.debug_struct("MatchResult2::None").finish(),
 87            MatchResult::Pending => f.debug_struct("MatchResult2::Pending").finish(),
 88            MatchResult::Match { view_id, action } => f
 89                .debug_struct("MatchResult::Match")
 90                .field("view_id", view_id)
 91                .field("action", &action.name())
 92                .finish(),
 93        }
 94    }
 95}
 96
 97impl PartialEq for MatchResult {
 98    fn eq(&self, other: &Self) -> bool {
 99        match (self, other) {
100            (MatchResult::None, MatchResult::None) => true,
101            (MatchResult::Pending, MatchResult::Pending) => true,
102            (
103                MatchResult::Match { view_id, action },
104                MatchResult::Match {
105                    view_id: other_view_id,
106                    action: other_action,
107                },
108            ) => view_id == other_view_id && action.eq(other_action.as_ref()),
109            _ => false,
110        }
111    }
112}
113
114impl Eq for MatchResult {}
115
116impl Matcher {
117    pub fn new(keymap: Keymap) -> Self {
118        Self {
119            pending: HashMap::new(),
120            keymap,
121        }
122    }
123
124    pub fn set_keymap(&mut self, keymap: Keymap) {
125        self.pending.clear();
126        self.keymap = keymap;
127    }
128
129    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
130        self.pending.clear();
131        self.keymap.add_bindings(bindings);
132    }
133
134    pub fn clear_bindings(&mut self) {
135        self.pending.clear();
136        self.keymap.clear();
137    }
138
139    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
140        self.keymap.bindings_for_action_type(action_type)
141    }
142
143    pub fn clear_pending(&mut self) {
144        self.pending.clear();
145    }
146
147    pub fn has_pending_keystrokes(&self) -> bool {
148        !self.pending.is_empty()
149    }
150
151    pub fn push_keystroke(
152        &mut self,
153        keystroke: Keystroke,
154        dispatch_path: Vec<(usize, Context)>,
155    ) -> MatchResult {
156        let mut any_pending = false;
157        let mut matched_bindings = Vec::new();
158
159        let first_keystroke = self.pending.is_empty();
160        dbg!(&dispatch_path);
161        for (view_id, context) in dispatch_path {
162            if !first_keystroke && !self.pending.contains_key(&view_id) {
163                continue;
164            }
165
166            let pending = self.pending.entry(view_id).or_default();
167
168            if let Some(pending_context) = pending.context.as_ref() {
169                if pending_context != &context {
170                    pending.keystrokes.clear();
171                }
172            }
173
174            pending.keystrokes.push(keystroke.clone());
175
176            let mut retain_pending = false;
177            for binding in self.keymap.bindings.iter().rev() {
178                if binding.keystrokes.starts_with(&pending.keystrokes)
179                    && binding
180                        .context_predicate
181                        .as_ref()
182                        .map(|c| c.eval(&context))
183                        .unwrap_or(true)
184                {
185                    if binding.keystrokes.len() == pending.keystrokes.len() {
186                        self.pending.remove(&view_id);
187                        matched_bindings.push((view_id, binding.action.boxed_clone()));
188                    } else {
189                        retain_pending = true;
190                        pending.context = Some(context.clone());
191                    }
192                }
193            }
194
195            if retain_pending {
196                any_pending = true;
197            } else {
198                self.pending.remove(&view_id);
199            }
200        }
201
202        if !matched_bindings.is_empty() {
203            MatchResult::Match(matched_bindings)
204        } else if any_pending {
205            MatchResult::Pending
206        } else {
207            MatchResult::None
208        }
209    }
210
211    pub fn keystrokes_for_action(
212        &self,
213        action: &dyn Action,
214        cx: &Context,
215    ) -> Option<SmallVec<[Keystroke; 2]>> {
216        for binding in self.keymap.bindings.iter().rev() {
217            if binding.action.eq(action)
218                && binding
219                    .context_predicate
220                    .as_ref()
221                    .map_or(true, |predicate| predicate.eval(cx))
222            {
223                return Some(binding.keystrokes.clone());
224            }
225        }
226        None
227    }
228}
229
230impl Default for Matcher {
231    fn default() -> Self {
232        Self::new(Keymap::default())
233    }
234}
235
236impl Keymap {
237    pub fn new(bindings: Vec<Binding>) -> Self {
238        let mut binding_indices_by_action_type = HashMap::new();
239        for (ix, binding) in bindings.iter().enumerate() {
240            binding_indices_by_action_type
241                .entry(binding.action.as_any().type_id())
242                .or_insert_with(SmallVec::new)
243                .push(ix);
244        }
245        Self {
246            binding_indices_by_action_type,
247            bindings,
248        }
249    }
250
251    fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
252        self.binding_indices_by_action_type
253            .get(&action_type)
254            .map(SmallVec::as_slice)
255            .unwrap_or(&[])
256            .iter()
257            .map(|ix| &self.bindings[*ix])
258    }
259
260    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
261        for binding in bindings {
262            self.binding_indices_by_action_type
263                .entry(binding.action.as_any().type_id())
264                .or_default()
265                .push(self.bindings.len());
266            self.bindings.push(binding);
267        }
268    }
269
270    fn clear(&mut self) {
271        self.bindings.clear();
272        self.binding_indices_by_action_type.clear();
273    }
274}
275
276impl Binding {
277    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
278        Self::load(keystrokes, Box::new(action), context).unwrap()
279    }
280
281    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
282        let context = if let Some(context) = context {
283            Some(ContextPredicate::parse(context)?)
284        } else {
285            None
286        };
287
288        let keystrokes = keystrokes
289            .split_whitespace()
290            .map(Keystroke::parse)
291            .collect::<Result<_>>()?;
292
293        Ok(Self {
294            keystrokes,
295            action,
296            context_predicate: context,
297        })
298    }
299
300    pub fn keystrokes(&self) -> &[Keystroke] {
301        &self.keystrokes
302    }
303
304    pub fn action(&self) -> &dyn Action {
305        self.action.as_ref()
306    }
307}
308
309impl Keystroke {
310    pub fn parse(source: &str) -> anyhow::Result<Self> {
311        let mut ctrl = false;
312        let mut alt = false;
313        let mut shift = false;
314        let mut cmd = false;
315        let mut function = false;
316        let mut key = None;
317
318        let mut components = source.split('-').peekable();
319        while let Some(component) = components.next() {
320            match component {
321                "ctrl" => ctrl = true,
322                "alt" => alt = true,
323                "shift" => shift = true,
324                "cmd" => cmd = true,
325                "fn" => function = true,
326                _ => {
327                    if let Some(component) = components.peek() {
328                        if component.is_empty() && source.ends_with('-') {
329                            key = Some(String::from("-"));
330                            break;
331                        } else {
332                            return Err(anyhow!("Invalid keystroke `{}`", source));
333                        }
334                    } else {
335                        key = Some(String::from(component));
336                    }
337                }
338            }
339        }
340
341        let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
342
343        Ok(Keystroke {
344            ctrl,
345            alt,
346            shift,
347            cmd,
348            function,
349            key,
350        })
351    }
352
353    pub fn modified(&self) -> bool {
354        self.ctrl || self.alt || self.shift || self.cmd
355    }
356}
357
358impl std::fmt::Display for Keystroke {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        if self.ctrl {
361            f.write_char('^')?;
362        }
363        if self.alt {
364            f.write_char('⎇')?;
365        }
366        if self.cmd {
367            f.write_char('⌘')?;
368        }
369        if self.shift {
370            f.write_char('⇧')?;
371        }
372        let key = match self.key.as_str() {
373            "backspace" => '⌫',
374            "up" => '↑',
375            "down" => '↓',
376            "left" => '←',
377            "right" => '→',
378            "tab" => '⇥',
379            "escape" => '⎋',
380            key => {
381                if key.len() == 1 {
382                    key.chars().next().unwrap().to_ascii_uppercase()
383                } else {
384                    return f.write_str(key);
385                }
386            }
387        };
388        f.write_char(key)
389    }
390}
391
392impl Context {
393    pub fn extend(&mut self, other: &Context) {
394        for v in &other.set {
395            self.set.insert(v.clone());
396        }
397        for (k, v) in &other.map {
398            self.map.insert(k.clone(), v.clone());
399        }
400    }
401}
402
403impl ContextPredicate {
404    fn parse(source: &str) -> anyhow::Result<Self> {
405        let mut parser = Parser::new();
406        let language = unsafe { tree_sitter_context_predicate() };
407        parser.set_language(language).unwrap();
408        let source = source.as_bytes();
409        let tree = parser.parse(source, None).unwrap();
410        Self::from_node(tree.root_node(), source)
411    }
412
413    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
414        let parse_error = "error parsing context predicate";
415        let kind = node.kind();
416
417        match kind {
418            "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
419            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
420            "not" => {
421                let child = Self::from_node(
422                    node.child_by_field_name("expression")
423                        .ok_or_else(|| anyhow!(parse_error))?,
424                    source,
425                )?;
426                Ok(Self::Not(Box::new(child)))
427            }
428            "and" | "or" => {
429                let left = Box::new(Self::from_node(
430                    node.child_by_field_name("left")
431                        .ok_or_else(|| anyhow!(parse_error))?,
432                    source,
433                )?);
434                let right = Box::new(Self::from_node(
435                    node.child_by_field_name("right")
436                        .ok_or_else(|| anyhow!(parse_error))?,
437                    source,
438                )?);
439                if kind == "and" {
440                    Ok(Self::And(left, right))
441                } else {
442                    Ok(Self::Or(left, right))
443                }
444            }
445            "equal" | "not_equal" => {
446                let left = node
447                    .child_by_field_name("left")
448                    .ok_or_else(|| anyhow!(parse_error))?
449                    .utf8_text(source)?
450                    .into();
451                let right = node
452                    .child_by_field_name("right")
453                    .ok_or_else(|| anyhow!(parse_error))?
454                    .utf8_text(source)?
455                    .into();
456                if kind == "equal" {
457                    Ok(Self::Equal(left, right))
458                } else {
459                    Ok(Self::NotEqual(left, right))
460                }
461            }
462            "parenthesized" => Self::from_node(
463                node.child_by_field_name("expression")
464                    .ok_or_else(|| anyhow!(parse_error))?,
465                source,
466            ),
467            _ => Err(anyhow!(parse_error)),
468        }
469    }
470
471    fn eval(&self, cx: &Context) -> bool {
472        match self {
473            Self::Identifier(name) => cx.set.contains(name.as_str()),
474            Self::Equal(left, right) => cx
475                .map
476                .get(left)
477                .map(|value| value == right)
478                .unwrap_or(false),
479            Self::NotEqual(left, right) => {
480                cx.map.get(left).map(|value| value != right).unwrap_or(true)
481            }
482            Self::Not(pred) => !pred.eval(cx),
483            Self::And(left, right) => left.eval(cx) && right.eval(cx),
484            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
485        }
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use anyhow::Result;
492    use serde::Deserialize;
493
494    use crate::{actions, impl_actions};
495
496    use super::*;
497
498    #[test]
499    fn test_push_keystroke() -> Result<()> {
500        actions!(test, [B, AB, C]);
501
502        let mut ctx1 = Context::default();
503        ctx1.set.insert("1".into());
504
505        let mut ctx2 = Context::default();
506        ctx2.set.insert("2".into());
507
508        let dispatch_path = vec![(2, ctx2), (1, ctx1)];
509
510        let keymap = Keymap::new(vec![
511            Binding::new("a b", AB, Some("1")),
512            Binding::new("b", B, Some("2")),
513            Binding::new("c", C, Some("2")),
514        ]);
515
516        let mut matcher = Matcher::new(keymap);
517
518        assert_eq!(
519            MatchResult::Pending,
520            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
521        );
522        assert_eq!(
523            MatchResult::Match {
524                view_id: 1,
525                action: Box::new(AB)
526            },
527            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
528        );
529        assert!(matcher.pending.is_empty());
530        assert_eq!(
531            MatchResult::Match {
532                view_id: 2,
533                action: Box::new(B)
534            },
535            matcher.push_keystroke(Keystroke::parse("b")?, dispatch_path.clone())
536        );
537        assert!(matcher.pending.is_empty());
538        assert_eq!(
539            MatchResult::Pending,
540            matcher.push_keystroke(Keystroke::parse("a")?, dispatch_path.clone())
541        );
542        assert_eq!(
543            MatchResult::None,
544            matcher.push_keystroke(Keystroke::parse("c")?, dispatch_path.clone())
545        );
546        assert!(matcher.pending.is_empty());
547
548        Ok(())
549    }
550
551    #[test]
552    fn test_keystroke_parsing() -> Result<()> {
553        assert_eq!(
554            Keystroke::parse("ctrl-p")?,
555            Keystroke {
556                key: "p".into(),
557                ctrl: true,
558                alt: false,
559                shift: false,
560                cmd: false,
561                function: false,
562            }
563        );
564
565        assert_eq!(
566            Keystroke::parse("alt-shift-down")?,
567            Keystroke {
568                key: "down".into(),
569                ctrl: false,
570                alt: true,
571                shift: true,
572                cmd: false,
573                function: false,
574            }
575        );
576
577        assert_eq!(
578            Keystroke::parse("shift-cmd--")?,
579            Keystroke {
580                key: "-".into(),
581                ctrl: false,
582                alt: false,
583                shift: true,
584                cmd: true,
585                function: false,
586            }
587        );
588
589        Ok(())
590    }
591
592    #[test]
593    fn test_context_predicate_parsing() -> Result<()> {
594        use ContextPredicate::*;
595
596        assert_eq!(
597            ContextPredicate::parse("a && (b == c || d != e)")?,
598            And(
599                Box::new(Identifier("a".into())),
600                Box::new(Or(
601                    Box::new(Equal("b".into(), "c".into())),
602                    Box::new(NotEqual("d".into(), "e".into())),
603                ))
604            )
605        );
606
607        assert_eq!(
608            ContextPredicate::parse("!a")?,
609            Not(Box::new(Identifier("a".into())),)
610        );
611
612        Ok(())
613    }
614
615    #[test]
616    fn test_context_predicate_eval() -> Result<()> {
617        let predicate = ContextPredicate::parse("a && b || c == d")?;
618
619        let mut context = Context::default();
620        context.set.insert("a".into());
621        assert!(!predicate.eval(&context));
622
623        context.set.insert("b".into());
624        assert!(predicate.eval(&context));
625
626        context.set.remove("b");
627        context.map.insert("c".into(), "x".into());
628        assert!(!predicate.eval(&context));
629
630        context.map.insert("c".into(), "d".into());
631        assert!(predicate.eval(&context));
632
633        let predicate = ContextPredicate::parse("!a")?;
634        assert!(predicate.eval(&Context::default()));
635
636        Ok(())
637    }
638
639    #[test]
640    fn test_matcher() -> Result<()> {
641        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
642        pub struct A(pub String);
643        impl_actions!(test, [A]);
644        actions!(test, [B, Ab]);
645
646        #[derive(Clone, Debug, Eq, PartialEq)]
647        struct ActionArg {
648            a: &'static str,
649        }
650
651        let keymap = Keymap::new(vec![
652            Binding::new("a", A("x".to_string()), Some("a")),
653            Binding::new("b", B, Some("a")),
654            Binding::new("a b", Ab, Some("a || b")),
655        ]);
656
657        let mut ctx_a = Context::default();
658        ctx_a.set.insert("a".into());
659
660        let mut ctx_b = Context::default();
661        ctx_b.set.insert("b".into());
662
663        let mut matcher = Matcher::new(keymap);
664
665        // Basic match
666        assert_eq!(
667            downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
668            Some(&A("x".to_string()))
669        );
670
671        // Multi-keystroke match
672        assert!(matcher
673            .test_keystroke("a", vec![(1, ctx_b.clone())])
674            .is_none());
675        assert_eq!(
676            downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
677            Some(&Ab)
678        );
679
680        // Failed matches don't interfere with matching subsequent keys
681        assert!(matcher
682            .test_keystroke("x", vec![(1, ctx_a.clone())])
683            .is_none());
684        assert_eq!(
685            downcast(&matcher.test_keystroke("a", vec![(1, ctx_a.clone())])),
686            Some(&A("x".to_string()))
687        );
688
689        // Pending keystrokes are cleared when the context changes
690        assert!(&matcher
691            .test_keystroke("a", vec![(1, ctx_b.clone())])
692            .is_none());
693        assert_eq!(
694            downcast(&matcher.test_keystroke("b", vec![(1, ctx_a.clone())])),
695            Some(&B)
696        );
697
698        let mut ctx_c = Context::default();
699        ctx_c.set.insert("c".into());
700
701        // Pending keystrokes are maintained per-view
702        assert!(matcher
703            .test_keystroke("a", vec![(1, ctx_b.clone()), (2, ctx_c.clone())])
704            .is_none());
705        assert_eq!(
706            downcast(&matcher.test_keystroke("b", vec![(1, ctx_b.clone())])),
707            Some(&Ab)
708        );
709
710        Ok(())
711    }
712
713    fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
714        action
715            .as_ref()
716            .and_then(|action| action.as_any().downcast_ref())
717    }
718
719    impl Matcher {
720        fn test_keystroke(
721            &mut self,
722            keystroke: &str,
723            dispatch_path: Vec<(usize, Context)>,
724        ) -> Option<Box<dyn Action>> {
725            if let MatchResult::Match { action, .. } =
726                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), dispatch_path)
727            {
728                Some(action.boxed_clone())
729            } else {
730                None
731            }
732        }
733    }
734}