keymap.rs

  1use anyhow::anyhow;
  2use std::{
  3    any::Any,
  4    collections::{HashMap, HashSet},
  5    fmt::Debug,
  6};
  7use tree_sitter::{Language, Node, Parser};
  8
  9use crate::{Action, AnyAction};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending: HashMap<usize, Pending>,
 17    keymap: Keymap,
 18}
 19
 20#[derive(Default)]
 21struct Pending {
 22    keystrokes: Vec<Keystroke>,
 23    context: Option<Context>,
 24}
 25
 26#[derive(Default)]
 27pub struct Keymap(Vec<Binding>);
 28
 29pub struct Binding {
 30    keystrokes: Vec<Keystroke>,
 31    action: Box<dyn AnyAction>,
 32    context: Option<ContextPredicate>,
 33}
 34
 35#[derive(Clone, Debug, Eq, PartialEq)]
 36pub struct Keystroke {
 37    pub ctrl: bool,
 38    pub alt: bool,
 39    pub shift: bool,
 40    pub cmd: bool,
 41    pub key: String,
 42}
 43
 44#[derive(Clone, Debug, Default, Eq, PartialEq)]
 45pub struct Context {
 46    pub set: HashSet<String>,
 47    pub map: HashMap<String, String>,
 48}
 49
 50#[derive(Debug, Eq, PartialEq)]
 51enum ContextPredicate {
 52    Identifier(String),
 53    Equal(String, String),
 54    NotEqual(String, String),
 55    Not(Box<ContextPredicate>),
 56    And(Box<ContextPredicate>, Box<ContextPredicate>),
 57    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 58}
 59
 60trait ActionArg {
 61    fn boxed_clone(&self) -> Box<dyn Any>;
 62}
 63
 64impl<T> ActionArg for T
 65where
 66    T: 'static + Any + Clone,
 67{
 68    fn boxed_clone(&self) -> Box<dyn Any> {
 69        Box::new(self.clone())
 70    }
 71}
 72
 73pub enum MatchResult {
 74    None,
 75    Pending,
 76    Action(Box<dyn AnyAction>),
 77}
 78
 79impl Matcher {
 80    pub fn new(keymap: Keymap) -> Self {
 81        Self {
 82            pending: HashMap::new(),
 83            keymap,
 84        }
 85    }
 86
 87    pub fn set_keymap(&mut self, keymap: Keymap) {
 88        self.pending.clear();
 89        self.keymap = keymap;
 90    }
 91
 92    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
 93        self.pending.clear();
 94        self.keymap.add_bindings(bindings);
 95    }
 96
 97    pub fn clear_pending(&mut self) {
 98        self.pending.clear();
 99    }
