keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::{Debug, Write},
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending: HashMap<usize, Pending>,
 17    keymap: Keymap,
 18}
 19
 20#[derive(Default)]
 21struct Pending {
 22    keystrokes: Vec<Keystroke>,
 23    context: Option<Context>,
 24}
 25
 26#[derive(Default)]
 27pub struct Keymap {
 28    bindings: Vec<Binding>,
 29    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 30}
 31
 32pub struct Binding {
 33    keystrokes: SmallVec<[Keystroke; 2]>,
 34    action: Box<dyn Action>,
 35    context_predicate: Option<ContextPredicate>,
 36}
 37
 38#[derive(Clone, Debug, Eq, PartialEq)]
 39pub struct Keystroke {
 40    pub ctrl: bool,
 41    pub alt: bool,
 42    pub shift: bool,
 43    pub cmd: bool,
 44    pub key: String,
 45}
 46
 47#[derive(Clone, Debug, Default, Eq, PartialEq)]
 48pub struct Context {
 49    pub set: HashSet<String>,
 50    pub map: HashMap<String, String>,
 51}
 52
 53#[derive(Debug, Eq, PartialEq)]
 54enum ContextPredicate {
 55    Identifier(String),
 56    Equal(String, String),
 57    NotEqual(String, String),
 58    Not(Box<ContextPredicate>),
 59    And(Box<ContextPredicate>, Box<ContextPredicate>),
 60    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 61}
 62
 63trait ActionArg {
 64    fn boxed_clone(&self) -> Box<dyn Any>;
 65}
 66
 67impl<T> ActionArg for T
 68where
 69    T: 'static + Any + Clone,
 70{
 71    fn boxed_clone(&self) -> Box<dyn Any> {
 72        Box::new(self.clone())
 73    }
 74}
 75
 76pub enum MatchResult {
 77    None,
 78    Pending,
 79    Action(Box<dyn Action>),
 80}
 81
 82impl Debug for MatchResult {
 83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 84        match self {
 85            MatchResult::None => f.debug_struct("MatchResult::None").finish(),
 86            MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
 87            MatchResult::Action(action) => f
 88                .debug_tuple("MatchResult::Action")
 89                .field(&action.name())
 90                .finish(),
 91        }
 92    }
 93}
 94
 95impl Matcher {
 96    pub fn new(keymap: Keymap) -> Self {
 97        Self {
 98            pending: HashMap::new(),
 99            keymap,
100        }
101    }
102
103    pub fn set_keymap(&mut self, keymap: Keymap) {
104        self.pending.clear();
105        self.keymap = keymap;
106    }
107
108    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
109        self.pending.clear();
110        self.keymap.add_bindings(bindings);
111    }
112
113    pub fn clear_bindings(&mut self) {
114        self.pending.clear();
115        self.keymap.clear();
116    }
117
118    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
119        self.keymap.bindings_for_action_type(action_type)
120    }
121
122    pub fn clear_pending(&mut self) {
123        self.pending.clear();
124    }
125
126    pub fn has_pending_keystrokes(&self) -> bool {
127        !self.pending.is_empty()
128    }
129
130    pub fn push_keystroke(
131        &mut self,
132        keystroke: Keystroke,
133        view_id: usize,
134        cx: &Context,
135    ) -> MatchResult {
136        let pending = self.pending.entry(view_id).or_default();
137
138        if let Some(pending_ctx) = pending.context.as_ref() {
139            if pending_ctx != cx {
140                pending.keystrokes.clear();
141            }
142        }
143
144        pending.keystrokes.push(keystroke);
145
146        let mut retain_pending = false;
147        for binding in self.keymap.bindings.iter().rev() {
148            if binding.keystrokes.starts_with(&pending.keystrokes)
149                && binding
150                    .context_predicate
151                    .as_ref()
152                    .map(|c| c.eval(cx))
153                    .unwrap_or(true)
154            {
155                if binding.keystrokes.len() == pending.keystrokes.len() {
156                    self.pending.remove(&view_id);
157                    return MatchResult::Action(binding.action.boxed_clone());
158                } else {
159                    retain_pending = true;
160                    pending.context = Some(cx.clone());
161                }
162            }
163        }
164
165        if retain_pending {
166            MatchResult::Pending
167        } else {
168            self.pending.remove(&view_id);
169            MatchResult::None
170        }
171    }
172
173    pub fn keystrokes_for_action(
174        &self,
175        action: &dyn Action,
176        cx: &Context,
177    ) -> Option<SmallVec<[Keystroke; 2]>> {
178        for binding in self.keymap.bindings.iter().rev() {
179            if binding.action.eq(action)
180                && binding
181                    .context_predicate
182                    .as_ref()
183                    .map_or(true, |predicate| predicate.eval(cx))
184            {
185                return Some(binding.keystrokes.clone());
186            }
187        }
188        None
189    }
190}
191
192impl Default for Matcher {
193    fn default() -> Self {
194        Self::new(Keymap::default())
195    }
196}
197
198impl Keymap {
199    pub fn new(bindings: Vec<Binding>) -> Self {
200        let mut binding_indices_by_action_type = HashMap::new();
201        for (ix, binding) in bindings.iter().enumerate() {
202            binding_indices_by_action_type
203                .entry(binding.action.as_any().type_id())
204                .or_insert_with(|| SmallVec::new())
205                .push(ix);
206        }
207        Self {
208            binding_indices_by_action_type,
209            bindings,
210        }
211    }
212
213    fn bindings_for_action_type<'a>(
214        &'a self,
215        action_type: TypeId,
216    ) -> impl Iterator<Item = &'a Binding> {
217        self.binding_indices_by_action_type
218            .get(&action_type)
219            .map(SmallVec::as_slice)
220            .unwrap_or(&[])
221            .iter()
222            .map(|ix| &self.bindings[*ix])
223    }
224
225    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
226        for binding in bindings {
227            self.binding_indices_by_action_type
228                .entry(binding.action.as_any().type_id())
229                .or_default()
230                .push(self.bindings.len());
231            self.bindings.push(binding);
232        }
233    }
234
235    fn clear(&mut self) {
236        self.bindings.clear();
237        self.binding_indices_by_action_type.clear();
238    }
239}
240
241impl Binding {
242    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
243        Self::load(keystrokes, Box::new(action), context).unwrap()
244    }
245
246    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
247        let context = if let Some(context) = context {
248            Some(ContextPredicate::parse(context)?)
249        } else {
250            None
251        };
252
253        let keystrokes = keystrokes
254            .split_whitespace()
255            .map(|key| Keystroke::parse(key))
256            .collect::<Result<_>>()?;
257
258        Ok(Self {
259            keystrokes,
260            action,
261            context_predicate: context,
262        })
263    }
264
265    pub fn keystrokes(&self) -> &[Keystroke] {
266        &self.keystrokes
267    }
268
269    pub fn action(&self) -> &dyn Action {
270        self.action.as_ref()
271    }
272}
273
274impl Keystroke {
275    pub fn parse(source: &str) -> anyhow::Result<Self> {
276        let mut ctrl = false;
277        let mut alt = false;
278        let mut shift = false;
279        let mut cmd = false;
280        let mut key = None;
281
282        let mut components = source.split("-").peekable();
283        while let Some(component) = components.next() {
284            match component {
285                "ctrl" => ctrl = true,
286                "alt" => alt = true,
287                "shift" => shift = true,
288                "cmd" => cmd = true,
289                _ => {
290                    if let Some(component) = components.peek() {
291                        if component.is_empty() && source.ends_with('-') {
292                            key = Some(String::from("-"));
293                            break;
294                        } else {
295                            return Err(anyhow!("Invalid keystroke `{}`", source));
296                        }
297                    } else {
298                        key = Some(String::from(component));
299                    }
300                }
301            }
302        }
303
304        Ok(Keystroke {
305            ctrl,
306            alt,
307            shift,
308            cmd,
309            key: key.unwrap(),
310        })
311    }
312
313    pub fn modified(&self) -> bool {
314        self.ctrl || self.alt || self.shift || self.cmd
315    }
316}
317
318impl std::fmt::Display for Keystroke {
319    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        if self.ctrl {
321            f.write_char('^')?;
322        }
323        if self.alt {
324            f.write_char('⎇')?;
325        }
326        if self.cmd {
327            f.write_char('⌘')?;
328        }
329        if self.shift {
330            f.write_char('⇧')?;
331        }
332        let key = match self.key.as_str() {
333            "backspace" => '⌫',
334            "up" => '↑',
335            "down" => '↓',
336            "left" => '←',
337            "right" => '→',
338            "tab" => '⇥',
339            "escape" => '⎋',
340            key => {
341                if key.len() == 1 {
342                    key.chars().next().unwrap().to_ascii_uppercase()
343                } else {
344                    return f.write_str(key);
345                }
346            }
347        };
348        f.write_char(key)
349    }
350}
351
352impl Context {
353    pub fn extend(&mut self, other: &Context) {
354        for v in &other.set {
355            self.set.insert(v.clone());
356        }
357        for (k, v) in &other.map {
358            self.map.insert(k.clone(), v.clone());
359        }
360    }
361}
362
363impl ContextPredicate {
364    fn parse(source: &str) -> anyhow::Result<Self> {
365        let mut parser = Parser::new();
366        let language = unsafe { tree_sitter_context_predicate() };
367        parser.set_language(language).unwrap();
368        let source = source.as_bytes();
369        let tree = parser.parse(source, None).unwrap();
370        Self::from_node(tree.root_node(), source)
371    }
372
373    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
374        let parse_error = "error parsing context predicate";
375        let kind = node.kind();
376
377        match kind {
378            "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
379            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
380            "not" => {
381                let child = Self::from_node(
382                    node.child_by_field_name("expression")
383                        .ok_or(anyhow!(parse_error))?,
384                    source,
385                )?;
386                Ok(Self::Not(Box::new(child)))
387            }
388            "and" | "or" => {
389                let left = Box::new(Self::from_node(
390                    node.child_by_field_name("left")
391                        .ok_or(anyhow!(parse_error))?,
392                    source,
393                )?);
394                let right = Box::new(Self::from_node(
395                    node.child_by_field_name("right")
396                        .ok_or(anyhow!(parse_error))?,
397                    source,
398                )?);
399                if kind == "and" {
400                    Ok(Self::And(left, right))
401                } else {
402                    Ok(Self::Or(left, right))
403                }
404            }
405            "equal" | "not_equal" => {
406                let left = node
407                    .child_by_field_name("left")
408                    .ok_or(anyhow!(parse_error))?
409                    .utf8_text(source)?
410                    .into();
411                let right = node
412                    .child_by_field_name("right")
413                    .ok_or(anyhow!(parse_error))?
414                    .utf8_text(source)?
415                    .into();
416                if kind == "equal" {
417                    Ok(Self::Equal(left, right))
418                } else {
419                    Ok(Self::NotEqual(left, right))
420                }
421            }
422            "parenthesized" => Self::from_node(
423                node.child_by_field_name("expression")
424                    .ok_or(anyhow!(parse_error))?,
425                source,
426            ),
427            _ => Err(anyhow!(parse_error)),
428        }
429    }
430
431    fn eval(&self, cx: &Context) -> bool {
432        match self {
433            Self::Identifier(name) => cx.set.contains(name.as_str()),
434            Self::Equal(left, right) => cx
435                .map
436                .get(left)
437                .map(|value| value == right)
438                .unwrap_or(false),
439            Self::NotEqual(left, right) => {
440                cx.map.get(left).map(|value| value != right).unwrap_or(true)
441            }
442            Self::Not(pred) => !pred.eval(cx),
443            Self::And(left, right) => left.eval(cx) && right.eval(cx),
444            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
445        }
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use serde::Deserialize;
452
453    use crate::{actions, impl_actions};
454
455    use super::*;
456
457    #[test]
458    fn test_keystroke_parsing() -> anyhow::Result<()> {
459        assert_eq!(
460            Keystroke::parse("ctrl-p")?,
461            Keystroke {
462                key: "p".into(),
463                ctrl: true,
464                alt: false,
465                shift: false,
466                cmd: false,
467            }
468        );
469
470        assert_eq!(
471            Keystroke::parse("alt-shift-down")?,
472            Keystroke {
473                key: "down".into(),
474                ctrl: false,
475                alt: true,
476                shift: true,
477                cmd: false,
478            }
479        );
480
481        assert_eq!(
482            Keystroke::parse("shift-cmd--")?,
483            Keystroke {
484                key: "-".into(),
485                ctrl: false,
486                alt: false,
487                shift: true,
488                cmd: true,
489            }
490        );
491
492        Ok(())
493    }
494
495    #[test]
496    fn test_context_predicate_parsing() -> anyhow::Result<()> {
497        use ContextPredicate::*;
498
499        assert_eq!(
500            ContextPredicate::parse("a && (b == c || d != e)")?,
501            And(
502                Box::new(Identifier("a".into())),
503                Box::new(Or(
504                    Box::new(Equal("b".into(), "c".into())),
505                    Box::new(NotEqual("d".into(), "e".into())),
506                ))
507            )
508        );
509
510        assert_eq!(
511            ContextPredicate::parse("!a")?,
512            Not(Box::new(Identifier("a".into())),)
513        );
514
515        Ok(())
516    }
517
518    #[test]
519    fn test_context_predicate_eval() -> anyhow::Result<()> {
520        let predicate = ContextPredicate::parse("a && b || c == d")?;
521
522        let mut context = Context::default();
523        context.set.insert("a".into());
524        assert!(!predicate.eval(&context));
525
526        context.set.insert("b".into());
527        assert!(predicate.eval(&context));
528
529        context.set.remove("b");
530        context.map.insert("c".into(), "x".into());
531        assert!(!predicate.eval(&context));
532
533        context.map.insert("c".into(), "d".into());
534        assert!(predicate.eval(&context));
535
536        let predicate = ContextPredicate::parse("!a")?;
537        assert!(predicate.eval(&Context::default()));
538
539        Ok(())
540    }
541
542    #[test]
543    fn test_matcher() -> anyhow::Result<()> {
544        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
545        pub struct A(pub String);
546        impl_actions!(test, [A]);
547        actions!(test, [B, Ab]);
548
549        #[derive(Clone, Debug, Eq, PartialEq)]
550        struct ActionArg {
551            a: &'static str,
552        }
553
554        let keymap = Keymap::new(vec![
555            Binding::new("a", A("x".to_string()), Some("a")),
556            Binding::new("b", B, Some("a")),
557            Binding::new("a b", Ab, Some("a || b")),
558        ]);
559
560        let mut ctx_a = Context::default();
561        ctx_a.set.insert("a".into());
562
563        let mut ctx_b = Context::default();
564        ctx_b.set.insert("b".into());
565
566        let mut matcher = Matcher::new(keymap);
567
568        // Basic match
569        assert_eq!(
570            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
571            Some(&A("x".to_string()))
572        );
573
574        // Multi-keystroke match
575        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
576        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
577
578        // Failed matches don't interfere with matching subsequent keys
579        assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
580        assert_eq!(
581            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
582            Some(&A("x".to_string()))
583        );
584
585        // Pending keystrokes are cleared when the context changes
586        assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
587        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
588
589        let mut ctx_c = Context::default();
590        ctx_c.set.insert("c".into());
591
592        // Pending keystrokes are maintained per-view
593        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
594        assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
595        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
596
597        Ok(())
598    }
599
600    fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
601        action
602            .as_ref()
603            .and_then(|action| action.as_any().downcast_ref())
604    }
605
606    impl Matcher {
607        fn test_keystroke(
608            &mut self,
609            keystroke: &str,
610            view_id: usize,
611            cx: &Context,
612        ) -> Option<Box<dyn Action>> {
613            if let MatchResult::Action(action) =
614                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
615            {
616                Some(action.boxed_clone())
617            } else {
618                None
619            }
620        }
621    }
622}