keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::Debug,
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending: HashMap<usize, Pending>,
 17    keymap: Keymap,
 18}
 19
 20#[derive(Default)]
 21struct Pending {
 22    keystrokes: Vec<Keystroke>,
 23    context: Option<Context>,
 24}
 25
 26#[derive(Default)]
 27pub struct Keymap {
 28    bindings: Vec<Binding>,
 29    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 30}
 31
 32pub struct Binding {
 33    keystrokes: Vec<Keystroke>,
 34    action: Box<dyn Action>,
 35    context: Option<ContextPredicate>,
 36}
 37
 38#[derive(Clone, Debug, Eq, PartialEq)]
 39pub struct Keystroke {
 40    pub ctrl: bool,
 41    pub alt: bool,
 42    pub shift: bool,
 43    pub cmd: bool,
 44    pub key: String,
 45}
 46
 47#[derive(Clone, Debug, Default, Eq, PartialEq)]
 48pub struct Context {
 49    pub set: HashSet<String>,
 50    pub map: HashMap<String, String>,
 51}
 52
 53#[derive(Debug, Eq, PartialEq)]
 54enum ContextPredicate {
 55    Identifier(String),
 56    Equal(String, String),
 57    NotEqual(String, String),
 58    Not(Box<ContextPredicate>),
 59    And(Box<ContextPredicate>, Box<ContextPredicate>),
 60    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 61}
 62
 63trait ActionArg {
 64    fn boxed_clone(&self) -> Box<dyn Any>;
 65}
 66
 67impl<T> ActionArg for T
 68where
 69    T: 'static + Any + Clone,
 70{
 71    fn boxed_clone(&self) -> Box<dyn Any> {
 72        Box::new(self.clone())
 73    }
 74}
 75
 76pub enum MatchResult {
 77    None,
 78    Pending,
 79    Action(Box<dyn Action>),
 80}
 81
 82impl Debug for MatchResult {
 83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 84        match self {
 85            MatchResult::None => f.debug_struct("MatchResult::None").finish(),
 86            MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
 87            MatchResult::Action(action) => f
 88                .debug_tuple("MatchResult::Action")
 89                .field(&action.name())
 90                .finish(),
 91        }
 92    }
 93}
 94
 95impl Matcher {
 96    pub fn new(keymap: Keymap) -> Self {
 97        Self {
 98            pending: HashMap::new(),
 99            keymap,
100        }
101    }
102
103    pub fn set_keymap(&mut self, keymap: Keymap) {
104        self.pending.clear();
105        self.keymap = keymap;
106    }
107
108    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
109        self.pending.clear();
110        self.keymap.add_bindings(bindings);
111    }
112
113    pub fn clear_bindings(&mut self) {
114        self.pending.clear();
115        self.keymap.clear();
116    }
117
118    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
119        self.keymap.bindings_for_action_type(action_type)
120    }
121
122    pub fn clear_pending(&mut self) {
123        self.pending.clear();
124    }
125
126    pub fn push_keystroke(
127        &mut self,
128        keystroke: Keystroke,
129        view_id: usize,
130        cx: &Context,
131    ) -> MatchResult {
132        let pending = self.pending.entry(view_id).or_default();
133
134        if let Some(pending_ctx) = pending.context.as_ref() {
135            if pending_ctx != cx {
136                pending.keystrokes.clear();
137            }
138        }
139
140        pending.keystrokes.push(keystroke);
141
142        let mut retain_pending = false;
143        for binding in self.keymap.bindings.iter().rev() {
144            if binding.keystrokes.starts_with(&pending.keystrokes)
145                && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
146            {
147                if binding.keystrokes.len() == pending.keystrokes.len() {
148                    self.pending.remove(&view_id);
149                    return MatchResult::Action(binding.action.boxed_clone());
150                } else {
151                    retain_pending = true;
152                    pending.context = Some(cx.clone());
153                }
154            }
155        }
156
157        if retain_pending {
158            MatchResult::Pending
159        } else {
160            self.pending.remove(&view_id);
161            MatchResult::None
162        }
163    }
164}
165
166impl Default for Matcher {
167    fn default() -> Self {
168        Self::new(Keymap::default())
169    }
170}
171
172impl Keymap {
173    pub fn new(bindings: Vec<Binding>) -> Self {
174        let mut binding_indices_by_action_type = HashMap::new();
175        for (ix, binding) in bindings.iter().enumerate() {
176            binding_indices_by_action_type
177                .entry(binding.action.as_any().type_id())
178                .or_insert_with(|| SmallVec::new())
179                .push(ix);
180        }
181        Self {
182            binding_indices_by_action_type,
183            bindings,
184        }
185    }
186
187    fn bindings_for_action_type<'a>(
188        &'a self,
189        action_type: TypeId,
190    ) -> impl Iterator<Item = &'a Binding> {
191        self.binding_indices_by_action_type
192            .get(&action_type)
193            .map(SmallVec::as_slice)
194            .unwrap_or(&[])
195            .iter()
196            .map(|ix| &self.bindings[*ix])
197    }
198
199    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
200        for binding in bindings {
201            self.binding_indices_by_action_type
202                .entry(binding.action.as_any().type_id())
203                .or_default()
204                .push(self.bindings.len());
205            self.bindings.push(binding);
206        }
207    }
208
209    fn clear(&mut self) {
210        self.bindings.clear();
211        self.binding_indices_by_action_type.clear();
212    }
213}
214
215impl Binding {
216    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
217        Self::load(keystrokes, Box::new(action), context).unwrap()
218    }
219
220    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
221        let context = if let Some(context) = context {
222            Some(ContextPredicate::parse(context)?)
223        } else {
224            None
225        };
226
227        let keystrokes = keystrokes
228            .split_whitespace()
229            .map(|key| Keystroke::parse(key))
230            .collect::<Result<_>>()?;
231
232        Ok(Self {
233            keystrokes,
234            action,
235            context,
236        })
237    }
238
239    pub fn keystrokes(&self) -> &[Keystroke] {
240        &self.keystrokes
241    }
242}
243
244impl Keystroke {
245    pub fn parse(source: &str) -> anyhow::Result<Self> {
246        let mut ctrl = false;
247        let mut alt = false;
248        let mut shift = false;
249        let mut cmd = false;
250        let mut key = None;
251
252        let mut components = source.split("-").peekable();
253        while let Some(component) = components.next() {
254            match component {
255                "ctrl" => ctrl = true,
256                "alt" => alt = true,
257                "shift" => shift = true,
258                "cmd" => cmd = true,
259                _ => {
260                    if let Some(component) = components.peek() {
261                        if component.is_empty() && source.ends_with('-') {
262                            key = Some(String::from("-"));
263                            break;
264                        } else {
265                            return Err(anyhow!("Invalid keystroke `{}`", source));
266                        }
267                    } else {
268                        key = Some(String::from(component));
269                    }
270                }
271            }
272        }
273
274        Ok(Keystroke {
275            ctrl,
276            alt,
277            shift,
278            cmd,
279            key: key.unwrap(),
280        })
281    }
282
283    pub fn modified(&self) -> bool {
284        self.ctrl || self.alt || self.shift || self.cmd
285    }
286}
287
288impl Context {
289    pub fn extend(&mut self, other: &Context) {
290        for v in &other.set {
291            self.set.insert(v.clone());
292        }
293        for (k, v) in &other.map {
294            self.map.insert(k.clone(), v.clone());
295        }
296    }
297}
298
299impl ContextPredicate {
300    fn parse(source: &str) -> anyhow::Result<Self> {
301        let mut parser = Parser::new();
302        let language = unsafe { tree_sitter_context_predicate() };
303        parser.set_language(language).unwrap();
304        let source = source.as_bytes();
305        let tree = parser.parse(source, None).unwrap();
306        Self::from_node(tree.root_node(), source)
307    }
308
309    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
310        let parse_error = "error parsing context predicate";
311        let kind = node.kind();
312
313        match kind {
314            "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
315            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
316            "not" => {
317                let child = Self::from_node(
318                    node.child_by_field_name("expression")
319                        .ok_or(anyhow!(parse_error))?,
320                    source,
321                )?;
322                Ok(Self::Not(Box::new(child)))
323            }
324            "and" | "or" => {
325                let left = Box::new(Self::from_node(
326                    node.child_by_field_name("left")
327                        .ok_or(anyhow!(parse_error))?,
328                    source,
329                )?);
330                let right = Box::new(Self::from_node(
331                    node.child_by_field_name("right")
332                        .ok_or(anyhow!(parse_error))?,
333                    source,
334                )?);
335                if kind == "and" {
336                    Ok(Self::And(left, right))
337                } else {
338                    Ok(Self::Or(left, right))
339                }
340            }
341            "equal" | "not_equal" => {
342                let left = node
343                    .child_by_field_name("left")
344                    .ok_or(anyhow!(parse_error))?
345                    .utf8_text(source)?
346                    .into();
347                let right = node
348                    .child_by_field_name("right")
349                    .ok_or(anyhow!(parse_error))?
350                    .utf8_text(source)?
351                    .into();
352                if kind == "equal" {
353                    Ok(Self::Equal(left, right))
354                } else {
355                    Ok(Self::NotEqual(left, right))
356                }
357            }
358            "parenthesized" => Self::from_node(
359                node.child_by_field_name("expression")
360                    .ok_or(anyhow!(parse_error))?,
361                source,
362            ),
363            _ => Err(anyhow!(parse_error)),
364        }
365    }
366
367    fn eval(&self, cx: &Context) -> bool {
368        match self {
369            Self::Identifier(name) => cx.set.contains(name.as_str()),
370            Self::Equal(left, right) => cx
371                .map
372                .get(left)
373                .map(|value| value == right)
374                .unwrap_or(false),
375            Self::NotEqual(left, right) => {
376                cx.map.get(left).map(|value| value != right).unwrap_or(true)
377            }
378            Self::Not(pred) => !pred.eval(cx),
379            Self::And(left, right) => left.eval(cx) && right.eval(cx),
380            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
381        }
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use serde::Deserialize;
388
389    use crate::{actions, impl_actions};
390
391    use super::*;
392
393    #[test]
394    fn test_keystroke_parsing() -> anyhow::Result<()> {
395        assert_eq!(
396            Keystroke::parse("ctrl-p")?,
397            Keystroke {
398                key: "p".into(),
399                ctrl: true,
400                alt: false,
401                shift: false,
402                cmd: false,
403            }
404        );
405
406        assert_eq!(
407            Keystroke::parse("alt-shift-down")?,
408            Keystroke {
409                key: "down".into(),
410                ctrl: false,
411                alt: true,
412                shift: true,
413                cmd: false,
414            }
415        );
416
417        assert_eq!(
418            Keystroke::parse("shift-cmd--")?,
419            Keystroke {
420                key: "-".into(),
421                ctrl: false,
422                alt: false,
423                shift: true,
424                cmd: true,
425            }
426        );
427
428        Ok(())
429    }
430
431    #[test]
432    fn test_context_predicate_parsing() -> anyhow::Result<()> {
433        use ContextPredicate::*;
434
435        assert_eq!(
436            ContextPredicate::parse("a && (b == c || d != e)")?,
437            And(
438                Box::new(Identifier("a".into())),
439                Box::new(Or(
440                    Box::new(Equal("b".into(), "c".into())),
441                    Box::new(NotEqual("d".into(), "e".into())),
442                ))
443            )
444        );
445
446        assert_eq!(
447            ContextPredicate::parse("!a")?,
448            Not(Box::new(Identifier("a".into())),)
449        );
450
451        Ok(())
452    }
453
454    #[test]
455    fn test_context_predicate_eval() -> anyhow::Result<()> {
456        let predicate = ContextPredicate::parse("a && b || c == d")?;
457
458        let mut context = Context::default();
459        context.set.insert("a".into());
460        assert!(!predicate.eval(&context));
461
462        context.set.insert("b".into());
463        assert!(predicate.eval(&context));
464
465        context.set.remove("b");
466        context.map.insert("c".into(), "x".into());
467        assert!(!predicate.eval(&context));
468
469        context.map.insert("c".into(), "d".into());
470        assert!(predicate.eval(&context));
471
472        let predicate = ContextPredicate::parse("!a")?;
473        assert!(predicate.eval(&Context::default()));
474
475        Ok(())
476    }
477
478    #[test]
479    fn test_matcher() -> anyhow::Result<()> {
480        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
481        pub struct A(pub String);
482        impl_actions!(test, [A]);
483        actions!(test, [B, Ab]);
484
485        #[derive(Clone, Debug, Eq, PartialEq)]
486        struct ActionArg {
487            a: &'static str,
488        }
489
490        let keymap = Keymap::new(vec![
491            Binding::new("a", A("x".to_string()), Some("a")),
492            Binding::new("b", B, Some("a")),
493            Binding::new("a b", Ab, Some("a || b")),
494        ]);
495
496        let mut ctx_a = Context::default();
497        ctx_a.set.insert("a".into());
498
499        let mut ctx_b = Context::default();
500        ctx_b.set.insert("b".into());
501
502        let mut matcher = Matcher::new(keymap);
503
504        // Basic match
505        assert_eq!(
506            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
507            Some(&A("x".to_string()))
508        );
509
510        // Multi-keystroke match
511        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
512        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
513
514        // Failed matches don't interfere with matching subsequent keys
515        assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
516        assert_eq!(
517            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
518            Some(&A("x".to_string()))
519        );
520
521        // Pending keystrokes are cleared when the context changes
522        assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
523        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
524
525        let mut ctx_c = Context::default();
526        ctx_c.set.insert("c".into());
527
528        // Pending keystrokes are maintained per-view
529        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
530        assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
531        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
532
533        Ok(())
534    }
535
536    fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
537        action
538            .as_ref()
539            .and_then(|action| action.as_any().downcast_ref())
540    }
541
542    impl Matcher {
543        fn test_keystroke(
544            &mut self,
545            keystroke: &str,
546            view_id: usize,
547            cx: &Context,
548        ) -> Option<Box<dyn Action>> {
549            if let MatchResult::Action(action) =
550                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
551            {
552                Some(action.boxed_clone())
553            } else {
554                None
555            }
556        }
557    }
558}