1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::{Debug, Write},
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: SmallVec<[Keystroke; 2]>,
34 action: Box<dyn Action>,
35 context_predicate: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub key: String,
45}
46
47#[derive(Clone, Debug, Default, Eq, PartialEq)]
48pub struct Context {
49 pub set: HashSet<String>,
50 pub map: HashMap<String, String>,
51}
52
53#[derive(Debug, Eq, PartialEq)]
54enum ContextPredicate {
55 Identifier(String),
56 Equal(String, String),
57 NotEqual(String, String),
58 Not(Box<ContextPredicate>),
59 And(Box<ContextPredicate>, Box<ContextPredicate>),
60 Or(Box<ContextPredicate>, Box<ContextPredicate>),
61}
62
63trait ActionArg {
64 fn boxed_clone(&self) -> Box<dyn Any>;
65}
66
67impl<T> ActionArg for T
68where
69 T: 'static + Any + Clone,
70{
71 fn boxed_clone(&self) -> Box<dyn Any> {
72 Box::new(self.clone())
73 }
74}
75
76pub enum MatchResult {
77 None,
78 Pending,
79 Action(Box<dyn Action>),
80}
81
82impl Debug for MatchResult {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
86 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
87 MatchResult::Action(action) => f
88 .debug_tuple("MatchResult::Action")
89 .field(&action.name())
90 .finish(),
91 }
92 }
93}
94
95impl Matcher {
96 pub fn new(keymap: Keymap) -> Self {
97 Self {
98 pending: HashMap::new(),
99 keymap,
100 }
101 }
102
103 pub fn set_keymap(&mut self, keymap: Keymap) {
104 self.pending.clear();
105 self.keymap = keymap;
106 }
107
108 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
109 self.pending.clear();
110 self.keymap.add_bindings(bindings);
111 }
112
113 pub fn clear_bindings(&mut self) {
114 self.pending.clear();
115 self.keymap.clear();
116 }
117
118 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
119 self.keymap.bindings_for_action_type(action_type)
120 }
121
122 pub fn clear_pending(&mut self) {
123 self.pending.clear();
124 }
125
126 pub fn has_pending_keystrokes(&self) -> bool {
127 !self.pending.is_empty()
128 }
129
130 pub fn push_keystroke(
131 &mut self,
132 keystroke: Keystroke,
133 view_id: usize,
134 cx: &Context,
135 ) -> MatchResult {
136 let pending = self.pending.entry(view_id).or_default();
137
138 if let Some(pending_ctx) = pending.context.as_ref() {
139 if pending_ctx != cx {
140 pending.keystrokes.clear();
141 }
142 }
143
144 pending.keystrokes.push(keystroke);
145
146 let mut retain_pending = false;
147 for binding in self.keymap.bindings.iter().rev() {
148 if binding.keystrokes.starts_with(&pending.keystrokes)
149 && binding
150 .context_predicate
151 .as_ref()
152 .map(|c| c.eval(cx))
153 .unwrap_or(true)
154 {
155 if binding.keystrokes.len() == pending.keystrokes.len() {
156 self.pending.remove(&view_id);
157 return MatchResult::Action(binding.action.boxed_clone());
158 } else {
159 retain_pending = true;
160 pending.context = Some(cx.clone());
161 }
162 }
163 }
164
165 if retain_pending {
166 MatchResult::Pending
167 } else {
168 self.pending.remove(&view_id);
169 MatchResult::None
170 }
171 }
172
173 pub fn keystrokes_for_action(
174 &self,
175 action: &dyn Action,
176 cx: &Context,
177 ) -> Option<SmallVec<[Keystroke; 2]>> {
178 for binding in self.keymap.bindings.iter().rev() {
179 if binding.action.eq(action)
180 && binding
181 .context_predicate
182 .as_ref()
183 .map_or(true, |predicate| predicate.eval(cx))
184 {
185 return Some(binding.keystrokes.clone());
186 }
187 }
188 None
189 }
190}
191
192impl Default for Matcher {
193 fn default() -> Self {
194 Self::new(Keymap::default())
195 }
196}
197
198impl Keymap {
199 pub fn new(bindings: Vec<Binding>) -> Self {
200 let mut binding_indices_by_action_type = HashMap::new();
201 for (ix, binding) in bindings.iter().enumerate() {
202 binding_indices_by_action_type
203 .entry(binding.action.as_any().type_id())
204 .or_insert_with(|| SmallVec::new())
205 .push(ix);
206 }
207 Self {
208 binding_indices_by_action_type,
209 bindings,
210 }
211 }
212
213 fn bindings_for_action_type<'a>(
214 &'a self,
215 action_type: TypeId,
216 ) -> impl Iterator<Item = &'a Binding> {
217 self.binding_indices_by_action_type
218 .get(&action_type)
219 .map(SmallVec::as_slice)
220 .unwrap_or(&[])
221 .iter()
222 .map(|ix| &self.bindings[*ix])
223 }
224
225 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
226 for binding in bindings {
227 self.binding_indices_by_action_type
228 .entry(binding.action.as_any().type_id())
229 .or_default()
230 .push(self.bindings.len());
231 self.bindings.push(binding);
232 }
233 }
234
235 fn clear(&mut self) {
236 self.bindings.clear();
237 self.binding_indices_by_action_type.clear();
238 }
239}
240
241impl Binding {
242 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
243 Self::load(keystrokes, Box::new(action), context).unwrap()
244 }
245
246 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
247 let context = if let Some(context) = context {
248 Some(ContextPredicate::parse(context)?)
249 } else {
250 None
251 };
252
253 let keystrokes = keystrokes
254 .split_whitespace()
255 .map(|key| Keystroke::parse(key))
256 .collect::<Result<_>>()?;
257
258 Ok(Self {
259 keystrokes,
260 action,
261 context_predicate: context,
262 })
263 }
264
265 pub fn keystrokes(&self) -> &[Keystroke] {
266 &self.keystrokes
267 }
268
269 pub fn action(&self) -> &dyn Action {
270 self.action.as_ref()
271 }
272}
273
274impl Keystroke {
275 pub fn parse(source: &str) -> anyhow::Result<Self> {
276 let mut ctrl = false;
277 let mut alt = false;
278 let mut shift = false;
279 let mut cmd = false;
280 let mut key = None;
281
282 let mut components = source.split("-").peekable();
283 while let Some(component) = components.next() {
284 match component {
285 "ctrl" => ctrl = true,
286 "alt" => alt = true,
287 "shift" => shift = true,
288 "cmd" => cmd = true,
289 _ => {
290 if let Some(component) = components.peek() {
291 if component.is_empty() && source.ends_with('-') {
292 key = Some(String::from("-"));
293 break;
294 } else {
295 return Err(anyhow!("Invalid keystroke `{}`", source));
296 }
297 } else {
298 key = Some(String::from(component));
299 }
300 }
301 }
302 }
303
304 Ok(Keystroke {
305 ctrl,
306 alt,
307 shift,
308 cmd,
309 key: key.unwrap(),
310 })
311 }
312
313 pub fn modified(&self) -> bool {
314 self.ctrl || self.alt || self.shift || self.cmd
315 }
316}
317
318impl std::fmt::Display for Keystroke {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 if self.ctrl {
321 f.write_char('^')?;
322 }
323 if self.alt {
324 f.write_char('⎇')?;
325 }
326 if self.cmd {
327 f.write_char('⌘')?;
328 }
329 if self.shift {
330 f.write_char('⇧')?;
331 }
332 let key = match self.key.as_str() {
333 "backspace" => '⌫',
334 "up" => '↑',
335 "down" => '↓',
336 "left" => '←',
337 "right" => '→',
338 "tab" => '⇥',
339 "escape" => '⎋',
340 key => {
341 if key.len() == 1 {
342 key.chars().next().unwrap().to_ascii_uppercase()
343 } else {
344 return f.write_str(key);
345 }
346 }
347 };
348 f.write_char(key)
349 }
350}
351
352impl Context {
353 pub fn extend(&mut self, other: &Context) {
354 for v in &other.set {
355 self.set.insert(v.clone());
356 }
357 for (k, v) in &other.map {
358 self.map.insert(k.clone(), v.clone());
359 }
360 }
361}
362
363impl ContextPredicate {
364 fn parse(source: &str) -> anyhow::Result<Self> {
365 let mut parser = Parser::new();
366 let language = unsafe { tree_sitter_context_predicate() };
367 parser.set_language(language).unwrap();
368 let source = source.as_bytes();
369 let tree = parser.parse(source, None).unwrap();
370 Self::from_node(tree.root_node(), source)
371 }
372
373 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
374 let parse_error = "error parsing context predicate";
375 let kind = node.kind();
376
377 match kind {
378 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
379 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
380 "not" => {
381 let child = Self::from_node(
382 node.child_by_field_name("expression")
383 .ok_or(anyhow!(parse_error))?,
384 source,
385 )?;
386 Ok(Self::Not(Box::new(child)))
387 }
388 "and" | "or" => {
389 let left = Box::new(Self::from_node(
390 node.child_by_field_name("left")
391 .ok_or(anyhow!(parse_error))?,
392 source,
393 )?);
394 let right = Box::new(Self::from_node(
395 node.child_by_field_name("right")
396 .ok_or(anyhow!(parse_error))?,
397 source,
398 )?);
399 if kind == "and" {
400 Ok(Self::And(left, right))
401 } else {
402 Ok(Self::Or(left, right))
403 }
404 }
405 "equal" | "not_equal" => {
406 let left = node
407 .child_by_field_name("left")
408 .ok_or(anyhow!(parse_error))?
409 .utf8_text(source)?
410 .into();
411 let right = node
412 .child_by_field_name("right")
413 .ok_or(anyhow!(parse_error))?
414 .utf8_text(source)?
415 .into();
416 if kind == "equal" {
417 Ok(Self::Equal(left, right))
418 } else {
419 Ok(Self::NotEqual(left, right))
420 }
421 }
422 "parenthesized" => Self::from_node(
423 node.child_by_field_name("expression")
424 .ok_or(anyhow!(parse_error))?,
425 source,
426 ),
427 _ => Err(anyhow!(parse_error)),
428 }
429 }
430
431 fn eval(&self, cx: &Context) -> bool {
432 match self {
433 Self::Identifier(name) => cx.set.contains(name.as_str()),
434 Self::Equal(left, right) => cx
435 .map
436 .get(left)
437 .map(|value| value == right)
438 .unwrap_or(false),
439 Self::NotEqual(left, right) => {
440 cx.map.get(left).map(|value| value != right).unwrap_or(true)
441 }
442 Self::Not(pred) => !pred.eval(cx),
443 Self::And(left, right) => left.eval(cx) && right.eval(cx),
444 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
445 }
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use serde::Deserialize;
452
453 use crate::{actions, impl_actions};
454
455 use super::*;
456
457 #[test]
458 fn test_keystroke_parsing() -> anyhow::Result<()> {
459 assert_eq!(
460 Keystroke::parse("ctrl-p")?,
461 Keystroke {
462 key: "p".into(),
463 ctrl: true,
464 alt: false,
465 shift: false,
466 cmd: false,
467 }
468 );
469
470 assert_eq!(
471 Keystroke::parse("alt-shift-down")?,
472 Keystroke {
473 key: "down".into(),
474 ctrl: false,
475 alt: true,
476 shift: true,
477 cmd: false,
478 }
479 );
480
481 assert_eq!(
482 Keystroke::parse("shift-cmd--")?,
483 Keystroke {
484 key: "-".into(),
485 ctrl: false,
486 alt: false,
487 shift: true,
488 cmd: true,
489 }
490 );
491
492 Ok(())
493 }
494
495 #[test]
496 fn test_context_predicate_parsing() -> anyhow::Result<()> {
497 use ContextPredicate::*;
498
499 assert_eq!(
500 ContextPredicate::parse("a && (b == c || d != e)")?,
501 And(
502 Box::new(Identifier("a".into())),
503 Box::new(Or(
504 Box::new(Equal("b".into(), "c".into())),
505 Box::new(NotEqual("d".into(), "e".into())),
506 ))
507 )
508 );
509
510 assert_eq!(
511 ContextPredicate::parse("!a")?,
512 Not(Box::new(Identifier("a".into())),)
513 );
514
515 Ok(())
516 }
517
518 #[test]
519 fn test_context_predicate_eval() -> anyhow::Result<()> {
520 let predicate = ContextPredicate::parse("a && b || c == d")?;
521
522 let mut context = Context::default();
523 context.set.insert("a".into());
524 assert!(!predicate.eval(&context));
525
526 context.set.insert("b".into());
527 assert!(predicate.eval(&context));
528
529 context.set.remove("b");
530 context.map.insert("c".into(), "x".into());
531 assert!(!predicate.eval(&context));
532
533 context.map.insert("c".into(), "d".into());
534 assert!(predicate.eval(&context));
535
536 let predicate = ContextPredicate::parse("!a")?;
537 assert!(predicate.eval(&Context::default()));
538
539 Ok(())
540 }
541
542 #[test]
543 fn test_matcher() -> anyhow::Result<()> {
544 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
545 pub struct A(pub String);
546 impl_actions!(test, [A]);
547 actions!(test, [B, Ab]);
548
549 #[derive(Clone, Debug, Eq, PartialEq)]
550 struct ActionArg {
551 a: &'static str,
552 }
553
554 let keymap = Keymap::new(vec![
555 Binding::new("a", A("x".to_string()), Some("a")),
556 Binding::new("b", B, Some("a")),
557 Binding::new("a b", Ab, Some("a || b")),
558 ]);
559
560 let mut ctx_a = Context::default();
561 ctx_a.set.insert("a".into());
562
563 let mut ctx_b = Context::default();
564 ctx_b.set.insert("b".into());
565
566 let mut matcher = Matcher::new(keymap);
567
568 // Basic match
569 assert_eq!(
570 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
571 Some(&A("x".to_string()))
572 );
573
574 // Multi-keystroke match
575 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
576 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
577
578 // Failed matches don't interfere with matching subsequent keys
579 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
580 assert_eq!(
581 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
582 Some(&A("x".to_string()))
583 );
584
585 // Pending keystrokes are cleared when the context changes
586 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
587 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
588
589 let mut ctx_c = Context::default();
590 ctx_c.set.insert("c".into());
591
592 // Pending keystrokes are maintained per-view
593 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
594 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
595 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
596
597 Ok(())
598 }
599
600 fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
601 action
602 .as_ref()
603 .and_then(|action| action.as_any().downcast_ref())
604 }
605
606 impl Matcher {
607 fn test_keystroke(
608 &mut self,
609 keystroke: &str,
610 view_id: usize,
611 cx: &Context,
612 ) -> Option<Box<dyn Action>> {
613 if let MatchResult::Action(action) =
614 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
615 {
616 Some(action.boxed_clone())
617 } else {
618 None
619 }
620 }
621 }
622}