1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: SmallVec<[Keystroke; 2]>,
34 action: Box<dyn Action>,
35 context_predicate: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub function: bool,
45 pub key: String,
46}
47
48#[derive(Clone, Debug, Default, Eq, PartialEq)]
49pub struct Context {
50 pub set: HashSet<String>,
51 pub map: HashMap<String, String>,
52}
53
54#[derive(Debug, Eq, PartialEq)]
55enum ContextPredicate {
56 Identifier(String),
57 Equal(String, String),
58 NotEqual(String, String),
59 Not(Box<ContextPredicate>),
60 And(Box<ContextPredicate>, Box<ContextPredicate>),
61 Or(Box<ContextPredicate>, Box<ContextPredicate>),
62}
63
64trait ActionArg {
65 fn boxed_clone(&self) -> Box<dyn Any>;
66}
67
68impl<T> ActionArg for T
69where
70 T: 'static + Any + Clone,
71{
72 fn boxed_clone(&self) -> Box<dyn Any> {
73 Box::new(self.clone())
74 }
75}
76
77pub enum MatchResult {
78 None,
79 Pending,
80 Action(Box<dyn Action>),
81}
82
83impl Debug for MatchResult {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
87 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
88 MatchResult::Action(action) => f
89 .debug_tuple("MatchResult::Action")
90 .field(&action.name())
91 .finish(),
92 }
93 }
94}
95
96impl Matcher {
97 pub fn new(keymap: Keymap) -> Self {
98 Self {
99 pending: HashMap::new(),
100 keymap,
101 }
102 }
103
104 pub fn set_keymap(&mut self, keymap: Keymap) {
105 self.pending.clear();
106 self.keymap = keymap;
107 }
108
109 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
110 self.pending.clear();
111 self.keymap.add_bindings(bindings);
112 }
113
114 pub fn clear_bindings(&mut self) {
115 self.pending.clear();
116 self.keymap.clear();
117 }
118
119 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
120 self.keymap.bindings_for_action_type(action_type)
121 }
122
123 pub fn clear_pending(&mut self) {
124 self.pending.clear();
125 }
126
127 pub fn has_pending_keystrokes(&self) -> bool {
128 !self.pending.is_empty()
129 }
130
131 pub fn push_keystroke(
132 &mut self,
133 keystroke: Keystroke,
134 view_id: usize,
135 cx: &Context,
136 ) -> MatchResult {
137 let pending = self.pending.entry(view_id).or_default();
138
139 if let Some(pending_ctx) = pending.context.as_ref() {
140 if pending_ctx != cx {
141 pending.keystrokes.clear();
142 }
143 }
144
145 pending.keystrokes.push(keystroke);
146
147 let mut retain_pending = false;
148 for binding in self.keymap.bindings.iter().rev() {
149 if binding.keystrokes.starts_with(&pending.keystrokes)
150 && binding
151 .context_predicate
152 .as_ref()
153 .map(|c| c.eval(cx))
154 .unwrap_or(true)
155 {
156 if binding.keystrokes.len() == pending.keystrokes.len() {
157 self.pending.remove(&view_id);
158 return MatchResult::Action(binding.action.boxed_clone());
159 } else {
160 retain_pending = true;
161 pending.context = Some(cx.clone());
162 }
163 }
164 }
165
166 if retain_pending {
167 MatchResult::Pending
168 } else {
169 self.pending.remove(&view_id);
170 MatchResult::None
171 }
172 }
173
174 pub fn keystrokes_for_action(
175 &self,
176 action: &dyn Action,
177 cx: &Context,
178 ) -> Option<SmallVec<[Keystroke; 2]>> {
179 for binding in self.keymap.bindings.iter().rev() {
180 if binding.action.eq(action)
181 && binding
182 .context_predicate
183 .as_ref()
184 .map_or(true, |predicate| predicate.eval(cx))
185 {
186 return Some(binding.keystrokes.clone());
187 }
188 }
189 None
190 }
191}
192
193impl Default for Matcher {
194 fn default() -> Self {
195 Self::new(Keymap::default())
196 }
197}
198
199impl Keymap {
200 pub fn new(bindings: Vec<Binding>) -> Self {
201 let mut binding_indices_by_action_type = HashMap::new();
202 for (ix, binding) in bindings.iter().enumerate() {
203 binding_indices_by_action_type
204 .entry(binding.action.as_any().type_id())
205 .or_insert_with(SmallVec::new)
206 .push(ix);
207 }
208 Self {
209 binding_indices_by_action_type,
210 bindings,
211 }
212 }
213
214 fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
215 self.binding_indices_by_action_type
216 .get(&action_type)
217 .map(SmallVec::as_slice)
218 .unwrap_or(&[])
219 .iter()
220 .map(|ix| &self.bindings[*ix])
221 }
222
223 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
224 for binding in bindings {
225 self.binding_indices_by_action_type
226 .entry(binding.action.as_any().type_id())
227 .or_default()
228 .push(self.bindings.len());
229 self.bindings.push(binding);
230 }
231 }
232
233 fn clear(&mut self) {
234 self.bindings.clear();
235 self.binding_indices_by_action_type.clear();
236 }
237}
238
239impl Binding {
240 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
241 Self::load(keystrokes, Box::new(action), context).unwrap()
242 }
243
244 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
245 let context = if let Some(context) = context {
246 Some(ContextPredicate::parse(context)?)
247 } else {
248 None
249 };
250
251 let keystrokes = keystrokes
252 .split_whitespace()
253 .map(Keystroke::parse)
254 .collect::<Result<_>>()?;
255
256 Ok(Self {
257 keystrokes,
258 action,
259 context_predicate: context,
260 })
261 }
262
263 pub fn keystrokes(&self) -> &[Keystroke] {
264 &self.keystrokes
265 }
266
267 pub fn action(&self) -> &dyn Action {
268 self.action.as_ref()
269 }
270}
271
272impl Keystroke {
273 pub fn parse(source: &str) -> anyhow::Result<Self> {
274 let mut ctrl = false;
275 let mut alt = false;
276 let mut shift = false;
277 let mut cmd = false;
278 let mut function = false;
279 let mut key = None;
280
281 let mut components = source.split('-').peekable();
282 while let Some(component) = components.next() {
283 match component {
284 "ctrl" => ctrl = true,
285 "alt" => alt = true,
286 "shift" => shift = true,
287 "cmd" => cmd = true,
288 "fn" => function = true,
289 _ => {
290 if let Some(component) = components.peek() {
291 if component.is_empty() && source.ends_with('-') {
292 key = Some(String::from("-"));
293 break;
294 } else {
295 return Err(anyhow!("Invalid keystroke `{}`", source));
296 }
297 } else {
298 key = Some(String::from(component));
299 }
300 }
301 }
302 }
303
304 let key = key.ok_or_else(|| anyhow!("Invalid keystroke `{}`", source))?;
305
306 Ok(Keystroke {
307 ctrl,
308 alt,
309 shift,
310 cmd,
311 function,
312 key,
313 })
314 }
315
316 pub fn modified(&self) -> bool {
317 self.ctrl || self.alt || self.shift || self.cmd
318 }
319}
320
321impl std::fmt::Display for Keystroke {
322 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 if self.ctrl {
324 f.write_char('^')?;
325 }
326 if self.alt {
327 f.write_char('⎇')?;
328 }
329 if self.cmd {
330 f.write_char('⌘')?;
331 }
332 if self.shift {
333 f.write_char('⇧')?;
334 }
335 let key = match self.key.as_str() {
336 "backspace" => '⌫',
337 "up" => '↑',
338 "down" => '↓',
339 "left" => '←',
340 "right" => '→',
341 "tab" => '⇥',
342 "escape" => '⎋',
343 key => {
344 if key.len() == 1 {
345 key.chars().next().unwrap().to_ascii_uppercase()
346 } else {
347 return f.write_str(key);
348 }
349 }
350 };
351 f.write_char(key)
352 }
353}
354
355impl Context {
356 pub fn extend(&mut self, other: &Context) {
357 for v in &other.set {
358 self.set.insert(v.clone());
359 }
360 for (k, v) in &other.map {
361 self.map.insert(k.clone(), v.clone());
362 }
363 }
364}
365
366impl ContextPredicate {
367 fn parse(source: &str) -> anyhow::Result<Self> {
368 let mut parser = Parser::new();
369 let language = unsafe { tree_sitter_context_predicate() };
370 parser.set_language(language).unwrap();
371 let source = source.as_bytes();
372 let tree = parser.parse(source, None).unwrap();
373 Self::from_node(tree.root_node(), source)
374 }
375
376 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
377 let parse_error = "error parsing context predicate";
378 let kind = node.kind();
379
380 match kind {
381 "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
382 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
383 "not" => {
384 let child = Self::from_node(
385 node.child_by_field_name("expression")
386 .ok_or_else(|| anyhow!(parse_error))?,
387 source,
388 )?;
389 Ok(Self::Not(Box::new(child)))
390 }
391 "and" | "or" => {
392 let left = Box::new(Self::from_node(
393 node.child_by_field_name("left")
394 .ok_or_else(|| anyhow!(parse_error))?,
395 source,
396 )?);
397 let right = Box::new(Self::from_node(
398 node.child_by_field_name("right")
399 .ok_or_else(|| anyhow!(parse_error))?,
400 source,
401 )?);
402 if kind == "and" {
403 Ok(Self::And(left, right))
404 } else {
405 Ok(Self::Or(left, right))
406 }
407 }
408 "equal" | "not_equal" => {
409 let left = node
410 .child_by_field_name("left")
411 .ok_or_else(|| anyhow!(parse_error))?
412 .utf8_text(source)?
413 .into();
414 let right = node
415 .child_by_field_name("right")
416 .ok_or_else(|| anyhow!(parse_error))?
417 .utf8_text(source)?
418 .into();
419 if kind == "equal" {
420 Ok(Self::Equal(left, right))
421 } else {
422 Ok(Self::NotEqual(left, right))
423 }
424 }
425 "parenthesized" => Self::from_node(
426 node.child_by_field_name("expression")
427 .ok_or_else(|| anyhow!(parse_error))?,
428 source,
429 ),
430 _ => Err(anyhow!(parse_error)),
431 }
432 }
433
434 fn eval(&self, cx: &Context) -> bool {
435 match self {
436 Self::Identifier(name) => cx.set.contains(name.as_str()),
437 Self::Equal(left, right) => cx
438 .map
439 .get(left)
440 .map(|value| value == right)
441 .unwrap_or(false),
442 Self::NotEqual(left, right) => {
443 cx.map.get(left).map(|value| value != right).unwrap_or(true)
444 }
445 Self::Not(pred) => !pred.eval(cx),
446 Self::And(left, right) => left.eval(cx) && right.eval(cx),
447 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use serde::Deserialize;
455
456 use crate::{actions, impl_actions};
457
458 use super::*;
459
460 #[test]
461 fn test_keystroke_parsing() -> anyhow::Result<()> {
462 assert_eq!(
463 Keystroke::parse("ctrl-p")?,
464 Keystroke {
465 key: "p".into(),
466 ctrl: true,
467 alt: false,
468 shift: false,
469 cmd: false,
470 function: false,
471 }
472 );
473
474 assert_eq!(
475 Keystroke::parse("alt-shift-down")?,
476 Keystroke {
477 key: "down".into(),
478 ctrl: false,
479 alt: true,
480 shift: true,
481 cmd: false,
482 function: false,
483 }
484 );
485
486 assert_eq!(
487 Keystroke::parse("shift-cmd--")?,
488 Keystroke {
489 key: "-".into(),
490 ctrl: false,
491 alt: false,
492 shift: true,
493 cmd: true,
494 function: false,
495 }
496 );
497
498 Ok(())
499 }
500
501 #[test]
502 fn test_context_predicate_parsing() -> anyhow::Result<()> {
503 use ContextPredicate::*;
504
505 assert_eq!(
506 ContextPredicate::parse("a && (b == c || d != e)")?,
507 And(
508 Box::new(Identifier("a".into())),
509 Box::new(Or(
510 Box::new(Equal("b".into(), "c".into())),
511 Box::new(NotEqual("d".into(), "e".into())),
512 ))
513 )
514 );
515
516 assert_eq!(
517 ContextPredicate::parse("!a")?,
518 Not(Box::new(Identifier("a".into())),)
519 );
520
521 Ok(())
522 }
523
524 #[test]
525 fn test_context_predicate_eval() -> anyhow::Result<()> {
526 let predicate = ContextPredicate::parse("a && b || c == d")?;
527
528 let mut context = Context::default();
529 context.set.insert("a".into());
530 assert!(!predicate.eval(&context));
531
532 context.set.insert("b".into());
533 assert!(predicate.eval(&context));
534
535 context.set.remove("b");
536 context.map.insert("c".into(), "x".into());
537 assert!(!predicate.eval(&context));
538
539 context.map.insert("c".into(), "d".into());
540 assert!(predicate.eval(&context));
541
542 let predicate = ContextPredicate::parse("!a")?;
543 assert!(predicate.eval(&Context::default()));
544
545 Ok(())
546 }
547
548 #[test]
549 fn test_matcher() -> anyhow::Result<()> {
550 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
551 pub struct A(pub String);
552 impl_actions!(test, [A]);
553 actions!(test, [B, Ab]);
554
555 #[derive(Clone, Debug, Eq, PartialEq)]
556 struct ActionArg {
557 a: &'static str,
558 }
559
560 let keymap = Keymap::new(vec![
561 Binding::new("a", A("x".to_string()), Some("a")),
562 Binding::new("b", B, Some("a")),
563 Binding::new("a b", Ab, Some("a || b")),
564 ]);
565
566 let mut ctx_a = Context::default();
567 ctx_a.set.insert("a".into());
568
569 let mut ctx_b = Context::default();
570 ctx_b.set.insert("b".into());
571
572 let mut matcher = Matcher::new(keymap);
573
574 // Basic match
575 assert_eq!(
576 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
577 Some(&A("x".to_string()))
578 );
579
580 // Multi-keystroke match
581 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
582 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
583
584 // Failed matches don't interfere with matching subsequent keys
585 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
586 assert_eq!(
587 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
588 Some(&A("x".to_string()))
589 );
590
591 // Pending keystrokes are cleared when the context changes
592 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
593 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
594
595 let mut ctx_c = Context::default();
596 ctx_c.set.insert("c".into());
597
598 // Pending keystrokes are maintained per-view
599 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
600 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
601 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
602
603 Ok(())
604 }
605
606 fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
607 action
608 .as_ref()
609 .and_then(|action| action.as_any().downcast_ref())
610 }
611
612 impl Matcher {
613 fn test_keystroke(
614 &mut self,
615 keystroke: &str,
616 view_id: usize,
617 cx: &Context,
618 ) -> Option<Box<dyn Action>> {
619 if let MatchResult::Action(action) =
620 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
621 {
622 Some(action.boxed_clone())
623 } else {
624 None
625 }
626 }
627 }
628}