1use anyhow::anyhow;
2use std::{
3 any::Any,
4 collections::{HashMap, HashSet},
5 fmt::Debug,
6};
7use tree_sitter::{Language, Node, Parser};
8
9use crate::{Action, AnyAction};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap(Vec<Binding>);
28
29pub struct Binding {
30 keystrokes: Vec<Keystroke>,
31 action: Box<dyn AnyAction>,
32 context: Option<ContextPredicate>,
33}
34
35#[derive(Clone, Debug, Eq, PartialEq)]
36pub struct Keystroke {
37 pub ctrl: bool,
38 pub alt: bool,
39 pub shift: bool,
40 pub cmd: bool,
41 pub key: String,
42}
43
44#[derive(Clone, Debug, Default, Eq, PartialEq)]
45pub struct Context {
46 pub set: HashSet<String>,
47 pub map: HashMap<String, String>,
48}
49
50#[derive(Debug, Eq, PartialEq)]
51enum ContextPredicate {
52 Identifier(String),
53 Equal(String, String),
54 NotEqual(String, String),
55 Not(Box<ContextPredicate>),
56 And(Box<ContextPredicate>, Box<ContextPredicate>),
57 Or(Box<ContextPredicate>, Box<ContextPredicate>),
58}
59
60trait ActionArg {
61 fn boxed_clone(&self) -> Box<dyn Any>;
62}
63
64impl<T> ActionArg for T
65where
66 T: 'static + Any + Clone,
67{
68 fn boxed_clone(&self) -> Box<dyn Any> {
69 Box::new(self.clone())
70 }
71}
72
73pub enum MatchResult {
74 None,
75 Pending,
76 Action(Box<dyn AnyAction>),
77}
78
79impl Matcher {
80 pub fn new(keymap: Keymap) -> Self {
81 Self {
82 pending: HashMap::new(),
83 keymap,
84 }
85 }
86
87 pub fn set_keymap(&mut self, keymap: Keymap) {
88 self.pending.clear();
89 self.keymap = keymap;
90 }
91
92 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
93 self.pending.clear();
94 self.keymap.add_bindings(bindings);
95 }
96
97 pub fn clear_pending(&mut self) {
98 self.pending.clear();
99 }
100
101 pub fn push_keystroke(
102 &mut self,
103 keystroke: Keystroke,
104 view_id: usize,
105 cx: &Context,
106 ) -> MatchResult {
107 let pending = self.pending.entry(view_id).or_default();
108
109 if let Some(pending_ctx) = pending.context.as_ref() {
110 if pending_ctx != cx {
111 pending.keystrokes.clear();
112 }
113 }
114
115 pending.keystrokes.push(keystroke);
116
117 let mut retain_pending = false;
118 for binding in self.keymap.0.iter().rev() {
119 if binding.keystrokes.starts_with(&pending.keystrokes)
120 && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
121 {
122 if binding.keystrokes.len() == pending.keystrokes.len() {
123 self.pending.remove(&view_id);
124 return MatchResult::Action(binding.action.boxed_clone());
125 } else {
126 retain_pending = true;
127 pending.context = Some(cx.clone());
128 }
129 }
130 }
131
132 if retain_pending {
133 MatchResult::Pending
134 } else {
135 self.pending.remove(&view_id);
136 MatchResult::None
137 }
138 }
139}
140
141impl Default for Matcher {
142 fn default() -> Self {
143 Self::new(Keymap::default())
144 }
145}
146
147impl Keymap {
148 pub fn new(bindings: Vec<Binding>) -> Self {
149 Self(bindings)
150 }
151
152 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
153 self.0.extend(bindings.into_iter());
154 }
155}
156
157impl Binding {
158 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
159 let context = if let Some(context) = context {
160 Some(ContextPredicate::parse(context).unwrap())
161 } else {
162 None
163 };
164
165 Self {
166 keystrokes: keystrokes
167 .split_whitespace()
168 .map(|key| Keystroke::parse(key).unwrap())
169 .collect(),
170 action: Box::new(action),
171 context,
172 }
173 }
174}
175
176impl Keystroke {
177 pub fn parse(source: &str) -> anyhow::Result<Self> {
178 let mut ctrl = false;
179 let mut alt = false;
180 let mut shift = false;
181 let mut cmd = false;
182 let mut key = None;
183
184 let mut components = source.split("-").peekable();
185 while let Some(component) = components.next() {
186 match component {
187 "ctrl" => ctrl = true,
188 "alt" => alt = true,
189 "shift" => shift = true,
190 "cmd" => cmd = true,
191 _ => {
192 if let Some(component) = components.peek() {
193 if component.is_empty() && source.ends_with('-') {
194 key = Some(String::from("-"));
195 break;
196 } else {
197 return Err(anyhow!("Invalid keystroke `{}`", source));
198 }
199 } else {
200 key = Some(String::from(component));
201 }
202 }
203 }
204 }
205
206 Ok(Keystroke {
207 ctrl,
208 alt,
209 shift,
210 cmd,
211 key: key.unwrap(),
212 })
213 }
214}
215
216impl Context {
217 pub fn extend(&mut self, other: Context) {
218 for v in other.set {
219 self.set.insert(v);
220 }
221 for (k, v) in other.map {
222 self.map.insert(k, v);
223 }
224 }
225}
226
227impl ContextPredicate {
228 fn parse(source: &str) -> anyhow::Result<Self> {
229 let mut parser = Parser::new();
230 let language = unsafe { tree_sitter_context_predicate() };
231 parser.set_language(language).unwrap();
232 let source = source.as_bytes();
233 let tree = parser.parse(source, None).unwrap();
234 Self::from_node(tree.root_node(), source)
235 }
236
237 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
238 let parse_error = "error parsing context predicate";
239 let kind = node.kind();
240
241 match kind {
242 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
243 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
244 "not" => {
245 let child = Self::from_node(
246 node.child_by_field_name("expression")
247 .ok_or(anyhow!(parse_error))?,
248 source,
249 )?;
250 Ok(Self::Not(Box::new(child)))
251 }
252 "and" | "or" => {
253 let left = Box::new(Self::from_node(
254 node.child_by_field_name("left")
255 .ok_or(anyhow!(parse_error))?,
256 source,
257 )?);
258 let right = Box::new(Self::from_node(
259 node.child_by_field_name("right")
260 .ok_or(anyhow!(parse_error))?,
261 source,
262 )?);
263 if kind == "and" {
264 Ok(Self::And(left, right))
265 } else {
266 Ok(Self::Or(left, right))
267 }
268 }
269 "equal" | "not_equal" => {
270 let left = node
271 .child_by_field_name("left")
272 .ok_or(anyhow!(parse_error))?
273 .utf8_text(source)?
274 .into();
275 let right = node
276 .child_by_field_name("right")
277 .ok_or(anyhow!(parse_error))?
278 .utf8_text(source)?
279 .into();
280 if kind == "equal" {
281 Ok(Self::Equal(left, right))
282 } else {
283 Ok(Self::NotEqual(left, right))
284 }
285 }
286 "parenthesized" => Self::from_node(
287 node.child_by_field_name("expression")
288 .ok_or(anyhow!(parse_error))?,
289 source,
290 ),
291 _ => Err(anyhow!(parse_error)),
292 }
293 }
294
295 fn eval(&self, cx: &Context) -> bool {
296 match self {
297 Self::Identifier(name) => cx.set.contains(name.as_str()),
298 Self::Equal(left, right) => cx
299 .map
300 .get(left)
301 .map(|value| value == right)
302 .unwrap_or(false),
303 Self::NotEqual(left, right) => {
304 cx.map.get(left).map(|value| value != right).unwrap_or(true)
305 }
306 Self::Not(pred) => !pred.eval(cx),
307 Self::And(left, right) => left.eval(cx) && right.eval(cx),
308 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use crate::action;
316
317 use super::*;
318
319 #[test]
320 fn test_keystroke_parsing() -> anyhow::Result<()> {
321 assert_eq!(
322 Keystroke::parse("ctrl-p")?,
323 Keystroke {
324 key: "p".into(),
325 ctrl: true,
326 alt: false,
327 shift: false,
328 cmd: false,
329 }
330 );
331
332 assert_eq!(
333 Keystroke::parse("alt-shift-down")?,
334 Keystroke {
335 key: "down".into(),
336 ctrl: false,
337 alt: true,
338 shift: true,
339 cmd: false,
340 }
341 );
342
343 assert_eq!(
344 Keystroke::parse("shift-cmd--")?,
345 Keystroke {
346 key: "-".into(),
347 ctrl: false,
348 alt: false,
349 shift: true,
350 cmd: true,
351 }
352 );
353
354 Ok(())
355 }
356
357 #[test]
358 fn test_context_predicate_parsing() -> anyhow::Result<()> {
359 use ContextPredicate::*;
360
361 assert_eq!(
362 ContextPredicate::parse("a && (b == c || d != e)")?,
363 And(
364 Box::new(Identifier("a".into())),
365 Box::new(Or(
366 Box::new(Equal("b".into(), "c".into())),
367 Box::new(NotEqual("d".into(), "e".into())),
368 ))
369 )
370 );
371
372 assert_eq!(
373 ContextPredicate::parse("!a")?,
374 Not(Box::new(Identifier("a".into())),)
375 );
376
377 Ok(())
378 }
379
380 #[test]
381 fn test_context_predicate_eval() -> anyhow::Result<()> {
382 let predicate = ContextPredicate::parse("a && b || c == d")?;
383
384 let mut context = Context::default();
385 context.set.insert("a".into());
386 assert!(!predicate.eval(&context));
387
388 context.set.insert("b".into());
389 assert!(predicate.eval(&context));
390
391 context.set.remove("b");
392 context.map.insert("c".into(), "x".into());
393 assert!(!predicate.eval(&context));
394
395 context.map.insert("c".into(), "d".into());
396 assert!(predicate.eval(&context));
397
398 let predicate = ContextPredicate::parse("!a")?;
399 assert!(predicate.eval(&Context::default()));
400
401 Ok(())
402 }
403
404 #[test]
405 fn test_matcher() -> anyhow::Result<()> {
406 action!(A, &'static str);
407 action!(B);
408 action!(Ab);
409
410 impl PartialEq for A {
411 fn eq(&self, other: &Self) -> bool {
412 self.0 == other.0
413 }
414 }
415 impl Eq for A {}
416 impl Debug for A {
417 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418 write!(f, "A({:?})", &self.0)
419 }
420 }
421
422 #[derive(Clone, Debug, Eq, PartialEq)]
423 struct ActionArg {
424 a: &'static str,
425 }
426
427 let keymap = Keymap(vec![
428 Binding::new("a", A("x"), Some("a")),
429 Binding::new("b", B, Some("a")),
430 Binding::new("a b", Ab, Some("a || b")),
431 ]);
432
433 let mut ctx_a = Context::default();
434 ctx_a.set.insert("a".into());
435
436 let mut ctx_b = Context::default();
437 ctx_b.set.insert("b".into());
438
439 let mut matcher = Matcher::new(keymap);
440
441 // Basic match
442 assert_eq!(matcher.test_keystroke("a", 1, &ctx_a), Some(A("x")));
443
444 // Multi-keystroke match
445 assert_eq!(matcher.test_keystroke::<A>("a", 1, &ctx_b), None);
446 assert_eq!(matcher.test_keystroke("b", 1, &ctx_b), Some(Ab));
447
448 // Failed matches don't interfere with matching subsequent keys
449 assert_eq!(matcher.test_keystroke::<A>("x", 1, &ctx_a), None);
450 assert_eq!(matcher.test_keystroke("a", 1, &ctx_a), Some(A("x")));
451
452 // Pending keystrokes are cleared when the context changes
453 assert_eq!(matcher.test_keystroke::<A>("a", 1, &ctx_b), None);
454 assert_eq!(matcher.test_keystroke("b", 1, &ctx_a), Some(B));
455
456 let mut ctx_c = Context::default();
457 ctx_c.set.insert("c".into());
458
459 // Pending keystrokes are maintained per-view
460 assert_eq!(matcher.test_keystroke::<A>("a", 1, &ctx_b), None);
461 assert_eq!(matcher.test_keystroke::<A>("a", 2, &ctx_c), None);
462 assert_eq!(matcher.test_keystroke("b", 1, &ctx_b), Some(Ab));
463
464 Ok(())
465 }
466
467 impl Matcher {
468 fn test_keystroke<A>(&mut self, keystroke: &str, view_id: usize, cx: &Context) -> Option<A>
469 where
470 A: Action + Debug + Eq,
471 {
472 if let MatchResult::Action(action) =
473 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
474 {
475 Some(*action.boxed_clone_as_any().downcast().unwrap())
476 } else {
477 None
478 }
479 }
480 }
481}