1use crate::Action;
2use anyhow::{anyhow, Result};
3use std::{
4 any::Any,
5 collections::{HashMap, HashSet},
6 fmt::Debug,
7};
8use tree_sitter::{Language, Node, Parser};
9
10extern "C" {
11 fn tree_sitter_context_predicate() -> Language;
12}
13
14pub struct Matcher {
15 pending: HashMap<usize, Pending>,
16 keymap: Keymap,
17}
18
19#[derive(Default)]
20struct Pending {
21 keystrokes: Vec<Keystroke>,
22 context: Option<Context>,
23}
24
25#[derive(Default)]
26pub struct Keymap(Vec<Binding>);
27
28pub struct Binding {
29 keystrokes: Vec<Keystroke>,
30 action: Box<dyn Action>,
31 context: Option<ContextPredicate>,
32}
33
34#[derive(Clone, Debug, Eq, PartialEq)]
35pub struct Keystroke {
36 pub ctrl: bool,
37 pub alt: bool,
38 pub shift: bool,
39 pub cmd: bool,
40 pub key: String,
41}
42
43#[derive(Clone, Debug, Default, Eq, PartialEq)]
44pub struct Context {
45 pub set: HashSet<String>,
46 pub map: HashMap<String, String>,
47}
48
49#[derive(Debug, Eq, PartialEq)]
50enum ContextPredicate {
51 Identifier(String),
52 Equal(String, String),
53 NotEqual(String, String),
54 Not(Box<ContextPredicate>),
55 And(Box<ContextPredicate>, Box<ContextPredicate>),
56 Or(Box<ContextPredicate>, Box<ContextPredicate>),
57}
58
59trait ActionArg {
60 fn boxed_clone(&self) -> Box<dyn Any>;
61}
62
63impl<T> ActionArg for T
64where
65 T: 'static + Any + Clone,
66{
67 fn boxed_clone(&self) -> Box<dyn Any> {
68 Box::new(self.clone())
69 }
70}
71
72pub enum MatchResult {
73 None,
74 Pending,
75 Action(Box<dyn Action>),
76}
77
78impl Debug for MatchResult {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
82 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
83 MatchResult::Action(action) => f
84 .debug_tuple("MatchResult::Action")
85 .field(&action.name())
86 .finish(),
87 }
88 }
89}
90
91impl Matcher {
92 pub fn new(keymap: Keymap) -> Self {
93 Self {
94 pending: HashMap::new(),
95 keymap,
96 }
97 }
98
99 pub fn set_keymap(&mut self, keymap: Keymap) {
100 self.pending.clear();
101 self.keymap = keymap;
102 }
103
104 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
105 self.pending.clear();
106 self.keymap.add_bindings(bindings);
107 }
108
109 pub fn clear_pending(&mut self) {
110 self.pending.clear();
111 }
112
113 pub fn push_keystroke(
114 &mut self,
115 keystroke: Keystroke,
116 view_id: usize,
117 cx: &Context,
118 ) -> MatchResult {
119 let pending = self.pending.entry(view_id).or_default();
120
121 if let Some(pending_ctx) = pending.context.as_ref() {
122 if pending_ctx != cx {
123 pending.keystrokes.clear();
124 }
125 }
126
127 pending.keystrokes.push(keystroke);
128
129 let mut retain_pending = false;
130 for binding in self.keymap.0.iter().rev() {
131 if binding.keystrokes.starts_with(&pending.keystrokes)
132 && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
133 {
134 if binding.keystrokes.len() == pending.keystrokes.len() {
135 self.pending.remove(&view_id);
136 return MatchResult::Action(binding.action.boxed_clone());
137 } else {
138 retain_pending = true;
139 pending.context = Some(cx.clone());
140 }
141 }
142 }
143
144 if retain_pending {
145 MatchResult::Pending
146 } else {
147 self.pending.remove(&view_id);
148 MatchResult::None
149 }
150 }
151}
152
153impl Default for Matcher {
154 fn default() -> Self {
155 Self::new(Keymap::default())
156 }
157}
158
159impl Keymap {
160 pub fn new(bindings: Vec<Binding>) -> Self {
161 Self(bindings)
162 }
163
164 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
165 self.0.extend(bindings.into_iter());
166 }
167}
168
169impl Binding {
170 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
171 Self::load(keystrokes, Box::new(action), context).unwrap()
172 }
173
174 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
175 let context = if let Some(context) = context {
176 Some(ContextPredicate::parse(context)?)
177 } else {
178 None
179 };
180
181 let keystrokes = keystrokes
182 .split_whitespace()
183 .map(|key| Keystroke::parse(key))
184 .collect::<Result<_>>()?;
185
186 Ok(Self {
187 keystrokes,
188 action,
189 context,
190 })
191 }
192}
193
194impl Keystroke {
195 pub fn parse(source: &str) -> anyhow::Result<Self> {
196 let mut ctrl = false;
197 let mut alt = false;
198 let mut shift = false;
199 let mut cmd = false;
200 let mut key = None;
201
202 let mut components = source.split("-").peekable();
203 while let Some(component) = components.next() {
204 match component {
205 "ctrl" => ctrl = true,
206 "alt" => alt = true,
207 "shift" => shift = true,
208 "cmd" => cmd = true,
209 _ => {
210 if let Some(component) = components.peek() {
211 if component.is_empty() && source.ends_with('-') {
212 key = Some(String::from("-"));
213 break;
214 } else {
215 return Err(anyhow!("Invalid keystroke `{}`", source));
216 }
217 } else {
218 key = Some(String::from(component));
219 }
220 }
221 }
222 }
223
224 Ok(Keystroke {
225 ctrl,
226 alt,
227 shift,
228 cmd,
229 key: key.unwrap(),
230 })
231 }
232
233 pub fn modified(&self) -> bool {
234 self.ctrl || self.alt || self.shift || self.cmd
235 }
236}
237
238impl Context {
239 pub fn extend(&mut self, other: &Context) {
240 for v in &other.set {
241 self.set.insert(v.clone());
242 }
243 for (k, v) in &other.map {
244 self.map.insert(k.clone(), v.clone());
245 }
246 }
247}
248
249impl ContextPredicate {
250 fn parse(source: &str) -> anyhow::Result<Self> {
251 let mut parser = Parser::new();
252 let language = unsafe { tree_sitter_context_predicate() };
253 parser.set_language(language).unwrap();
254 let source = source.as_bytes();
255 let tree = parser.parse(source, None).unwrap();
256 Self::from_node(tree.root_node(), source)
257 }
258
259 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
260 let parse_error = "error parsing context predicate";
261 let kind = node.kind();
262
263 match kind {
264 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
265 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
266 "not" => {
267 let child = Self::from_node(
268 node.child_by_field_name("expression")
269 .ok_or(anyhow!(parse_error))?,
270 source,
271 )?;
272 Ok(Self::Not(Box::new(child)))
273 }
274 "and" | "or" => {
275 let left = Box::new(Self::from_node(
276 node.child_by_field_name("left")
277 .ok_or(anyhow!(parse_error))?,
278 source,
279 )?);
280 let right = Box::new(Self::from_node(
281 node.child_by_field_name("right")
282 .ok_or(anyhow!(parse_error))?,
283 source,
284 )?);
285 if kind == "and" {
286 Ok(Self::And(left, right))
287 } else {
288 Ok(Self::Or(left, right))
289 }
290 }
291 "equal" | "not_equal" => {
292 let left = node
293 .child_by_field_name("left")
294 .ok_or(anyhow!(parse_error))?
295 .utf8_text(source)?
296 .into();
297 let right = node
298 .child_by_field_name("right")
299 .ok_or(anyhow!(parse_error))?
300 .utf8_text(source)?
301 .into();
302 if kind == "equal" {
303 Ok(Self::Equal(left, right))
304 } else {
305 Ok(Self::NotEqual(left, right))
306 }
307 }
308 "parenthesized" => Self::from_node(
309 node.child_by_field_name("expression")
310 .ok_or(anyhow!(parse_error))?,
311 source,
312 ),
313 _ => Err(anyhow!(parse_error)),
314 }
315 }
316
317 fn eval(&self, cx: &Context) -> bool {
318 match self {
319 Self::Identifier(name) => cx.set.contains(name.as_str()),
320 Self::Equal(left, right) => cx
321 .map
322 .get(left)
323 .map(|value| value == right)
324 .unwrap_or(false),
325 Self::NotEqual(left, right) => {
326 cx.map.get(left).map(|value| value != right).unwrap_or(true)
327 }
328 Self::Not(pred) => !pred.eval(cx),
329 Self::And(left, right) => left.eval(cx) && right.eval(cx),
330 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use serde::Deserialize;
338
339 use crate::{actions, impl_actions};
340
341 use super::*;
342
343 #[test]
344 fn test_keystroke_parsing() -> anyhow::Result<()> {
345 assert_eq!(
346 Keystroke::parse("ctrl-p")?,
347 Keystroke {
348 key: "p".into(),
349 ctrl: true,
350 alt: false,
351 shift: false,
352 cmd: false,
353 }
354 );
355
356 assert_eq!(
357 Keystroke::parse("alt-shift-down")?,
358 Keystroke {
359 key: "down".into(),
360 ctrl: false,
361 alt: true,
362 shift: true,
363 cmd: false,
364 }
365 );
366
367 assert_eq!(
368 Keystroke::parse("shift-cmd--")?,
369 Keystroke {
370 key: "-".into(),
371 ctrl: false,
372 alt: false,
373 shift: true,
374 cmd: true,
375 }
376 );
377
378 Ok(())
379 }
380
381 #[test]
382 fn test_context_predicate_parsing() -> anyhow::Result<()> {
383 use ContextPredicate::*;
384
385 assert_eq!(
386 ContextPredicate::parse("a && (b == c || d != e)")?,
387 And(
388 Box::new(Identifier("a".into())),
389 Box::new(Or(
390 Box::new(Equal("b".into(), "c".into())),
391 Box::new(NotEqual("d".into(), "e".into())),
392 ))
393 )
394 );
395
396 assert_eq!(
397 ContextPredicate::parse("!a")?,
398 Not(Box::new(Identifier("a".into())),)
399 );
400
401 Ok(())
402 }
403
404 #[test]
405 fn test_context_predicate_eval() -> anyhow::Result<()> {
406 let predicate = ContextPredicate::parse("a && b || c == d")?;
407
408 let mut context = Context::default();
409 context.set.insert("a".into());
410 assert!(!predicate.eval(&context));
411
412 context.set.insert("b".into());
413 assert!(predicate.eval(&context));
414
415 context.set.remove("b");
416 context.map.insert("c".into(), "x".into());
417 assert!(!predicate.eval(&context));
418
419 context.map.insert("c".into(), "d".into());
420 assert!(predicate.eval(&context));
421
422 let predicate = ContextPredicate::parse("!a")?;
423 assert!(predicate.eval(&Context::default()));
424
425 Ok(())
426 }
427
428 #[test]
429 fn test_matcher() -> anyhow::Result<()> {
430 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
431 pub struct A(pub String);
432 impl_actions!(test, [A]);
433 actions!(test, [B, Ab]);
434
435 #[derive(Clone, Debug, Eq, PartialEq)]
436 struct ActionArg {
437 a: &'static str,
438 }
439
440 let keymap = Keymap(vec![
441 Binding::new("a", A("x".to_string()), Some("a")),
442 Binding::new("b", B, Some("a")),
443 Binding::new("a b", Ab, Some("a || b")),
444 ]);
445
446 let mut ctx_a = Context::default();
447 ctx_a.set.insert("a".into());
448
449 let mut ctx_b = Context::default();
450 ctx_b.set.insert("b".into());
451
452 let mut matcher = Matcher::new(keymap);
453
454 // Basic match
455 assert_eq!(
456 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
457 Some(&A("x".to_string()))
458 );
459
460 // Multi-keystroke match
461 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
462 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
463
464 // Failed matches don't interfere with matching subsequent keys
465 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
466 assert_eq!(
467 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
468 Some(&A("x".to_string()))
469 );
470
471 // Pending keystrokes are cleared when the context changes
472 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
473 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
474
475 let mut ctx_c = Context::default();
476 ctx_c.set.insert("c".into());
477
478 // Pending keystrokes are maintained per-view
479 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
480 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
481 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
482
483 Ok(())
484 }
485
486 fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
487 action
488 .as_ref()
489 .and_then(|action| action.as_any().downcast_ref())
490 }
491
492 impl Matcher {
493 fn test_keystroke(
494 &mut self,
495 keystroke: &str,
496 view_id: usize,
497 cx: &Context,
498 ) -> Option<Box<dyn Action>> {
499 if let MatchResult::Action(action) =
500 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
501 {
502 Some(action.boxed_clone())
503 } else {
504 None
505 }
506 }
507 }
508}