1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: SmallVec<[Keystroke; 2]>,
34 action: Box<dyn Action>,
35 context_predicate: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub function: bool,
45 pub key: String,
46}
47
48#[derive(Clone, Debug, Default, Eq, PartialEq)]
49pub struct Context {
50 pub set: HashSet<String>,
51 pub map: HashMap<String, String>,
52}
53
54#[derive(Debug, Eq, PartialEq)]
55enum ContextPredicate {
56 Identifier(String),
57 Equal(String, String),
58 NotEqual(String, String),
59 Not(Box<ContextPredicate>),
60 And(Box<ContextPredicate>, Box<ContextPredicate>),
61 Or(Box<ContextPredicate>, Box<ContextPredicate>),
62}
63
64trait ActionArg {
65 fn boxed_clone(&self) -> Box<dyn Any>;
66}
67
68impl<T> ActionArg for T
69where
70 T: 'static + Any + Clone,
71{
72 fn boxed_clone(&self) -> Box<dyn Any> {
73 Box::new(self.clone())
74 }
75}
76
77pub enum MatchResult {
78 None,
79 Pending,
80 Action(Box<dyn Action>),
81}
82
83impl Debug for MatchResult {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
87 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
88 MatchResult::Action(action) => f
89 .debug_tuple("MatchResult::Action")
90 .field(&action.name())
91 .finish(),
92 }
93 }
94}
95
96impl Matcher {
97 pub fn new(keymap: Keymap) -> Self {
98 Self {
99 pending: HashMap::new(),
100 keymap,
101 }
102 }
103
104 pub fn set_keymap(&mut self, keymap: Keymap) {
105 self.pending.clear();
106 self.keymap = keymap;
107 }
108
109 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
110 self.pending.clear();
111 self.keymap.add_bindings(bindings);
112 }
113
114 pub fn clear_bindings(&mut self) {
115 self.pending.clear();
116 self.keymap.clear();
117 }
118
119 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
120 self.keymap.bindings_for_action_type(action_type)
121 }
122
123 pub fn clear_pending(&mut self) {
124 self.pending.clear();
125 }
126
127 pub fn has_pending_keystrokes(&self) -> bool {
128 !self.pending.is_empty()
129 }
130
131 pub fn push_keystroke(
132 &mut self,
133 keystroke: Keystroke,
134 view_id: usize,
135 cx: &Context,
136 ) -> MatchResult {
137 let pending = self.pending.entry(view_id).or_default();
138
139 if let Some(pending_ctx) = pending.context.as_ref() {
140 if pending_ctx != cx {
141 pending.keystrokes.clear();
142 }
143 }
144
145 pending.keystrokes.push(keystroke);
146
147 let mut retain_pending = false;
148 for binding in self.keymap.bindings.iter().rev() {
149 if binding.keystrokes.starts_with(&pending.keystrokes)
150 && binding
151 .context_predicate
152 .as_ref()
153 .map(|c| c.eval(cx))
154 .unwrap_or(true)
155 {
156 if binding.keystrokes.len() == pending.keystrokes.len() {
157 self.pending.remove(&view_id);
158 return MatchResult::Action(binding.action.boxed_clone());
159 } else {
160 retain_pending = true;
161 pending.context = Some(cx.clone());
162 }
163 }
164 }
165
166 if retain_pending {
167 MatchResult::Pending
168 } else {
169 self.pending.remove(&view_id);
170 MatchResult::None
171 }
172 }
173
174 pub fn keystrokes_for_action(
175 &self,
176 action: &dyn Action,
177 cx: &Context,
178 ) -> Option<SmallVec<[Keystroke; 2]>> {
179 for binding in self.keymap.bindings.iter().rev() {
180 if binding.action.eq(action)
181 && binding
182 .context_predicate
183 .as_ref()
184 .map_or(true, |predicate| predicate.eval(cx))
185 {
186 return Some(binding.keystrokes.clone());
187 }
188 }
189 None
190 }
191}
192
193impl Default for Matcher {
194 fn default() -> Self {
195 Self::new(Keymap::default())
196 }
197}
198
199impl Keymap {
200 pub fn new(bindings: Vec<Binding>) -> Self {
201 let mut binding_indices_by_action_type = HashMap::new();
202 for (ix, binding) in bindings.iter().enumerate() {
203 binding_indices_by_action_type
204 .entry(binding.action.as_any().type_id())
205 .or_insert_with(SmallVec::new)
206 .push(ix);
207 }
208 Self {
209 binding_indices_by_action_type,
210 bindings,
211 }
212 }
213
214 fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
215 self.binding_indices_by_action_type
216 .get(&action_type)
217 .map(SmallVec::as_slice)
218 .unwrap_or(&[])
219 .iter()
220 .map(|ix| &self.bindings[*ix])
221 }
222
223 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
224 for binding in bindings {
225 self.binding_indices_by_action_type
226 .entry(binding.action.as_any().type_id())
227 .or_default()
228 .push(self.bindings.len());
229 self.bindings.push(binding);
230 }
231 }
232
233 fn clear(&mut self) {
234 self.bindings.clear();
235 self.binding_indices_by_action_type.clear();
236 }
237}
238
239impl Binding {
240 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
241 Self::load(keystrokes, Box::new(action), context).unwrap()
242 }
243
244 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
245 let context = if let Some(context) = context {
246 Some(ContextPredicate::parse(context)?)
247 } else {
248 None
249 };
250
251 let keystrokes = keystrokes
252 .split_whitespace()
253 .map(Keystroke::parse)
254 .collect::<Result<_>>()?;
255
256 Ok(Self {
257 keystrokes,
258 action,
259 context_predicate: context,
260 })
261 }
262
263 pub fn keystrokes(&self) -> &[Keystroke] {
264 &self.keystrokes
265 }
266
267 pub fn action(&self) -> &dyn Action {
268 self.action.as_ref()
269 }
270}
271
272impl Keystroke {
273 pub fn parse(source: &str) -> anyhow::Result<Self> {
274 let mut ctrl = false;
275 let mut alt = false;
276 let mut shift = false;
277 let mut cmd = false;
278 let mut function = false;
279 let mut key = None;
280
281 let mut components = source.split('-').peekable();
282 while let Some(component) = components.next() {
283 match component {
284 "ctrl" => ctrl = true,
285 "alt" => alt = true,
286 "shift" => shift = true,
287 "cmd" => cmd = true,
288 "fn" => function = true,
289 _ => {
290 if let Some(component) = components.peek() {
291 if component.is_empty() && source.ends_with('-') {
292 key = Some(String::from("-"));
293 break;
294 } else {
295 return Err(anyhow!("Invalid keystroke `{}`", source));
296 }
297 } else {
298 key = Some(String::from(component));
299 }
300 }
301 }
302 }
303
304 Ok(Keystroke {
305 ctrl,
306 alt,
307 shift,
308 cmd,
309 function,
310 key: key.unwrap(),
311 })
312 }
313
314 pub fn modified(&self) -> bool {
315 self.ctrl || self.alt || self.shift || self.cmd
316 }
317}
318
319impl std::fmt::Display for Keystroke {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 if self.ctrl {
322 f.write_char('^')?;
323 }
324 if self.alt {
325 f.write_char('⎇')?;
326 }
327 if self.cmd {
328 f.write_char('⌘')?;
329 }
330 if self.shift {
331 f.write_char('⇧')?;
332 }
333 let key = match self.key.as_str() {
334 "backspace" => '⌫',
335 "up" => '↑',
336 "down" => '↓',
337 "left" => '←',
338 "right" => '→',
339 "tab" => '⇥',
340 "escape" => '⎋',
341 key => {
342 if key.len() == 1 {
343 key.chars().next().unwrap().to_ascii_uppercase()
344 } else {
345 return f.write_str(key);
346 }
347 }
348 };
349 f.write_char(key)
350 }
351}
352
353impl Context {
354 pub fn extend(&mut self, other: &Context) {
355 for v in &other.set {
356 self.set.insert(v.clone());
357 }
358 for (k, v) in &other.map {
359 self.map.insert(k.clone(), v.clone());
360 }
361 }
362}
363
364impl ContextPredicate {
365 fn parse(source: &str) -> anyhow::Result<Self> {
366 let mut parser = Parser::new();
367 let language = unsafe { tree_sitter_context_predicate() };
368 parser.set_language(language).unwrap();
369 let source = source.as_bytes();
370 let tree = parser.parse(source, None).unwrap();
371 Self::from_node(tree.root_node(), source)
372 }
373
374 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
375 let parse_error = "error parsing context predicate";
376 let kind = node.kind();
377
378 match kind {
379 "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
380 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
381 "not" => {
382 let child = Self::from_node(
383 node.child_by_field_name("expression")
384 .ok_or_else(|| anyhow!(parse_error))?,
385 source,
386 )?;
387 Ok(Self::Not(Box::new(child)))
388 }
389 "and" | "or" => {
390 let left = Box::new(Self::from_node(
391 node.child_by_field_name("left")
392 .ok_or_else(|| anyhow!(parse_error))?,
393 source,
394 )?);
395 let right = Box::new(Self::from_node(
396 node.child_by_field_name("right")
397 .ok_or_else(|| anyhow!(parse_error))?,
398 source,
399 )?);
400 if kind == "and" {
401 Ok(Self::And(left, right))
402 } else {
403 Ok(Self::Or(left, right))
404 }
405 }
406 "equal" | "not_equal" => {
407 let left = node
408 .child_by_field_name("left")
409 .ok_or_else(|| anyhow!(parse_error))?
410 .utf8_text(source)?
411 .into();
412 let right = node
413 .child_by_field_name("right")
414 .ok_or_else(|| anyhow!(parse_error))?
415 .utf8_text(source)?
416 .into();
417 if kind == "equal" {
418 Ok(Self::Equal(left, right))
419 } else {
420 Ok(Self::NotEqual(left, right))
421 }
422 }
423 "parenthesized" => Self::from_node(
424 node.child_by_field_name("expression")
425 .ok_or_else(|| anyhow!(parse_error))?,
426 source,
427 ),
428 _ => Err(anyhow!(parse_error)),
429 }
430 }
431
432 fn eval(&self, cx: &Context) -> bool {
433 match self {
434 Self::Identifier(name) => cx.set.contains(name.as_str()),
435 Self::Equal(left, right) => cx
436 .map
437 .get(left)
438 .map(|value| value == right)
439 .unwrap_or(false),
440 Self::NotEqual(left, right) => {
441 cx.map.get(left).map(|value| value != right).unwrap_or(true)
442 }
443 Self::Not(pred) => !pred.eval(cx),
444 Self::And(left, right) => left.eval(cx) && right.eval(cx),
445 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
446 }
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use serde::Deserialize;
453
454 use crate::{actions, impl_actions};
455
456 use super::*;
457
458 #[test]
459 fn test_keystroke_parsing() -> anyhow::Result<()> {
460 assert_eq!(
461 Keystroke::parse("ctrl-p")?,
462 Keystroke {
463 key: "p".into(),
464 ctrl: true,
465 alt: false,
466 shift: false,
467 cmd: false,
468 function: false,
469 }
470 );
471
472 assert_eq!(
473 Keystroke::parse("alt-shift-down")?,
474 Keystroke {
475 key: "down".into(),
476 ctrl: false,
477 alt: true,
478 shift: true,
479 cmd: false,
480 function: false,
481 }
482 );
483
484 assert_eq!(
485 Keystroke::parse("shift-cmd--")?,
486 Keystroke {
487 key: "-".into(),
488 ctrl: false,
489 alt: false,
490 shift: true,
491 cmd: true,
492 function: false,
493 }
494 );
495
496 Ok(())
497 }
498
499 #[test]
500 fn test_context_predicate_parsing() -> anyhow::Result<()> {
501 use ContextPredicate::*;
502
503 assert_eq!(
504 ContextPredicate::parse("a && (b == c || d != e)")?,
505 And(
506 Box::new(Identifier("a".into())),
507 Box::new(Or(
508 Box::new(Equal("b".into(), "c".into())),
509 Box::new(NotEqual("d".into(), "e".into())),
510 ))
511 )
512 );
513
514 assert_eq!(
515 ContextPredicate::parse("!a")?,
516 Not(Box::new(Identifier("a".into())),)
517 );
518
519 Ok(())
520 }
521
522 #[test]
523 fn test_context_predicate_eval() -> anyhow::Result<()> {
524 let predicate = ContextPredicate::parse("a && b || c == d")?;
525
526 let mut context = Context::default();
527 context.set.insert("a".into());
528 assert!(!predicate.eval(&context));
529
530 context.set.insert("b".into());
531 assert!(predicate.eval(&context));
532
533 context.set.remove("b");
534 context.map.insert("c".into(), "x".into());
535 assert!(!predicate.eval(&context));
536
537 context.map.insert("c".into(), "d".into());
538 assert!(predicate.eval(&context));
539
540 let predicate = ContextPredicate::parse("!a")?;
541 assert!(predicate.eval(&Context::default()));
542
543 Ok(())
544 }
545
546 #[test]
547 fn test_matcher() -> anyhow::Result<()> {
548 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
549 pub struct A(pub String);
550 impl_actions!(test, [A]);
551 actions!(test, [B, Ab]);
552
553 #[derive(Clone, Debug, Eq, PartialEq)]
554 struct ActionArg {
555 a: &'static str,
556 }
557
558 let keymap = Keymap::new(vec![
559 Binding::new("a", A("x".to_string()), Some("a")),
560 Binding::new("b", B, Some("a")),
561 Binding::new("a b", Ab, Some("a || b")),
562 ]);
563
564 let mut ctx_a = Context::default();
565 ctx_a.set.insert("a".into());
566
567 let mut ctx_b = Context::default();
568 ctx_b.set.insert("b".into());
569
570 let mut matcher = Matcher::new(keymap);
571
572 // Basic match
573 assert_eq!(
574 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
575 Some(&A("x".to_string()))
576 );
577
578 // Multi-keystroke match
579 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
580 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
581
582 // Failed matches don't interfere with matching subsequent keys
583 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
584 assert_eq!(
585 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
586 Some(&A("x".to_string()))
587 );
588
589 // Pending keystrokes are cleared when the context changes
590 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
591 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
592
593 let mut ctx_c = Context::default();
594 ctx_c.set.insert("c".into());
595
596 // Pending keystrokes are maintained per-view
597 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
598 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
599 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
600
601 Ok(())
602 }
603
604 fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
605 action
606 .as_ref()
607 .and_then(|action| action.as_any().downcast_ref())
608 }
609
610 impl Matcher {
611 fn test_keystroke(
612 &mut self,
613 keystroke: &str,
614 view_id: usize,
615 cx: &Context,
616 ) -> Option<Box<dyn Action>> {
617 if let MatchResult::Action(action) =
618 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
619 {
620 Some(action.boxed_clone())
621 } else {
622 None
623 }
624 }
625 }
626}