1use anyhow::anyhow;
2use std::{
3 any::Any,
4 collections::{HashMap, HashSet},
5};
6use tree_sitter::{Language, Node, Parser};
7
8extern "C" {
9 fn tree_sitter_context_predicate() -> Language;
10}
11
12pub struct Matcher {
13 pending: HashMap<usize, Pending>,
14 keymap: Keymap,
15}
16
17#[derive(Default)]
18struct Pending {
19 keystrokes: Vec<Keystroke>,
20 context: Option<Context>,
21}
22
23pub struct Keymap(Vec<Binding>);
24
25pub struct Binding {
26 keystrokes: Vec<Keystroke>,
27 action: String,
28 action_arg: Option<Box<dyn ActionArg>>,
29 context: Option<ContextPredicate>,
30}
31
32#[derive(Clone, Debug, Eq, PartialEq)]
33pub struct Keystroke {
34 pub ctrl: bool,
35 pub alt: bool,
36 pub shift: bool,
37 pub cmd: bool,
38 pub key: String,
39}
40
41#[derive(Clone, Debug, Default, Eq, PartialEq)]
42pub struct Context {
43 pub set: HashSet<String>,
44 pub map: HashMap<String, String>,
45}
46
47#[derive(Debug, Eq, PartialEq)]
48enum ContextPredicate {
49 Identifier(String),
50 Equal(String, String),
51 NotEqual(String, String),
52 Not(Box<ContextPredicate>),
53 And(Box<ContextPredicate>, Box<ContextPredicate>),
54 Or(Box<ContextPredicate>, Box<ContextPredicate>),
55}
56
57trait ActionArg {
58 fn boxed_clone(&self) -> Box<dyn Any>;
59}
60
61impl<T> ActionArg for T
62where
63 T: 'static + Any + Clone,
64{
65 fn boxed_clone(&self) -> Box<dyn Any> {
66 Box::new(self.clone())
67 }
68}
69
70pub enum MatchResult {
71 None,
72 Pending,
73 Action {
74 name: String,
75 arg: Option<Box<dyn Any>>,
76 },
77}
78
79impl Matcher {
80 pub fn new(keymap: Keymap) -> Self {
81 Self {
82 pending: HashMap::new(),
83 keymap,
84 }
85 }
86
87 pub fn set_keymap(&mut self, keymap: Keymap) {
88 self.pending.clear();
89 self.keymap = keymap;
90 }
91
92 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
93 self.pending.clear();
94 self.keymap.add_bindings(bindings);
95 }
96
97 pub fn push_keystroke(
98 &mut self,
99 keystroke: Keystroke,
100 view_id: usize,
101 cx: &Context,
102 ) -> MatchResult {
103 let pending = self.pending.entry(view_id).or_default();
104
105 if let Some(pending_ctx) = pending.context.as_ref() {
106 if pending_ctx != cx {
107 pending.keystrokes.clear();
108 }
109 }
110
111 pending.keystrokes.push(keystroke);
112
113 let mut retain_pending = false;
114 for binding in self.keymap.0.iter().rev() {
115 if binding.keystrokes.starts_with(&pending.keystrokes)
116 && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
117 {
118 if binding.keystrokes.len() == pending.keystrokes.len() {
119 self.pending.remove(&view_id);
120 return MatchResult::Action {
121 name: binding.action.clone(),
122 arg: binding.action_arg.as_ref().map(|arg| (*arg).boxed_clone()),
123 };
124 } else {
125 retain_pending = true;
126 pending.context = Some(cx.clone());
127 }
128 }
129 }
130
131 if retain_pending {
132 MatchResult::Pending
133 } else {
134 self.pending.remove(&view_id);
135 MatchResult::None
136 }
137 }
138}
139
140impl Default for Matcher {
141 fn default() -> Self {
142 Self::new(Keymap::default())
143 }
144}
145
146impl Keymap {
147 pub fn new(bindings: Vec<Binding>) -> Self {
148 Self(bindings)
149 }
150
151 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
152 self.0.extend(bindings.into_iter());
153 }
154}
155
156impl Default for Keymap {
157 fn default() -> Self {
158 Self(vec![
159 Binding::new("up", "menu:select_prev", Some("menu")),
160 Binding::new("ctrl-p", "menu:select_prev", Some("menu")),
161 Binding::new("down", "menu:select_next", Some("menu")),
162 Binding::new("ctrl-n", "menu:select_next", Some("menu")),
163 ])
164 }
165}
166
167impl Binding {
168 pub fn new<S: Into<String>>(keystrokes: &str, action: S, context: Option<&str>) -> Self {
169 let context = if let Some(context) = context {
170 Some(ContextPredicate::parse(context).unwrap())
171 } else {
172 None
173 };
174
175 Self {
176 keystrokes: keystrokes
177 .split_whitespace()
178 .map(|key| Keystroke::parse(key).unwrap())
179 .collect(),
180 action: action.into(),
181 action_arg: None,
182 context,
183 }
184 }
185
186 pub fn with_arg<T: 'static + Any + Clone>(mut self, arg: T) -> Self {
187 self.action_arg = Some(Box::new(arg));
188 self
189 }
190}
191
192impl Keystroke {
193 pub fn parse(source: &str) -> anyhow::Result<Self> {
194 let mut ctrl = false;
195 let mut alt = false;
196 let mut shift = false;
197 let mut cmd = false;
198 let mut key = None;
199
200 let mut components = source.split("-").peekable();
201 while let Some(component) = components.next() {
202 match component {
203 "ctrl" => ctrl = true,
204 "alt" => alt = true,
205 "shift" => shift = true,
206 "cmd" => cmd = true,
207 _ => {
208 if let Some(component) = components.peek() {
209 if component.is_empty() && source.ends_with('-') {
210 key = Some(String::from("-"));
211 break;
212 } else {
213 return Err(anyhow!("Invalid keystroke `{}`", source));
214 }
215 } else {
216 key = Some(String::from(component));
217 }
218 }
219 }
220 }
221
222 Ok(Keystroke {
223 ctrl,
224 alt,
225 shift,
226 cmd,
227 key: key.unwrap(),
228 })
229 }
230}
231
232impl Context {
233 pub fn extend(&mut self, other: Context) {
234 for v in other.set {
235 self.set.insert(v);
236 }
237 for (k, v) in other.map {
238 self.map.insert(k, v);
239 }
240 }
241}
242
243impl ContextPredicate {
244 fn parse(source: &str) -> anyhow::Result<Self> {
245 let mut parser = Parser::new();
246 let language = unsafe { tree_sitter_context_predicate() };
247 parser.set_language(language).unwrap();
248 let source = source.as_bytes();
249 let tree = parser.parse(source, None).unwrap();
250 Self::from_node(tree.root_node(), source)
251 }
252
253 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
254 let parse_error = "error parsing context predicate";
255 let kind = node.kind();
256
257 match kind {
258 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
259 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
260 "not" => {
261 let child = Self::from_node(
262 node.child_by_field_name("expression")
263 .ok_or(anyhow!(parse_error))?,
264 source,
265 )?;
266 Ok(Self::Not(Box::new(child)))
267 }
268 "and" | "or" => {
269 let left = Box::new(Self::from_node(
270 node.child_by_field_name("left")
271 .ok_or(anyhow!(parse_error))?,
272 source,
273 )?);
274 let right = Box::new(Self::from_node(
275 node.child_by_field_name("right")
276 .ok_or(anyhow!(parse_error))?,
277 source,
278 )?);
279 if kind == "and" {
280 Ok(Self::And(left, right))
281 } else {
282 Ok(Self::Or(left, right))
283 }
284 }
285 "equal" | "not_equal" => {
286 let left = node
287 .child_by_field_name("left")
288 .ok_or(anyhow!(parse_error))?
289 .utf8_text(source)?
290 .into();
291 let right = node
292 .child_by_field_name("right")
293 .ok_or(anyhow!(parse_error))?
294 .utf8_text(source)?
295 .into();
296 if kind == "equal" {
297 Ok(Self::Equal(left, right))
298 } else {
299 Ok(Self::NotEqual(left, right))
300 }
301 }
302 "parenthesized" => Self::from_node(
303 node.child_by_field_name("expression")
304 .ok_or(anyhow!(parse_error))?,
305 source,
306 ),
307 _ => Err(anyhow!(parse_error)),
308 }
309 }
310
311 fn eval(&self, cx: &Context) -> bool {
312 match self {
313 Self::Identifier(name) => cx.set.contains(name.as_str()),
314 Self::Equal(left, right) => cx
315 .map
316 .get(left)
317 .map(|value| value == right)
318 .unwrap_or(false),
319 Self::NotEqual(left, right) => {
320 cx.map.get(left).map(|value| value != right).unwrap_or(true)
321 }
322 Self::Not(pred) => !pred.eval(cx),
323 Self::And(left, right) => left.eval(cx) && right.eval(cx),
324 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_keystroke_parsing() -> anyhow::Result<()> {
335 assert_eq!(
336 Keystroke::parse("ctrl-p")?,
337 Keystroke {
338 key: "p".into(),
339 ctrl: true,
340 alt: false,
341 shift: false,
342 cmd: false,
343 }
344 );
345
346 assert_eq!(
347 Keystroke::parse("alt-shift-down")?,
348 Keystroke {
349 key: "down".into(),
350 ctrl: false,
351 alt: true,
352 shift: true,
353 cmd: false,
354 }
355 );
356
357 assert_eq!(
358 Keystroke::parse("shift-cmd--")?,
359 Keystroke {
360 key: "-".into(),
361 ctrl: false,
362 alt: false,
363 shift: true,
364 cmd: true,
365 }
366 );
367
368 Ok(())
369 }
370
371 #[test]
372 fn test_context_predicate_parsing() -> anyhow::Result<()> {
373 use ContextPredicate::*;
374
375 assert_eq!(
376 ContextPredicate::parse("a && (b == c || d != e)")?,
377 And(
378 Box::new(Identifier("a".into())),
379 Box::new(Or(
380 Box::new(Equal("b".into(), "c".into())),
381 Box::new(NotEqual("d".into(), "e".into())),
382 ))
383 )
384 );
385
386 assert_eq!(
387 ContextPredicate::parse("!a")?,
388 Not(Box::new(Identifier("a".into())),)
389 );
390
391 Ok(())
392 }
393
394 #[test]
395 fn test_context_predicate_eval() -> anyhow::Result<()> {
396 let predicate = ContextPredicate::parse("a && b || c == d")?;
397
398 let mut context = Context::default();
399 context.set.insert("a".into());
400 assert!(!predicate.eval(&context));
401
402 context.set.insert("b".into());
403 assert!(predicate.eval(&context));
404
405 context.set.remove("b");
406 context.map.insert("c".into(), "x".into());
407 assert!(!predicate.eval(&context));
408
409 context.map.insert("c".into(), "d".into());
410 assert!(predicate.eval(&context));
411
412 let predicate = ContextPredicate::parse("!a")?;
413 assert!(predicate.eval(&Context::default()));
414
415 Ok(())
416 }
417
418 #[test]
419 fn test_matcher() -> anyhow::Result<()> {
420 #[derive(Clone, Debug, Eq, PartialEq)]
421 struct ActionArg {
422 a: &'static str,
423 }
424
425 let keymap = Keymap(vec![
426 Binding::new("a", "a", Some("a")).with_arg(ActionArg { a: "b" }),
427 Binding::new("b", "b", Some("a")),
428 Binding::new("a b", "a_b", Some("a || b")),
429 ]);
430
431 let mut ctx_a = Context::default();
432 ctx_a.set.insert("a".into());
433
434 let mut ctx_b = Context::default();
435 ctx_b.set.insert("b".into());
436
437 let mut matcher = Matcher::new(keymap);
438
439 // Basic match
440 assert_eq!(
441 matcher.test_keystroke("a", 1, &ctx_a),
442 Some(("a".to_string(), Some(ActionArg { a: "b" })))
443 );
444
445 // Multi-keystroke match
446 assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
447 assert_eq!(
448 matcher.test_keystroke::<()>("b", 1, &ctx_b),
449 Some(("a_b".to_string(), None))
450 );
451
452 // Failed matches don't interfere with matching subsequent keys
453 assert_eq!(matcher.test_keystroke::<()>("x", 1, &ctx_a), None);
454 assert_eq!(
455 matcher.test_keystroke("a", 1, &ctx_a),
456 Some(("a".to_string(), Some(ActionArg { a: "b" })))
457 );
458
459 // Pending keystrokes are cleared when the context changes
460 assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
461 assert_eq!(
462 matcher.test_keystroke::<()>("b", 1, &ctx_a),
463 Some(("b".to_string(), None))
464 );
465
466 let mut ctx_c = Context::default();
467 ctx_c.set.insert("c".into());
468
469 // Pending keystrokes are maintained per-view
470 assert_eq!(matcher.test_keystroke::<()>("a", 1, &ctx_b), None);
471 assert_eq!(matcher.test_keystroke::<()>("a", 2, &ctx_c), None);
472 assert_eq!(
473 matcher.test_keystroke::<()>("b", 1, &ctx_b),
474 Some(("a_b".to_string(), None))
475 );
476
477 Ok(())
478 }
479
480 impl Matcher {
481 fn test_keystroke<A: Any + Clone>(
482 &mut self,
483 keystroke: &str,
484 view_id: usize,
485 cx: &Context,
486 ) -> Option<(String, Option<A>)> {
487 if let MatchResult::Action { name, arg } =
488 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
489 {
490 Some((name, arg.and_then(|arg| arg.downcast_ref::<A>().cloned())))
491 } else {
492 None
493 }
494 }
495 }
496}