keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::Debug,
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending: HashMap<usize, Pending>,
 17    keymap: Keymap,
 18}
 19
 20#[derive(Default)]
 21struct Pending {
 22    keystrokes: Vec<Keystroke>,
 23    context: Option<Context>,
 24}
 25
 26#[derive(Default)]
 27pub struct Keymap {
 28    bindings: Vec<Binding>,
 29    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 30}
 31
 32pub struct Binding {
 33    keystrokes: SmallVec<[Keystroke; 2]>,
 34    action: Box<dyn Action>,
 35    context_predicate: Option<ContextPredicate>,
 36}
 37
 38#[derive(Clone, Debug, Eq, PartialEq)]
 39pub struct Keystroke {
 40    pub ctrl: bool,
 41    pub alt: bool,
 42    pub shift: bool,
 43    pub cmd: bool,
 44    pub key: String,
 45}
 46
 47#[derive(Clone, Debug, Default, Eq, PartialEq)]
 48pub struct Context {
 49    pub set: HashSet<String>,
 50    pub map: HashMap<String, String>,
 51}
 52
 53#[derive(Debug, Eq, PartialEq)]
 54enum ContextPredicate {
 55    Identifier(String),
 56    Equal(String, String),
 57    NotEqual(String, String),
 58    Not(Box<ContextPredicate>),
 59    And(Box<ContextPredicate>, Box<ContextPredicate>),
 60    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 61}
 62
 63trait ActionArg {
 64    fn boxed_clone(&self) -> Box<dyn Any>;
 65}
 66
 67impl<T> ActionArg for T
 68where
 69    T: 'static + Any + Clone,
 70{
 71    fn boxed_clone(&self) -> Box<dyn Any> {
 72        Box::new(self.clone())
 73    }
 74}
 75
 76pub enum MatchResult {
 77    None,
 78    Pending,
 79    Action(Box<dyn Action>),
 80}
 81
 82impl Debug for MatchResult {
 83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 84        match self {
 85            MatchResult::None => f.debug_struct("MatchResult::None").finish(),
 86            MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
 87            MatchResult::Action(action) => f
 88                .debug_tuple("MatchResult::Action")
 89                .field(&action.name())
 90                .finish(),
 91        }
 92    }
 93}
 94
 95impl Matcher {
 96    pub fn new(keymap: Keymap) -> Self {
 97        Self {
 98            pending: HashMap::new(),
 99            keymap,
100        }
101    }
102
103    pub fn set_keymap(&mut self, keymap: Keymap) {
104        self.pending.clear();
105        self.keymap = keymap;
106    }
107
108    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
109        self.pending.clear();
110        self.keymap.add_bindings(bindings);
111    }
112
113    pub fn clear_bindings(&mut self) {
114        self.pending.clear();
115        self.keymap.clear();
116    }
117
118    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
119        self.keymap.bindings_for_action_type(action_type)
120    }
121
122    pub fn clear_pending(&mut self) {
123        self.pending.clear();
124    }
125
126    pub fn has_pending_keystrokes(&self) -> bool {
127        !self.pending.is_empty()
128    }
129
130    pub fn push_keystroke(
131        &mut self,
132        keystroke: Keystroke,
133        view_id: usize,
134        cx: &Context,
135    ) -> MatchResult {
136        let pending = self.pending.entry(view_id).or_default();
137
138        if let Some(pending_ctx) = pending.context.as_ref() {
139            if pending_ctx != cx {
140                pending.keystrokes.clear();
141            }
142        }
143
144        pending.keystrokes.push(keystroke);
145
146        let mut retain_pending = false;
147        for binding in self.keymap.bindings.iter().rev() {
148            if binding.keystrokes.starts_with(&pending.keystrokes)
149                && binding
150                    .context_predicate
151                    .as_ref()
152                    .map(|c| c.eval(cx))
153                    .unwrap_or(true)
154            {
155                if binding.keystrokes.len() == pending.keystrokes.len() {
156                    self.pending.remove(&view_id);
157                    return MatchResult::Action(binding.action.boxed_clone());
158                } else {
159                    retain_pending = true;
160                    pending.context = Some(cx.clone());
161                }
162            }
163        }
164
165        if retain_pending {
166            MatchResult::Pending
167        } else {
168            self.pending.remove(&view_id);
169            MatchResult::None
170        }
171    }
172
173    pub fn keystrokes_for_action(
174        &self,
175        action: &dyn Action,
176        cx: &Context,
177    ) -> Option<SmallVec<[Keystroke; 2]>> {
178        for binding in self.keymap.bindings.iter().rev() {
179            if binding.action.id() == action.id()
180                && binding
181                    .context_predicate
182                    .as_ref()
183                    .map_or(true, |predicate| predicate.eval(cx))
184            {
185                return Some(binding.keystrokes.clone());
186            }
187        }
188        None
189    }
190}
191
192impl Default for Matcher {
193    fn default() -> Self {
194        Self::new(Keymap::default())
195    }
196}
197
198impl Keymap {
199    pub fn new(bindings: Vec<Binding>) -> Self {
200        let mut binding_indices_by_action_type = HashMap::new();
201        for (ix, binding) in bindings.iter().enumerate() {
202            binding_indices_by_action_type
203                .entry(binding.action.as_any().type_id())
204                .or_insert_with(|| SmallVec::new())
205                .push(ix);
206        }
207        Self {
208            binding_indices_by_action_type,
209            bindings,
210        }
211    }
212
213    fn bindings_for_action_type<'a>(
214        &'a self,
215        action_type: TypeId,
216    ) -> impl Iterator<Item = &'a Binding> {
217        self.binding_indices_by_action_type
218            .get(&action_type)
219            .map(SmallVec::as_slice)
220            .unwrap_or(&[])
221            .iter()
222            .map(|ix| &self.bindings[*ix])
223    }
224
225    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
226        for binding in bindings {
227            self.binding_indices_by_action_type
228                .entry(binding.action.as_any().type_id())
229                .or_default()
230                .push(self.bindings.len());
231            self.bindings.push(binding);
232        }
233    }
234
235    fn clear(&mut self) {
236        self.bindings.clear();
237        self.binding_indices_by_action_type.clear();
238    }
239}
240
241impl Binding {
242    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
243        Self::load(keystrokes, Box::new(action), context).unwrap()
244    }
245
246    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
247        let context = if let Some(context) = context {
248            Some(ContextPredicate::parse(context)?)
249        } else {
250            None
251        };
252
253        let keystrokes = keystrokes
254            .split_whitespace()
255            .map(|key| Keystroke::parse(key))
256            .collect::<Result<_>>()?;
257
258        Ok(Self {
259            keystrokes,
260            action,
261            context_predicate: context,
262        })
263    }
264
265    pub fn keystrokes(&self) -> &[Keystroke] {
266        &self.keystrokes
267    }
268}
269
270impl Keystroke {
271    pub fn parse(source: &str) -> anyhow::Result<Self> {
272        let mut ctrl = false;
273        let mut alt = false;
274        let mut shift = false;
275        let mut cmd = false;
276        let mut key = None;
277
278        let mut components = source.split("-").peekable();
279        while let Some(component) = components.next() {
280            match component {
281                "ctrl" => ctrl = true,
282                "alt" => alt = true,
283                "shift" => shift = true,
284                "cmd" => cmd = true,
285                _ => {
286                    if let Some(component) = components.peek() {
287                        if component.is_empty() && source.ends_with('-') {
288                            key = Some(String::from("-"));
289                            break;
290                        } else {
291                            return Err(anyhow!("Invalid keystroke `{}`", source));
292                        }
293                    } else {
294                        key = Some(String::from(component));
295                    }
296                }
297            }
298        }
299
300        Ok(Keystroke {
301            ctrl,
302            alt,
303            shift,
304            cmd,
305            key: key.unwrap(),
306        })
307    }
308
309    pub fn modified(&self) -> bool {
310        self.ctrl || self.alt || self.shift || self.cmd
311    }
312}
313
314impl std::fmt::Display for Keystroke {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        if self.ctrl {
317            write!(f, "{}", "^")?;
318        }
319        if self.alt {
320            write!(f, "{}", "")?;
321        }
322        if self.cmd {
323            write!(f, "{}", "")?;
324        }
325        if self.shift {
326            write!(f, "{}", "")?;
327        }
328        let key = match self.key.as_str() {
329            "backspace" => "",
330            "up" => "",
331            "down" => "",
332            "left" => "",
333            "right" => "",
334            "tab" => "",
335            "escape" => "",
336            key => key,
337        };
338        write!(f, "{}", key)
339    }
340}
341
342impl Context {
343    pub fn extend(&mut self, other: &Context) {
344        for v in &other.set {
345            self.set.insert(v.clone());
346        }
347        for (k, v) in &other.map {
348            self.map.insert(k.clone(), v.clone());
349        }
350    }
351}
352
353impl ContextPredicate {
354    fn parse(source: &str) -> anyhow::Result<Self> {
355        let mut parser = Parser::new();
356        let language = unsafe { tree_sitter_context_predicate() };
357        parser.set_language(language).unwrap();
358        let source = source.as_bytes();
359        let tree = parser.parse(source, None).unwrap();
360        Self::from_node(tree.root_node(), source)
361    }
362
363    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
364        let parse_error = "error parsing context predicate";
365        let kind = node.kind();
366
367        match kind {
368            "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
369            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
370            "not" => {
371                let child = Self::from_node(
372                    node.child_by_field_name("expression")
373                        .ok_or(anyhow!(parse_error))?,
374                    source,
375                )?;
376                Ok(Self::Not(Box::new(child)))
377            }
378            "and" | "or" => {
379                let left = Box::new(Self::from_node(
380                    node.child_by_field_name("left")
381                        .ok_or(anyhow!(parse_error))?,
382                    source,
383                )?);
384                let right = Box::new(Self::from_node(
385                    node.child_by_field_name("right")
386                        .ok_or(anyhow!(parse_error))?,
387                    source,
388                )?);
389                if kind == "and" {
390                    Ok(Self::And(left, right))
391                } else {
392                    Ok(Self::Or(left, right))
393                }
394            }
395            "equal" | "not_equal" => {
396                let left = node
397                    .child_by_field_name("left")
398                    .ok_or(anyhow!(parse_error))?
399                    .utf8_text(source)?
400                    .into();
401                let right = node
402                    .child_by_field_name("right")
403                    .ok_or(anyhow!(parse_error))?
404                    .utf8_text(source)?
405                    .into();
406                if kind == "equal" {
407                    Ok(Self::Equal(left, right))
408                } else {
409                    Ok(Self::NotEqual(left, right))
410                }
411            }
412            "parenthesized" => Self::from_node(
413                node.child_by_field_name("expression")
414                    .ok_or(anyhow!(parse_error))?,
415                source,
416            ),
417            _ => Err(anyhow!(parse_error)),
418        }
419    }
420
421    fn eval(&self, cx: &Context) -> bool {
422        match self {
423            Self::Identifier(name) => cx.set.contains(name.as_str()),
424            Self::Equal(left, right) => cx
425                .map
426                .get(left)
427                .map(|value| value == right)
428                .unwrap_or(false),
429            Self::NotEqual(left, right) => {
430                cx.map.get(left).map(|value| value != right).unwrap_or(true)
431            }
432            Self::Not(pred) => !pred.eval(cx),
433            Self::And(left, right) => left.eval(cx) && right.eval(cx),
434            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
435        }
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use serde::Deserialize;
442
443    use crate::{actions, impl_actions};
444
445    use super::*;
446
447    #[test]
448    fn test_keystroke_parsing() -> anyhow::Result<()> {
449        assert_eq!(
450            Keystroke::parse("ctrl-p")?,
451            Keystroke {
452                key: "p".into(),
453                ctrl: true,
454                alt: false,
455                shift: false,
456                cmd: false,
457            }
458        );
459
460        assert_eq!(
461            Keystroke::parse("alt-shift-down")?,
462            Keystroke {
463                key: "down".into(),
464                ctrl: false,
465                alt: true,
466                shift: true,
467                cmd: false,
468            }
469        );
470
471        assert_eq!(
472            Keystroke::parse("shift-cmd--")?,
473            Keystroke {
474                key: "-".into(),
475                ctrl: false,
476                alt: false,
477                shift: true,
478                cmd: true,
479            }
480        );
481
482        Ok(())
483    }
484
485    #[test]
486    fn test_context_predicate_parsing() -> anyhow::Result<()> {
487        use ContextPredicate::*;
488
489        assert_eq!(
490            ContextPredicate::parse("a && (b == c || d != e)")?,
491            And(
492                Box::new(Identifier("a".into())),
493                Box::new(Or(
494                    Box::new(Equal("b".into(), "c".into())),
495                    Box::new(NotEqual("d".into(), "e".into())),
496                ))
497            )
498        );
499
500        assert_eq!(
501            ContextPredicate::parse("!a")?,
502            Not(Box::new(Identifier("a".into())),)
503        );
504
505        Ok(())
506    }
507
508    #[test]
509    fn test_context_predicate_eval() -> anyhow::Result<()> {
510        let predicate = ContextPredicate::parse("a && b || c == d")?;
511
512        let mut context = Context::default();
513        context.set.insert("a".into());
514        assert!(!predicate.eval(&context));
515
516        context.set.insert("b".into());
517        assert!(predicate.eval(&context));
518
519        context.set.remove("b");
520        context.map.insert("c".into(), "x".into());
521        assert!(!predicate.eval(&context));
522
523        context.map.insert("c".into(), "d".into());
524        assert!(predicate.eval(&context));
525
526        let predicate = ContextPredicate::parse("!a")?;
527        assert!(predicate.eval(&Context::default()));
528
529        Ok(())
530    }
531
532    #[test]
533    fn test_matcher() -> anyhow::Result<()> {
534        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
535        pub struct A(pub String);
536        impl_actions!(test, [A]);
537        actions!(test, [B, Ab]);
538
539        #[derive(Clone, Debug, Eq, PartialEq)]
540        struct ActionArg {
541            a: &'static str,
542        }
543
544        let keymap = Keymap::new(vec![
545            Binding::new("a", A("x".to_string()), Some("a")),
546            Binding::new("b", B, Some("a")),
547            Binding::new("a b", Ab, Some("a || b")),
548        ]);
549
550        let mut ctx_a = Context::default();
551        ctx_a.set.insert("a".into());
552
553        let mut ctx_b = Context::default();
554        ctx_b.set.insert("b".into());
555
556        let mut matcher = Matcher::new(keymap);
557
558        // Basic match
559        assert_eq!(
560            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
561            Some(&A("x".to_string()))
562        );
563
564        // Multi-keystroke match
565        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
566        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
567
568        // Failed matches don't interfere with matching subsequent keys
569        assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
570        assert_eq!(
571            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
572            Some(&A("x".to_string()))
573        );
574
575        // Pending keystrokes are cleared when the context changes
576        assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
577        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
578
579        let mut ctx_c = Context::default();
580        ctx_c.set.insert("c".into());
581
582        // Pending keystrokes are maintained per-view
583        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
584        assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
585        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
586
587        Ok(())
588    }
589
590    fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
591        action
592            .as_ref()
593            .and_then(|action| action.as_any().downcast_ref())
594    }
595
596    impl Matcher {
597        fn test_keystroke(
598            &mut self,
599            keystroke: &str,
600            view_id: usize,
601            cx: &Context,
602        ) -> Option<Box<dyn Action>> {
603            if let MatchResult::Action(action) =
604                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
605            {
606                Some(action.boxed_clone())
607            } else {
608                None
609            }
610        }
611    }
612}