1use crate::Action;
2use anyhow::{anyhow, Result};
3use smallvec::SmallVec;
4use std::{
5 any::{Any, TypeId},
6 collections::{HashMap, HashSet},
7 fmt::Debug,
8};
9use tree_sitter::{Language, Node, Parser};
10
11extern "C" {
12 fn tree_sitter_context_predicate() -> Language;
13}
14
15pub struct Matcher {
16 pending: HashMap<usize, Pending>,
17 keymap: Keymap,
18}
19
20#[derive(Default)]
21struct Pending {
22 keystrokes: Vec<Keystroke>,
23 context: Option<Context>,
24}
25
26#[derive(Default)]
27pub struct Keymap {
28 bindings: Vec<Binding>,
29 binding_indices_by_action_type: HashMap<TypeId, SmallVec<[usize; 3]>>,
30}
31
32pub struct Binding {
33 keystrokes: Vec<Keystroke>,
34 action: Box<dyn Action>,
35 context: Option<ContextPredicate>,
36}
37
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub struct Keystroke {
40 pub ctrl: bool,
41 pub alt: bool,
42 pub shift: bool,
43 pub cmd: bool,
44 pub key: String,
45}
46
47#[derive(Clone, Debug, Default, Eq, PartialEq)]
48pub struct Context {
49 pub set: HashSet<String>,
50 pub map: HashMap<String, String>,
51}
52
53#[derive(Debug, Eq, PartialEq)]
54enum ContextPredicate {
55 Identifier(String),
56 Equal(String, String),
57 NotEqual(String, String),
58 Not(Box<ContextPredicate>),
59 And(Box<ContextPredicate>, Box<ContextPredicate>),
60 Or(Box<ContextPredicate>, Box<ContextPredicate>),
61}
62
63trait ActionArg {
64 fn boxed_clone(&self) -> Box<dyn Any>;
65}
66
67impl<T> ActionArg for T
68where
69 T: 'static + Any + Clone,
70{
71 fn boxed_clone(&self) -> Box<dyn Any> {
72 Box::new(self.clone())
73 }
74}
75
76pub enum MatchResult {
77 None,
78 Pending,
79 Action(Box<dyn Action>),
80}
81
82impl Debug for MatchResult {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 MatchResult::None => f.debug_struct("MatchResult::None").finish(),
86 MatchResult::Pending => f.debug_struct("MatchResult::Pending").finish(),
87 MatchResult::Action(action) => f
88 .debug_tuple("MatchResult::Action")
89 .field(&action.name())
90 .finish(),
91 }
92 }
93}
94
95impl Matcher {
96 pub fn new(keymap: Keymap) -> Self {
97 Self {
98 pending: HashMap::new(),
99 keymap,
100 }
101 }
102
103 pub fn set_keymap(&mut self, keymap: Keymap) {
104 self.pending.clear();
105 self.keymap = keymap;
106 }
107
108 pub fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
109 self.pending.clear();
110 self.keymap.add_bindings(bindings);
111 }
112
113 pub fn clear_bindings(&mut self) {
114 self.pending.clear();
115 self.keymap.clear();
116 }
117
118 pub fn bindings_for_action_type(&self, action_type: TypeId) -> impl Iterator<Item = &Binding> {
119 self.keymap.bindings_for_action_type(action_type)
120 }
121
122 pub fn clear_pending(&mut self) {
123 self.pending.clear();
124 }
125
126 pub fn push_keystroke(
127 &mut self,
128 keystroke: Keystroke,
129 view_id: usize,
130 cx: &Context,
131 ) -> MatchResult {
132 let pending = self.pending.entry(view_id).or_default();
133
134 if let Some(pending_ctx) = pending.context.as_ref() {
135 if pending_ctx != cx {
136 pending.keystrokes.clear();
137 }
138 }
139
140 pending.keystrokes.push(keystroke);
141
142 let mut retain_pending = false;
143 for binding in self.keymap.bindings.iter().rev() {
144 if binding.keystrokes.starts_with(&pending.keystrokes)
145 && binding.context.as_ref().map(|c| c.eval(cx)).unwrap_or(true)
146 {
147 if binding.keystrokes.len() == pending.keystrokes.len() {
148 self.pending.remove(&view_id);
149 return MatchResult::Action(binding.action.boxed_clone());
150 } else {
151 retain_pending = true;
152 pending.context = Some(cx.clone());
153 }
154 }
155 }
156
157 if retain_pending {
158 MatchResult::Pending
159 } else {
160 self.pending.remove(&view_id);
161 MatchResult::None
162 }
163 }
164}
165
166impl Default for Matcher {
167 fn default() -> Self {
168 Self::new(Keymap::default())
169 }
170}
171
172impl Keymap {
173 pub fn new(bindings: Vec<Binding>) -> Self {
174 let mut binding_indices_by_action_type = HashMap::new();
175 for (ix, binding) in bindings.iter().enumerate() {
176 binding_indices_by_action_type
177 .entry(binding.action.as_any().type_id())
178 .or_insert_with(|| SmallVec::new())
179 .push(ix);
180 }
181 Self {
182 binding_indices_by_action_type,
183 bindings,
184 }
185 }
186
187 fn bindings_for_action_type<'a>(
188 &'a self,
189 action_type: TypeId,
190 ) -> impl Iterator<Item = &'a Binding> {
191 self.binding_indices_by_action_type
192 .get(&action_type)
193 .map(SmallVec::as_slice)
194 .unwrap_or(&[])
195 .iter()
196 .map(|ix| &self.bindings[*ix])
197 }
198
199 fn add_bindings<T: IntoIterator<Item = Binding>>(&mut self, bindings: T) {
200 for binding in bindings {
201 self.binding_indices_by_action_type
202 .entry(binding.action.as_any().type_id())
203 .or_default()
204 .push(self.bindings.len());
205 self.bindings.push(binding);
206 }
207 }
208
209 fn clear(&mut self) {
210 self.bindings.clear();
211 self.binding_indices_by_action_type.clear();
212 }
213}
214
215impl Binding {
216 pub fn new<A: Action>(keystrokes: &str, action: A, context: Option<&str>) -> Self {
217 Self::load(keystrokes, Box::new(action), context).unwrap()
218 }
219
220 pub fn load(keystrokes: &str, action: Box<dyn Action>, context: Option<&str>) -> Result<Self> {
221 let context = if let Some(context) = context {
222 Some(ContextPredicate::parse(context)?)
223 } else {
224 None
225 };
226
227 let keystrokes = keystrokes
228 .split_whitespace()
229 .map(|key| Keystroke::parse(key))
230 .collect::<Result<_>>()?;
231
232 Ok(Self {
233 keystrokes,
234 action,
235 context,
236 })
237 }
238
239 pub fn keystrokes(&self) -> &[Keystroke] {
240 &self.keystrokes
241 }
242}
243
244impl Keystroke {
245 pub fn parse(source: &str) -> anyhow::Result<Self> {
246 let mut ctrl = false;
247 let mut alt = false;
248 let mut shift = false;
249 let mut cmd = false;
250 let mut key = None;
251
252 let mut components = source.split("-").peekable();
253 while let Some(component) = components.next() {
254 match component {
255 "ctrl" => ctrl = true,
256 "alt" => alt = true,
257 "shift" => shift = true,
258 "cmd" => cmd = true,
259 _ => {
260 if let Some(component) = components.peek() {
261 if component.is_empty() && source.ends_with('-') {
262 key = Some(String::from("-"));
263 break;
264 } else {
265 return Err(anyhow!("Invalid keystroke `{}`", source));
266 }
267 } else {
268 key = Some(String::from(component));
269 }
270 }
271 }
272 }
273
274 Ok(Keystroke {
275 ctrl,
276 alt,
277 shift,
278 cmd,
279 key: key.unwrap(),
280 })
281 }
282
283 pub fn modified(&self) -> bool {
284 self.ctrl || self.alt || self.shift || self.cmd
285 }
286}
287
288impl Context {
289 pub fn extend(&mut self, other: &Context) {
290 for v in &other.set {
291 self.set.insert(v.clone());
292 }
293 for (k, v) in &other.map {
294 self.map.insert(k.clone(), v.clone());
295 }
296 }
297}
298
299impl ContextPredicate {
300 fn parse(source: &str) -> anyhow::Result<Self> {
301 let mut parser = Parser::new();
302 let language = unsafe { tree_sitter_context_predicate() };
303 parser.set_language(language).unwrap();
304 let source = source.as_bytes();
305 let tree = parser.parse(source, None).unwrap();
306 Self::from_node(tree.root_node(), source)
307 }
308
309 fn from_node(node: Node, source: &[u8]) -> anyhow::Result<Self> {
310 let parse_error = "error parsing context predicate";
311 let kind = node.kind();
312
313 match kind {
314 "source" => Self::from_node(node.child(0).ok_or(anyhow!(parse_error))?, source),
315 "identifier" => Ok(Self::Identifier(node.utf8_text(source)?.into())),
316 "not" => {
317 let child = Self::from_node(
318 node.child_by_field_name("expression")
319 .ok_or(anyhow!(parse_error))?,
320 source,
321 )?;
322 Ok(Self::Not(Box::new(child)))
323 }
324 "and" | "or" => {
325 let left = Box::new(Self::from_node(
326 node.child_by_field_name("left")
327 .ok_or(anyhow!(parse_error))?,
328 source,
329 )?);
330 let right = Box::new(Self::from_node(
331 node.child_by_field_name("right")
332 .ok_or(anyhow!(parse_error))?,
333 source,
334 )?);
335 if kind == "and" {
336 Ok(Self::And(left, right))
337 } else {
338 Ok(Self::Or(left, right))
339 }
340 }
341 "equal" | "not_equal" => {
342 let left = node
343 .child_by_field_name("left")
344 .ok_or(anyhow!(parse_error))?
345 .utf8_text(source)?
346 .into();
347 let right = node
348 .child_by_field_name("right")
349 .ok_or(anyhow!(parse_error))?
350 .utf8_text(source)?
351 .into();
352 if kind == "equal" {
353 Ok(Self::Equal(left, right))
354 } else {
355 Ok(Self::NotEqual(left, right))
356 }
357 }
358 "parenthesized" => Self::from_node(
359 node.child_by_field_name("expression")
360 .ok_or(anyhow!(parse_error))?,
361 source,
362 ),
363 _ => Err(anyhow!(parse_error)),
364 }
365 }
366
367 fn eval(&self, cx: &Context) -> bool {
368 match self {
369 Self::Identifier(name) => cx.set.contains(name.as_str()),
370 Self::Equal(left, right) => cx
371 .map
372 .get(left)
373 .map(|value| value == right)
374 .unwrap_or(false),
375 Self::NotEqual(left, right) => {
376 cx.map.get(left).map(|value| value != right).unwrap_or(true)
377 }
378 Self::Not(pred) => !pred.eval(cx),
379 Self::And(left, right) => left.eval(cx) && right.eval(cx),
380 Self::Or(left, right) => left.eval(cx) || right.eval(cx),
381 }
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use serde::Deserialize;
388
389 use crate::{actions, impl_actions};
390
391 use super::*;
392
393 #[test]
394 fn test_keystroke_parsing() -> anyhow::Result<()> {
395 assert_eq!(
396 Keystroke::parse("ctrl-p")?,
397 Keystroke {
398 key: "p".into(),
399 ctrl: true,
400 alt: false,
401 shift: false,
402 cmd: false,
403 }
404 );
405
406 assert_eq!(
407 Keystroke::parse("alt-shift-down")?,
408 Keystroke {
409 key: "down".into(),
410 ctrl: false,
411 alt: true,
412 shift: true,
413 cmd: false,
414 }
415 );
416
417 assert_eq!(
418 Keystroke::parse("shift-cmd--")?,
419 Keystroke {
420 key: "-".into(),
421 ctrl: false,
422 alt: false,
423 shift: true,
424 cmd: true,
425 }
426 );
427
428 Ok(())
429 }
430
431 #[test]
432 fn test_context_predicate_parsing() -> anyhow::Result<()> {
433 use ContextPredicate::*;
434
435 assert_eq!(
436 ContextPredicate::parse("a && (b == c || d != e)")?,
437 And(
438 Box::new(Identifier("a".into())),
439 Box::new(Or(
440 Box::new(Equal("b".into(), "c".into())),
441 Box::new(NotEqual("d".into(), "e".into())),
442 ))
443 )
444 );
445
446 assert_eq!(
447 ContextPredicate::parse("!a")?,
448 Not(Box::new(Identifier("a".into())),)
449 );
450
451 Ok(())
452 }
453
454 #[test]
455 fn test_context_predicate_eval() -> anyhow::Result<()> {
456 let predicate = ContextPredicate::parse("a && b || c == d")?;
457
458 let mut context = Context::default();
459 context.set.insert("a".into());
460 assert!(!predicate.eval(&context));
461
462 context.set.insert("b".into());
463 assert!(predicate.eval(&context));
464
465 context.set.remove("b");
466 context.map.insert("c".into(), "x".into());
467 assert!(!predicate.eval(&context));
468
469 context.map.insert("c".into(), "d".into());
470 assert!(predicate.eval(&context));
471
472 let predicate = ContextPredicate::parse("!a")?;
473 assert!(predicate.eval(&Context::default()));
474
475 Ok(())
476 }
477
478 #[test]
479 fn test_matcher() -> anyhow::Result<()> {
480 #[derive(Clone, Deserialize, PartialEq, Eq, Debug)]
481 pub struct A(pub String);
482 impl_actions!(test, [A]);
483 actions!(test, [B, Ab]);
484
485 #[derive(Clone, Debug, Eq, PartialEq)]
486 struct ActionArg {
487 a: &'static str,
488 }
489
490 let keymap = Keymap::new(vec![
491 Binding::new("a", A("x".to_string()), Some("a")),
492 Binding::new("b", B, Some("a")),
493 Binding::new("a b", Ab, Some("a || b")),
494 ]);
495
496 let mut ctx_a = Context::default();
497 ctx_a.set.insert("a".into());
498
499 let mut ctx_b = Context::default();
500 ctx_b.set.insert("b".into());
501
502 let mut matcher = Matcher::new(keymap);
503
504 // Basic match
505 assert_eq!(
506 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
507 Some(&A("x".to_string()))
508 );
509
510 // Multi-keystroke match
511 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
512 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
513
514 // Failed matches don't interfere with matching subsequent keys
515 assert!(matcher.test_keystroke("x", 1, &ctx_a).is_none());
516 assert_eq!(
517 downcast(&matcher.test_keystroke("a", 1, &ctx_a)),
518 Some(&A("x".to_string()))
519 );
520
521 // Pending keystrokes are cleared when the context changes
522 assert!(&matcher.test_keystroke("a", 1, &ctx_b).is_none());
523 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_a)), Some(&B));
524
525 let mut ctx_c = Context::default();
526 ctx_c.set.insert("c".into());
527
528 // Pending keystrokes are maintained per-view
529 assert!(matcher.test_keystroke("a", 1, &ctx_b).is_none());
530 assert!(matcher.test_keystroke("a", 2, &ctx_c).is_none());
531 assert_eq!(downcast(&matcher.test_keystroke("b", 1, &ctx_b)), Some(&Ab));
532
533 Ok(())
534 }
535
536 fn downcast<'a, A: Action>(action: &'a Option<Box<dyn Action>>) -> Option<&'a A> {
537 action
538 .as_ref()
539 .and_then(|action| action.as_any().downcast_ref())
540 }
541
542 impl Matcher {
543 fn test_keystroke(
544 &mut self,
545 keystroke: &str,
546 view_id: usize,
547 cx: &Context,
548 ) -> Option<Box<dyn Action>> {
549 if let MatchResult::Action(action) =
550 self.push_keystroke(Keystroke::parse(keystroke).unwrap(), view_id, cx)
551 {
552 Some(action.boxed_clone())
553 } else {
554 None
555 }
556 }
557 }
558}