1use anyhow::anyhow;
2use std::{
3 any::Any,
4 collections::{HashMap, HashSet},
5};
6use tree_sitter::{Language, Node, Parser};
7
8extern "C" {
9 fn tree_sitter_context_predicate() -> Language;
10}
11
12pub struct Matcher {
13 pending: HashMap<usize, Pending>,
14 keymap: Keymap,
15}
16
17#[derive(Default)]
18struct Pending {
19 keystrokes: Vec<Keystroke>,
20 context: Option<Context>,
21}
22
23pub struct Keymap(Vec<Binding>);
24
25pub struct Binding {
26 keystrokes: Vec<Keystroke>,
27 action: String,
28 action_arg: Option<Box<dyn ActionArg>>,
29 context: Option<ContextPredicate>,
30}
31
32#[derive(Clone, Debug, Eq, PartialEq)]
33pub struct Keystroke {
34 pub ctrl: bool,
35 pub alt: bool,
36 pub shift: bool,
37 pub cmd: bool,
38 pub key: String,
39}
40
41#[derive(Clone, Debug, Default, Eq, PartialEq)]
42pub struct Context {
43 pub set: HashSet<String>,
44 pub map: HashMap<String, String>,
45}
46
47#[derive(Debug, Eq, PartialEq)]
48enum ContextPredicate {
49 Identifier(String),
50 Equal(String, String),
51 NotEqual(String, String),
52 Not(Box<ContextPredicate>),
53 And(Box<ContextPredicate>, Box<ContextPredicate>),
54 Or(Box<ContextPredicate>, Box<ContextPredicate>),
55}
56
57trait ActionArg {
58 fn boxed_clone(&self) -> Box<dyn Any>;
59}
60
61impl<T> ActionArg for T
62where
63 T: 'static + Any + Clone,
64{
65 fn boxed_clone(&self) -> Box<dyn Any> {
66 Box::new(self.clone())
67 }
68}
69
70pub enum MatchResult {
71 None,
72 Pending,
73 Action {
74 name: String,
75 arg: Option<Box<dyn Any>>,
76 },
77}
78
79impl Matcher {
80 pub fn new(keymap: Keymap) -> Self {
81 Self {
82 pending: HashMap::new(),
83 keymap,
84 }
85 }
86
87 pub fn set_keymap(&mut self, keymap: Keymap) {
88 self.pending.clear();
89 self.keymap = keymap;
90 }
91
92 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
93 self.pending.clear();
94 self.keymap.add_bindings(bindings);
95 }
96
97 pub fn push_keystroke(
98 &mut self,
99 keystroke: Keystroke,
100 view_id: usize,
101 ctx: &Context,
102 ) -> MatchResult {
103 let pending = self.pending.entry(view_id).or_default();
104
105 if let Some(pending_ctx) = pending.context.as_ref() {
106 if pending_ctx != ctx {
107 pending.keystrokes.clear();
108 }
109 }
110
111 pending.keystrokes.push(keystroke);
112
113 let mut retain_pending = false;
114 for binding in self.keymap.0.iter().rev() {
115 if binding.keystrokes.starts_with(&pending.keystrokes)
116 && binding
117 .context
118 .as_ref()
119 .map(|c| c.eval(ctx))
120 .unwrap_or(true)
121 {
122 if binding.keystrokes.len() == pending.keystrokes.len() {
123 self.pending.remove(&view_id);
124 return MatchResult::Action {
125 name: binding.action.clone(),
126 arg: binding.action_arg.as_ref().map(|arg| (*arg).boxed_clone()),
127 };
128 } else {
129 retain_pending = true;
130 pending.context = Some(ctx.clone());
131 }
132 }
133 }
134
135 if retain_pending {
136 MatchResult::Pending
137 } else {
138 self.pending.remove(&view_id);
139 MatchResult::None
140 }
141 }
142}
143
144impl Default for Matcher {
145 fn default() -> Self {
146 Self::new(Keymap::default())
147 }
148}
149
150impl Keymap {
151 pub fn new(bindings: Vec<Binding>) -> Self {
152 Self(bindings)
153 }
154
155 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
156 self.0.extend(bindings.into_iter());
157 }
158}
159
160impl Default for Keymap {
161 fn default() -> Self {
162 Self(vec![
163 Binding::new("up", "menu:select_prev", Some("menu")),
164 Binding::new("ctrl-p", "menu:select_prev", Some("menu")),
165 Binding::new("down", "menu:select_next", Some("menu")),
166 Binding::new("ctrl-n", "menu:select_next", Some("menu")),
167 ])
168 }
169}
170
171impl Binding {
172 pub fn new<S: Into<String>>(keystrokes: &str, action: S, context: Option<&str>) -> Self {
173 let context = if let Some(context) = context {
174 Some(ContextPredicate::parse(context).unwrap())
175 } else {
176 None
177 };
178
179 Self {
180 keystrokes: keystrokes
181 .split_whitespace()
182 .map(|key| Keystroke::parse(key).unwrap())
183 .collect(),
184 action: action.into(),
185 action_arg: None,
186 context,
187 }
188 }
189
190 pub fn with_arg<T: 'static + Any + Clone>(mut self, arg: T) -> Self {
191 self.action_arg = Some(Box::new(arg));
192 self
193 }
194}
195
196impl Keystroke {
197 pub fn parse(source: &str) -> anyhow::Result<Self> {
198 let mut ctrl = false;
199 let mut alt = false;
200 let mut shift = false;
201 let mut cmd = false;
202 let mut key = None;
203
204 let mut components = source.split("-").peekable();
205 while let Some(component) = components.next() {
206 match component {
207 "ctrl" => ctrl = true,
208 "alt" => alt = true,
209 "shift" => shift = true,
210 "cmd" => cmd = true,
211 _ => {
212 if let Some(component) = components.peek() {
213 if component.is_empty() && source.ends_with('-') {
214 key = Some(String::from("-"));
215 break;
216 } else {
217 return Err(anyhow!("Invalid keystroke `{}`", source));
218 }
219 } else {
220 key = Some(String::from(component));
221 }
222 }
223 }
224 }
225
226 Ok(Keystroke {
227 ctrl,
228 alt,
229 shift,
230 cmd,
231 key: key.unwrap(),
232 })
233 }
234}
235
236impl Context {
237 pub fn extend(&mut self, other: Context) {
238 for v in other.set {
239 self.set.insert(v);
240 }
241 for (k, v) in other.map {
242 self.map.insert(k, v);
243 }
244 }
245}
246
247impl ContextPredicate {
248 fn parse(source: &str) -> anyhow::Result<Self> {
249 let mut parser = Parser::new();
250 let language = unsafe { tree_sitter_context_predicate() };
251 parser.set_language(language).unwrap();
252 let source = source.as_bytes();
253 let tree = parser.parse(source, None).unwrap();
254 Self::from_node(tree.root_node(), source)
255 }
256
257 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
258 let parse_error = "error parsing context predicate";
259 let kind = node.kind();
260
261 match kind {
262 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
263 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
264 "not" => {
265 let child = Self::from_node(
266 node.child_by_field_name("expression")
267 .ok_or(anyhow!(parse_error))?,
268 source,
269 )?;
270 Ok(Self::Not(Box::new(child)))
271 }
272 "and" | "or" => {
273 let left = Box::new(Self::from_node(
274 node.child_by_field_name("left")
275 .ok_or(anyhow!(parse_error))?,
276 source,
277 )?);
278 let right = Box::new(Self::from_node(
279 node.child_by_field_name("right")
280 .ok_or(anyhow!(parse_error))?,
281 source,
282 )?);
283 if kind == "and" {
284 Ok(Self::And(left, right))
285 } else {
286 Ok(Self::Or(left, right))
287 }
288 }
289 "equal" | "not_equal" => {
290 let left = node
291 .child_by_field_name("left")
292 .ok_or(anyhow!(parse_error))?
293 .utf8_text(source)?
294 .into();
295 let right = node
296 .child_by_field_name("right")
297 .ok_or(anyhow!(parse_error))?
298 .utf8_text(source)?
299 .into();
300 if kind == "equal" {
301 Ok(Self::Equal(left, right))
302 } else {
303 Ok(Self::NotEqual(left, right))
304 }
305 }
306 "parenthesized" => Self::from_node(
307 node.child_by_field_name("expression")
308 .ok_or(anyhow!(parse_error))?,
309 source,
310 ),
311 _ => Err(anyhow!(parse_error)),
312 }
313 }
314
315 fn eval(&self, ctx: &Context) -> bool {
316 match self {
317 Self::Identifier(name) => ctx.set.contains(name.as_str()),
318 Self::Equal(left, right) => ctx
319 .map
320 .get(left)
321 .map(|value| value == right)
322 .unwrap_or(false),
323 Self::NotEqual(left, right) => ctx
324 .map
325 .get(left)
326 .map(|value| value != right)
327 .unwrap_or(true),
328 Self::Not(pred) => !pred.eval(ctx),
329 Self::And(left, right) => left.eval(ctx) && right.eval(ctx),
330 Self::Or(left, right) => left.eval(ctx) || right.eval(ctx),
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_keystroke_parsing() -> anyhow::Result<()> {
341 assert_eq!(
342 Keystroke::parse("ctrl-p")?,
343 Keystroke {
344 key: "p".into(),
345 ctrl: true,
346 alt: false,
347 shift: false,
348 cmd: false,
349 }
350 );
351
352 assert_eq!(
353 Keystroke::parse("alt-shift-down")?,
354 Keystroke {
355 key: "down".into(),
356 ctrl: false,
357 alt: true,
358 shift: true,
359 cmd: false,
360 }
361 );
362
363 assert_eq!(
364 Keystroke::parse("shift-cmd--")?,
365 Keystroke {
366 key: "-".into(),
367 ctrl: false,
368 alt: false,
369 shift: true,
370 cmd: true,
371 }
372 );
373
374 Ok(())
375 }
376
377 #[test]
378 fn test_context_predicate_parsing() -> anyhow::Result<()> {
379 use ContextPredicate::*;
380
381 assert_eq!(
382 ContextPredicate::parse("a && (b == c || d != e)")?,
383 And(
384 Box::new(Identifier("a".into())),
385 Box::new(Or(
386 Box::new(Equal("b".into(), "c".into())),
387 Box::new(NotEqual("d".into(), "e".into())),
388 ))
389 )
390 );
391
392 assert_eq!(
393 ContextPredicate::parse("!a")?,
394 Not(Box::new(Identifier("a".into())),)
395 );
396
397 Ok(())
398 }
399
400 #[test]
401 fn test_context_predicate_eval() -> anyhow::Result<()> {
402 let predicate = ContextPredicate::parse("a && b || c == d")?;
403
404 let mut context = Context::default();
405 context.set.insert("a".into());
406 assert!(!predicate.eval(&context));
407
408 context.set.insert("b".into());
409 assert!(predicate.eval(&context));
410
411 context.set.remove("b");
412 context.map.insert("c".into(), "x".into());
413 assert!(!predicate.eval(&context));
414
415 context.map.insert("c".into(), "d".into());
416 assert!(predicate.eval(&context));
417
418 let predicate = ContextPredicate::parse("!a")?;
419 assert!(predicate.eval(&Context::default()));
420
421 Ok(())
422 }
423
424 #[test]
425 fn test_matcher() -> anyhow::Result<()> {
426 #[derive(Clone, Debug, Eq, PartialEq)]
427 struct ActionArg {
428 a: &'static str,
429 }
430
431 let keymap = Keymap(vec![
432 Binding::new("a", "a", Some("a")).with_arg(ActionArg { a: "b" }),
433 Binding::new("b", "b", Some("a")),
434 Binding::new("a b", "a_b", Some("a || b")),
435 ]);
436
437 let mut ctx_a = Context::default();
438 ctx_a.set.insert("a".into());
439
440 let mut ctx_b = Context::default();
441 ctx_b.set.insert("b".into());
442
443 let mut matcher = Matcher::new(keymap);
444
445 // Basic match
446 assert_eq!(
447 matcher.test_keystroke("a", 1, &ctx_a),
448 Some(("a".to_string(), Some(ActionArg { a: "b" })))
449 );
450
451 // Multi-keystroke match
452 assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
453 assert_eq!(
454 matcher.test_keystroke::<()>("b", 1, &ctx_b),
455 Some(("a_b".to_string(), None))
456 );
457
458 // Failed matches don't interfere with matching subsequent keys
459 assert_eq!(matcher.test_keystroke::<()>("x", 1, &ctx_a), None);
460 assert_eq!(
461 matcher.test_keystroke("a", 1, &ctx_a),
462 Some(("a".to_string(), Some(ActionArg { a: "b" })))
463 );
464
465 // Pending keystrokes are cleared when the context changes
466 assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
467 assert_eq!(
468 matcher.test_keystroke::<()>("b", 1, &ctx_a),
469 Some(("b".to_string(), None))
470 );
471
472 let mut ctx_c = Context::default();
473 ctx_c.set.insert("c".into());
474
475 // Pending keystrokes are maintained per-view
476 assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
477 assert_eq!(matcher.test_keystroke::<()>("a", 2, &ctx_c), None);
478 assert_eq!(
479 matcher.test_keystroke::<()>("b", 1, &ctx_b),
480 Some(("a_b".to_string(), None))
481 );
482
483 Ok(())
484 }
485
486 impl Matcher {
487 fn test_keystroke<A: Any + Clone>(
488 &mut self,
489 keystroke: &str,
490 view_id: usize,
491 ctx: &Context,
492 ) -> Option<(String, Option<A>)> {
493 if let MatchResult::Action { name, arg } =
494 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, ctx)
495 {
496 Some((name, arg.and_then(|arg| arg.downcast_ref::<A>().cloned())))
497 } else {
498 None
499 }
500 }
501 }
502}