keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::Debug,
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending: HashMap<usize, Pending>,
 17    keymap: Keymap,
 18}
 19
 20#[derive(Default)]
 21struct Pending {
 22    keystrokes: Vec<Keystroke>,
 23    context: Option<Context>,
 24}
 25
 26#[derive(Default)]
 27pub struct Keymap {
 28    bindings: Vec<Binding>,
 29    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 30}
 31
 32pub struct Binding {
 33    keystrokes: Vec<Keystroke>,
 34    action: Box<dyn Action>,
 35    context: Option<ContextPredicate>,
 36}
 37
 38#[derive(Clone, Debug, Eq, PartialEq)]
 39pub struct Keystroke {
 40    pub ctrl: bool,
 41    pub alt: bool,
 42    pub shift: bool,
 43    pub cmd: bool,
 44    pub key: String,
 45}
 46
 47#[derive(Clone, Debug, Default, Eq, PartialEq)]
 48pub struct Context {
 49    pub set: HashSet<String>,
 50    pub map: HashMap<String, String>,
 51}
 52
 53#[derive(Debug, Eq, PartialEq)]
 54enum ContextPredicate {
 55    Identifier(String),
 56    Equal(String, String),
 57    NotEqual(String, String),
 58    Not(Box<ContextPredicate>),
 59    And(Box<ContextPredicate>, Box<ContextPredicate>),
 60    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 61}
 62
 63trait ActionArg {
 64    fn boxed_clone(&self) -> Box<dyn Any>;
 65}
 66
 67impl<T> ActionArg for T
 68where
 69    T: 'static + Any + Clone,
 70{
 71    fn boxed_clone(&self) -> Box<dyn Any> {
 72        Box::new(self.clone())
 73    }
 74}
 75
 76pub enum MatchResult {
 77    None,
 78    Pending,
 79    Action(Box<dyn Action>),
 80}
 81
 82impl Debug for MatchResult {
 83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 84        match self {
 85            MatchResult::None => f.debug_struct("MatchResult::None").finish(),
 86            MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
 87            MatchResult::Action(action) => f
 88                .debug_tuple("MatchResult::Action")
 89                .field(&action.name())
 90                .finish(),
 91        }
 92    }
 93}
 94
 95impl Matcher {
 96    pub fn new(keymap: Keymap) -> Self {
 97        Self {
 98            pending: HashMap::new(),
 99            keymap,
100        }
101    }
102
103    pub fn set_keymap(&mut self, keymap: Keymap) {
104        self.pending.clear();
105        self.keymap = keymap;
106    }
107
108    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
109        self.pending.clear();
110        self.keymap.add_bindings(bindings);
111    }
112
113    pub fn clear_bindings(&mut self) {
114        self.pending.clear();
115        self.keymap.clear();
116    }
117
118    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
119        self.keymap.bindings_for_action_type(action_type)
120    }
121
122    pub fn clear_pending(&mut self) {
123        self.pending.clear();
124    }
125
126    pub fn has_pending_keystrokes(&self) -> bool {
127        !self.pending.is_empty()
128    }
129
130    pub fn push_keystroke(
131        &mut self,
132        keystroke: Keystroke,
133        view_id: usize,
134        cx: &Context,
135    ) -> MatchResult {
136        let pending = self.pending.entry(view_id).or_default();
137
138        if let Some(pending_ctx) = pending.context.as_ref() {
139            if pending_ctx != cx {
140                pending.keystrokes.clear();
141            }
142        }
143
144        pending.keystrokes.push(keystroke);
145
146        let mut retain_pending = false;
147        for binding in self.keymap.bindings.iter().rev() {
148            if binding.keystrokes.starts_with(&pending.keystrokes)
149                && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
150            {
151                if binding.keystrokes.len() == pending.keystrokes.len() {
152                    self.pending.remove(&view_id);
153                    return MatchResult::Action(binding.action.boxed_clone());
154                } else {
155                    retain_pending = true;
156                    pending.context = Some(cx.clone());
157                }
158            }
159        }
160
161        if retain_pending {
162            MatchResult::Pending
163        } else {
164            self.pending.remove(&view_id);
165            MatchResult::None
166        }
167    }
168}
169
170impl Default for Matcher {
171    fn default() -> Self {
172        Self::new(Keymap::default())
173    }
174}
175
176impl Keymap {
177    pub fn new(bindings: Vec<Binding>) -> Self {
178        let mut binding_indices_by_action_type = HashMap::new();
179        for (ix, binding) in bindings.iter().enumerate() {
180            binding_indices_by_action_type
181                .entry(binding.action.as_any().type_id())
182                .or_insert_with(|| SmallVec::new())
183                .push(ix);
184        }
185        Self {
186            binding_indices_by_action_type,
187            bindings,
188        }
189    }
190
191    fn bindings_for_action_type<'a>(
192        &'a self,
193        action_type: TypeId,
194    ) -> impl Iterator<Item = &'a Binding> {
195        self.binding_indices_by_action_type
196            .get(&action_type)
197            .map(SmallVec::as_slice)
198            .unwrap_or(&[])
199            .iter()
200            .map(|ix| &self.bindings[*ix])
201    }
202
203    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
204        for binding in bindings {
205            self.binding_indices_by_action_type
206                .entry(binding.action.as_any().type_id())
207                .or_default()
208                .push(self.bindings.len());
209            self.bindings.push(binding);
210        }
211    }
212
213    fn clear(&mut self) {
214        self.bindings.clear();
215        self.binding_indices_by_action_type.clear();
216    }
217}
218
219impl Binding {
220    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
221        Self::load(keystrokes, Box::new(action), context).unwrap()
222    }
223
224    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
225        let context = if let Some(context) = context {
226            Some(ContextPredicate::parse(context)?)
227        } else {
228            None
229        };
230
231        let keystrokes = keystrokes
232            .split_whitespace()
233            .map(|key| Keystroke::parse(key))
234            .collect::<Result<_>>()?;
235
236        Ok(Self {
237            keystrokes,
238            action,
239            context,
240        })
241    }
242
243    pub fn keystrokes(&self) -> &[Keystroke] {
244        &self.keystrokes
245    }
246}
247
248impl Keystroke {
249    pub fn parse(source: &str) -> anyhow::Result<Self> {
250        let mut ctrl = false;
251        let mut alt = false;
252        let mut shift = false;
253        let mut cmd = false;
254        let mut key = None;
255
256        let mut components = source.split("-").peekable();
257        while let Some(component) = components.next() {
258            match component {
259                "ctrl" => ctrl = true,
260                "alt" => alt = true,
261                "shift" => shift = true,
262                "cmd" => cmd = true,
263                _ => {
264                    if let Some(component) = components.peek() {
265                        if component.is_empty() && source.ends_with('-') {
266                            key = Some(String::from("-"));
267                            break;
268                        } else {
269                            return Err(anyhow!("Invalid keystroke `{}`", source));
270                        }
271                    } else {
272                        key = Some(String::from(component));
273                    }
274                }
275            }
276        }
277
278        Ok(Keystroke {
279            ctrl,
280            alt,
281            shift,
282            cmd,
283            key: key.unwrap(),
284        })
285    }
286
287    pub fn modified(&self) -> bool {
288        self.ctrl || self.alt || self.shift || self.cmd
289    }
290}
291
292impl Context {
293    pub fn extend(&mut self, other: &Context) {
294        for v in &other.set {
295            self.set.insert(v.clone());
296        }
297        for (k, v) in &other.map {
298            self.map.insert(k.clone(), v.clone());
299        }
300    }
301}
302
303impl ContextPredicate {
304    fn parse(source: &str) -> anyhow::Result<Self> {
305        let mut parser = Parser::new();
306        let language = unsafe { tree_sitter_context_predicate() };
307        parser.set_language(language).unwrap();
308        let source = source.as_bytes();
309        let tree = parser.parse(source, None).unwrap();
310        Self::from_node(tree.root_node(), source)
311    }
312
313    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
314        let parse_error = "error parsing context predicate";
315        let kind = node.kind();
316
317        match kind {
318            "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
319            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
320            "not" => {
321                let child = Self::from_node(
322                    node.child_by_field_name("expression")
323                        .ok_or(anyhow!(parse_error))?,
324                    source,
325                )?;
326                Ok(Self::Not(Box::new(child)))
327            }
328            "and" | "or" => {
329                let left = Box::new(Self::from_node(
330                    node.child_by_field_name("left")
331                        .ok_or(anyhow!(parse_error))?,
332                    source,
333                )?);
334                let right = Box::new(Self::from_node(
335                    node.child_by_field_name("right")
336                        .ok_or(anyhow!(parse_error))?,
337                    source,
338                )?);
339                if kind == "and" {
340                    Ok(Self::And(left, right))
341                } else {
342                    Ok(Self::Or(left, right))
343                }
344            }
345            "equal" | "not_equal" => {
346                let left = node
347                    .child_by_field_name("left")
348                    .ok_or(anyhow!(parse_error))?
349                    .utf8_text(source)?
350                    .into();
351                let right = node
352                    .child_by_field_name("right")
353                    .ok_or(anyhow!(parse_error))?
354                    .utf8_text(source)?
355                    .into();
356                if kind == "equal" {
357                    Ok(Self::Equal(left, right))
358                } else {
359                    Ok(Self::NotEqual(left, right))
360                }
361            }
362            "parenthesized" => Self::from_node(
363                node.child_by_field_name("expression")
364                    .ok_or(anyhow!(parse_error))?,
365                source,
366            ),
367            _ => Err(anyhow!(parse_error)),
368        }
369    }
370
371    fn eval(&self, cx: &Context) -> bool {
372        match self {
373            Self::Identifier(name) => cx.set.contains(name.as_str()),
374            Self::Equal(left, right) => cx
375                .map
376                .get(left)
377                .map(|value| value == right)
378                .unwrap_or(false),
379            Self::NotEqual(left, right) => {
380                cx.map.get(left).map(|value| value != right).unwrap_or(true)
381            }
382            Self::Not(pred) => !pred.eval(cx),
383            Self::And(left, right) => left.eval(cx) && right.eval(cx),
384            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
385        }
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use serde::Deserialize;
392
393    use crate::{actions, impl_actions};
394
395    use super::*;
396
397    #[test]
398    fn test_keystroke_parsing() -> anyhow::Result<()> {
399        assert_eq!(
400            Keystroke::parse("ctrl-p")?,
401            Keystroke {
402                key: "p".into(),
403                ctrl: true,
404                alt: false,
405                shift: false,
406                cmd: false,
407            }
408        );
409
410        assert_eq!(
411            Keystroke::parse("alt-shift-down")?,
412            Keystroke {
413                key: "down".into(),
414                ctrl: false,
415                alt: true,
416                shift: true,
417                cmd: false,
418            }
419        );
420
421        assert_eq!(
422            Keystroke::parse("shift-cmd--")?,
423            Keystroke {
424                key: "-".into(),
425                ctrl: false,
426                alt: false,
427                shift: true,
428                cmd: true,
429            }
430        );
431
432        Ok(())
433    }
434
435    #[test]
436    fn test_context_predicate_parsing() -> anyhow::Result<()> {
437        use ContextPredicate::*;
438
439        assert_eq!(
440            ContextPredicate::parse("a && (b == c || d != e)")?,
441            And(
442                Box::new(Identifier("a".into())),
443                Box::new(Or(
444                    Box::new(Equal("b".into(), "c".into())),
445                    Box::new(NotEqual("d".into(), "e".into())),
446                ))
447            )
448        );
449
450        assert_eq!(
451            ContextPredicate::parse("!a")?,
452            Not(Box::new(Identifier("a".into())),)
453        );
454
455        Ok(())
456    }
457
458    #[test]
459    fn test_context_predicate_eval() -> anyhow::Result<()> {
460        let predicate = ContextPredicate::parse("a && b || c == d")?;
461
462        let mut context = Context::default();
463        context.set.insert("a".into());
464        assert!(!predicate.eval(&context));
465
466        context.set.insert("b".into());
467        assert!(predicate.eval(&context));
468
469        context.set.remove("b");
470        context.map.insert("c".into(), "x".into());
471        assert!(!predicate.eval(&context));
472
473        context.map.insert("c".into(), "d".into());
474        assert!(predicate.eval(&context));
475
476        let predicate = ContextPredicate::parse("!a")?;
477        assert!(predicate.eval(&Context::default()));
478
479        Ok(())
480    }
481
482    #[test]
483    fn test_matcher() -> anyhow::Result<()> {
484        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
485        pub struct A(pub String);
486        impl_actions!(test, [A]);
487        actions!(test, [B, Ab]);
488
489        #[derive(Clone, Debug, Eq, PartialEq)]
490        struct ActionArg {
491            a: &'static str,
492        }
493
494        let keymap = Keymap::new(vec![
495            Binding::new("a", A("x".to_string()), Some("a")),
496            Binding::new("b", B, Some("a")),
497            Binding::new("a b", Ab, Some("a || b")),
498        ]);
499
500        let mut ctx_a = Context::default();
501        ctx_a.set.insert("a".into());
502
503        let mut ctx_b = Context::default();
504        ctx_b.set.insert("b".into());
505
506        let mut matcher = Matcher::new(keymap);
507
508        // Basic match
509        assert_eq!(
510            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
511            Some(&A("x".to_string()))
512        );
513
514        // Multi-keystroke match
515        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
516        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
517
518        // Failed matches don't interfere with matching subsequent keys
519        assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
520        assert_eq!(
521            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
522            Some(&A("x".to_string()))
523        );
524
525        // Pending keystrokes are cleared when the context changes
526        assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
527        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
528
529        let mut ctx_c = Context::default();
530        ctx_c.set.insert("c".into());
531
532        // Pending keystrokes are maintained per-view
533        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
534        assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
535        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
536
537        Ok(())
538    }
539
540    fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
541        action
542            .as_ref()
543            .and_then(|action| action.as_any().downcast_ref())
544    }
545
546    impl Matcher {
547        fn test_keystroke(
548            &mut self,
549            keystroke: &str,
550            view_id: usize,
551            cx: &Context,
552        ) -> Option<Box<dyn Action>> {
553            if let MatchResult::Action(action) =
554                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
555            {
556                Some(action.boxed_clone())
557            } else {
558                None
559            }
560        }
561    }
562}