keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use std::{
  4    any::Any,
  5    collections::{HashMap, HashSet},
  6    fmt::Debug,
  7};
  8use tree_sitter::{Language, Node, Parser};
  9
 10extern "C" {
 11    fn tree_sitter_context_predicate() -> Language;
 12}
 13
 14pub struct Matcher {
 15    pending: HashMap<usize, Pending>,
 16    keymap: Keymap,
 17}
 18
 19#[derive(Default)]
 20struct Pending {
 21    keystrokes: Vec<Keystroke>,
 22    context: Option<Context>,
 23}
 24
 25#[derive(Default)]
 26pub struct Keymap(Vec<Binding>);
 27
 28pub struct Binding {
 29    keystrokes: Vec<Keystroke>,
 30    action: Box<dyn Action>,
 31    context: Option<ContextPredicate>,
 32}
 33
 34#[derive(Clone, Debug, Eq, PartialEq)]
 35pub struct Keystroke {
 36    pub ctrl: bool,
 37    pub alt: bool,
 38    pub shift: bool,
 39    pub cmd: bool,
 40    pub key: String,
 41}
 42
 43#[derive(Clone, Debug, Default, Eq, PartialEq)]
 44pub struct Context {
 45    pub set: HashSet<String>,
 46    pub map: HashMap<String, String>,
 47}
 48
 49#[derive(Debug, Eq, PartialEq)]
 50enum ContextPredicate {
 51    Identifier(String),
 52    Equal(String, String),
 53    NotEqual(String, String),
 54    Not(Box<ContextPredicate>),
 55    And(Box<ContextPredicate>, Box<ContextPredicate>),
 56    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 57}
 58
 59trait ActionArg {
 60    fn boxed_clone(&self) -> Box<dyn Any>;
 61}
 62
 63impl<T> ActionArg for T
 64where
 65    T: 'static + Any + Clone,
 66{
 67    fn boxed_clone(&self) -> Box<dyn Any> {
 68        Box::new(self.clone())
 69    }
 70}
 71
 72pub enum MatchResult {
 73    None,
 74    Pending,
 75    Action(Box<dyn Action>),
 76}
 77
 78impl Debug for MatchResult {
 79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 80        match self {
 81            MatchResult::None => f.debug_struct("MatchResult::None").finish(),
 82            MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
 83            MatchResult::Action(action) => f
 84                .debug_tuple("MatchResult::Action")
 85                .field(&action.name())
 86                .finish(),
 87        }
 88    }
 89}
 90
 91impl Matcher {
 92    pub fn new(keymap: Keymap) -> Self {
 93        Self {
 94            pending: HashMap::new(),
 95            keymap,
 96        }
 97    }
 98
 99    pub fn set_keymap(&mut self, keymap: Keymap) {
100        self.pending.clear();
101        self.keymap = keymap;
102    }
103
104    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
105        self.pending.clear();
106        self.keymap.add_bindings(bindings);
107    }
108
109    pub fn clear_pending(&mut self) {
110        self.pending.clear();
111    }
112
113    pub fn push_keystroke(
114        &mut self,
115        keystroke: Keystroke,
116        view_id: usize,
117        cx: &Context,
118    ) -> MatchResult {
119        let pending = self.pending.entry(view_id).or_default();
120
121        if let Some(pending_ctx) = pending.context.as_ref() {
122            if pending_ctx != cx {
123                pending.keystrokes.clear();
124            }
125        }
126
127        pending.keystrokes.push(keystroke);
128
129        let mut retain_pending = false;
130        for binding in self.keymap.0.iter().rev() {
131            if binding.keystrokes.starts_with(&pending.keystrokes)
132                && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
133            {
134                if binding.keystrokes.len() == pending.keystrokes.len() {
135                    self.pending.remove(&view_id);
136                    return MatchResult::Action(binding.action.boxed_clone());
137                } else {
138                    retain_pending = true;
139                    pending.context = Some(cx.clone());
140                }
141            }
142        }
143
144        if retain_pending {
145            MatchResult::Pending
146        } else {
147            self.pending.remove(&view_id);
148            MatchResult::None
149        }
150    }
151}
152
153impl Default for Matcher {
154    fn default() -> Self {
155        Self::new(Keymap::default())
156    }
157}
158
159impl Keymap {
160    pub fn new(bindings: Vec<Binding>) -> Self {
161        Self(bindings)
162    }
163
164    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
165        self.0.extend(bindings.into_iter());
166    }
167}
168
169impl Binding {
170    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
171        Self::load(keystrokes, Box::new(action), context).unwrap()
172    }
173
174    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
175        let context = if let Some(context) = context {
176            Some(ContextPredicate::parse(context)?)
177        } else {
178            None
179        };
180
181        let keystrokes = keystrokes
182            .split_whitespace()
183            .map(|key| Keystroke::parse(key))
184            .collect::<Result<_>>()?;
185
186        Ok(Self {
187            keystrokes,
188            action,
189            context,
190        })
191    }
192}
193
194impl Keystroke {
195    pub fn parse(source: &str) -> anyhow::Result<Self> {
196        let mut ctrl = false;
197        let mut alt = false;
198        let mut shift = false;
199        let mut cmd = false;
200        let mut key = None;
201
202        let mut components = source.split("-").peekable();
203        while let Some(component) = components.next() {
204            match component {
205                "ctrl" => ctrl = true,
206                "alt" => alt = true,
207                "shift" => shift = true,
208                "cmd" => cmd = true,
209                _ => {
210                    if let Some(component) = components.peek() {
211                        if component.is_empty() && source.ends_with('-') {
212                            key = Some(String::from("-"));
213                            break;
214                        } else {
215                            return Err(anyhow!("Invalid keystroke `{}`", source));
216                        }
217                    } else {
218                        key = Some(String::from(component));
219                    }
220                }
221            }
222        }
223
224        Ok(Keystroke {
225            ctrl,
226            alt,
227            shift,
228            cmd,
229            key: key.unwrap(),
230        })
231    }
232
233    pub fn modified(&self) -> bool {
234        self.ctrl || self.alt || self.shift || self.cmd
235    }
236}
237
238impl Context {
239    pub fn extend(&mut self, other: &Context) {
240        for v in &other.set {
241            self.set.insert(v.clone());
242        }
243        for (k, v) in &other.map {
244            self.map.insert(k.clone(), v.clone());
245        }
246    }
247}
248
249impl ContextPredicate {
250    fn parse(source: &str) -> anyhow::Result<Self> {
251        let mut parser = Parser::new();
252        let language = unsafe { tree_sitter_context_predicate() };
253        parser.set_language(language).unwrap();
254        let source = source.as_bytes();
255        let tree = parser.parse(source, None).unwrap();
256        Self::from_node(tree.root_node(), source)
257    }
258
259    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
260        let parse_error = "error parsing context predicate";
261        let kind = node.kind();
262
263        match kind {
264            "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
265            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
266            "not" => {
267                let child = Self::from_node(
268                    node.child_by_field_name("expression")
269                        .ok_or(anyhow!(parse_error))?,
270                    source,
271                )?;
272                Ok(Self::Not(Box::new(child)))
273            }
274            "and" | "or" => {
275                let left = Box::new(Self::from_node(
276                    node.child_by_field_name("left")
277                        .ok_or(anyhow!(parse_error))?,
278                    source,
279                )?);
280                let right = Box::new(Self::from_node(
281                    node.child_by_field_name("right")
282                        .ok_or(anyhow!(parse_error))?,
283                    source,
284                )?);
285                if kind == "and" {
286                    Ok(Self::And(left, right))
287                } else {
288                    Ok(Self::Or(left, right))
289                }
290            }
291            "equal" | "not_equal" => {
292                let left = node
293                    .child_by_field_name("left")
294                    .ok_or(anyhow!(parse_error))?
295                    .utf8_text(source)?
296                    .into();
297                let right = node
298                    .child_by_field_name("right")
299                    .ok_or(anyhow!(parse_error))?
300                    .utf8_text(source)?
301                    .into();
302                if kind == "equal" {
303                    Ok(Self::Equal(left, right))
304                } else {
305                    Ok(Self::NotEqual(left, right))
306                }
307            }
308            "parenthesized" => Self::from_node(
309                node.child_by_field_name("expression")
310                    .ok_or(anyhow!(parse_error))?,
311                source,
312            ),
313            _ => Err(anyhow!(parse_error)),
314        }
315    }
316
317    fn eval(&self, cx: &Context) -> bool {
318        match self {
319            Self::Identifier(name) => cx.set.contains(name.as_str()),
320            Self::Equal(left, right) => cx
321                .map
322                .get(left)
323                .map(|value| value == right)
324                .unwrap_or(false),
325            Self::NotEqual(left, right) => {
326                cx.map.get(left).map(|value| value != right).unwrap_or(true)
327            }
328            Self::Not(pred) => !pred.eval(cx),
329            Self::And(left, right) => left.eval(cx) && right.eval(cx),
330            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use serde::Deserialize;
338
339    use crate::{actions, impl_actions};
340
341    use super::*;
342
343    #[test]
344    fn test_keystroke_parsing() -> anyhow::Result<()> {
345        assert_eq!(
346            Keystroke::parse("ctrl-p")?,
347            Keystroke {
348                key: "p".into(),
349                ctrl: true,
350                alt: false,
351                shift: false,
352                cmd: false,
353            }
354        );
355
356        assert_eq!(
357            Keystroke::parse("alt-shift-down")?,
358            Keystroke {
359                key: "down".into(),
360                ctrl: false,
361                alt: true,
362                shift: true,
363                cmd: false,
364            }
365        );
366
367        assert_eq!(
368            Keystroke::parse("shift-cmd--")?,
369            Keystroke {
370                key: "-".into(),
371                ctrl: false,
372                alt: false,
373                shift: true,
374                cmd: true,
375            }
376        );
377
378        Ok(())
379    }
380
381    #[test]
382    fn test_context_predicate_parsing() -> anyhow::Result<()> {
383        use ContextPredicate::*;
384
385        assert_eq!(
386            ContextPredicate::parse("a && (b == c || d != e)")?,
387            And(
388                Box::new(Identifier("a".into())),
389                Box::new(Or(
390                    Box::new(Equal("b".into(), "c".into())),
391                    Box::new(NotEqual("d".into(), "e".into())),
392                ))
393            )
394        );
395
396        assert_eq!(
397            ContextPredicate::parse("!a")?,
398            Not(Box::new(Identifier("a".into())),)
399        );
400
401        Ok(())
402    }
403
404    #[test]
405    fn test_context_predicate_eval() -> anyhow::Result<()> {
406        let predicate = ContextPredicate::parse("a && b || c == d")?;
407
408        let mut context = Context::default();
409        context.set.insert("a".into());
410        assert!(!predicate.eval(&context));
411
412        context.set.insert("b".into());
413        assert!(predicate.eval(&context));
414
415        context.set.remove("b");
416        context.map.insert("c".into(), "x".into());
417        assert!(!predicate.eval(&context));
418
419        context.map.insert("c".into(), "d".into());
420        assert!(predicate.eval(&context));
421
422        let predicate = ContextPredicate::parse("!a")?;
423        assert!(predicate.eval(&Context::default()));
424
425        Ok(())
426    }
427
428    #[test]
429    fn test_matcher() -> anyhow::Result<()> {
430        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
431        pub struct A(pub String);
432        impl_actions!(test, [A]);
433        actions!(test, [B, Ab]);
434
435        #[derive(Clone, Debug, Eq, PartialEq)]
436        struct ActionArg {
437            a: &'static str,
438        }
439
440        let keymap = Keymap(vec![
441            Binding::new("a", A("x".to_string()), Some("a")),
442            Binding::new("b", B, Some("a")),
443            Binding::new("a b", Ab, Some("a || b")),
444        ]);
445
446        let mut ctx_a = Context::default();
447        ctx_a.set.insert("a".into());
448
449        let mut ctx_b = Context::default();
450        ctx_b.set.insert("b".into());
451
452        let mut matcher = Matcher::new(keymap);
453
454        // Basic match
455        assert_eq!(
456            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
457            Some(&A("x".to_string()))
458        );
459
460        // Multi-keystroke match
461        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
462        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
463
464        // Failed matches don't interfere with matching subsequent keys
465        assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
466        assert_eq!(
467            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
468            Some(&A("x".to_string()))
469        );
470
471        // Pending keystrokes are cleared when the context changes
472        assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
473        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
474
475        let mut ctx_c = Context::default();
476        ctx_c.set.insert("c".into());
477
478        // Pending keystrokes are maintained per-view
479        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
480        assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
481        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
482
483        Ok(())
484    }
485
486    fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
487        action
488            .as_ref()
489            .and_then(|action| action.as_any().downcast_ref())
490    }
491
492    impl Matcher {
493        fn test_keystroke(
494            &mut self,
495            keystroke: &str,
496            view_id: usize,
497            cx: &Context,
498        ) -> Option<Box<dyn Action>> {
499            if let MatchResult::Action(action) =
500                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
501            {
502                Some(action.boxed_clone())
503            } else {
504                None
505            }
506        }
507    }
508}