1use crate::Action;
2use anyhow::{anyhow, Result};
3use std::{
4 any::Any,
5 collections::{HashMap, HashSet},
6 fmt::Debug,
7};
8use tree_sitter::{Language, Node, Parser};
9
10extern "C" {
11 fn tree_sitter_context_predicate() -> Language;
12}
13
14pub struct Matcher {
15 pending: HashMap<usize, Pending>,
16 keymap: Keymap,
17}
18
19#[derive(Default)]
20struct Pending {
21 keystrokes: Vec<Keystroke>,
22 context: Option<Context>,
23}
24
25#[derive(Default)]
26pub struct Keymap(Vec<Binding>);
27
28pub struct Binding {
29 keystrokes: Vec<Keystroke>,
30 action: Box<dyn Action>,
31 context: Option<ContextPredicate>,
32}
33
34#[derive(Clone, Debug, Eq, PartialEq)]
35pub struct Keystroke {
36 pub ctrl: bool,
37 pub alt: bool,
38 pub shift: bool,
39 pub cmd: bool,
40 pub key: String,
41}
42
43#[derive(Clone, Debug, Default, Eq, PartialEq)]
44pub struct Context {
45 pub set: HashSet<String>,
46 pub map: HashMap<String, String>,
47}
48
49#[derive(Debug, Eq, PartialEq)]
50enum ContextPredicate {
51 Identifier(String),
52 Equal(String, String),
53 NotEqual(String, String),
54 Not(Box<ContextPredicate>),
55 And(Box<ContextPredicate>, Box<ContextPredicate>),
56 Or(Box<ContextPredicate>, Box<ContextPredicate>),
57}
58
59trait ActionArg {
60 fn boxed_clone(&self) -> Box<dyn Any>;
61}
62
63impl<T> ActionArg for T
64where
65 T: 'static + Any + Clone,
66{
67 fn boxed_clone(&self) -> Box<dyn Any> {
68 Box::new(self.clone())
69 }
70}
71
72pub enum MatchResult {
73 None,
74 Pending,
75 Action(Box<dyn Action>),
76}
77
78impl Debug for MatchResult {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
82 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
83 MatchResult::Action(action) => f
84 .debug_tuple("MatchResult::Action")
85 .field(&action.name())
86 .finish(),
87 }
88 }
89}
90
91impl Matcher {
92 pub fn new(keymap: Keymap) -> Self {
93 Self {
94 pending: HashMap::new(),
95 keymap,
96 }
97 }
98
99 pub fn set_keymap(&mut self, keymap: Keymap) {
100 self.pending.clear();
101 self.keymap = keymap;
102 }
103
104 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
105 self.pending.clear();
106 self.keymap.add_bindings(bindings);
107 }
108
109 pub fn clear_bindings(&mut self) {
110 self.pending.clear();
111 self.keymap.clear();
112 }
113
114 pub fn clear_pending(&mut self) {
115 self.pending.clear();
116 }
117
118 pub fn push_keystroke(
119 &mut self,
120 keystroke: Keystroke,
121 view_id: usize,
122 cx: &Context,
123 ) -> MatchResult {
124 let pending = self.pending.entry(view_id).or_default();
125
126 if let Some(pending_ctx) = pending.context.as_ref() {
127 if pending_ctx != cx {
128 pending.keystrokes.clear();
129 }
130 }
131
132 pending.keystrokes.push(keystroke);
133
134 let mut retain_pending = false;
135 for binding in self.keymap.0.iter().rev() {
136 if binding.keystrokes.starts_with(&pending.keystrokes)
137 && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
138 {
139 if binding.keystrokes.len() == pending.keystrokes.len() {
140 self.pending.remove(&view_id);
141 return MatchResult::Action(binding.action.boxed_clone());
142 } else {
143 retain_pending = true;
144 pending.context = Some(cx.clone());
145 }
146 }
147 }
148
149 if retain_pending {
150 MatchResult::Pending
151 } else {
152 self.pending.remove(&view_id);
153 MatchResult::None
154 }
155 }
156}
157
158impl Default for Matcher {
159 fn default() -> Self {
160 Self::new(Keymap::default())
161 }
162}
163
164impl Keymap {
165 pub fn new(bindings: Vec<Binding>) -> Self {
166 Self(bindings)
167 }
168
169 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
170 self.0.extend(bindings.into_iter());
171 }
172
173 fn clear(&mut self) {
174 self.0.clear();
175 }
176}
177
178impl Binding {
179 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
180 Self::load(keystrokes, Box::new(action), context).unwrap()
181 }
182
183 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
184 let context = if let Some(context) = context {
185 Some(ContextPredicate::parse(context)?)
186 } else {
187 None
188 };
189
190 let keystrokes = keystrokes
191 .split_whitespace()
192 .map(|key| Keystroke::parse(key))
193 .collect::<Result<_>>()?;
194
195 Ok(Self {
196 keystrokes,
197 action,
198 context,
199 })
200 }
201}
202
203impl Keystroke {
204 pub fn parse(source: &str) -> anyhow::Result<Self> {
205 let mut ctrl = false;
206 let mut alt = false;
207 let mut shift = false;
208 let mut cmd = false;
209 let mut key = None;
210
211 let mut components = source.split("-").peekable();
212 while let Some(component) = components.next() {
213 match component {
214 "ctrl" => ctrl = true,
215 "alt" => alt = true,
216 "shift" => shift = true,
217 "cmd" => cmd = true,
218 _ => {
219 if let Some(component) = components.peek() {
220 if component.is_empty() && source.ends_with('-') {
221 key = Some(String::from("-"));
222 break;
223 } else {
224 return Err(anyhow!("Invalid keystroke `{}`", source));
225 }
226 } else {
227 key = Some(String::from(component));
228 }
229 }
230 }
231 }
232
233 Ok(Keystroke {
234 ctrl,
235 alt,
236 shift,
237 cmd,
238 key: key.unwrap(),
239 })
240 }
241
242 pub fn modified(&self) -> bool {
243 self.ctrl || self.alt || self.shift || self.cmd
244 }
245}
246
247impl Context {
248 pub fn extend(&mut self, other: &Context) {
249 for v in &other.set {
250 self.set.insert(v.clone());
251 }
252 for (k, v) in &other.map {
253 self.map.insert(k.clone(), v.clone());
254 }
255 }
256}
257
258impl ContextPredicate {
259 fn parse(source: &str) -> anyhow::Result<Self> {
260 let mut parser = Parser::new();
261 let language = unsafe { tree_sitter_context_predicate() };
262 parser.set_language(language).unwrap();
263 let source = source.as_bytes();
264 let tree = parser.parse(source, None).unwrap();
265 Self::from_node(tree.root_node(), source)
266 }
267
268 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
269 let parse_error = "error parsing context predicate";
270 let kind = node.kind();
271
272 match kind {
273 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
274 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
275 "not" => {
276 let child = Self::from_node(
277 node.child_by_field_name("expression")
278 .ok_or(anyhow!(parse_error))?,
279 source,
280 )?;
281 Ok(Self::Not(Box::new(child)))
282 }
283 "and" | "or" => {
284 let left = Box::new(Self::from_node(
285 node.child_by_field_name("left")
286 .ok_or(anyhow!(parse_error))?,
287 source,
288 )?);
289 let right = Box::new(Self::from_node(
290 node.child_by_field_name("right")
291 .ok_or(anyhow!(parse_error))?,
292 source,
293 )?);
294 if kind == "and" {
295 Ok(Self::And(left, right))
296 } else {
297 Ok(Self::Or(left, right))
298 }
299 }
300 "equal" | "not_equal" => {
301 let left = node
302 .child_by_field_name("left")
303 .ok_or(anyhow!(parse_error))?
304 .utf8_text(source)?
305 .into();
306 let right = node
307 .child_by_field_name("right")
308 .ok_or(anyhow!(parse_error))?
309 .utf8_text(source)?
310 .into();
311 if kind == "equal" {
312 Ok(Self::Equal(left, right))
313 } else {
314 Ok(Self::NotEqual(left, right))
315 }
316 }
317 "parenthesized" => Self::from_node(
318 node.child_by_field_name("expression")
319 .ok_or(anyhow!(parse_error))?,
320 source,
321 ),
322 _ => Err(anyhow!(parse_error)),
323 }
324 }
325
326 fn eval(&self, cx: &Context) -> bool {
327 match self {
328 Self::Identifier(name) => cx.set.contains(name.as_str()),
329 Self::Equal(left, right) => cx
330 .map
331 .get(left)
332 .map(|value| value == right)
333 .unwrap_or(false),
334 Self::NotEqual(left, right) => {
335 cx.map.get(left).map(|value| value != right).unwrap_or(true)
336 }
337 Self::Not(pred) => !pred.eval(cx),
338 Self::And(left, right) => left.eval(cx) && right.eval(cx),
339 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use serde::Deserialize;
347
348 use crate::{actions, impl_actions};
349
350 use super::*;
351
352 #[test]
353 fn test_keystroke_parsing() -> anyhow::Result<()> {
354 assert_eq!(
355 Keystroke::parse("ctrl-p")?,
356 Keystroke {
357 key: "p".into(),
358 ctrl: true,
359 alt: false,
360 shift: false,
361 cmd: false,
362 }
363 );
364
365 assert_eq!(
366 Keystroke::parse("alt-shift-down")?,
367 Keystroke {
368 key: "down".into(),
369 ctrl: false,
370 alt: true,
371 shift: true,
372 cmd: false,
373 }
374 );
375
376 assert_eq!(
377 Keystroke::parse("shift-cmd--")?,
378 Keystroke {
379 key: "-".into(),
380 ctrl: false,
381 alt: false,
382 shift: true,
383 cmd: true,
384 }
385 );
386
387 Ok(())
388 }
389
390 #[test]
391 fn test_context_predicate_parsing() -> anyhow::Result<()> {
392 use ContextPredicate::*;
393
394 assert_eq!(
395 ContextPredicate::parse("a && (b == c || d != e)")?,
396 And(
397 Box::new(Identifier("a".into())),
398 Box::new(Or(
399 Box::new(Equal("b".into(), "c".into())),
400 Box::new(NotEqual("d".into(), "e".into())),
401 ))
402 )
403 );
404
405 assert_eq!(
406 ContextPredicate::parse("!a")?,
407 Not(Box::new(Identifier("a".into())),)
408 );
409
410 Ok(())
411 }
412
413 #[test]
414 fn test_context_predicate_eval() -> anyhow::Result<()> {
415 let predicate = ContextPredicate::parse("a && b || c == d")?;
416
417 let mut context = Context::default();
418 context.set.insert("a".into());
419 assert!(!predicate.eval(&context));
420
421 context.set.insert("b".into());
422 assert!(predicate.eval(&context));
423
424 context.set.remove("b");
425 context.map.insert("c".into(), "x".into());
426 assert!(!predicate.eval(&context));
427
428 context.map.insert("c".into(), "d".into());
429 assert!(predicate.eval(&context));
430
431 let predicate = ContextPredicate::parse("!a")?;
432 assert!(predicate.eval(&Context::default()));
433
434 Ok(())
435 }
436
437 #[test]
438 fn test_matcher() -> anyhow::Result<()> {
439 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
440 pub struct A(pub String);
441 impl_actions!(test, [A]);
442 actions!(test, [B, Ab]);
443
444 #[derive(Clone, Debug, Eq, PartialEq)]
445 struct ActionArg {
446 a: &'static str,
447 }
448
449 let keymap = Keymap(vec![
450 Binding::new("a", A("x".to_string()), Some("a")),
451 Binding::new("b", B, Some("a")),
452 Binding::new("a b", Ab, Some("a || b")),
453 ]);
454
455 let mut ctx_a = Context::default();
456 ctx_a.set.insert("a".into());
457
458 let mut ctx_b = Context::default();
459 ctx_b.set.insert("b".into());
460
461 let mut matcher = Matcher::new(keymap);
462
463 // Basic match
464 assert_eq!(
465 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
466 Some(&A("x".to_string()))
467 );
468
469 // Multi-keystroke match
470 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
471 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
472
473 // Failed matches don't interfere with matching subsequent keys
474 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
475 assert_eq!(
476 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
477 Some(&A("x".to_string()))
478 );
479
480 // Pending keystrokes are cleared when the context changes
481 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
482 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
483
484 let mut ctx_c = Context::default();
485 ctx_c.set.insert("c".into());
486
487 // Pending keystrokes are maintained per-view
488 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
489 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
490 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
491
492 Ok(())
493 }
494
495 fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
496 action
497 .as_ref()
498 .and_then(|action| action.as_any().downcast_ref())
499 }
500
501 impl Matcher {
502 fn test_keystroke(
503 &mut self,
504 keystroke: &str,
505 view_id: usize,
506 cx: &Context,
507 ) -> Option<Box<dyn Action>> {
508 if let MatchResult::Action(action) =
509 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
510 {
511 Some(action.boxed_clone())
512 } else {
513 None
514 }
515 }
516 }
517}