action.rs

  1use crate::SharedString;
  2use anyhow::{anyhow, Context, Result};
  3use collections::{HashMap, HashSet};
  4use lazy_static::lazy_static;
  5use parking_lot::{MappedRwLockReadGuard, RwLock, RwLockReadGuard};
  6use serde::Deserialize;
  7use std::any::{type_name, Any, TypeId};
  8
  9/// Actions are used to implement keyboard-driven UI.
 10/// When you declare an action, you can bind keys to the action in the keymap and
 11/// listeners for that action in the element tree.
 12///
 13/// To declare a list of simple actions, you can use the actions! macro, which defines a simple unit struct
 14/// action for each listed action name.
 15/// ```rust
 16/// actions!(MoveUp, MoveDown, MoveLeft, MoveRight, Newline);
 17/// ```
 18/// More complex data types can also be actions. If you annotate your type with the `#[action]` proc macro,
 19/// it will automatically
 20/// ```
 21/// #[action]
 22/// pub struct SelectNext {
 23///     pub replace_newest: bool,
 24/// }
 25///
 26/// Any type A that satisfies the following bounds is automatically an action:
 27///
 28/// ```
 29/// A: for<'a> Deserialize<'a> + PartialEq + Clone + Default + std::fmt::Debug + 'static,
 30/// ```
 31///
 32/// The `#[action]` annotation will derive these implementations for your struct automatically. If you
 33/// want to control them manually, you can use the lower-level `#[register_action]` macro, which only
 34/// generates the code needed to register your action before `main`. Then you'll need to implement all
 35/// the traits manually.
 36///
 37/// ```
 38/// #[gpui::register_action]
 39/// #[derive(gpui::serde::Deserialize, std::cmp::PartialEq, std::clone::Clone, std::fmt::Debug)]
 40/// pub struct Paste {
 41///     pub content: SharedString,
 42/// }
 43///
 44/// impl std::default::Default for Paste {
 45///     fn default() -> Self {
 46///         Self {
 47///             content: SharedString::from("🍝"),
 48///         }
 49///     }
 50/// }
 51/// ```
 52pub trait Action: std::fmt::Debug + 'static {
 53    fn qualified_name() -> SharedString
 54    where
 55        Self: Sized;
 56    fn build(value: Option<serde_json::Value>) -> Result<Box<dyn Action>>
 57    where
 58        Self: Sized;
 59
 60    fn partial_eq(&self, action: &dyn Action) -> bool;
 61    fn boxed_clone(&self) -> Box<dyn Action>;
 62    fn as_any(&self) -> &dyn Any;
 63}
 64
 65// Types become actions by satisfying a list of trait bounds.
 66impl<A> Action for A
 67where
 68    A: for<'a> Deserialize<'a> + PartialEq + Clone + Default + std::fmt::Debug + 'static,
 69{
 70    fn qualified_name() -> SharedString {
 71        // todo!() remove the 2 replacement when migration is done
 72        type_name::<A>().replace("2::", "::").into()
 73    }
 74
 75    fn build(params: Option<serde_json::Value>) -> Result<Box<dyn Action>>
 76    where
 77        Self: Sized,
 78    {
 79        let action = if let Some(params) = params {
 80            serde_json::from_value(params).context("failed to deserialize action")?
 81        } else {
 82            Self::default()
 83        };
 84        Ok(Box::new(action))
 85    }
 86
 87    fn partial_eq(&self, action: &dyn Action) -> bool {
 88        action
 89            .as_any()
 90            .downcast_ref::<Self>()
 91            .map_or(false, |a| self == a)
 92    }
 93
 94    fn boxed_clone(&self) -> Box<dyn Action> {
 95        Box::new(self.clone())
 96    }
 97
 98    fn as_any(&self) -> &dyn Any {
 99        self
100    }
101}
102
103impl dyn Action {
104    pub fn type_id(&self) -> TypeId {
105        self.as_any().type_id()
106    }
107
108    pub fn name(&self) -> SharedString {
109        ACTION_REGISTRY
110            .read()
111            .names_by_type_id
112            .get(&self.type_id())
113            .expect("type is not a registered action")
114            .clone()
115    }
116}
117
118type ActionBuilder = fn(json: Option<serde_json::Value>) -> anyhow::Result<Box<dyn Action>>;
119
120lazy_static! {
121    static ref ACTION_REGISTRY: RwLock<ActionRegistry> = RwLock::default();
122}
123
124#[derive(Default)]
125struct ActionRegistry {
126    builders_by_name: HashMap<SharedString, ActionBuilder>,
127    names_by_type_id: HashMap<TypeId, SharedString>,
128    all_names: Vec<SharedString>, // So we can return a static slice.
129}
130
131/// Register an action type to allow it to be referenced in keymaps.
132pub fn register_action<A: Action>() {
133    let name = A::qualified_name();
134    let mut lock = ACTION_REGISTRY.write();
135    lock.builders_by_name.insert(name.clone(), A::build);
136    lock.names_by_type_id
137        .insert(TypeId::of::<A>(), name.clone());
138    lock.all_names.push(name);
139}
140
141/// Construct an action based on its name and optional JSON parameters sourced from the keymap.
142pub fn build_action_from_type(type_id: &TypeId) -> Result<Box<dyn Action>> {
143    let lock = ACTION_REGISTRY.read();
144    let name = lock
145        .names_by_type_id
146        .get(type_id)
147        .ok_or_else(|| anyhow!("no action type registered for {:?}", type_id))?
148        .clone();
149    drop(lock);
150
151    build_action(&name, None)
152}
153
154/// Construct an action based on its name and optional JSON parameters sourced from the keymap.
155pub fn build_action(name: &str, params: Option<serde_json::Value>) -> Result<Box<dyn Action>> {
156    let lock = ACTION_REGISTRY.read();
157
158    let build_action = lock
159        .builders_by_name
160        .get(name)
161        .ok_or_else(|| anyhow!("no action type registered for {}", name))?;
162    (build_action)(params)
163}
164
165pub fn all_action_names() -> MappedRwLockReadGuard<'static, [SharedString]> {
166    let lock = ACTION_REGISTRY.read();
167    RwLockReadGuard::map(lock, |registry: &ActionRegistry| {
168        registry.all_names.as_slice()
169    })
170}
171
172/// Defines unit structs that can be used as actions.
173/// To use more complex data types as actions, annotate your type with the #[action] macro.
174#[macro_export]
175macro_rules! actions {
176    () => {};
177
178    ( $name:ident ) => {
179        #[gpui::register_action]
180        #[derive(::std::clone::Clone, ::std::default::Default, ::std::fmt::Debug, ::std::cmp::PartialEq, $crate::serde::Deserialize)]
181        pub struct $name;
182    };
183
184    ( $name:ident, $($rest:tt)* ) => {
185        actions!($name);
186        actions!($($rest)*);
187    };
188}
189
190#[derive(Clone, Debug, Default, Eq, PartialEq)]
191pub struct DispatchContext {
192    set: HashSet<SharedString>,
193    map: HashMap<SharedString, SharedString>,
194}
195
196impl<'a> TryFrom<&'a str> for DispatchContext {
197    type Error = anyhow::Error;
198
199    fn try_from(value: &'a str) -> Result<Self> {
200        Self::parse(value)
201    }
202}
203
204impl DispatchContext {
205    pub fn parse(source: &str) -> Result<Self> {
206        let mut context = Self::default();
207        let source = skip_whitespace(source);
208        Self::parse_expr(&source, &mut context)?;
209        Ok(context)
210    }
211
212    fn parse_expr(mut source: &str, context: &mut Self) -> Result<()> {
213        if source.is_empty() {
214            return Ok(());
215        }
216
217        let key = source
218            .chars()
219            .take_while(|c| is_identifier_char(*c))
220            .collect::<String>();
221        source = skip_whitespace(&source[key.len()..]);
222        if let Some(suffix) = source.strip_prefix('=') {
223            source = skip_whitespace(suffix);
224            let value = source
225                .chars()
226                .take_while(|c| is_identifier_char(*c))
227                .collect::<String>();
228            source = skip_whitespace(&source[value.len()..]);
229            context.set(key, value);
230        } else {
231            context.insert(key);
232        }
233
234        Self::parse_expr(source, context)
235    }
236
237    pub fn is_empty(&self) -> bool {
238        self.set.is_empty() && self.map.is_empty()
239    }
240
241    pub fn clear(&mut self) {
242        self.set.clear();
243        self.map.clear();
244    }
245
246    pub fn extend(&mut self, other: &Self) {
247        for v in &other.set {
248            self.set.insert(v.clone());
249        }
250        for (k, v) in &other.map {
251            self.map.insert(k.clone(), v.clone());
252        }
253    }
254
255    pub fn insert<I: Into<SharedString>>(&mut self, identifier: I) {
256        self.set.insert(identifier.into());
257    }
258
259    pub fn set<S1: Into<SharedString>, S2: Into<SharedString>>(&mut self, key: S1, value: S2) {
260        self.map.insert(key.into(), value.into());
261    }
262}
263
264#[derive(Clone, Debug, Eq, PartialEq, Hash)]
265pub enum DispatchContextPredicate {
266    Identifier(SharedString),
267    Equal(SharedString, SharedString),
268    NotEqual(SharedString, SharedString),
269    Child(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
270    Not(Box<DispatchContextPredicate>),
271    And(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
272    Or(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
273}
274
275impl DispatchContextPredicate {
276    pub fn parse(source: &str) -> Result<Self> {
277        let source = skip_whitespace(source);
278        let (predicate, rest) = Self::parse_expr(source, 0)?;
279        if let Some(next) = rest.chars().next() {
280            Err(anyhow!("unexpected character {next:?}"))
281        } else {
282            Ok(predicate)
283        }
284    }
285
286    pub fn eval(&self, contexts: &[&DispatchContext]) -> bool {
287        let Some(context) = contexts.last() else {
288            return false;
289        };
290        match self {
291            Self::Identifier(name) => context.set.contains(name),
292            Self::Equal(left, right) => context
293                .map
294                .get(left)
295                .map(|value| value == right)
296                .unwrap_or(false),
297            Self::NotEqual(left, right) => context
298                .map
299                .get(left)
300                .map(|value| value != right)
301                .unwrap_or(true),
302            Self::Not(pred) => !pred.eval(contexts),
303            Self::Child(parent, child) => {
304                parent.eval(&contexts[..contexts.len() - 1]) && child.eval(contexts)
305            }
306            Self::And(left, right) => left.eval(contexts) && right.eval(contexts),
307            Self::Or(left, right) => left.eval(contexts) || right.eval(contexts),
308        }
309    }
310
311    fn parse_expr(mut source: &str, min_precedence: u32) -> anyhow::Result<(Self, &str)> {
312        type Op = fn(
313            DispatchContextPredicate,
314            DispatchContextPredicate,
315        ) -> Result<DispatchContextPredicate>;
316
317        let (mut predicate, rest) = Self::parse_primary(source)?;
318        source = rest;
319
320        'parse: loop {
321            for (operator, precedence, constructor) in [
322                (">", PRECEDENCE_CHILD, Self::new_child as Op),
323                ("&&", PRECEDENCE_AND, Self::new_and as Op),
324                ("||", PRECEDENCE_OR, Self::new_or as Op),
325                ("==", PRECEDENCE_EQ, Self::new_eq as Op),
326                ("!=", PRECEDENCE_EQ, Self::new_neq as Op),
327            ] {
328                if source.starts_with(operator) && precedence >= min_precedence {
329                    source = skip_whitespace(&source[operator.len()..]);
330                    let (right, rest) = Self::parse_expr(source, precedence + 1)?;
331                    predicate = constructor(predicate, right)?;
332                    source = rest;
333                    continue 'parse;
334                }
335            }
336            break;
337        }
338
339        Ok((predicate, source))
340    }
341
342    fn parse_primary(mut source: &str) -> anyhow::Result<(Self, &str)> {
343        let next = source
344            .chars()
345            .next()
346            .ok_or_else(|| anyhow!("unexpected eof"))?;
347        match next {
348            '(' => {
349                source = skip_whitespace(&source[1..]);
350                let (predicate, rest) = Self::parse_expr(source, 0)?;
351                if rest.starts_with(')') {
352                    source = skip_whitespace(&rest[1..]);
353                    Ok((predicate, source))
354                } else {
355                    Err(anyhow!("expected a ')'"))
356                }
357            }
358            '!' => {
359                let source = skip_whitespace(&source[1..]);
360                let (predicate, source) = Self::parse_expr(&source, PRECEDENCE_NOT)?;
361                Ok((DispatchContextPredicate::Not(Box::new(predicate)), source))
362            }
363            _ if is_identifier_char(next) => {
364                let len = source
365                    .find(|c: char| !is_identifier_char(c))
366                    .unwrap_or(source.len());
367                let (identifier, rest) = source.split_at(len);
368                source = skip_whitespace(rest);
369                Ok((
370                    DispatchContextPredicate::Identifier(identifier.to_string().into()),
371                    source,
372                ))
373            }
374            _ => Err(anyhow!("unexpected character {next:?}")),
375        }
376    }
377
378    fn new_or(self, other: Self) -> Result<Self> {
379        Ok(Self::Or(Box::new(self), Box::new(other)))
380    }
381
382    fn new_and(self, other: Self) -> Result<Self> {
383        Ok(Self::And(Box::new(self), Box::new(other)))
384    }
385
386    fn new_child(self, other: Self) -> Result<Self> {
387        Ok(Self::Child(Box::new(self), Box::new(other)))
388    }
389
390    fn new_eq(self, other: Self) -> Result<Self> {
391        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
392            Ok(Self::Equal(left, right))
393        } else {
394            Err(anyhow!("operands must be identifiers"))
395        }
396    }
397
398    fn new_neq(self, other: Self) -> Result<Self> {
399        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
400            Ok(Self::NotEqual(left, right))
401        } else {
402            Err(anyhow!("operands must be identifiers"))
403        }
404    }
405}
406
407const PRECEDENCE_CHILD: u32 = 1;
408const PRECEDENCE_OR: u32 = 2;
409const PRECEDENCE_AND: u32 = 3;
410const PRECEDENCE_EQ: u32 = 4;
411const PRECEDENCE_NOT: u32 = 5;
412
413fn is_identifier_char(c: char) -> bool {
414    c.is_alphanumeric() || c == '_' || c == '-'
415}
416
417fn skip_whitespace(source: &str) -> &str {
418    let len = source
419        .find(|c: char| !c.is_whitespace())
420        .unwrap_or(source.len());
421    &source[len..]
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate as gpui;
428    use DispatchContextPredicate::*;
429
430    #[test]
431    fn test_actions_definition() {
432        {
433            actions!(A, B, C, D, E, F, G);
434        }
435
436        {
437            actions!(
438                A,
439                B,
440                C,
441                D,
442                E,
443                F,
444                G, // Don't wrap, test the trailing comma
445            );
446        }
447    }
448
449    #[test]
450    fn test_parse_context() {
451        let mut expected = DispatchContext::default();
452        expected.set("foo", "bar");
453        expected.insert("baz");
454        assert_eq!(DispatchContext::parse("baz foo=bar").unwrap(), expected);
455        assert_eq!(DispatchContext::parse("foo = bar baz").unwrap(), expected);
456        assert_eq!(
457            DispatchContext::parse("  baz foo   =   bar baz").unwrap(),
458            expected
459        );
460        assert_eq!(DispatchContext::parse(" foo = bar baz").unwrap(), expected);
461    }
462
463    #[test]
464    fn test_parse_identifiers() {
465        // Identifiers
466        assert_eq!(
467            DispatchContextPredicate::parse("abc12").unwrap(),
468            Identifier("abc12".into())
469        );
470        assert_eq!(
471            DispatchContextPredicate::parse("_1a").unwrap(),
472            Identifier("_1a".into())
473        );
474    }
475
476    #[test]
477    fn test_parse_negations() {
478        assert_eq!(
479            DispatchContextPredicate::parse("!abc").unwrap(),
480            Not(Box::new(Identifier("abc".into())))
481        );
482        assert_eq!(
483            DispatchContextPredicate::parse(" ! ! abc").unwrap(),
484            Not(Box::new(Not(Box::new(Identifier("abc".into())))))
485        );
486    }
487
488    #[test]
489    fn test_parse_equality_operators() {
490        assert_eq!(
491            DispatchContextPredicate::parse("a == b").unwrap(),
492            Equal("a".into(), "b".into())
493        );
494        assert_eq!(
495            DispatchContextPredicate::parse("c!=d").unwrap(),
496            NotEqual("c".into(), "d".into())
497        );
498        assert_eq!(
499            DispatchContextPredicate::parse("c == !d")
500                .unwrap_err()
501                .to_string(),
502            "operands must be identifiers"
503        );
504    }
505
506    #[test]
507    fn test_parse_boolean_operators() {
508        assert_eq!(
509            DispatchContextPredicate::parse("a || b").unwrap(),
510            Or(
511                Box::new(Identifier("a".into())),
512                Box::new(Identifier("b".into()))
513            )
514        );
515        assert_eq!(
516            DispatchContextPredicate::parse("a || !b && c").unwrap(),
517            Or(
518                Box::new(Identifier("a".into())),
519                Box::new(And(
520                    Box::new(Not(Box::new(Identifier("b".into())))),
521                    Box::new(Identifier("c".into()))
522                ))
523            )
524        );
525        assert_eq!(
526            DispatchContextPredicate::parse("a && b || c&&d").unwrap(),
527            Or(
528                Box::new(And(
529                    Box::new(Identifier("a".into())),
530                    Box::new(Identifier("b".into()))
531                )),
532                Box::new(And(
533                    Box::new(Identifier("c".into())),
534                    Box::new(Identifier("d".into()))
535                ))
536            )
537        );
538        assert_eq!(
539            DispatchContextPredicate::parse("a == b && c || d == e && f").unwrap(),
540            Or(
541                Box::new(And(
542                    Box::new(Equal("a".into(), "b".into())),
543                    Box::new(Identifier("c".into()))
544                )),
545                Box::new(And(
546                    Box::new(Equal("d".into(), "e".into())),
547                    Box::new(Identifier("f".into()))
548                ))
549            )
550        );
551        assert_eq!(
552            DispatchContextPredicate::parse("a && b && c && d").unwrap(),
553            And(
554                Box::new(And(
555                    Box::new(And(
556                        Box::new(Identifier("a".into())),
557                        Box::new(Identifier("b".into()))
558                    )),
559                    Box::new(Identifier("c".into())),
560                )),
561                Box::new(Identifier("d".into()))
562            ),
563        );
564    }
565
566    #[test]
567    fn test_parse_parenthesized_expressions() {
568        assert_eq!(
569            DispatchContextPredicate::parse("a && (b == c || d != e)").unwrap(),
570            And(
571                Box::new(Identifier("a".into())),
572                Box::new(Or(
573                    Box::new(Equal("b".into(), "c".into())),
574                    Box::new(NotEqual("d".into(), "e".into())),
575                )),
576            ),
577        );
578        assert_eq!(
579            DispatchContextPredicate::parse(" ( a || b ) ").unwrap(),
580            Or(
581                Box::new(Identifier("a".into())),
582                Box::new(Identifier("b".into())),
583            )
584        );
585    }
586}