keymap.rs

  1use crate::Action;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::{
  5    any::{Any, TypeId},
  6    collections::{HashMap, HashSet},
  7    fmt::{Debug, Write},
  8};
  9use tree_sitter::{Language, Node, Parser};
 10
 11extern "C" {
 12    fn tree_sitter_context_predicate() -> Language;
 13}
 14
 15pub struct Matcher {
 16    pending: HashMap<usize, Pending>,
 17    keymap: Keymap,
 18}
 19
 20#[derive(Default)]
 21struct Pending {
 22    keystrokes: Vec<Keystroke>,
 23    context: Option<Context>,
 24}
 25
 26#[derive(Default)]
 27pub struct Keymap {
 28    bindings: Vec<Binding>,
 29    binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
 30}
 31
 32pub struct Binding {
 33    keystrokes: SmallVec<[Keystroke; 2]>,
 34    action: Box<dyn Action>,
 35    context_predicate: Option<ContextPredicate>,
 36}
 37
 38#[derive(Clone, Debug, Eq, PartialEq)]
 39pub struct Keystroke {
 40    pub ctrl: bool,
 41    pub alt: bool,
 42    pub shift: bool,
 43    pub cmd: bool,
 44    pub function: bool,
 45    pub key: String,
 46}
 47
 48#[derive(Clone, Debug, Default, Eq, PartialEq)]
 49pub struct Context {
 50    pub set: HashSet<String>,
 51    pub map: HashMap<String, String>,
 52}
 53
 54#[derive(Debug, Eq, PartialEq)]
 55enum ContextPredicate {
 56    Identifier(String),
 57    Equal(String, String),
 58    NotEqual(String, String),
 59    Not(Box<ContextPredicate>),
 60    And(Box<ContextPredicate>, Box<ContextPredicate>),
 61    Or(Box<ContextPredicate>, Box<ContextPredicate>),
 62}
 63
 64trait ActionArg {
 65    fn boxed_clone(&self) -> Box<dyn Any>;
 66}
 67
 68impl<T> ActionArg for T
 69where
 70    T: 'static + Any + Clone,
 71{
 72    fn boxed_clone(&self) -> Box<dyn Any> {
 73        Box::new(self.clone())
 74    }
 75}
 76
 77pub enum MatchResult {
 78    None,
 79    Pending,
 80    Action(Box<dyn Action>),
 81}
 82
 83impl Debug for MatchResult {
 84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 85        match self {
 86            MatchResult::None => f.debug_struct("MatchResult::None").finish(),
 87            MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
 88            MatchResult::Action(action) => f
 89                .debug_tuple("MatchResult::Action")
 90                .field(&action.name())
 91                .finish(),
 92        }
 93    }
 94}
 95
 96impl Matcher {
 97    pub fn new(keymap: Keymap) -> Self {
 98        Self {
 99            pending: HashMap::new(),
100            keymap,
101        }
102    }
103
104    pub fn set_keymap(&mut self, keymap: Keymap) {
105        self.pending.clear();
106        self.keymap = keymap;
107    }
108
109    pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
110        self.pending.clear();
111        self.keymap.add_bindings(bindings);
112    }
113
114    pub fn clear_bindings(&mut self) {
115        self.pending.clear();
116        self.keymap.clear();
117    }
118
119    pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
120        self.keymap.bindings_for_action_type(action_type)
121    }
122
123    pub fn clear_pending(&mut self) {
124        self.pending.clear();
125    }
126
127    pub fn has_pending_keystrokes(&self) -> bool {
128        !self.pending.is_empty()
129    }
130
131    pub fn push_keystroke(
132        &mut self,
133        keystroke: Keystroke,
134        view_id: usize,
135        cx: &Context,
136    ) -> MatchResult {
137        let pending = self.pending.entry(view_id).or_default();
138
139        if let Some(pending_ctx) = pending.context.as_ref() {
140            if pending_ctx != cx {
141                pending.keystrokes.clear();
142            }
143        }
144
145        pending.keystrokes.push(keystroke);
146
147        let mut retain_pending = false;
148        for binding in self.keymap.bindings.iter().rev() {
149            if binding.keystrokes.starts_with(&pending.keystrokes)
150                && binding
151                    .context_predicate
152                    .as_ref()
153                    .map(|c| c.eval(cx))
154                    .unwrap_or(true)
155            {
156                if binding.keystrokes.len() == pending.keystrokes.len() {
157                    self.pending.remove(&view_id);
158                    return MatchResult::Action(binding.action.boxed_clone());
159                } else {
160                    retain_pending = true;
161                    pending.context = Some(cx.clone());
162                }
163            }
164        }
165
166        if retain_pending {
167            MatchResult::Pending
168        } else {
169            self.pending.remove(&view_id);
170            MatchResult::None
171        }
172    }
173
174    pub fn keystrokes_for_action(
175        &self,
176        action: &dyn Action,
177        cx: &Context,
178    ) -> Option<SmallVec<[Keystroke; 2]>> {
179        for binding in self.keymap.bindings.iter().rev() {
180            if binding.action.eq(action)
181                && binding
182                    .context_predicate
183                    .as_ref()
184                    .map_or(true, |predicate| predicate.eval(cx))
185            {
186                return Some(binding.keystrokes.clone());
187            }
188        }
189        None
190    }
191}
192
193impl Default for Matcher {
194    fn default() -> Self {
195        Self::new(Keymap::default())
196    }
197}
198
199impl Keymap {
200    pub fn new(bindings: Vec<Binding>) -> Self {
201        let mut binding_indices_by_action_type = HashMap::new();
202        for (ix, binding) in bindings.iter().enumerate() {
203            binding_indices_by_action_type
204                .entry(binding.action.as_any().type_id())
205                .or_insert_with(SmallVec::new)
206                .push(ix);
207        }
208        Self {
209            binding_indices_by_action_type,
210            bindings,
211        }
212    }
213
214    fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
215        self.binding_indices_by_action_type
216            .get(&action_type)
217            .map(SmallVec::as_slice)
218            .unwrap_or(&[])
219            .iter()
220            .map(|ix| &self.bindings[*ix])
221    }
222
223    fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
224        for binding in bindings {
225            self.binding_indices_by_action_type
226                .entry(binding.action.as_any().type_id())
227                .or_default()
228                .push(self.bindings.len());
229            self.bindings.push(binding);
230        }
231    }
232
233    fn clear(&mut self) {
234        self.bindings.clear();
235        self.binding_indices_by_action_type.clear();
236    }
237}
238
239impl Binding {
240    pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
241        Self::load(keystrokes, Box::new(action), context).unwrap()
242    }
243
244    pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
245        let context = if let Some(context) = context {
246            Some(ContextPredicate::parse(context)?)
247        } else {
248            None
249        };
250
251        let keystrokes = keystrokes
252            .split_whitespace()
253            .map(Keystroke::parse)
254            .collect::<Result<_>>()?;
255
256        Ok(Self {
257            keystrokes,
258            action,
259            context_predicate: context,
260        })
261    }
262
263    pub fn keystrokes(&self) -> &[Keystroke] {
264        &self.keystrokes
265    }
266
267    pub fn action(&self) -> &dyn Action {
268        self.action.as_ref()
269    }
270}
271
272impl Keystroke {
273    pub fn parse(source: &str) -> anyhow::Result<Self> {
274        let mut ctrl = false;
275        let mut alt = false;
276        let mut shift = false;
277        let mut cmd = false;
278        let mut function = false;
279        let mut key = None;
280
281        let mut components = source.split('-').peekable();
282        while let Some(component) = components.next() {
283            match component {
284                "ctrl" => ctrl = true,
285                "alt" => alt = true,
286                "shift" => shift = true,
287                "cmd" => cmd = true,
288                "fn" => function = true,
289                _ => {
290                    if let Some(component) = components.peek() {
291                        if component.is_empty() && source.ends_with('-') {
292                            key = Some(String::from("-"));
293                            break;
294                        } else {
295                            return Err(anyhow!("Invalid keystroke `{}`", source));
296                        }
297                    } else {
298                        key = Some(String::from(component));
299                    }
300                }
301            }
302        }
303
304        if key.is_none() {
305            return Err(anyhow!("Invalid keystroke `{}`", source));
306        }
307
308        Ok(Keystroke {
309            ctrl,
310            alt,
311            shift,
312            cmd,
313            function,
314            key: key.unwrap(),
315        })
316    }
317
318    pub fn modified(&self) -> bool {
319        self.ctrl || self.alt || self.shift || self.cmd
320    }
321}
322
323impl std::fmt::Display for Keystroke {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        if self.ctrl {
326            f.write_char('^')?;
327        }
328        if self.alt {
329            f.write_char('⎇')?;
330        }
331        if self.cmd {
332            f.write_char('⌘')?;
333        }
334        if self.shift {
335            f.write_char('⇧')?;
336        }
337        let key = match self.key.as_str() {
338            "backspace" => '⌫',
339            "up" => '↑',
340            "down" => '↓',
341            "left" => '←',
342            "right" => '→',
343            "tab" => '⇥',
344            "escape" => '⎋',
345            key => {
346                if key.len() == 1 {
347                    key.chars().next().unwrap().to_ascii_uppercase()
348                } else {
349                    return f.write_str(key);
350                }
351            }
352        };
353        f.write_char(key)
354    }
355}
356
357impl Context {
358    pub fn extend(&mut self, other: &Context) {
359        for v in &other.set {
360            self.set.insert(v.clone());
361        }
362        for (k, v) in &other.map {
363            self.map.insert(k.clone(), v.clone());
364        }
365    }
366}
367
368impl ContextPredicate {
369    fn parse(source: &str) -> anyhow::Result<Self> {
370        let mut parser = Parser::new();
371        let language = unsafe { tree_sitter_context_predicate() };
372        parser.set_language(language).unwrap();
373        let source = source.as_bytes();
374        let tree = parser.parse(source, None).unwrap();
375        Self::from_node(tree.root_node(), source)
376    }
377
378    fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
379        let parse_error = "error parsing context predicate";
380        let kind = node.kind();
381
382        match kind {
383            "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
384            "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
385            "not" => {
386                let child = Self::from_node(
387                    node.child_by_field_name("expression")
388                        .ok_or_else(|| anyhow!(parse_error))?,
389                    source,
390                )?;
391                Ok(Self::Not(Box::new(child)))
392            }
393            "and" | "or" => {
394                let left = Box::new(Self::from_node(
395                    node.child_by_field_name("left")
396                        .ok_or_else(|| anyhow!(parse_error))?,
397                    source,
398                )?);
399                let right = Box::new(Self::from_node(
400                    node.child_by_field_name("right")
401                        .ok_or_else(|| anyhow!(parse_error))?,
402                    source,
403                )?);
404                if kind == "and" {
405                    Ok(Self::And(left, right))
406                } else {
407                    Ok(Self::Or(left, right))
408                }
409            }
410            "equal" | "not_equal" => {
411                let left = node
412                    .child_by_field_name("left")
413                    .ok_or_else(|| anyhow!(parse_error))?
414                    .utf8_text(source)?
415                    .into();
416                let right = node
417                    .child_by_field_name("right")
418                    .ok_or_else(|| anyhow!(parse_error))?
419                    .utf8_text(source)?
420                    .into();
421                if kind == "equal" {
422                    Ok(Self::Equal(left, right))
423                } else {
424                    Ok(Self::NotEqual(left, right))
425                }
426            }
427            "parenthesized" => Self::from_node(
428                node.child_by_field_name("expression")
429                    .ok_or_else(|| anyhow!(parse_error))?,
430                source,
431            ),
432            _ => Err(anyhow!(parse_error)),
433        }
434    }
435
436    fn eval(&self, cx: &Context) -> bool {
437        match self {
438            Self::Identifier(name) => cx.set.contains(name.as_str()),
439            Self::Equal(left, right) => cx
440                .map
441                .get(left)
442                .map(|value| value == right)
443                .unwrap_or(false),
444            Self::NotEqual(left, right) => {
445                cx.map.get(left).map(|value| value != right).unwrap_or(true)
446            }
447            Self::Not(pred) => !pred.eval(cx),
448            Self::And(left, right) => left.eval(cx) && right.eval(cx),
449            Self::Or(left, right) => left.eval(cx) || right.eval(cx),
450        }
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use serde::Deserialize;
457
458    use crate::{actions, impl_actions};
459
460    use super::*;
461
462    #[test]
463    fn test_keystroke_parsing() -> anyhow::Result<()> {
464        assert_eq!(
465            Keystroke::parse("ctrl-p")?,
466            Keystroke {
467                key: "p".into(),
468                ctrl: true,
469                alt: false,
470                shift: false,
471                cmd: false,
472                function: false,
473            }
474        );
475
476        assert_eq!(
477            Keystroke::parse("alt-shift-down")?,
478            Keystroke {
479                key: "down".into(),
480                ctrl: false,
481                alt: true,
482                shift: true,
483                cmd: false,
484                function: false,
485            }
486        );
487
488        assert_eq!(
489            Keystroke::parse("shift-cmd--")?,
490            Keystroke {
491                key: "-".into(),
492                ctrl: false,
493                alt: false,
494                shift: true,
495                cmd: true,
496                function: false,
497            }
498        );
499
500        Ok(())
501    }
502
503    #[test]
504    fn test_context_predicate_parsing() -> anyhow::Result<()> {
505        use ContextPredicate::*;
506
507        assert_eq!(
508            ContextPredicate::parse("a && (b == c || d != e)")?,
509            And(
510                Box::new(Identifier("a".into())),
511                Box::new(Or(
512                    Box::new(Equal("b".into(), "c".into())),
513                    Box::new(NotEqual("d".into(), "e".into())),
514                ))
515            )
516        );
517
518        assert_eq!(
519            ContextPredicate::parse("!a")?,
520            Not(Box::new(Identifier("a".into())),)
521        );
522
523        Ok(())
524    }
525
526    #[test]
527    fn test_context_predicate_eval() -> anyhow::Result<()> {
528        let predicate = ContextPredicate::parse("a && b || c == d")?;
529
530        let mut context = Context::default();
531        context.set.insert("a".into());
532        assert!(!predicate.eval(&context));
533
534        context.set.insert("b".into());
535        assert!(predicate.eval(&context));
536
537        context.set.remove("b");
538        context.map.insert("c".into(), "x".into());
539        assert!(!predicate.eval(&context));
540
541        context.map.insert("c".into(), "d".into());
542        assert!(predicate.eval(&context));
543
544        let predicate = ContextPredicate::parse("!a")?;
545        assert!(predicate.eval(&Context::default()));
546
547        Ok(())
548    }
549
550    #[test]
551    fn test_matcher() -> anyhow::Result<()> {
552        #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
553        pub struct A(pub String);
554        impl_actions!(test, [A]);
555        actions!(test, [B, Ab]);
556
557        #[derive(Clone, Debug, Eq, PartialEq)]
558        struct ActionArg {
559            a: &'static str,
560        }
561
562        let keymap = Keymap::new(vec![
563            Binding::new("a", A("x".to_string()), Some("a")),
564            Binding::new("b", B, Some("a")),
565            Binding::new("a b", Ab, Some("a || b")),
566        ]);
567
568        let mut ctx_a = Context::default();
569        ctx_a.set.insert("a".into());
570
571        let mut ctx_b = Context::default();
572        ctx_b.set.insert("b".into());
573
574        let mut matcher = Matcher::new(keymap);
575
576        // Basic match
577        assert_eq!(
578            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
579            Some(&A("x".to_string()))
580        );
581
582        // Multi-keystroke match
583        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
584        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
585
586        // Failed matches don't interfere with matching subsequent keys
587        assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
588        assert_eq!(
589            downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
590            Some(&A("x".to_string()))
591        );
592
593        // Pending keystrokes are cleared when the context changes
594        assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
595        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
596
597        let mut ctx_c = Context::default();
598        ctx_c.set.insert("c".into());
599
600        // Pending keystrokes are maintained per-view
601        assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
602        assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
603        assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
604
605        Ok(())
606    }
607
608    fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
609        action
610            .as_ref()
611            .and_then(|action| action.as_any().downcast_ref())
612    }
613
614    impl Matcher {
615        fn test_keystroke(
616            &mut self,
617            keystroke: &str,
618            view_id: usize,
619            cx: &Context,
620        ) -> Option<Box<dyn Action>> {
621            if let MatchResult::Action(action) =
622                self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
623            {
624                Some(action.boxed_clone())
625            } else {
626                None
627            }
628        }
629    }
630}