100
101    pub fn push_keystroke(
102        &mut self,
103        keystroke: Keystroke,
104        view_id: usize,
105        cx: &Context,
106    ) -> MatchResult {
107        let pending = self.pending.entry(view_id).or_default();
108
109        if let Some(pending_ctx) = pending.context.as_ref() {
110            if pending_ctx != cx {
111                pending.keystrokes.clear();
112            }
113        }
114
115        pending.keystrokes.push(keystroke);
116
117        let mut retain_pending = false;
118        for binding in self.keymap.0.iter().rev() {
119            if binding.keystrokes.starts_with(&pending.keystrokes)
120                && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
121            {
122                if binding.keystrokes.len() == pending.keystrokes.len() {
123                    self.pending.remove(&view_id);
124                    return MatchResult::Action(binding.action.boxed_clone());
125                } else {
126                    retain_pending = true;
127                    pending.context = Some(cx.clone());
128                }
129            }
130        }
131
132        if retain_pending {
133            MatchResult::Pending
134        } else {
135            self.pending.remove(&view_id);
136            MatchResult::None
137        }
138    }
139}
140
141impl Default for Matcher {
142    fn default() -> Self {
143        Self::new(Keymap::default())
144    }
145}
146
147impl Keymap {
148    pub fn new(bindings: Vec<Binding>) -> Self {
149        Self(bindings)
150    }
151
152    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
153        self.0.extend(bindings.into_iter());
154    }
155}
156
157impl Binding {
158    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
159        let context = if let Some(context) = context {
160            Some(ContextPredicate::parse(context).unwrap())
161        } else {
162            None
163        };
164
165        Self {
166            keystrokes: keystrokes
167                .split_whitespace()
168                .map(|key| Keystroke::parse(key).unwrap())
169                .collect(),
170            action: Box::new(action),
171            context,
172        }
173    }
174}
175
176impl Keystroke {
177    pub fn parse(source: &str) -> anyhow::Result<Self> {
178        let mut ctrl = false;
179        let mut alt = false;
180        let mut shift = false;
181        let mut cmd = false;
182        let mut key = None;
183
184        let mut components = source.split("-").peekable();
185        while let Some(component) = components.next() {
186            match component {
187                "ctrl" => ctrl = true,
188                "alt" => alt = true,
189                "shift" => shift = true,
190                "cmd" => cmd = true,
191                _ => {
192                    if let Some(component) = components.peek() {
193                        if component.is_empty() && source.ends_with('-') {
194                            key = Some(String::from("-"));
195                            break;
196                        } else {
197                            return Err(anyhow!("Invalid keystroke `{}`", source));
198                        }
199                    } else {
200                        key = Some(String::from(component));
201                    }
202                }
203            }
204        }
205
206        Ok(Keystroke {
207            ctrl,
208            alt,
209            shift,
210            cmd,
211            key: key.unwrap(),
212        })
213    }
214}
215
216impl Context {
217    pub fn extend(&mut self, other: Context) {
218        for v in other.set {
219            self.set.insert(v);
220        }
221        for (k, v) in other.map {
222            self.map.insert(k, v);
223        }
224    }
225}
226
227impl ContextPredicate {
228    fn parse(source: &str) -> anyhow::Result<Self> {
229        let mut parser = Parser::new();
230        let language = unsafe { tree_sitter_context_predicate() };
231        parser.set_language(language).unwrap();
232        let source = source.as_bytes();
233        let tree = parser.parse(source, None).unwrap();
234        Self::from_node(tree.root_node(), source)
235    }
236
237    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
238        let parse_error = "error parsing context predicate";
239        let kind = node.kind();
240
241        match kind {
242            "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
243            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
244            "not" => {
245                let child = Self::from_node(
246                    node.child_by_field_name("expression")
247                        .ok_or(anyhow!(parse_error))?,
248                    source,
249                )?;
250                Ok(Self::Not(Box::new(child)))
251            }
252            "and" | "or" => {
253                let left = Box::new(Self::from_node(
254                    node.child_by_field_name("left")
255                        .ok_or(anyhow!(parse_error))?,
256                    source,
257                )?);
258                let right = Box::new(Self::from_node(
259                    node.child_by_field_name("right")
260                        .ok_or(anyhow!(parse_error))?,
261                    source,
262                )?);
263                if kind == "and" {
264                    Ok(Self::And(left, right))
265                } else {
266                    Ok(Self::Or(left, right))
267                }
268            }
269            "equal" | "not_equal" => {
270                let left = node
271                    .child_by_field_name("left")
272                    .ok_or(anyhow!(parse_error))?
273                    .utf8_text(source)?
274                    .into();
275                let right = node
276                    .child_by_field_name("right")
277                    .ok_or(anyhow!(parse_error))?
278                    .utf8_text(source)?
279                    .into();
280                if kind == "equal" {
281                    Ok(Self::Equal(left, right))
282                } else {
283                    Ok(Self::NotEqual(left, right))
284                }
285            }
286            "parenthesized" => Self::from_node(
287                node.child_by_field_name("expression")
288                    .ok_or(anyhow!(parse_error))?,
289                source,
290            ),
291            _ => Err(anyhow!(parse_error)),
292        }
293    }
294
295    fn eval(&self, cx: &Context) -> bool {
296        match self {
297            Self::Identifier(name) => cx.set.contains(name.as_str()),
298            Self::Equal(left, right) => cx
299                .map
300                .get(left)
301                .map(|value| value == right)
302                .unwrap_or(false),
303            Self::NotEqual(left, right) => {
304                cx.map.get(left).map(|value| value != right).unwrap_or(true)
305            }
306            Self::Not(pred) => !pred.eval(cx),
307            Self::And(left, right) => left.eval(cx) && right.eval(cx),
308            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use crate::action;
316
317    use super::*;
318
319    #[test]
320    fn test_keystroke_parsing() -> anyhow::Result<()> {
321        assert_eq!(
322            Keystroke::parse("ctrl-p")?,
323            Keystroke {
324                key: "p".into(),
325                ctrl: true,
326                alt: false,
327                shift: false,
328                cmd: false,
329            }
330        );
331
332        assert_eq!(
333            Keystroke::parse("alt-shift-down")?,
334            Keystroke {
335                key: "down".into(),
336                ctrl: false,
337                alt: true,
338                shift: true,
339                cmd: false,
340            }
341        );
342
343        assert_eq!(
344            Keystroke::parse("shift-cmd--")?,
345            Keystroke {
346                key: "-".into(),
347                ctrl: false,
348                alt: false,
349                shift: true,
350                cmd: true,
351            }
352        );
353
354        Ok(())
355    }
356
357    #[test]
358    fn test_context_predicate_parsing() -> anyhow::Result<()> {
359        use ContextPredicate::*;
360
361        assert_eq!(
362            ContextPredicate::parse("a && (b == c || d != e)")?,
363            And(
364                Box::new(Identifier("a".into())),
365                Box::new(Or(
366                    Box::new(Equal("b".into(), "c".into())),
367                    Box::new(NotEqual("d".into(), "e".into())),
368                ))
369            )
370        );
371
372        assert_eq!(
373            ContextPredicate::parse("!a")?,
374            Not(Box::new(Identifier("a".into())),)
375        );
376
377        Ok(())
378    }
379
380    #[test]
381    fn test_context_predicate_eval() -> anyhow::Result<()> {
382        let predicate = ContextPredicate::parse("a && b || c == d")?;
383
384        let mut context = Context::default();
385        context.set.insert("a".into());
386        assert!(!predicate.eval(&context));
387
388        context.set.insert("b".into());
389        assert!(predicate.eval(&context));
390
391        context.set.remove("b");
392        context.map.insert("c".into(), "x".into());
393        assert!(!predicate.eval(&context));
394
395        context.map.insert("c".into(), "d".into());
396        assert!(predicate.eval(&context));
397
398        let predicate = ContextPredicate::parse("!a")?;
399        assert!(predicate.eval(&Context::default()));
400
401        Ok(())
402    }
403
404    #[test]
405    fn test_matcher() -> anyhow::Result<()> {
406        action!(A, &'static str);
407        action!(B);
408        action!(Ab);
409
410        impl PartialEq for A {
411            fn eq(&self, other: &Self) -> bool {
412                self.0 == other.0
413            }
414        }
415        impl Eq for A {}
416        impl Debug for A {
417            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418                write!(f, "A({:?})", &self.0)
419            }
420        }
421
422        #[derive(Clone, Debug, Eq, PartialEq)]
423        struct ActionArg {
424            a: &'static str,
425        }
426
427        let keymap = Keymap(vec![
428            Binding::new("a", A("x"), Some("a")),
429            Binding::new("b", B, Some("a")),
430            Binding::new("a b", Ab, Some("a || b")),
431        ]);
432
433        let mut ctx_a = Context::default();
434        ctx_a.set.insert("a".into());
435
436        let mut ctx_b = Context::default();
437        ctx_b.set.insert("b".into());
438
439        let mut matcher = Matcher::new(keymap);
440
441        // Basic match
442        assert_eq!(matcher.test_keystroke("a", 1, &ctx_a), Some(A("x")));
443
444        // Multi-keystroke match
445        assert_eq!(matcher.test_keystroke::<A>("a", 1, &ctx_b), None);
446        assert_eq!(matcher.test_keystroke("b", 1, &ctx_b), Some(Ab));
447
448        // Failed matches don't interfere with matching subsequent keys
449        assert_eq!(matcher.test_keystroke::<A>("x", 1, &ctx_a), None);
450        assert_eq!(matcher.test_keystroke("a", 1, &ctx_a), Some(A("x")));
451
452        // Pending keystrokes are cleared when the context changes
453        assert_eq!(matcher.test_keystroke::<A>("a", 1, &ctx_b), None);
454        assert_eq!(matcher.test_keystroke("b", 1, &ctx_a), Some(B));
455
456        let mut ctx_c = Context::default();
457        ctx_c.set.insert("c".into());
458
459        // Pending keystrokes are maintained per-view
460        assert_eq!(matcher.test_keystroke::<A>("a", 1, &ctx_b), None);
461        assert_eq!(matcher.test_keystroke::<A>("a", 2, &ctx_c), None);
462        assert_eq!(matcher.test_keystroke("b", 1, &ctx_b), Some(Ab));
463
464        Ok(())
465    }
466
467    impl Matcher {
468        fn test_keystroke<A>(&mut self, keystroke: &str, view_id: usize, cx: &Context) -> Option<A>
469        where
470            A: Action + Debug + Eq,
471        {
472            if let MatchResult::Action(action) =
473                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
474            {
475                Some(*action.boxed_clone_as_any().downcast().unwrap())
476            } else {
477                None
478            }
479        }
480    }
481}