context.rs

  1use crate::SharedString;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4
  5#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
  6pub struct KeyContext(SmallVec<[ContextEntry; 8]>);
  7
  8#[derive(Clone, Debug, Eq, PartialEq, Hash)]
  9struct ContextEntry {
 10    key: SharedString,
 11    value: Option<SharedString>,
 12}
 13
 14impl<'a> TryFrom<&'a str> for KeyContext {
 15    type Error = anyhow::Error;
 16
 17    fn try_from(value: &'a str) -> Result<Self> {
 18        Self::parse(value)
 19    }
 20}
 21
 22impl KeyContext {
 23    pub fn parse(source: &str) -> Result<Self> {
 24        let mut context = Self::default();
 25        let source = skip_whitespace(source);
 26        Self::parse_expr(&source, &mut context)?;
 27        Ok(context)
 28    }
 29
 30    fn parse_expr(mut source: &str, context: &mut Self) -> Result<()> {
 31        if source.is_empty() {
 32            return Ok(());
 33        }
 34
 35        let key = source
 36            .chars()
 37            .take_while(|c| is_identifier_char(*c))
 38            .collect::<String>();
 39        source = skip_whitespace(&source[key.len()..]);
 40        if let Some(suffix) = source.strip_prefix('=') {
 41            source = skip_whitespace(suffix);
 42            let value = source
 43                .chars()
 44                .take_while(|c| is_identifier_char(*c))
 45                .collect::<String>();
 46            source = skip_whitespace(&source[value.len()..]);
 47            context.set(key, value);
 48        } else {
 49            context.add(key);
 50        }
 51
 52        Self::parse_expr(source, context)
 53    }
 54
 55    pub fn is_empty(&self) -> bool {
 56        self.0.is_empty()
 57    }
 58
 59    pub fn clear(&mut self) {
 60        self.0.clear();
 61    }
 62
 63    pub fn extend(&mut self, other: &Self) {
 64        for entry in &other.0 {
 65            if !self.contains(&entry.key) {
 66                self.0.push(entry.clone());
 67            }
 68        }
 69    }
 70
 71    pub fn add<I: Into<SharedString>>(&mut self, identifier: I) {
 72        let key = identifier.into();
 73
 74        if !self.contains(&key) {
 75            self.0.push(ContextEntry { key, value: None })
 76        }
 77    }
 78
 79    pub fn set<S1: Into<SharedString>, S2: Into<SharedString>>(&mut self, key: S1, value: S2) {
 80        let key = key.into();
 81        if !self.contains(&key) {
 82            self.0.push(ContextEntry {
 83                key,
 84                value: Some(value.into()),
 85            })
 86        }
 87    }
 88
 89    pub fn contains(&self, key: &str) -> bool {
 90        self.0.iter().any(|entry| entry.key.as_ref() == key)
 91    }
 92
 93    pub fn get(&self, key: &str) -> Option<&SharedString> {
 94        self.0
 95            .iter()
 96            .find(|entry| entry.key.as_ref() == key)?
 97            .value
 98            .as_ref()
 99    }
100}
101
102#[derive(Clone, Debug, Eq, PartialEq, Hash)]
103pub enum KeyBindingContextPredicate {
104    Identifier(SharedString),
105    Equal(SharedString, SharedString),
106    NotEqual(SharedString, SharedString),
107    Child(
108        Box<KeyBindingContextPredicate>,
109        Box<KeyBindingContextPredicate>,
110    ),
111    Not(Box<KeyBindingContextPredicate>),
112    And(
113        Box<KeyBindingContextPredicate>,
114        Box<KeyBindingContextPredicate>,
115    ),
116    Or(
117        Box<KeyBindingContextPredicate>,
118        Box<KeyBindingContextPredicate>,
119    ),
120}
121
122impl KeyBindingContextPredicate {
123    pub fn parse(source: &str) -> Result<Self> {
124        let source = skip_whitespace(source);
125        let (predicate, rest) = Self::parse_expr(source, 0)?;
126        if let Some(next) = rest.chars().next() {
127            Err(anyhow!("unexpected character {next:?}"))
128        } else {
129            Ok(predicate)
130        }
131    }
132
133    pub fn eval(&self, contexts: &[KeyContext]) -> bool {
134        let Some(context) = contexts.last() else {
135            return false;
136        };
137        match self {
138            Self::Identifier(name) => context.contains(name),
139            Self::Equal(left, right) => context
140                .get(left)
141                .map(|value| value == right)
142                .unwrap_or(false),
143            Self::NotEqual(left, right) => context
144                .get(left)
145                .map(|value| value != right)
146                .unwrap_or(true),
147            Self::Not(pred) => !pred.eval(contexts),
148            Self::Child(parent, child) => {
149                parent.eval(&contexts[..contexts.len() - 1]) && child.eval(contexts)
150            }
151            Self::And(left, right) => left.eval(contexts) && right.eval(contexts),
152            Self::Or(left, right) => left.eval(contexts) || right.eval(contexts),
153        }
154    }
155
156    fn parse_expr(mut source: &str, min_precedence: u32) -> anyhow::Result<(Self, &str)> {
157        type Op = fn(
158            KeyBindingContextPredicate,
159            KeyBindingContextPredicate,
160        ) -> Result<KeyBindingContextPredicate>;
161
162        let (mut predicate, rest) = Self::parse_primary(source)?;
163        source = rest;
164
165        'parse: loop {
166            for (operator, precedence, constructor) in [
167                (">", PRECEDENCE_CHILD, Self::new_child as Op),
168                ("&&", PRECEDENCE_AND, Self::new_and as Op),
169                ("||", PRECEDENCE_OR, Self::new_or as Op),
170                ("==", PRECEDENCE_EQ, Self::new_eq as Op),
171                ("!=", PRECEDENCE_EQ, Self::new_neq as Op),
172            ] {
173                if source.starts_with(operator) && precedence >= min_precedence {
174                    source = skip_whitespace(&source[operator.len()..]);
175                    let (right, rest) = Self::parse_expr(source, precedence + 1)?;
176                    predicate = constructor(predicate, right)?;
177                    source = rest;
178                    continue 'parse;
179                }
180            }
181            break;
182        }
183
184        Ok((predicate, source))
185    }
186
187    fn parse_primary(mut source: &str) -> anyhow::Result<(Self, &str)> {
188        let next = source
189            .chars()
190            .next()
191            .ok_or_else(|| anyhow!("unexpected eof"))?;
192        match next {
193            '(' => {
194                source = skip_whitespace(&source[1..]);
195                let (predicate, rest) = Self::parse_expr(source, 0)?;
196                if rest.starts_with(')') {
197                    source = skip_whitespace(&rest[1..]);
198                    Ok((predicate, source))
199                } else {
200                    Err(anyhow!("expected a ')'"))
201                }
202            }
203            '!' => {
204                let source = skip_whitespace(&source[1..]);
205                let (predicate, source) = Self::parse_expr(&source, PRECEDENCE_NOT)?;
206                Ok((KeyBindingContextPredicate::Not(Box::new(predicate)), source))
207            }
208            _ if is_identifier_char(next) => {
209                let len = source
210                    .find(|c: char| !is_identifier_char(c))
211                    .unwrap_or(source.len());
212                let (identifier, rest) = source.split_at(len);
213                source = skip_whitespace(rest);
214                Ok((
215                    KeyBindingContextPredicate::Identifier(identifier.to_string().into()),
216                    source,
217                ))
218            }
219            _ => Err(anyhow!("unexpected character {next:?}")),
220        }
221    }
222
223    fn new_or(self, other: Self) -> Result<Self> {
224        Ok(Self::Or(Box::new(self), Box::new(other)))
225    }
226
227    fn new_and(self, other: Self) -> Result<Self> {
228        Ok(Self::And(Box::new(self), Box::new(other)))
229    }
230
231    fn new_child(self, other: Self) -> Result<Self> {
232        Ok(Self::Child(Box::new(self), Box::new(other)))
233    }
234
235    fn new_eq(self, other: Self) -> Result<Self> {
236        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
237            Ok(Self::Equal(left, right))
238        } else {
239            Err(anyhow!("operands must be identifiers"))
240        }
241    }
242
243    fn new_neq(self, other: Self) -> Result<Self> {
244        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
245            Ok(Self::NotEqual(left, right))
246        } else {
247            Err(anyhow!("operands must be identifiers"))
248        }
249    }
250}
251
252const PRECEDENCE_CHILD: u32 = 1;
253const PRECEDENCE_OR: u32 = 2;
254const PRECEDENCE_AND: u32 = 3;
255const PRECEDENCE_EQ: u32 = 4;
256const PRECEDENCE_NOT: u32 = 5;
257
258fn is_identifier_char(c: char) -> bool {
259    c.is_alphanumeric() || c == '_' || c == '-'
260}
261
262fn skip_whitespace(source: &str) -> &str {
263    let len = source
264        .find(|c: char| !c.is_whitespace())
265        .unwrap_or(source.len());
266    &source[len..]
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate as gpui;
273    use KeyBindingContextPredicate::*;
274
275    #[test]
276    fn test_actions_definition() {
277        {
278            actions!(A, B, C, D, E, F, G);
279        }
280
281        {
282            actions!(
283                A,
284                B,
285                C,
286                D,
287                E,
288                F,
289                G, // Don't wrap, test the trailing comma
290            );
291        }
292    }
293
294    #[test]
295    fn test_parse_context() {
296        let mut expected = KeyContext::default();
297        expected.set("foo", "bar");
298        expected.add("baz");
299        assert_eq!(KeyContext::parse("baz foo=bar").unwrap(), expected);
300        assert_eq!(KeyContext::parse("foo = bar baz").unwrap(), expected);
301        assert_eq!(
302            KeyContext::parse("  baz foo   =   bar baz").unwrap(),
303            expected
304        );
305        assert_eq!(KeyContext::parse(" foo = bar baz").unwrap(), expected);
306    }
307
308    #[test]
309    fn test_parse_identifiers() {
310        // Identifiers
311        assert_eq!(
312            KeyBindingContextPredicate::parse("abc12").unwrap(),
313            Identifier("abc12".into())
314        );
315        assert_eq!(
316            KeyBindingContextPredicate::parse("_1a").unwrap(),
317            Identifier("_1a".into())
318        );
319    }
320
321    #[test]
322    fn test_parse_negations() {
323        assert_eq!(
324            KeyBindingContextPredicate::parse("!abc").unwrap(),
325            Not(Box::new(Identifier("abc".into())))
326        );
327        assert_eq!(
328            KeyBindingContextPredicate::parse(" ! ! abc").unwrap(),
329            Not(Box::new(Not(Box::new(Identifier("abc".into())))))
330        );
331    }
332
333    #[test]
334    fn test_parse_equality_operators() {
335        assert_eq!(
336            KeyBindingContextPredicate::parse("a == b").unwrap(),
337            Equal("a".into(), "b".into())
338        );
339        assert_eq!(
340            KeyBindingContextPredicate::parse("c!=d").unwrap(),
341            NotEqual("c".into(), "d".into())
342        );
343        assert_eq!(
344            KeyBindingContextPredicate::parse("c == !d")
345                .unwrap_err()
346                .to_string(),
347            "operands must be identifiers"
348        );
349    }
350
351    #[test]
352    fn test_parse_boolean_operators() {
353        assert_eq!(
354            KeyBindingContextPredicate::parse("a || b").unwrap(),
355            Or(
356                Box::new(Identifier("a".into())),
357                Box::new(Identifier("b".into()))
358            )
359        );
360        assert_eq!(
361            KeyBindingContextPredicate::parse("a || !b && c").unwrap(),
362            Or(
363                Box::new(Identifier("a".into())),
364                Box::new(And(
365                    Box::new(Not(Box::new(Identifier("b".into())))),
366                    Box::new(Identifier("c".into()))
367                ))
368            )
369        );
370        assert_eq!(
371            KeyBindingContextPredicate::parse("a && b || c&&d").unwrap(),
372            Or(
373                Box::new(And(
374                    Box::new(Identifier("a".into())),
375                    Box::new(Identifier("b".into()))
376                )),
377                Box::new(And(
378                    Box::new(Identifier("c".into())),
379                    Box::new(Identifier("d".into()))
380                ))
381            )
382        );
383        assert_eq!(
384            KeyBindingContextPredicate::parse("a == b && c || d == e && f").unwrap(),
385            Or(
386                Box::new(And(
387                    Box::new(Equal("a".into(), "b".into())),
388                    Box::new(Identifier("c".into()))
389                )),
390                Box::new(And(
391                    Box::new(Equal("d".into(), "e".into())),
392                    Box::new(Identifier("f".into()))
393                ))
394            )
395        );
396        assert_eq!(
397            KeyBindingContextPredicate::parse("a && b && c && d").unwrap(),
398            And(
399                Box::new(And(
400                    Box::new(And(
401                        Box::new(Identifier("a".into())),
402                        Box::new(Identifier("b".into()))
403                    )),
404                    Box::new(Identifier("c".into())),
405                )),
406                Box::new(Identifier("d".into()))
407            ),
408        );
409    }
410
411    #[test]
412    fn test_parse_parenthesized_expressions() {
413        assert_eq!(
414            KeyBindingContextPredicate::parse("a && (b == c || d != e)").unwrap(),
415            And(
416                Box::new(Identifier("a".into())),
417                Box::new(Or(
418                    Box::new(Equal("b".into(), "c".into())),
419                    Box::new(NotEqual("d".into(), "e".into())),
420                )),
421            ),
422        );
423        assert_eq!(
424            KeyBindingContextPredicate::parse(" ( a || b ) ").unwrap(),
425            Or(
426                Box::new(Identifier("a".into())),
427                Box::new(Identifier("b".into())),
428            )
429        );
430    }
431}