keymap.rs

  1use anyhow::anyhow;
  2use std::{
  3    any::Any,
  4    collections::{HashMap, HashSet},
  5};
  6use tree_sitter::{Language, Node, Parser};
  7
  8extern "C" {
  9    fn tree_sitter_context_predicate() -> Language;
 10}
 11
 12pub struct Matcher {
 13    pending: HashMap<usize, Pending>,
 14    keymap: Keymap,
 15}
 16
 17#[derive(Default)]
 18struct Pending {
 19    keystrokes: Vec<Keystroke>,
 20    context: Option<Context>,
 21}
 22
 23pub struct Keymap(Vec<Binding>);
 24
 25pub struct Binding {
 26    keystrokes: Vec<Keystroke>,
 27    action: String,
 28    action_arg: Option<Box<dyn ActionArg>>,
 29    context: Option<ContextPredicate>,
 30}
 31
 32#[derive(Clone, Debug, Eq, PartialEq)]
 33pub struct Keystroke {
 34    pub ctrl: bool,
 35    pub alt: bool,
 36    pub shift: bool,
 37    pub cmd: bool,
 38    pub key: String,
 39}
 40
 41#[derive(Clone, Debug, Default, Eq, PartialEq)]
 42pub struct Context {
 43    pub set: HashSet<String>,
 44    pub map: HashMap<String, String>,
 45}
 46
 47#[derive(Debug, Eq, PartialEq)]
 48enum ContextPredicate {
 49    Identifier(String),
 50    Equal(String, String),
 51    NotEqual(String, String),
 52    Not(Box<ContextPredicate>),
 53    And(Box<ContextPredicate>, Box<ContextPredicate>),
 54    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 55}
 56
 57trait ActionArg {
 58    fn boxed_clone(&self) -> Box<dyn Any>;
 59}
 60
 61impl<T> ActionArg for T
 62where
 63    T: 'static + Any + Clone,
 64{
 65    fn boxed_clone(&self) -> Box<dyn Any> {
 66        Box::new(self.clone())
 67    }
 68}
 69
 70pub enum MatchResult {
 71    None,
 72    Pending,
 73    Action {
 74        name: String,
 75        arg: Option<Box<dyn Any>>,
 76    },
 77}
 78
 79impl Matcher {
 80    pub fn new(keymap: Keymap) -> Self {
 81        Self {
 82            pending: HashMap::new(),
 83            keymap,
 84        }
 85    }
 86
 87    pub fn set_keymap(&mut self, keymap: Keymap) {
 88        self.pending.clear();
 89        self.keymap = keymap;
 90    }
 91
 92    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
 93        self.pending.clear();
 94        self.keymap.add_bindings(bindings);
 95    }
 96
 97    pub fn push_keystroke(
 98        &mut self,
 99        keystroke: Keystroke,
100        view_id: usize,
101        ctx: &Context,
102    ) -> MatchResult {
103        let pending = self.pending.entry(view_id).or_default();
104
105        if let Some(pending_ctx) = pending.context.as_ref() {
106            if pending_ctx != ctx {
107                pending.keystrokes.clear();
108            }
109        }
110
111        pending.keystrokes.push(keystroke);
112
113        let mut retain_pending = false;
114        for binding in self.keymap.0.iter().rev() {
115            if binding.keystrokes.starts_with(&pending.keystrokes)
116                && binding
117                    .context
118                    .as_ref()
119                    .map(|c| c.eval(ctx))
120                    .unwrap_or(true)
121            {
122                if binding.keystrokes.len() == pending.keystrokes.len() {
123                    self.pending.remove(&view_id);
124                    return MatchResult::Action {
125                        name: binding.action.clone(),
126                        arg: binding.action_arg.as_ref().map(|arg| (*arg).boxed_clone()),
127                    };
128                } else {
129                    retain_pending = true;
130                    pending.context = Some(ctx.clone());
131                }
132            }
133        }
134
135        if retain_pending {
136            MatchResult::Pending
137        } else {
138            self.pending.remove(&view_id);
139            MatchResult::None
140        }
141    }
142}
143
144impl Default for Matcher {
145    fn default() -> Self {
146        Self::new(Keymap::default())
147    }
148}
149
150impl Keymap {
151    pub fn new(bindings: Vec<Binding>) -> Self {
152        Self(bindings)
153    }
154
155    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
156        self.0.extend(bindings.into_iter());
157    }
158}
159
160impl Default for Keymap {
161    fn default() -> Self {
162        Self(vec![
163            Binding::new("up", "menu:select_prev", Some("menu")),
164            Binding::new("ctrl-p", "menu:select_prev", Some("menu")),
165            Binding::new("down", "menu:select_next", Some("menu")),
166            Binding::new("ctrl-n", "menu:select_next", Some("menu")),
167        ])
168    }
169}
170
171impl Binding {
172    pub fn new<S: Into<String>>(keystrokes: &str, action: S, context: Option<&str>) -> Self {
173        let context = if let Some(context) = context {
174            Some(ContextPredicate::parse(context).unwrap())
175        } else {
176            None
177        };
178
179        Self {
180            keystrokes: keystrokes
181                .split_whitespace()
182                .map(|key| Keystroke::parse(key).unwrap())
183                .collect(),
184            action: action.into(),
185            action_arg: None,
186            context,
187        }
188    }
189
190    pub fn with_arg<T: 'static + Any + Clone>(mut self, arg: T) -> Self {
191        self.action_arg = Some(Box::new(arg));
192        self
193    }
194}
195
196impl Keystroke {
197    pub fn parse(source: &str) -> anyhow::Result<Self> {
198        let mut ctrl = false;
199        let mut alt = false;
200        let mut shift = false;
201        let mut cmd = false;
202        let mut key = None;
203
204        let mut components = source.split("-").peekable();
205        while let Some(component) = components.next() {
206            match component {
207                "ctrl" => ctrl = true,
208                "alt" => alt = true,
209                "shift" => shift = true,
210                "cmd" => cmd = true,
211                _ => {
212                    if let Some(component) = components.peek() {
213                        if component.is_empty() && source.ends_with('-') {
214                            key = Some(String::from("-"));
215                            break;
216                        } else {
217                            return Err(anyhow!("Invalid keystroke `{}`", source));
218                        }
219                    } else {
220                        key = Some(String::from(component));
221                    }
222                }
223            }
224        }
225
226        Ok(Keystroke {
227            ctrl,
228            alt,
229            shift,
230            cmd,
231            key: key.unwrap(),
232        })
233    }
234}
235
236impl Context {
237    pub fn extend(&mut self, other: Context) {
238        for v in other.set {
239            self.set.insert(v);
240        }
241        for (k, v) in other.map {
242            self.map.insert(k, v);
243        }
244    }
245}
246
247impl ContextPredicate {
248    fn parse(source: &str) -> anyhow::Result<Self> {
249        let mut parser = Parser::new();
250        let language = unsafe { tree_sitter_context_predicate() };
251        parser.set_language(language).unwrap();
252        let source = source.as_bytes();
253        let tree = parser.parse(source, None).unwrap();
254        Self::from_node(tree.root_node(), source)
255    }
256
257    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
258        let parse_error = "error parsing context predicate";
259        let kind = node.kind();
260
261        match kind {
262            "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
263            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
264            "not" => {
265                let child = Self::from_node(
266                    node.child_by_field_name("expression")
267                        .ok_or(anyhow!(parse_error))?,
268                    source,
269                )?;
270                Ok(Self::Not(Box::new(child)))
271            }
272            "and" | "or" => {
273                let left = Box::new(Self::from_node(
274                    node.child_by_field_name("left")
275                        .ok_or(anyhow!(parse_error))?,
276                    source,
277                )?);
278                let right = Box::new(Self::from_node(
279                    node.child_by_field_name("right")
280                        .ok_or(anyhow!(parse_error))?,
281                    source,
282                )?);
283                if kind == "and" {
284                    Ok(Self::And(left, right))
285                } else {
286                    Ok(Self::Or(left, right))
287                }
288            }
289            "equal" | "not_equal" => {
290                let left = node
291                    .child_by_field_name("left")
292                    .ok_or(anyhow!(parse_error))?
293                    .utf8_text(source)?
294                    .into();
295                let right = node
296                    .child_by_field_name("right")
297                    .ok_or(anyhow!(parse_error))?
298                    .utf8_text(source)?
299                    .into();
300                if kind == "equal" {
301                    Ok(Self::Equal(left, right))
302                } else {
303                    Ok(Self::NotEqual(left, right))
304                }
305            }
306            "parenthesized" => Self::from_node(
307                node.child_by_field_name("expression")
308                    .ok_or(anyhow!(parse_error))?,
309                source,
310            ),
311            _ => Err(anyhow!(parse_error)),
312        }
313    }
314
315    fn eval(&self, ctx: &Context) -> bool {
316        match self {
317            Self::Identifier(name) => ctx.set.contains(name.as_str()),
318            Self::Equal(left, right) => ctx
319                .map
320                .get(left)
321                .map(|value| value == right)
322                .unwrap_or(false),
323            Self::NotEqual(left, right) => ctx
324                .map
325                .get(left)
326                .map(|value| value != right)
327                .unwrap_or(true),
328            Self::Not(pred) => !pred.eval(ctx),
329            Self::And(left, right) => left.eval(ctx) && right.eval(ctx),
330            Self::Or(left, right) => left.eval(ctx) || right.eval(ctx),
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_keystroke_parsing() -> anyhow::Result<()> {
341        assert_eq!(
342            Keystroke::parse("ctrl-p")?,
343            Keystroke {
344                key: "p".into(),
345                ctrl: true,
346                alt: false,
347                shift: false,
348                cmd: false,
349            }
350        );
351
352        assert_eq!(
353            Keystroke::parse("alt-shift-down")?,
354            Keystroke {
355                key: "down".into(),
356                ctrl: false,
357                alt: true,
358                shift: true,
359                cmd: false,
360            }
361        );
362
363        assert_eq!(
364            Keystroke::parse("shift-cmd--")?,
365            Keystroke {
366                key: "-".into(),
367                ctrl: false,
368                alt: false,
369                shift: true,
370                cmd: true,
371            }
372        );
373
374        Ok(())
375    }
376
377    #[test]
378    fn test_context_predicate_parsing() -> anyhow::Result<()> {
379        use ContextPredicate::*;
380
381        assert_eq!(
382            ContextPredicate::parse("a && (b == c || d != e)")?,
383            And(
384                Box::new(Identifier("a".into())),
385                Box::new(Or(
386                    Box::new(Equal("b".into(), "c".into())),
387                    Box::new(NotEqual("d".into(), "e".into())),
388                ))
389            )
390        );
391
392        assert_eq!(
393            ContextPredicate::parse("!a")?,
394            Not(Box::new(Identifier("a".into())),)
395        );
396
397        Ok(())
398    }
399
400    #[test]
401    fn test_context_predicate_eval() -> anyhow::Result<()> {
402        let predicate = ContextPredicate::parse("a && b || c == d")?;
403
404        let mut context = Context::default();
405        context.set.insert("a".into());
406        assert!(!predicate.eval(&context));
407
408        context.set.insert("b".into());
409        assert!(predicate.eval(&context));
410
411        context.set.remove("b");
412        context.map.insert("c".into(), "x".into());
413        assert!(!predicate.eval(&context));
414
415        context.map.insert("c".into(), "d".into());
416        assert!(predicate.eval(&context));
417
418        let predicate = ContextPredicate::parse("!a")?;
419        assert!(predicate.eval(&Context::default()));
420
421        Ok(())
422    }
423
424    #[test]
425    fn test_matcher() -> anyhow::Result<()> {
426        #[derive(Clone, Debug, Eq, PartialEq)]
427        struct ActionArg {
428            a: &'static str,
429        }
430
431        let keymap = Keymap(vec![
432            Binding::new("a", "a", Some("a")).with_arg(ActionArg { a: "b" }),
433            Binding::new("b", "b", Some("a")),
434            Binding::new("a b", "a_b", Some("a || b")),
435        ]);
436
437        let mut ctx_a = Context::default();
438        ctx_a.set.insert("a".into());
439
440        let mut ctx_b = Context::default();
441        ctx_b.set.insert("b".into());
442
443        let mut matcher = Matcher::new(keymap);
444
445        // Basic match
446        assert_eq!(
447            matcher.test_keystroke("a", 1, &ctx_a),
448            Some(("a".to_string(), Some(ActionArg { a: "b" })))
449        );
450
451        // Multi-keystroke match
452        assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
453        assert_eq!(
454            matcher.test_keystroke::<()>("b", 1, &ctx_b),
455            Some(("a_b".to_string(), None))
456        );
457
458        // Failed matches don't interfere with matching subsequent keys
459        assert_eq!(matcher.test_keystroke::<()>("x", 1, &ctx_a), None);
460        assert_eq!(
461            matcher.test_keystroke("a", 1, &ctx_a),
462            Some(("a".to_string(), Some(ActionArg { a: "b" })))
463        );
464
465        // Pending keystrokes are cleared when the context changes
466        assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
467        assert_eq!(
468            matcher.test_keystroke::<()>("b", 1, &ctx_a),
469            Some(("b".to_string(), None))
470        );
471
472        let mut ctx_c = Context::default();
473        ctx_c.set.insert("c".into());
474
475        // Pending keystrokes are maintained per-view
476        assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
477        assert_eq!(matcher.test_keystroke::<()>("a", 2, &ctx_c), None);
478        assert_eq!(
479            matcher.test_keystroke::<()>("b", 1, &ctx_b),
480            Some(("a_b".to_string(), None))
481        );
482
483        Ok(())
484    }
485
486    impl Matcher {
487        fn test_keystroke<A: Any + Clone>(
488            &mut self,
489            keystroke: &str,
490            view_id: usize,
491            ctx: &Context,
492        ) -> Option<(String, Option<A>)> {
493            if let MatchResult::Action { name, arg } =
494                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, ctx)
495            {
496                Some((name, arg.and_then(|arg| arg.downcast_ref::<A>().cloned())))
497            } else {
498                None
499            }
500        }
501    }
502}