1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::Debug,
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: Vec<Keystroke>,
34 action: Box<dyn Action>,
35 context: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub key: String,
45}
46
47#[derive(Clone, Debug, Default, Eq, PartialEq)]
48pub struct Context {
49 pub set: HashSet<String>,
50 pub map: HashMap<String, String>,
51}
52
53#[derive(Debug, Eq, PartialEq)]
54enum ContextPredicate {
55 Identifier(String),
56 Equal(String, String),
57 NotEqual(String, String),
58 Not(Box<ContextPredicate>),
59 And(Box<ContextPredicate>, Box<ContextPredicate>),
60 Or(Box<ContextPredicate>, Box<ContextPredicate>),
61}
62
63trait ActionArg {
64 fn boxed_clone(&self) -> Box<dyn Any>;
65}
66
67impl<T> ActionArg for T
68where
69 T: 'static + Any + Clone,
70{
71 fn boxed_clone(&self) -> Box<dyn Any> {
72 Box::new(self.clone())
73 }
74}
75
76pub enum MatchResult {
77 None,
78 Pending,
79 Action(Box<dyn Action>),
80}
81
82impl Debug for MatchResult {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
86 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
87 MatchResult::Action(action) => f
88 .debug_tuple("MatchResult::Action")
89 .field(&action.name())
90 .finish(),
91 }
92 }
93}
94
95impl Matcher {
96 pub fn new(keymap: Keymap) -> Self {
97 Self {
98 pending: HashMap::new(),
99 keymap,
100 }
101 }
102
103 pub fn set_keymap(&mut self, keymap: Keymap) {
104 self.pending.clear();
105 self.keymap = keymap;
106 }
107
108 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
109 self.pending.clear();
110 self.keymap.add_bindings(bindings);
111 }
112
113 pub fn clear_bindings(&mut self) {
114 self.pending.clear();
115 self.keymap.clear();
116 }
117
118 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
119 self.keymap.bindings_for_action_type(action_type)
120 }
121
122 pub fn clear_pending(&mut self) {
123 self.pending.clear();
124 }
125
126 pub fn has_pending_keystrokes(&self) -> bool {
127 !self.pending.is_empty()
128 }
129
130 pub fn push_keystroke(
131 &mut self,
132 keystroke: Keystroke,
133 view_id: usize,
134 cx: &Context,
135 ) -> MatchResult {
136 let pending = self.pending.entry(view_id).or_default();
137
138 if let Some(pending_ctx) = pending.context.as_ref() {
139 if pending_ctx != cx {
140 pending.keystrokes.clear();
141 }
142 }
143
144 pending.keystrokes.push(keystroke);
145
146 let mut retain_pending = false;
147 for binding in self.keymap.bindings.iter().rev() {
148 if binding.keystrokes.starts_with(&pending.keystrokes)
149 && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
150 {
151 if binding.keystrokes.len() == pending.keystrokes.len() {
152 self.pending.remove(&view_id);
153 return MatchResult::Action(binding.action.boxed_clone());
154 } else {
155 retain_pending = true;
156 pending.context = Some(cx.clone());
157 }
158 }
159 }
160
161 if retain_pending {
162 MatchResult::Pending
163 } else {
164 self.pending.remove(&view_id);
165 MatchResult::None
166 }
167 }
168}
169
170impl Default for Matcher {
171 fn default() -> Self {
172 Self::new(Keymap::default())
173 }
174}
175
176impl Keymap {
177 pub fn new(bindings: Vec<Binding>) -> Self {
178 let mut binding_indices_by_action_type = HashMap::new();
179 for (ix, binding) in bindings.iter().enumerate() {
180 binding_indices_by_action_type
181 .entry(binding.action.as_any().type_id())
182 .or_insert_with(|| SmallVec::new())
183 .push(ix);
184 }
185 Self {
186 binding_indices_by_action_type,
187 bindings,
188 }
189 }
190
191 fn bindings_for_action_type<'a>(
192 &'a self,
193 action_type: TypeId,
194 ) -> impl Iterator<Item = &'a Binding> {
195 self.binding_indices_by_action_type
196 .get(&action_type)
197 .map(SmallVec::as_slice)
198 .unwrap_or(&[])
199 .iter()
200 .map(|ix| &self.bindings[*ix])
201 }
202
203 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
204 for binding in bindings {
205 self.binding_indices_by_action_type
206 .entry(binding.action.as_any().type_id())
207 .or_default()
208 .push(self.bindings.len());
209 self.bindings.push(binding);
210 }
211 }
212
213 fn clear(&mut self) {
214 self.bindings.clear();
215 self.binding_indices_by_action_type.clear();
216 }
217}
218
219impl Binding {
220 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
221 Self::load(keystrokes, Box::new(action), context).unwrap()
222 }
223
224 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
225 let context = if let Some(context) = context {
226 Some(ContextPredicate::parse(context)?)
227 } else {
228 None
229 };
230
231 let keystrokes = keystrokes
232 .split_whitespace()
233 .map(|key| Keystroke::parse(key))
234 .collect::<Result<_>>()?;
235
236 Ok(Self {
237 keystrokes,
238 action,
239 context,
240 })
241 }
242
243 pub fn keystrokes(&self) -> &[Keystroke] {
244 &self.keystrokes
245 }
246}
247
248impl Keystroke {
249 pub fn parse(source: &str) -> anyhow::Result<Self> {
250 let mut ctrl = false;
251 let mut alt = false;
252 let mut shift = false;
253 let mut cmd = false;
254 let mut key = None;
255
256 let mut components = source.split("-").peekable();
257 while let Some(component) = components.next() {
258 match component {
259 "ctrl" => ctrl = true,
260 "alt" => alt = true,
261 "shift" => shift = true,
262 "cmd" => cmd = true,
263 _ => {
264 if let Some(component) = components.peek() {
265 if component.is_empty() && source.ends_with('-') {
266 key = Some(String::from("-"));
267 break;
268 } else {
269 return Err(anyhow!("Invalid keystroke `{}`", source));
270 }
271 } else {
272 key = Some(String::from(component));
273 }
274 }
275 }
276 }
277
278 Ok(Keystroke {
279 ctrl,
280 alt,
281 shift,
282 cmd,
283 key: key.unwrap(),
284 })
285 }
286
287 pub fn modified(&self) -> bool {
288 self.ctrl || self.alt || self.shift || self.cmd
289 }
290}
291
292impl Context {
293 pub fn extend(&mut self, other: &Context) {
294 for v in &other.set {
295 self.set.insert(v.clone());
296 }
297 for (k, v) in &other.map {
298 self.map.insert(k.clone(), v.clone());
299 }
300 }
301}
302
303impl ContextPredicate {
304 fn parse(source: &str) -> anyhow::Result<Self> {
305 let mut parser = Parser::new();
306 let language = unsafe { tree_sitter_context_predicate() };
307 parser.set_language(language).unwrap();
308 let source = source.as_bytes();
309 let tree = parser.parse(source, None).unwrap();
310 Self::from_node(tree.root_node(), source)
311 }
312
313 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
314 let parse_error = "error parsing context predicate";
315 let kind = node.kind();
316
317 match kind {
318 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
319 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
320 "not" => {
321 let child = Self::from_node(
322 node.child_by_field_name("expression")
323 .ok_or(anyhow!(parse_error))?,
324 source,
325 )?;
326 Ok(Self::Not(Box::new(child)))
327 }
328 "and" | "or" => {
329 let left = Box::new(Self::from_node(
330 node.child_by_field_name("left")
331 .ok_or(anyhow!(parse_error))?,
332 source,
333 )?);
334 let right = Box::new(Self::from_node(
335 node.child_by_field_name("right")
336 .ok_or(anyhow!(parse_error))?,
337 source,
338 )?);
339 if kind == "and" {
340 Ok(Self::And(left, right))
341 } else {
342 Ok(Self::Or(left, right))
343 }
344 }
345 "equal" | "not_equal" => {
346 let left = node
347 .child_by_field_name("left")
348 .ok_or(anyhow!(parse_error))?
349 .utf8_text(source)?
350 .into();
351 let right = node
352 .child_by_field_name("right")
353 .ok_or(anyhow!(parse_error))?
354 .utf8_text(source)?
355 .into();
356 if kind == "equal" {
357 Ok(Self::Equal(left, right))
358 } else {
359 Ok(Self::NotEqual(left, right))
360 }
361 }
362 "parenthesized" => Self::from_node(
363 node.child_by_field_name("expression")
364 .ok_or(anyhow!(parse_error))?,
365 source,
366 ),
367 _ => Err(anyhow!(parse_error)),
368 }
369 }
370
371 fn eval(&self, cx: &Context) -> bool {
372 match self {
373 Self::Identifier(name) => cx.set.contains(name.as_str()),
374 Self::Equal(left, right) => cx
375 .map
376 .get(left)
377 .map(|value| value == right)
378 .unwrap_or(false),
379 Self::NotEqual(left, right) => {
380 cx.map.get(left).map(|value| value != right).unwrap_or(true)
381 }
382 Self::Not(pred) => !pred.eval(cx),
383 Self::And(left, right) => left.eval(cx) && right.eval(cx),
384 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
385 }
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use serde::Deserialize;
392
393 use crate::{actions, impl_actions};
394
395 use super::*;
396
397 #[test]
398 fn test_keystroke_parsing() -> anyhow::Result<()> {
399 assert_eq!(
400 Keystroke::parse("ctrl-p")?,
401 Keystroke {
402 key: "p".into(),
403 ctrl: true,
404 alt: false,
405 shift: false,
406 cmd: false,
407 }
408 );
409
410 assert_eq!(
411 Keystroke::parse("alt-shift-down")?,
412 Keystroke {
413 key: "down".into(),
414 ctrl: false,
415 alt: true,
416 shift: true,
417 cmd: false,
418 }
419 );
420
421 assert_eq!(
422 Keystroke::parse("shift-cmd--")?,
423 Keystroke {
424 key: "-".into(),
425 ctrl: false,
426 alt: false,
427 shift: true,
428 cmd: true,
429 }
430 );
431
432 Ok(())
433 }
434
435 #[test]
436 fn test_context_predicate_parsing() -> anyhow::Result<()> {
437 use ContextPredicate::*;
438
439 assert_eq!(
440 ContextPredicate::parse("a && (b == c || d != e)")?,
441 And(
442 Box::new(Identifier("a".into())),
443 Box::new(Or(
444 Box::new(Equal("b".into(), "c".into())),
445 Box::new(NotEqual("d".into(), "e".into())),
446 ))
447 )
448 );
449
450 assert_eq!(
451 ContextPredicate::parse("!a")?,
452 Not(Box::new(Identifier("a".into())),)
453 );
454
455 Ok(())
456 }
457
458 #[test]
459 fn test_context_predicate_eval() -> anyhow::Result<()> {
460 let predicate = ContextPredicate::parse("a && b || c == d")?;
461
462 let mut context = Context::default();
463 context.set.insert("a".into());
464 assert!(!predicate.eval(&context));
465
466 context.set.insert("b".into());
467 assert!(predicate.eval(&context));
468
469 context.set.remove("b");
470 context.map.insert("c".into(), "x".into());
471 assert!(!predicate.eval(&context));
472
473 context.map.insert("c".into(), "d".into());
474 assert!(predicate.eval(&context));
475
476 let predicate = ContextPredicate::parse("!a")?;
477 assert!(predicate.eval(&Context::default()));
478
479 Ok(())
480 }
481
482 #[test]
483 fn test_matcher() -> anyhow::Result<()> {
484 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
485 pub struct A(pub String);
486 impl_actions!(test, [A]);
487 actions!(test, [B, Ab]);
488
489 #[derive(Clone, Debug, Eq, PartialEq)]
490 struct ActionArg {
491 a: &'static str,
492 }
493
494 let keymap = Keymap::new(vec![
495 Binding::new("a", A("x".to_string()), Some("a")),
496 Binding::new("b", B, Some("a")),
497 Binding::new("a b", Ab, Some("a || b")),
498 ]);
499
500 let mut ctx_a = Context::default();
501 ctx_a.set.insert("a".into());
502
503 let mut ctx_b = Context::default();
504 ctx_b.set.insert("b".into());
505
506 let mut matcher = Matcher::new(keymap);
507
508 // Basic match
509 assert_eq!(
510 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
511 Some(&A("x".to_string()))
512 );
513
514 // Multi-keystroke match
515 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
516 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
517
518 // Failed matches don't interfere with matching subsequent keys
519 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
520 assert_eq!(
521 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
522 Some(&A("x".to_string()))
523 );
524
525 // Pending keystrokes are cleared when the context changes
526 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
527 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
528
529 let mut ctx_c = Context::default();
530 ctx_c.set.insert("c".into());
531
532 // Pending keystrokes are maintained per-view
533 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
534 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
535 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
536
537 Ok(())
538 }
539
540 fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
541 action
542 .as_ref()
543 .and_then(|action| action.as_any().downcast_ref())
544 }
545
546 impl Matcher {
547 fn test_keystroke(
548 &mut self,
549 keystroke: &str,
550 view_id: usize,
551 cx: &Context,
552 ) -> Option<Box<dyn Action>> {
553 if let MatchResult::Action(action) =
554 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
555 {
556 Some(action.boxed_clone())
557 } else {
558 None
559 }
560 }
561 }
562}