keymap.rs

  1use anyhow::anyhow;
  2use std::{
  3    any::Any,
  4    collections::{HashMap, HashSet},
  5};
  6use tree_sitter::{Language, Node, Parser};
  7
  8extern "C" {
  9    fn tree_sitter_context_predicate() -> Language;
 10}
 11
 12pub struct Matcher {
 13    pending: HashMap<usize, Pending>,
 14    keymap: Keymap,
 15}
 16
 17#[derive(Default)]
 18struct Pending {
 19    keystrokes: Vec<Keystroke>,
 20    context: Option<Context>,
 21}
 22
 23pub struct Keymap(Vec<Binding>);
 24
 25pub struct Binding {
 26    keystrokes: Vec<Keystroke>,
 27    action: String,
 28    action_arg: Option<Box<dyn ActionArg>>,
 29    context: Option<ContextPredicate>,
 30}
 31
 32#[derive(Clone, Debug, Eq, PartialEq)]
 33pub struct Keystroke {
 34    pub ctrl: bool,
 35    pub alt: bool,
 36    pub shift: bool,
 37    pub cmd: bool,
 38    pub key: String,
 39}
 40
 41#[derive(Clone, Debug, Default, Eq, PartialEq)]
 42pub struct Context {
 43    pub set: HashSet<String>,
 44    pub map: HashMap<String, String>,
 45}
 46
 47#[derive(Debug, Eq, PartialEq)]
 48enum ContextPredicate {
 49    Identifier(String),
 50    Equal(String, String),
 51    NotEqual(String, String),
 52    Not(Box<ContextPredicate>),
 53    And(Box<ContextPredicate>, Box<ContextPredicate>),
 54    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 55}
 56
 57trait ActionArg {
 58    fn boxed_clone(&self) -> Box<dyn Any>;
 59}
 60
 61impl<T> ActionArg for T
 62where
 63    T: 'static + Any + Clone,
 64{
 65    fn boxed_clone(&self) -> Box<dyn Any> {
 66        Box::new(self.clone())
 67    }
 68}
 69
 70pub enum MatchResult {
 71    None,
 72    Pending,
 73    Action {
 74        name: String,
 75        arg: Option<Box<dyn Any>>,
 76    },
 77}
 78
 79impl Matcher {
 80    pub fn new(keymap: Keymap) -> Self {
 81        Self {
 82            pending: HashMap::new(),
 83            keymap,
 84        }
 85    }
 86
 87    pub fn set_keymap(&mut self, keymap: Keymap) {
 88        self.pending.clear();
 89        self.keymap = keymap;
 90    }
 91
 92    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
 93        self.pending.clear();
 94        self.keymap.add_bindings(bindings);
 95    }
 96
 97    pub fn push_keystroke(
 98        &mut self,
 99        keystroke: Keystroke,
100        view_id: usize,
101        cx: &Context,
102    ) -> MatchResult {
103        let pending = self.pending.entry(view_id).or_default();
104
105        if let Some(pending_ctx) = pending.context.as_ref() {
106            if pending_ctx != cx {
107                pending.keystrokes.clear();
108            }
109        }
110
111        pending.keystrokes.push(keystroke);
112
113        let mut retain_pending = false;
114        for binding in self.keymap.0.iter().rev() {
115            if binding.keystrokes.starts_with(&pending.keystrokes)
116                && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
117            {
118                if binding.keystrokes.len() == pending.keystrokes.len() {
119                    self.pending.remove(&view_id);
120                    return MatchResult::Action {
121                        name: binding.action.clone(),
122                        arg: binding.action_arg.as_ref().map(|arg| (*arg).boxed_clone()),
123                    };
124                } else {
125                    retain_pending = true;
126                    pending.context = Some(cx.clone());
127                }
128            }
129        }
130
131        if retain_pending {
132            MatchResult::Pending
133        } else {
134            self.pending.remove(&view_id);
135            MatchResult::None
136        }
137    }
138}
139
140impl Default for Matcher {
141    fn default() -> Self {
142        Self::new(Keymap::default())
143    }
144}
145
146impl Keymap {
147    pub fn new(bindings: Vec<Binding>) -> Self {
148        Self(bindings)
149    }
150
151    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
152        self.0.extend(bindings.into_iter());
153    }
154}
155
156impl Default for Keymap {
157    fn default() -> Self {
158        Self(vec![
159            Binding::new("up", "menu:select_prev", Some("menu")),
160            Binding::new("ctrl-p", "menu:select_prev", Some("menu")),
161            Binding::new("down", "menu:select_next", Some("menu")),
162            Binding::new("ctrl-n", "menu:select_next", Some("menu")),
163        ])
164    }
165}
166
167impl Binding {
168    pub fn new<S: Into<String>>(keystrokes: &str, action: S, context: Option<&str>) -> Self {
169        let context = if let Some(context) = context {
170            Some(ContextPredicate::parse(context).unwrap())
171        } else {
172            None
173        };
174
175        Self {
176            keystrokes: keystrokes
177                .split_whitespace()
178                .map(|key| Keystroke::parse(key).unwrap())
179                .collect(),
180            action: action.into(),
181            action_arg: None,
182            context,
183        }
184    }
185
186    pub fn with_arg<T: 'static + Any + Clone>(mut self, arg: T) -> Self {
187        self.action_arg = Some(Box::new(arg));
188        self
189    }
190}
191
192impl Keystroke {
193    pub fn parse(source: &str) -> anyhow::Result<Self> {
194        let mut ctrl = false;
195        let mut alt = false;
196        let mut shift = false;
197        let mut cmd = false;
198        let mut key = None;
199
200        let mut components = source.split("-").peekable();
201        while let Some(component) = components.next() {
202            match component {
203                "ctrl" => ctrl = true,
204                "alt" => alt = true,
205                "shift" => shift = true,
206                "cmd" => cmd = true,
207                _ => {
208                    if let Some(component) = components.peek() {
209                        if component.is_empty() && source.ends_with('-') {
210                            key = Some(String::from("-"));
211                            break;
212                        } else {
213                            return Err(anyhow!("Invalid keystroke `{}`", source));
214                        }
215                    } else {
216                        key = Some(String::from(component));
217                    }
218                }
219            }
220        }
221
222        Ok(Keystroke {
223            ctrl,
224            alt,
225            shift,
226            cmd,
227            key: key.unwrap(),
228        })
229    }
230}
231
232impl Context {
233    pub fn extend(&mut self, other: Context) {
234        for v in other.set {
235            self.set.insert(v);
236        }
237        for (k, v) in other.map {
238            self.map.insert(k, v);
239        }
240    }
241}
242
243impl ContextPredicate {
244    fn parse(source: &str) -> anyhow::Result<Self> {
245        let mut parser = Parser::new();
246        let language = unsafe { tree_sitter_context_predicate() };
247        parser.set_language(language).unwrap();
248        let source = source.as_bytes();
249        let tree = parser.parse(source, None).unwrap();
250        Self::from_node(tree.root_node(), source)
251    }
252
253    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
254        let parse_error = "error parsing context predicate";
255        let kind = node.kind();
256
257        match kind {
258            "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
259            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
260            "not" => {
261                let child = Self::from_node(
262                    node.child_by_field_name("expression")
263                        .ok_or(anyhow!(parse_error))?,
264                    source,
265                )?;
266                Ok(Self::Not(Box::new(child)))
267            }
268            "and" | "or" => {
269                let left = Box::new(Self::from_node(
270                    node.child_by_field_name("left")
271                        .ok_or(anyhow!(parse_error))?,
272                    source,
273                )?);
274                let right = Box::new(Self::from_node(
275                    node.child_by_field_name("right")
276                        .ok_or(anyhow!(parse_error))?,
277                    source,
278                )?);
279                if kind == "and" {
280                    Ok(Self::And(left, right))
281                } else {
282                    Ok(Self::Or(left, right))
283                }
284            }
285            "equal" | "not_equal" => {
286                let left = node
287                    .child_by_field_name("left")
288                    .ok_or(anyhow!(parse_error))?
289                    .utf8_text(source)?
290                    .into();
291                let right = node
292                    .child_by_field_name("right")
293                    .ok_or(anyhow!(parse_error))?
294                    .utf8_text(source)?
295                    .into();
296                if kind == "equal" {
297                    Ok(Self::Equal(left, right))
298                } else {
299                    Ok(Self::NotEqual(left, right))
300                }
301            }
302            "parenthesized" => Self::from_node(
303                node.child_by_field_name("expression")
304                    .ok_or(anyhow!(parse_error))?,
305                source,
306            ),
307            _ => Err(anyhow!(parse_error)),
308        }
309    }
310
311    fn eval(&self, cx: &Context) -> bool {
312        match self {
313            Self::Identifier(name) => cx.set.contains(name.as_str()),
314            Self::Equal(left, right) => cx
315                .map
316                .get(left)
317                .map(|value| value == right)
318                .unwrap_or(false),
319            Self::NotEqual(left, right) => {
320                cx.map.get(left).map(|value| value != right).unwrap_or(true)
321            }
322            Self::Not(pred) => !pred.eval(cx),
323            Self::And(left, right) => left.eval(cx) && right.eval(cx),
324            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_keystroke_parsing() -> anyhow::Result<()> {
335        assert_eq!(
336            Keystroke::parse("ctrl-p")?,
337            Keystroke {
338                key: "p".into(),
339                ctrl: true,
340                alt: false,
341                shift: false,
342                cmd: false,
343            }
344        );
345
346        assert_eq!(
347            Keystroke::parse("alt-shift-down")?,
348            Keystroke {
349                key: "down".into(),
350                ctrl: false,
351                alt: true,
352                shift: true,
353                cmd: false,
354            }
355        );
356
357        assert_eq!(
358            Keystroke::parse("shift-cmd--")?,
359            Keystroke {
360                key: "-".into(),
361                ctrl: false,
362                alt: false,
363                shift: true,
364                cmd: true,
365            }
366        );
367
368        Ok(())
369    }
370
371    #[test]
372    fn test_context_predicate_parsing() -> anyhow::Result<()> {
373        use ContextPredicate::*;
374
375        assert_eq!(
376            ContextPredicate::parse("a && (b == c || d != e)")?,
377            And(
378                Box::new(Identifier("a".into())),
379                Box::new(Or(
380                    Box::new(Equal("b".into(), "c".into())),
381                    Box::new(NotEqual("d".into(), "e".into())),
382                ))
383            )
384        );
385
386        assert_eq!(
387            ContextPredicate::parse("!a")?,
388            Not(Box::new(Identifier("a".into())),)
389        );
390
391        Ok(())
392    }
393
394    #[test]
395    fn test_context_predicate_eval() -> anyhow::Result<()> {
396        let predicate = ContextPredicate::parse("a && b || c == d")?;
397
398        let mut context = Context::default();
399        context.set.insert("a".into());
400        assert!(!predicate.eval(&context));
401
402        context.set.insert("b".into());
403        assert!(predicate.eval(&context));
404
405        context.set.remove("b");
406        context.map.insert("c".into(), "x".into());
407        assert!(!predicate.eval(&context));
408
409        context.map.insert("c".into(), "d".into());
410        assert!(predicate.eval(&context));
411
412        let predicate = ContextPredicate::parse("!a")?;
413        assert!(predicate.eval(&Context::default()));
414
415        Ok(())
416    }
417
418    #[test]
419    fn test_matcher() -> anyhow::Result<()> {
420        #[derive(Clone, Debug, Eq, PartialEq)]
421        struct ActionArg {
422            a: &'static str,
423        }
424
425        let keymap = Keymap(vec![
426            Binding::new("a", "a", Some("a")).with_arg(ActionArg { a: "b" }),
427            Binding::new("b", "b", Some("a")),
428            Binding::new("a b", "a_b", Some("a || b")),
429        ]);
430
431        let mut ctx_a = Context::default();
432        ctx_a.set.insert("a".into());
433
434        let mut ctx_b = Context::default();
435        ctx_b.set.insert("b".into());
436
437        let mut matcher = Matcher::new(keymap);
438
439        // Basic match
440        assert_eq!(
441            matcher.test_keystroke("a", 1, &ctx_a),
442            Some(("a".to_string(), Some(ActionArg { a: "b" })))
443        );
444
445        // Multi-keystroke match
446        assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
447        assert_eq!(
448            matcher.test_keystroke::<()>("b", 1, &ctx_b),
449            Some(("a_b".to_string(), None))
450        );
451
452        // Failed matches don't interfere with matching subsequent keys
453        assert_eq!(matcher.test_keystroke::<()>("x", 1, &ctx_a), None);
454        assert_eq!(
455            matcher.test_keystroke("a", 1, &ctx_a),
456            Some(("a".to_string(), Some(ActionArg { a: "b" })))
457        );
458
459        // Pending keystrokes are cleared when the context changes
460        assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
461        assert_eq!(
462            matcher.test_keystroke::<()>("b", 1, &ctx_a),
463            Some(("b".to_string(), None))
464        );
465
466        let mut ctx_c = Context::default();
467        ctx_c.set.insert("c".into());
468
469        // Pending keystrokes are maintained per-view
470        assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
471        assert_eq!(matcher.test_keystroke::<()>("a", 2, &ctx_c), None);
472        assert_eq!(
473            matcher.test_keystroke::<()>("b", 1, &ctx_b),
474            Some(("a_b".to_string(), None))
475        );
476
477        Ok(())
478    }
479
480    impl Matcher {
481        fn test_keystroke<A: Any + Clone>(
482            &mut self,
483            keystroke: &str,
484            view_id: usize,
485            cx: &Context,
486        ) -> Option<(String, Option<A>)> {
487            if let MatchResult::Action { name, arg } =
488                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
489            {
490                Some((name, arg.and_then(|arg| arg.downcast_ref::<A>().cloned())))
491            } else {
492                None
493            }
494        }
495    }
496}