1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: SmallVec<[Keystroke; 2]>,
34 action: Box<dyn Action>,
35 context_predicate: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub function: bool,
45 pub key: String,
46}
47
48#[derive(Clone, Debug, Default, Eq, PartialEq)]
49pub struct Context {
50 pub set: HashSet<String>,
51 pub map: HashMap<String, String>,
52}
53
54#[derive(Debug, Eq, PartialEq)]
55enum ContextPredicate {
56 Identifier(String),
57 Equal(String, String),
58 NotEqual(String, String),
59 Not(Box<ContextPredicate>),
60 And(Box<ContextPredicate>, Box<ContextPredicate>),
61 Or(Box<ContextPredicate>, Box<ContextPredicate>),
62}
63
64trait ActionArg {
65 fn boxed_clone(&self) -> Box<dyn Any>;
66}
67
68impl<T> ActionArg for T
69where
70 T: 'static + Any + Clone,
71{
72 fn boxed_clone(&self) -> Box<dyn Any> {
73 Box::new(self.clone())
74 }
75}
76
77pub enum MatchResult {
78 None,
79 Pending,
80 Action(Box<dyn Action>),
81}
82
83impl Debug for MatchResult {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
87 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
88 MatchResult::Action(action) => f
89 .debug_tuple("MatchResult::Action")
90 .field(&action.name())
91 .finish(),
92 }
93 }
94}
95
96impl Matcher {
97 pub fn new(keymap: Keymap) -> Self {
98 Self {
99 pending: HashMap::new(),
100 keymap,
101 }
102 }
103
104 pub fn set_keymap(&mut self, keymap: Keymap) {
105 self.pending.clear();
106 self.keymap = keymap;
107 }
108
109 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
110 self.pending.clear();
111 self.keymap.add_bindings(bindings);
112 }
113
114 pub fn clear_bindings(&mut self) {
115 self.pending.clear();
116 self.keymap.clear();
117 }
118
119 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
120 self.keymap.bindings_for_action_type(action_type)
121 }
122
123 pub fn clear_pending(&mut self) {
124 self.pending.clear();
125 }
126
127 pub fn has_pending_keystrokes(&self) -> bool {
128 !self.pending.is_empty()
129 }
130
131 pub fn push_keystroke(
132 &mut self,
133 keystroke: Keystroke,
134 view_id: usize,
135 cx: &Context,
136 ) -> MatchResult {
137 let pending = self.pending.entry(view_id).or_default();
138
139 if let Some(pending_ctx) = pending.context.as_ref() {
140 if pending_ctx != cx {
141 pending.keystrokes.clear();
142 }
143 }
144
145 pending.keystrokes.push(keystroke);
146
147 let mut retain_pending = false;
148 for binding in self.keymap.bindings.iter().rev() {
149 if binding.keystrokes.starts_with(&pending.keystrokes)
150 && binding
151 .context_predicate
152 .as_ref()
153 .map(|c| c.eval(cx))
154 .unwrap_or(true)
155 {
156 if binding.keystrokes.len() == pending.keystrokes.len() {
157 self.pending.remove(&view_id);
158 return MatchResult::Action(binding.action.boxed_clone());
159 } else {
160 retain_pending = true;
161 pending.context = Some(cx.clone());
162 }
163 }
164 }
165
166 if retain_pending {
167 MatchResult::Pending
168 } else {
169 self.pending.remove(&view_id);
170 MatchResult::None
171 }
172 }
173
174 pub fn keystrokes_for_action(
175 &self,
176 action: &dyn Action,
177 cx: &Context,
178 ) -> Option<SmallVec<[Keystroke; 2]>> {
179 for binding in self.keymap.bindings.iter().rev() {
180 if binding.action.eq(action)
181 && binding
182 .context_predicate
183 .as_ref()
184 .map_or(true, |predicate| predicate.eval(cx))
185 {
186 return Some(binding.keystrokes.clone());
187 }
188 }
189 None
190 }
191}
192
193impl Default for Matcher {
194 fn default() -> Self {
195 Self::new(Keymap::default())
196 }
197}
198
199impl Keymap {
200 pub fn new(bindings: Vec<Binding>) -> Self {
201 let mut binding_indices_by_action_type = HashMap::new();
202 for (ix, binding) in bindings.iter().enumerate() {
203 binding_indices_by_action_type
204 .entry(binding.action.as_any().type_id())
205 .or_insert_with(SmallVec::new)
206 .push(ix);
207 }
208 Self {
209 binding_indices_by_action_type,
210 bindings,
211 }
212 }
213
214 fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &'_ Binding> {
215 self.binding_indices_by_action_type
216 .get(&action_type)
217 .map(SmallVec::as_slice)
218 .unwrap_or(&[])
219 .iter()
220 .map(|ix| &self.bindings[*ix])
221 }
222
223 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
224 for binding in bindings {
225 self.binding_indices_by_action_type
226 .entry(binding.action.as_any().type_id())
227 .or_default()
228 .push(self.bindings.len());
229 self.bindings.push(binding);
230 }
231 }
232
233 fn clear(&mut self) {
234 self.bindings.clear();
235 self.binding_indices_by_action_type.clear();
236 }
237}
238
239impl Binding {
240 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
241 Self::load(keystrokes, Box::new(action), context).unwrap()
242 }
243
244 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
245 let context = if let Some(context) = context {
246 Some(ContextPredicate::parse(context)?)
247 } else {
248 None
249 };
250
251 let keystrokes = keystrokes
252 .split_whitespace()
253 .map(Keystroke::parse)
254 .collect::<Result<_>>()?;
255
256 Ok(Self {
257 keystrokes,
258 action,
259 context_predicate: context,
260 })
261 }
262
263 pub fn keystrokes(&self) -> &[Keystroke] {
264 &self.keystrokes
265 }
266
267 pub fn action(&self) -> &dyn Action {
268 self.action.as_ref()
269 }
270}
271
272impl Keystroke {
273 pub fn parse(source: &str) -> anyhow::Result<Self> {
274 let mut ctrl = false;
275 let mut alt = false;
276 let mut shift = false;
277 let mut cmd = false;
278 let mut function = false;
279 let mut key = None;
280
281 let mut components = source.split('-').peekable();
282 while let Some(component) = components.next() {
283 match component {
284 "ctrl" => ctrl = true,
285 "alt" => alt = true,
286 "shift" => shift = true,
287 "cmd" => cmd = true,
288 "fn" => function = true,
289 _ => {
290 if let Some(component) = components.peek() {
291 if component.is_empty() && source.ends_with('-') {
292 key = Some(String::from("-"));
293 break;
294 } else {
295 return Err(anyhow!("Invalid keystroke `{}`", source));
296 }
297 } else {
298 key = Some(String::from(component));
299 }
300 }
301 }
302 }
303
304 if key.is_none() {
305 return Err(anyhow!("Invalid keystroke `{}`", source));
306 }
307
308 Ok(Keystroke {
309 ctrl,
310 alt,
311 shift,
312 cmd,
313 function,
314 key: key.unwrap(),
315 })
316 }
317
318 pub fn modified(&self) -> bool {
319 self.ctrl || self.alt || self.shift || self.cmd
320 }
321}
322
323impl std::fmt::Display for Keystroke {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 if self.ctrl {
326 f.write_char('^')?;
327 }
328 if self.alt {
329 f.write_char('⎇')?;
330 }
331 if self.cmd {
332 f.write_char('⌘')?;
333 }
334 if self.shift {
335 f.write_char('⇧')?;
336 }
337 let key = match self.key.as_str() {
338 "backspace" => '⌫',
339 "up" => '↑',
340 "down" => '↓',
341 "left" => '←',
342 "right" => '→',
343 "tab" => '⇥',
344 "escape" => '⎋',
345 key => {
346 if key.len() == 1 {
347 key.chars().next().unwrap().to_ascii_uppercase()
348 } else {
349 return f.write_str(key);
350 }
351 }
352 };
353 f.write_char(key)
354 }
355}
356
357impl Context {
358 pub fn extend(&mut self, other: &Context) {
359 for v in &other.set {
360 self.set.insert(v.clone());
361 }
362 for (k, v) in &other.map {
363 self.map.insert(k.clone(), v.clone());
364 }
365 }
366}
367
368impl ContextPredicate {
369 fn parse(source: &str) -> anyhow::Result<Self> {
370 let mut parser = Parser::new();
371 let language = unsafe { tree_sitter_context_predicate() };
372 parser.set_language(language).unwrap();
373 let source = source.as_bytes();
374 let tree = parser.parse(source, None).unwrap();
375 Self::from_node(tree.root_node(), source)
376 }
377
378 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
379 let parse_error = "error parsing context predicate";
380 let kind = node.kind();
381
382 match kind {
383 "source" => Self::from_node(node.child(0).ok_or_else(|| anyhow!(parse_error))?, source),
384 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
385 "not" => {
386 let child = Self::from_node(
387 node.child_by_field_name("expression")
388 .ok_or_else(|| anyhow!(parse_error))?,
389 source,
390 )?;
391 Ok(Self::Not(Box::new(child)))
392 }
393 "and" | "or" => {
394 let left = Box::new(Self::from_node(
395 node.child_by_field_name("left")
396 .ok_or_else(|| anyhow!(parse_error))?,
397 source,
398 )?);
399 let right = Box::new(Self::from_node(
400 node.child_by_field_name("right")
401 .ok_or_else(|| anyhow!(parse_error))?,
402 source,
403 )?);
404 if kind == "and" {
405 Ok(Self::And(left, right))
406 } else {
407 Ok(Self::Or(left, right))
408 }
409 }
410 "equal" | "not_equal" => {
411 let left = node
412 .child_by_field_name("left")
413 .ok_or_else(|| anyhow!(parse_error))?
414 .utf8_text(source)?
415 .into();
416 let right = node
417 .child_by_field_name("right")
418 .ok_or_else(|| anyhow!(parse_error))?
419 .utf8_text(source)?
420 .into();
421 if kind == "equal" {
422 Ok(Self::Equal(left, right))
423 } else {
424 Ok(Self::NotEqual(left, right))
425 }
426 }
427 "parenthesized" => Self::from_node(
428 node.child_by_field_name("expression")
429 .ok_or_else(|| anyhow!(parse_error))?,
430 source,
431 ),
432 _ => Err(anyhow!(parse_error)),
433 }
434 }
435
436 fn eval(&self, cx: &Context) -> bool {
437 match self {
438 Self::Identifier(name) => cx.set.contains(name.as_str()),
439 Self::Equal(left, right) => cx
440 .map
441 .get(left)
442 .map(|value| value == right)
443 .unwrap_or(false),
444 Self::NotEqual(left, right) => {
445 cx.map.get(left).map(|value| value != right).unwrap_or(true)
446 }
447 Self::Not(pred) => !pred.eval(cx),
448 Self::And(left, right) => left.eval(cx) && right.eval(cx),
449 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
450 }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use serde::Deserialize;
457
458 use crate::{actions, impl_actions};
459
460 use super::*;
461
462 #[test]
463 fn test_keystroke_parsing() -> anyhow::Result<()> {
464 assert_eq!(
465 Keystroke::parse("ctrl-p")?,
466 Keystroke {
467 key: "p".into(),
468 ctrl: true,
469 alt: false,
470 shift: false,
471 cmd: false,
472 function: false,
473 }
474 );
475
476 assert_eq!(
477 Keystroke::parse("alt-shift-down")?,
478 Keystroke {
479 key: "down".into(),
480 ctrl: false,
481 alt: true,
482 shift: true,
483 cmd: false,
484 function: false,
485 }
486 );
487
488 assert_eq!(
489 Keystroke::parse("shift-cmd--")?,
490 Keystroke {
491 key: "-".into(),
492 ctrl: false,
493 alt: false,
494 shift: true,
495 cmd: true,
496 function: false,
497 }
498 );
499
500 Ok(())
501 }
502
503 #[test]
504 fn test_context_predicate_parsing() -> anyhow::Result<()> {
505 use ContextPredicate::*;
506
507 assert_eq!(
508 ContextPredicate::parse("a && (b == c || d != e)")?,
509 And(
510 Box::new(Identifier("a".into())),
511 Box::new(Or(
512 Box::new(Equal("b".into(), "c".into())),
513 Box::new(NotEqual("d".into(), "e".into())),
514 ))
515 )
516 );
517
518 assert_eq!(
519 ContextPredicate::parse("!a")?,
520 Not(Box::new(Identifier("a".into())),)
521 );
522
523 Ok(())
524 }
525
526 #[test]
527 fn test_context_predicate_eval() -> anyhow::Result<()> {
528 let predicate = ContextPredicate::parse("a && b || c == d")?;
529
530 let mut context = Context::default();
531 context.set.insert("a".into());
532 assert!(!predicate.eval(&context));
533
534 context.set.insert("b".into());
535 assert!(predicate.eval(&context));
536
537 context.set.remove("b");
538 context.map.insert("c".into(), "x".into());
539 assert!(!predicate.eval(&context));
540
541 context.map.insert("c".into(), "d".into());
542 assert!(predicate.eval(&context));
543
544 let predicate = ContextPredicate::parse("!a")?;
545 assert!(predicate.eval(&Context::default()));
546
547 Ok(())
548 }
549
550 #[test]
551 fn test_matcher() -> anyhow::Result<()> {
552 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
553 pub struct A(pub String);
554 impl_actions!(test, [A]);
555 actions!(test, [B, Ab]);
556
557 #[derive(Clone, Debug, Eq, PartialEq)]
558 struct ActionArg {
559 a: &'static str,
560 }
561
562 let keymap = Keymap::new(vec![
563 Binding::new("a", A("x".to_string()), Some("a")),
564 Binding::new("b", B, Some("a")),
565 Binding::new("a b", Ab, Some("a || b")),
566 ]);
567
568 let mut ctx_a = Context::default();
569 ctx_a.set.insert("a".into());
570
571 let mut ctx_b = Context::default();
572 ctx_b.set.insert("b".into());
573
574 let mut matcher = Matcher::new(keymap);
575
576 // Basic match
577 assert_eq!(
578 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
579 Some(&A("x".to_string()))
580 );
581
582 // Multi-keystroke match
583 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
584 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
585
586 // Failed matches don't interfere with matching subsequent keys
587 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
588 assert_eq!(
589 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
590 Some(&A("x".to_string()))
591 );
592
593 // Pending keystrokes are cleared when the context changes
594 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
595 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
596
597 let mut ctx_c = Context::default();
598 ctx_c.set.insert("c".into());
599
600 // Pending keystrokes are maintained per-view
601 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
602 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
603 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
604
605 Ok(())
606 }
607
608 fn downcast<A: Action>(action: &Option<Box<dyn Action>>) -> Option<&A> {
609 action
610 .as_ref()
611 .and_then(|action| action.as_any().downcast_ref())
612 }
613
614 impl Matcher {
615 fn test_keystroke(
616 &mut self,
617 keystroke: &str,
618 view_id: usize,
619 cx: &Context,
620 ) -> Option<Box<dyn Action>> {
621 if let MatchResult::Action(action) =
622 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
623 {
624 Some(action.boxed_clone())
625 } else {
626 None
627 }
628 }
629 }
630}