action.rs

  1use crate::SharedString;
  2use anyhow::{anyhow, Context, Result};
  3use collections::{HashMap, HashSet};
  4use serde::Deserialize;
  5use std::any::{type_name, Any};
  6
  7pub trait Action: std::fmt::Debug + 'static {
  8    fn qualified_name() -> SharedString
  9    where
 10        Self: Sized;
 11    fn build(value: Option<serde_json::Value>) -> Result<Box<dyn Action>>
 12    where
 13        Self: Sized;
 14
 15    fn partial_eq(&self, action: &dyn Action) -> bool;
 16    fn boxed_clone(&self) -> Box<dyn Action>;
 17    fn as_any(&self) -> &dyn Any;
 18}
 19
 20// actions defines structs that can be used as actions.
 21#[macro_export]
 22macro_rules! actions {
 23    () => {};
 24
 25    ( $name:ident ) => {
 26        #[derive(::std::clone::Clone, ::std::default::Default, ::std::fmt::Debug, ::std::cmp::PartialEq, $crate::serde::Deserialize)]
 27        pub struct $name;
 28    };
 29
 30    ( $name:ident { $($token:tt)* } ) => {
 31        #[derive(::std::clone::Clone, ::std::default::Default, ::std::fmt::Debug, ::std::cmp::PartialEq, $crate::serde::Deserialize)]
 32        pub struct $name { $($token)* }
 33    };
 34
 35    ( $name:ident, $($rest:tt)* ) => {
 36        actions!($name);
 37        actions!($($rest)*);
 38    };
 39
 40    ( $name:ident { $($token:tt)* }, $($rest:tt)* ) => {
 41        actions!($name { $($token)* });
 42        actions!($($rest)*);
 43    };
 44}
 45
 46impl<A> Action for A
 47where
 48    A: for<'a> Deserialize<'a> + PartialEq + Clone + Default + std::fmt::Debug + 'static,
 49{
 50    fn qualified_name() -> SharedString {
 51        // todo!() remove the 2 replacement when migration is done
 52        type_name::<A>().replace("2::", "::").into()
 53    }
 54
 55    fn build(params: Option<serde_json::Value>) -> Result<Box<dyn Action>>
 56    where
 57        Self: Sized,
 58    {
 59        let action = if let Some(params) = params {
 60            serde_json::from_value(params).context("failed to deserialize action")?
 61        } else {
 62            Self::default()
 63        };
 64        Ok(Box::new(action))
 65    }
 66
 67    fn partial_eq(&self, action: &dyn Action) -> bool {
 68        action
 69            .as_any()
 70            .downcast_ref::<Self>()
 71            .map_or(false, |a| self == a)
 72    }
 73
 74    fn boxed_clone(&self) -> Box<dyn Action> {
 75        Box::new(self.clone())
 76    }
 77
 78    fn as_any(&self) -> &dyn Any {
 79        self
 80    }
 81}
 82
 83#[derive(Clone, Debug, Default, Eq, PartialEq)]
 84pub struct DispatchContext {
 85    set: HashSet<SharedString>,
 86    map: HashMap<SharedString, SharedString>,
 87}
 88
 89impl<'a> TryFrom<&'a str> for DispatchContext {
 90    type Error = anyhow::Error;
 91
 92    fn try_from(value: &'a str) -> Result<Self> {
 93        Self::parse(value)
 94    }
 95}
 96
 97impl DispatchContext {
 98    pub fn parse(source: &str) -> Result<Self> {
 99        let mut context = Self::default();
100        let source = skip_whitespace(source);
101        Self::parse_expr(&source, &mut context)?;
102        Ok(context)
103    }
104
105    fn parse_expr(mut source: &str, context: &mut Self) -> Result<()> {
106        if source.is_empty() {
107            return Ok(());
108        }
109
110        let key = source
111            .chars()
112            .take_while(|c| is_identifier_char(*c))
113            .collect::<String>();
114        source = skip_whitespace(&source[key.len()..]);
115        if let Some(suffix) = source.strip_prefix('=') {
116            source = skip_whitespace(suffix);
117            let value = source
118                .chars()
119                .take_while(|c| is_identifier_char(*c))
120                .collect::<String>();
121            source = skip_whitespace(&source[value.len()..]);
122            context.set(key, value);
123        } else {
124            context.insert(key);
125        }
126
127        Self::parse_expr(source, context)
128    }
129
130    pub fn is_empty(&self) -> bool {
131        self.set.is_empty() && self.map.is_empty()
132    }
133
134    pub fn clear(&mut self) {
135        self.set.clear();
136        self.map.clear();
137    }
138
139    pub fn extend(&mut self, other: &Self) {
140        for v in &other.set {
141            self.set.insert(v.clone());
142        }
143        for (k, v) in &other.map {
144            self.map.insert(k.clone(), v.clone());
145        }
146    }
147
148    pub fn insert<I: Into<SharedString>>(&mut self, identifier: I) {
149        self.set.insert(identifier.into());
150    }
151
152    pub fn set<S1: Into<SharedString>, S2: Into<SharedString>>(&mut self, key: S1, value: S2) {
153        self.map.insert(key.into(), value.into());
154    }
155}
156
157#[derive(Clone, Debug, Eq, PartialEq, Hash)]
158pub enum DispatchContextPredicate {
159    Identifier(SharedString),
160    Equal(SharedString, SharedString),
161    NotEqual(SharedString, SharedString),
162    Child(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
163    Not(Box<DispatchContextPredicate>),
164    And(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
165    Or(Box<DispatchContextPredicate>, Box<DispatchContextPredicate>),
166}
167
168impl DispatchContextPredicate {
169    pub fn parse(source: &str) -> Result<Self> {
170        let source = skip_whitespace(source);
171        let (predicate, rest) = Self::parse_expr(source, 0)?;
172        if let Some(next) = rest.chars().next() {
173            Err(anyhow!("unexpected character {next:?}"))
174        } else {
175            Ok(predicate)
176        }
177    }
178
179    pub fn eval(&self, contexts: &[&DispatchContext]) -> bool {
180        let Some(context) = contexts.last() else {
181            return false;
182        };
183        match self {
184            Self::Identifier(name) => context.set.contains(name),
185            Self::Equal(left, right) => context
186                .map
187                .get(left)
188                .map(|value| value == right)
189                .unwrap_or(false),
190            Self::NotEqual(left, right) => context
191                .map
192                .get(left)
193                .map(|value| value != right)
194                .unwrap_or(true),
195            Self::Not(pred) => !pred.eval(contexts),
196            Self::Child(parent, child) => {
197                parent.eval(&contexts[..contexts.len() - 1]) && child.eval(contexts)
198            }
199            Self::And(left, right) => left.eval(contexts) && right.eval(contexts),
200            Self::Or(left, right) => left.eval(contexts) || right.eval(contexts),
201        }
202    }
203
204    fn parse_expr(mut source: &str, min_precedence: u32) -> anyhow::Result<(Self, &str)> {
205        type Op = fn(
206            DispatchContextPredicate,
207            DispatchContextPredicate,
208        ) -> Result<DispatchContextPredicate>;
209
210        let (mut predicate, rest) = Self::parse_primary(source)?;
211        source = rest;
212
213        'parse: loop {
214            for (operator, precedence, constructor) in [
215                (">", PRECEDENCE_CHILD, Self::new_child as Op),
216                ("&&", PRECEDENCE_AND, Self::new_and as Op),
217                ("||", PRECEDENCE_OR, Self::new_or as Op),
218                ("==", PRECEDENCE_EQ, Self::new_eq as Op),
219                ("!=", PRECEDENCE_EQ, Self::new_neq as Op),
220            ] {
221                if source.starts_with(operator) && precedence >= min_precedence {
222                    source = skip_whitespace(&source[operator.len()..]);
223                    let (right, rest) = Self::parse_expr(source, precedence + 1)?;
224                    predicate = constructor(predicate, right)?;
225                    source = rest;
226                    continue 'parse;
227                }
228            }
229            break;
230        }
231
232        Ok((predicate, source))
233    }
234
235    fn parse_primary(mut source: &str) -> anyhow::Result<(Self, &str)> {
236        let next = source
237            .chars()
238            .next()
239            .ok_or_else(|| anyhow!("unexpected eof"))?;
240        match next {
241            '(' => {
242                source = skip_whitespace(&source[1..]);
243                let (predicate, rest) = Self::parse_expr(source, 0)?;
244                if rest.starts_with(')') {
245                    source = skip_whitespace(&rest[1..]);
246                    Ok((predicate, source))
247                } else {
248                    Err(anyhow!("expected a ')'"))
249                }
250            }
251            '!' => {
252                let source = skip_whitespace(&source[1..]);
253                let (predicate, source) = Self::parse_expr(&source, PRECEDENCE_NOT)?;
254                Ok((DispatchContextPredicate::Not(Box::new(predicate)), source))
255            }
256            _ if is_identifier_char(next) => {
257                let len = source
258                    .find(|c: char| !is_identifier_char(c))
259                    .unwrap_or(source.len());
260                let (identifier, rest) = source.split_at(len);
261                source = skip_whitespace(rest);
262                Ok((
263                    DispatchContextPredicate::Identifier(identifier.to_string().into()),
264                    source,
265                ))
266            }
267            _ => Err(anyhow!("unexpected character {next:?}")),
268        }
269    }
270
271    fn new_or(self, other: Self) -> Result<Self> {
272        Ok(Self::Or(Box::new(self), Box::new(other)))
273    }
274
275    fn new_and(self, other: Self) -> Result<Self> {
276        Ok(Self::And(Box::new(self), Box::new(other)))
277    }
278
279    fn new_child(self, other: Self) -> Result<Self> {
280        Ok(Self::Child(Box::new(self), Box::new(other)))
281    }
282
283    fn new_eq(self, other: Self) -> Result<Self> {
284        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
285            Ok(Self::Equal(left, right))
286        } else {
287            Err(anyhow!("operands must be identifiers"))
288        }
289    }
290
291    fn new_neq(self, other: Self) -> Result<Self> {
292        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
293            Ok(Self::NotEqual(left, right))
294        } else {
295            Err(anyhow!("operands must be identifiers"))
296        }
297    }
298}
299
300const PRECEDENCE_CHILD: u32 = 1;
301const PRECEDENCE_OR: u32 = 2;
302const PRECEDENCE_AND: u32 = 3;
303const PRECEDENCE_EQ: u32 = 4;
304const PRECEDENCE_NOT: u32 = 5;
305
306fn is_identifier_char(c: char) -> bool {
307    c.is_alphanumeric() || c == '_' || c == '-'
308}
309
310fn skip_whitespace(source: &str) -> &str {
311    let len = source
312        .find(|c: char| !c.is_whitespace())
313        .unwrap_or(source.len());
314    &source[len..]
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use DispatchContextPredicate::*;
321
322    #[test]
323    fn test_actions_definition() {
324        {
325            actions!(A, B { field: i32 }, C, D, E, F {}, G);
326        }
327
328        {
329            actions!(
330                A,
331                B { field: i32 },
332                C,
333                D,
334                E,
335                F {},
336                G, // Don't wrap, test the trailing comma
337            );
338        }
339    }
340
341    #[test]
342    fn test_parse_context() {
343        let mut expected = DispatchContext::default();
344        expected.set("foo", "bar");
345        expected.insert("baz");
346        assert_eq!(DispatchContext::parse("baz foo=bar").unwrap(), expected);
347        assert_eq!(DispatchContext::parse("foo = bar baz").unwrap(), expected);
348        assert_eq!(
349            DispatchContext::parse("  baz foo   =   bar baz").unwrap(),
350            expected
351        );
352        assert_eq!(DispatchContext::parse(" foo = bar baz").unwrap(), expected);
353    }
354
355    #[test]
356    fn test_parse_identifiers() {
357        // Identifiers
358        assert_eq!(
359            DispatchContextPredicate::parse("abc12").unwrap(),
360            Identifier("abc12".into())
361        );
362        assert_eq!(
363            DispatchContextPredicate::parse("_1a").unwrap(),
364            Identifier("_1a".into())
365        );
366    }
367
368    #[test]
369    fn test_parse_negations() {
370        assert_eq!(
371            DispatchContextPredicate::parse("!abc").unwrap(),
372            Not(Box::new(Identifier("abc".into())))
373        );
374        assert_eq!(
375            DispatchContextPredicate::parse(" ! ! abc").unwrap(),
376            Not(Box::new(Not(Box::new(Identifier("abc".into())))))
377        );
378    }
379
380    #[test]
381    fn test_parse_equality_operators() {
382        assert_eq!(
383            DispatchContextPredicate::parse("a == b").unwrap(),
384            Equal("a".into(), "b".into())
385        );
386        assert_eq!(
387            DispatchContextPredicate::parse("c!=d").unwrap(),
388            NotEqual("c".into(), "d".into())
389        );
390        assert_eq!(
391            DispatchContextPredicate::parse("c == !d")
392                .unwrap_err()
393                .to_string(),
394            "operands must be identifiers"
395        );
396    }
397
398    #[test]
399    fn test_parse_boolean_operators() {
400        assert_eq!(
401            DispatchContextPredicate::parse("a || b").unwrap(),
402            Or(
403                Box::new(Identifier("a".into())),
404                Box::new(Identifier("b".into()))
405            )
406        );
407        assert_eq!(
408            DispatchContextPredicate::parse("a || !b && c").unwrap(),
409            Or(
410                Box::new(Identifier("a".into())),
411                Box::new(And(
412                    Box::new(Not(Box::new(Identifier("b".into())))),
413                    Box::new(Identifier("c".into()))
414                ))
415            )
416        );
417        assert_eq!(
418            DispatchContextPredicate::parse("a && b || c&&d").unwrap(),
419            Or(
420                Box::new(And(
421                    Box::new(Identifier("a".into())),
422                    Box::new(Identifier("b".into()))
423                )),
424                Box::new(And(
425                    Box::new(Identifier("c".into())),
426                    Box::new(Identifier("d".into()))
427                ))
428            )
429        );
430        assert_eq!(
431            DispatchContextPredicate::parse("a == b && c || d == e && f").unwrap(),
432            Or(
433                Box::new(And(
434                    Box::new(Equal("a".into(), "b".into())),
435                    Box::new(Identifier("c".into()))
436                )),
437                Box::new(And(
438                    Box::new(Equal("d".into(), "e".into())),
439                    Box::new(Identifier("f".into()))
440                ))
441            )
442        );
443        assert_eq!(
444            DispatchContextPredicate::parse("a && b && c && d").unwrap(),
445            And(
446                Box::new(And(
447                    Box::new(And(
448                        Box::new(Identifier("a".into())),
449                        Box::new(Identifier("b".into()))
450                    )),
451                    Box::new(Identifier("c".into())),
452                )),
453                Box::new(Identifier("d".into()))
454            ),
455        );
456    }
457
458    #[test]
459    fn test_parse_parenthesized_expressions() {
460        assert_eq!(
461            DispatchContextPredicate::parse("a && (b == c || d != e)").unwrap(),
462            And(
463                Box::new(Identifier("a".into())),
464                Box::new(Or(
465                    Box::new(Equal("b".into(), "c".into())),
466                    Box::new(NotEqual("d".into(), "e".into())),
467                )),
468            ),
469        );
470        assert_eq!(
471            DispatchContextPredicate::parse(" ( a || b ) ").unwrap(),
472            Or(
473                Box::new(Identifier("a".into())),
474                Box::new(Identifier("b".into())),
475            )
476        );
477    }
478}