1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: SmallVec<[Keystroke; 2]>,
34 action: Box<dyn Action>,
35 context_predicate: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub function: bool,
45 pub key: String,
46}
47
48#[derive(Clone, Debug, Default, Eq, PartialEq)]
49pub struct Context {
50 pub set: HashSet<String>,
51 pub map: HashMap<String, String>,
52}
53
54#[derive(Debug, Eq, PartialEq)]
55enum ContextPredicate {
56 Identifier(String),
57 Equal(String, String),
58 NotEqual(String, String),
59 Not(Box<ContextPredicate>),
60 And(Box<ContextPredicate>, Box<ContextPredicate>),
61 Or(Box<ContextPredicate>, Box<ContextPredicate>),
62}
63
64trait ActionArg {
65 fn boxed_clone(&self) -> Box<dyn Any>;
66}
67
68impl<T> ActionArg for T
69where
70 T: 'static + Any + Clone,
71{
72 fn boxed_clone(&self) -> Box<dyn Any> {
73 Box::new(self.clone())
74 }
75}
76
77pub enum MatchResult {
78 None,
79 Pending,
80 Action(Box<dyn Action>),
81}
82
83impl Debug for MatchResult {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
87 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
88 MatchResult::Action(action) => f
89 .debug_tuple("MatchResult::Action")
90 .field(&action.name())
91 .finish(),
92 }
93 }
94}
95
96impl Matcher {
97 pub fn new(keymap: Keymap) -> Self {
98 Self {
99 pending: HashMap::new(),
100 keymap,
101 }
102 }
103
104 pub fn set_keymap(&mut self, keymap: Keymap) {
105 self.pending.clear();
106 self.keymap = keymap;
107 }
108
109 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
110 self.pending.clear();
111 self.keymap.add_bindings(bindings);
112 }
113
114 pub fn clear_bindings(&mut self) {
115 self.pending.clear();
116 self.keymap.clear();
117 }
118
119 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
120 self.keymap.bindings_for_action_type(action_type)
121 }
122
123 pub fn clear_pending(&mut self) {
124 self.pending.clear();
125 }
126
127 pub fn has_pending_keystrokes(&self) -> bool {
128 !self.pending.is_empty()
129 }
130
131 pub fn push_keystroke(
132 &mut self,
133 keystroke: Keystroke,
134 view_id: usize,
135 cx: &Context,
136 ) -> MatchResult {
137 let pending = self.pending.entry(view_id).or_default();
138
139 if let Some(pending_ctx) = pending.context.as_ref() {
140 if pending_ctx != cx {
141 pending.keystrokes.clear();
142 }
143 }
144
145 pending.keystrokes.push(keystroke);
146
147 let mut retain_pending = false;
148 for binding in self.keymap.bindings.iter().rev() {
149 if binding.keystrokes.starts_with(&pending.keystrokes)
150 && binding
151 .context_predicate
152 .as_ref()
153 .map(|c| c.eval(cx))
154 .unwrap_or(true)
155 {
156 if binding.keystrokes.len() == pending.keystrokes.len() {
157 self.pending.remove(&view_id);
158 return MatchResult::Action(binding.action.boxed_clone());
159 } else {
160 retain_pending = true;
161 pending.context = Some(cx.clone());
162 }
163 }
164 }
165
166 if retain_pending {
167 MatchResult::Pending
168 } else {
169 self.pending.remove(&view_id);
170 MatchResult::None
171 }
172 }
173
174 pub fn keystrokes_for_action(
175 &self,
176 action: &dyn Action,
177 cx: &Context,
178 ) -> Option<SmallVec<[Keystroke; 2]>> {
179 for binding in self.keymap.bindings.iter().rev() {
180 if binding.action.eq(action)
181 && binding
182 .context_predicate
183 .as_ref()
184 .map_or(true, |predicate| predicate.eval(cx))
185 {
186 return Some(binding.keystrokes.clone());
187 }
188 }
189 None
190 }
191}
192
193impl Default for Matcher {
194 fn default() -> Self {
195 Self::new(Keymap::default())
196 }
197}
198
199impl Keymap {
200 pub fn new(bindings: Vec<Binding>) -> Self {
201 let mut binding_indices_by_action_type = HashMap::new();
202 for (ix, binding) in bindings.iter().enumerate() {
203 binding_indices_by_action_type
204 .entry(binding.action.as_any().type_id())
205 .or_insert_with(|| SmallVec::new())
206 .push(ix);
207 }
208 Self {
209 binding_indices_by_action_type,
210 bindings,
211 }
212 }
213
214 fn bindings_for_action_type<'a>(
215 &'a self,
216 action_type: TypeId,
217 ) -> impl Iterator<Item = &'a Binding> {
218 self.binding_indices_by_action_type
219 .get(&action_type)
220 .map(SmallVec::as_slice)
221 .unwrap_or(&[])
222 .iter()
223 .map(|ix| &self.bindings[*ix])
224 }
225
226 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
227 for binding in bindings {
228 self.binding_indices_by_action_type
229 .entry(binding.action.as_any().type_id())
230 .or_default()
231 .push(self.bindings.len());
232 self.bindings.push(binding);
233 }
234 }
235
236 fn clear(&mut self) {
237 self.bindings.clear();
238 self.binding_indices_by_action_type.clear();
239 }
240}
241
242impl Binding {
243 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
244 Self::load(keystrokes, Box::new(action), context).unwrap()
245 }
246
247 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
248 let context = if let Some(context) = context {
249 Some(ContextPredicate::parse(context)?)
250 } else {
251 None
252 };
253
254 let keystrokes = keystrokes
255 .split_whitespace()
256 .map(|key| Keystroke::parse(key))
257 .collect::<Result<_>>()?;
258
259 Ok(Self {
260 keystrokes,
261 action,
262 context_predicate: context,
263 })
264 }
265
266 pub fn keystrokes(&self) -> &[Keystroke] {
267 &self.keystrokes
268 }
269
270 pub fn action(&self) -> &dyn Action {
271 self.action.as_ref()
272 }
273}
274
275impl Keystroke {
276 pub fn parse(source: &str) -> anyhow::Result<Self> {
277 let mut ctrl = false;
278 let mut alt = false;
279 let mut shift = false;
280 let mut cmd = false;
281 let mut function = false;
282 let mut key = None;
283
284 let mut components = source.split("-").peekable();
285 while let Some(component) = components.next() {
286 match component {
287 "ctrl" => ctrl = true,
288 "alt" => alt = true,
289 "shift" => shift = true,
290 "cmd" => cmd = true,
291 "fn" => function = true,
292 _ => {
293 if let Some(component) = components.peek() {
294 if component.is_empty() && source.ends_with('-') {
295 key = Some(String::from("-"));
296 break;
297 } else {
298 return Err(anyhow!("Invalid keystroke `{}`", source));
299 }
300 } else {
301 key = Some(String::from(component));
302 }
303 }
304 }
305 }
306
307 Ok(Keystroke {
308 ctrl,
309 alt,
310 shift,
311 cmd,
312 function,
313 key: key.unwrap(),
314 })
315 }
316
317 pub fn modified(&self) -> bool {
318 self.ctrl || self.alt || self.shift || self.cmd
319 }
320}
321
322impl std::fmt::Display for Keystroke {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 if self.ctrl {
325 f.write_char('^')?;
326 }
327 if self.alt {
328 f.write_char('⎇')?;
329 }
330 if self.cmd {
331 f.write_char('⌘')?;
332 }
333 if self.shift {
334 f.write_char('⇧')?;
335 }
336 let key = match self.key.as_str() {
337 "backspace" => '⌫',
338 "up" => '↑',
339 "down" => '↓',
340 "left" => '←',
341 "right" => '→',
342 "tab" => '⇥',
343 "escape" => '⎋',
344 key => {
345 if key.len() == 1 {
346 key.chars().next().unwrap().to_ascii_uppercase()
347 } else {
348 return f.write_str(key);
349 }
350 }
351 };
352 f.write_char(key)
353 }
354}
355
356impl Context {
357 pub fn extend(&mut self, other: &Context) {
358 for v in &other.set {
359 self.set.insert(v.clone());
360 }
361 for (k, v) in &other.map {
362 self.map.insert(k.clone(), v.clone());
363 }
364 }
365}
366
367impl ContextPredicate {
368 fn parse(source: &str) -> anyhow::Result<Self> {
369 let mut parser = Parser::new();
370 let language = unsafe { tree_sitter_context_predicate() };
371 parser.set_language(language).unwrap();
372 let source = source.as_bytes();
373 let tree = parser.parse(source, None).unwrap();
374 Self::from_node(tree.root_node(), source)
375 }
376
377 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
378 let parse_error = "error parsing context predicate";
379 let kind = node.kind();
380
381 match kind {
382 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
383 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
384 "not" => {
385 let child = Self::from_node(
386 node.child_by_field_name("expression")
387 .ok_or(anyhow!(parse_error))?,
388 source,
389 )?;
390 Ok(Self::Not(Box::new(child)))
391 }
392 "and" | "or" => {
393 let left = Box::new(Self::from_node(
394 node.child_by_field_name("left")
395 .ok_or(anyhow!(parse_error))?,
396 source,
397 )?);
398 let right = Box::new(Self::from_node(
399 node.child_by_field_name("right")
400 .ok_or(anyhow!(parse_error))?,
401 source,
402 )?);
403 if kind == "and" {
404 Ok(Self::And(left, right))
405 } else {
406 Ok(Self::Or(left, right))
407 }
408 }
409 "equal" | "not_equal" => {
410 let left = node
411 .child_by_field_name("left")
412 .ok_or(anyhow!(parse_error))?
413 .utf8_text(source)?
414 .into();
415 let right = node
416 .child_by_field_name("right")
417 .ok_or(anyhow!(parse_error))?
418 .utf8_text(source)?
419 .into();
420 if kind == "equal" {
421 Ok(Self::Equal(left, right))
422 } else {
423 Ok(Self::NotEqual(left, right))
424 }
425 }
426 "parenthesized" => Self::from_node(
427 node.child_by_field_name("expression")
428 .ok_or(anyhow!(parse_error))?,
429 source,
430 ),
431 _ => Err(anyhow!(parse_error)),
432 }
433 }
434
435 fn eval(&self, cx: &Context) -> bool {
436 match self {
437 Self::Identifier(name) => cx.set.contains(name.as_str()),
438 Self::Equal(left, right) => cx
439 .map
440 .get(left)
441 .map(|value| value == right)
442 .unwrap_or(false),
443 Self::NotEqual(left, right) => {
444 cx.map.get(left).map(|value| value != right).unwrap_or(true)
445 }
446 Self::Not(pred) => !pred.eval(cx),
447 Self::And(left, right) => left.eval(cx) && right.eval(cx),
448 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
449 }
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use serde::Deserialize;
456
457 use crate::{actions, impl_actions};
458
459 use super::*;
460
461 #[test]
462 fn test_keystroke_parsing() -> anyhow::Result<()> {
463 assert_eq!(
464 Keystroke::parse("ctrl-p")?,
465 Keystroke {
466 key: "p".into(),
467 ctrl: true,
468 alt: false,
469 shift: false,
470 cmd: false,
471 function: false,
472 }
473 );
474
475 assert_eq!(
476 Keystroke::parse("alt-shift-down")?,
477 Keystroke {
478 key: "down".into(),
479 ctrl: false,
480 alt: true,
481 shift: true,
482 cmd: false,
483 function: false,
484 }
485 );
486
487 assert_eq!(
488 Keystroke::parse("shift-cmd--")?,
489 Keystroke {
490 key: "-".into(),
491 ctrl: false,
492 alt: false,
493 shift: true,
494 cmd: true,
495 function: false,
496 }
497 );
498
499 Ok(())
500 }
501
502 #[test]
503 fn test_context_predicate_parsing() -> anyhow::Result<()> {
504 use ContextPredicate::*;
505
506 assert_eq!(
507 ContextPredicate::parse("a && (b == c || d != e)")?,
508 And(
509 Box::new(Identifier("a".into())),
510 Box::new(Or(
511 Box::new(Equal("b".into(), "c".into())),
512 Box::new(NotEqual("d".into(), "e".into())),
513 ))
514 )
515 );
516
517 assert_eq!(
518 ContextPredicate::parse("!a")?,
519 Not(Box::new(Identifier("a".into())),)
520 );
521
522 Ok(())
523 }
524
525 #[test]
526 fn test_context_predicate_eval() -> anyhow::Result<()> {
527 let predicate = ContextPredicate::parse("a && b || c == d")?;
528
529 let mut context = Context::default();
530 context.set.insert("a".into());
531 assert!(!predicate.eval(&context));
532
533 context.set.insert("b".into());
534 assert!(predicate.eval(&context));
535
536 context.set.remove("b");
537 context.map.insert("c".into(), "x".into());
538 assert!(!predicate.eval(&context));
539
540 context.map.insert("c".into(), "d".into());
541 assert!(predicate.eval(&context));
542
543 let predicate = ContextPredicate::parse("!a")?;
544 assert!(predicate.eval(&Context::default()));
545
546 Ok(())
547 }
548
549 #[test]
550 fn test_matcher() -> anyhow::Result<()> {
551 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
552 pub struct A(pub String);
553 impl_actions!(test, [A]);
554 actions!(test, [B, Ab]);
555
556 #[derive(Clone, Debug, Eq, PartialEq)]
557 struct ActionArg {
558 a: &'static str,
559 }
560
561 let keymap = Keymap::new(vec![
562 Binding::new("a", A("x".to_string()), Some("a")),
563 Binding::new("b", B, Some("a")),
564 Binding::new("a b", Ab, Some("a || b")),
565 ]);
566
567 let mut ctx_a = Context::default();
568 ctx_a.set.insert("a".into());
569
570 let mut ctx_b = Context::default();
571 ctx_b.set.insert("b".into());
572
573 let mut matcher = Matcher::new(keymap);
574
575 // Basic match
576 assert_eq!(
577 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
578 Some(&A("x".to_string()))
579 );
580
581 // Multi-keystroke match
582 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
583 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
584
585 // Failed matches don't interfere with matching subsequent keys
586 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
587 assert_eq!(
588 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
589 Some(&A("x".to_string()))
590 );
591
592 // Pending keystrokes are cleared when the context changes
593 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
594 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
595
596 let mut ctx_c = Context::default();
597 ctx_c.set.insert("c".into());
598
599 // Pending keystrokes are maintained per-view
600 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
601 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
602 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
603
604 Ok(())
605 }
606
607 fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
608 action
609 .as_ref()
610 .and_then(|action| action.as_any().downcast_ref())
611 }
612
613 impl Matcher {
614 fn test_keystroke(
615 &mut self,
616 keystroke: &str,
617 view_id: usize,
618 cx: &Context,
619 ) -> Option<Box<dyn Action>> {
620 if let MatchResult::Action(action) =
621 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
622 {
623 Some(action.boxed_clone())
624 } else {
625 None
626 }
627 }
628 }
629}