context.rs

  1use crate::SharedString;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::fmt;
  5
  6#[derive(Clone, Default, Eq, PartialEq, Hash)]
  7pub struct KeyContext(SmallVec<[ContextEntry; 1]>);
  8
  9#[derive(Clone, Debug, Eq, PartialEq, Hash)]
 10struct ContextEntry {
 11    key: SharedString,
 12    value: Option<SharedString>,
 13}
 14
 15impl<'a> TryFrom<&'a str> for KeyContext {
 16    type Error = anyhow::Error;
 17
 18    fn try_from(value: &'a str) -> Result<Self> {
 19        Self::parse(value)
 20    }
 21}
 22
 23impl KeyContext {
 24    pub fn parse(source: &str) -> Result<Self> {
 25        let mut context = Self::default();
 26        let source = skip_whitespace(source);
 27        Self::parse_expr(source, &mut context)?;
 28        Ok(context)
 29    }
 30
 31    fn parse_expr(mut source: &str, context: &mut Self) -> Result<()> {
 32        if source.is_empty() {
 33            return Ok(());
 34        }
 35
 36        let key = source
 37            .chars()
 38            .take_while(|c| is_identifier_char(*c))
 39            .collect::<String>();
 40        source = skip_whitespace(&source[key.len()..]);
 41        if let Some(suffix) = source.strip_prefix('=') {
 42            source = skip_whitespace(suffix);
 43            let value = source
 44                .chars()
 45                .take_while(|c| is_identifier_char(*c))
 46                .collect::<String>();
 47            source = skip_whitespace(&source[value.len()..]);
 48            context.set(key, value);
 49        } else {
 50            context.add(key);
 51        }
 52
 53        Self::parse_expr(source, context)
 54    }
 55
 56    pub fn is_empty(&self) -> bool {
 57        self.0.is_empty()
 58    }
 59
 60    pub fn clear(&mut self) {
 61        self.0.clear();
 62    }
 63
 64    pub fn extend(&mut self, other: &Self) {
 65        for entry in &other.0 {
 66            if !self.contains(&entry.key) {
 67                self.0.push(entry.clone());
 68            }
 69        }
 70    }
 71
 72    pub fn add<I: Into<SharedString>>(&mut self, identifier: I) {
 73        let key = identifier.into();
 74
 75        if !self.contains(&key) {
 76            self.0.push(ContextEntry { key, value: None })
 77        }
 78    }
 79
 80    pub fn set<S1: Into<SharedString>, S2: Into<SharedString>>(&mut self, key: S1, value: S2) {
 81        let key = key.into();
 82        if !self.contains(&key) {
 83            self.0.push(ContextEntry {
 84                key,
 85                value: Some(value.into()),
 86            })
 87        }
 88    }
 89
 90    pub fn contains(&self, key: &str) -> bool {
 91        self.0.iter().any(|entry| entry.key.as_ref() == key)
 92    }
 93
 94    pub fn get(&self, key: &str) -> Option<&SharedString> {
 95        self.0
 96            .iter()
 97            .find(|entry| entry.key.as_ref() == key)?
 98            .value
 99            .as_ref()
100    }
101}
102
103impl fmt::Debug for KeyContext {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        let mut entries = self.0.iter().peekable();
106        while let Some(entry) = entries.next() {
107            if let Some(ref value) = entry.value {
108                write!(f, "{}={}", entry.key, value)?;
109            } else {
110                write!(f, "{}", entry.key)?;
111            }
112            if entries.peek().is_some() {
113                write!(f, " ")?;
114            }
115        }
116        Ok(())
117    }
118}
119
120#[derive(Clone, Debug, Eq, PartialEq, Hash)]
121pub enum KeyBindingContextPredicate {
122    Identifier(SharedString),
123    Equal(SharedString, SharedString),
124    NotEqual(SharedString, SharedString),
125    Child(
126        Box<KeyBindingContextPredicate>,
127        Box<KeyBindingContextPredicate>,
128    ),
129    Not(Box<KeyBindingContextPredicate>),
130    And(
131        Box<KeyBindingContextPredicate>,
132        Box<KeyBindingContextPredicate>,
133    ),
134    Or(
135        Box<KeyBindingContextPredicate>,
136        Box<KeyBindingContextPredicate>,
137    ),
138}
139
140impl KeyBindingContextPredicate {
141    pub fn parse(source: &str) -> Result<Self> {
142        let source = skip_whitespace(source);
143        let (predicate, rest) = Self::parse_expr(source, 0)?;
144        if let Some(next) = rest.chars().next() {
145            Err(anyhow!("unexpected character {next:?}"))
146        } else {
147            Ok(predicate)
148        }
149    }
150
151    pub fn eval(&self, contexts: &[KeyContext]) -> bool {
152        let Some(context) = contexts.last() else {
153            return false;
154        };
155        match self {
156            Self::Identifier(name) => context.contains(name),
157            Self::Equal(left, right) => context
158                .get(left)
159                .map(|value| value == right)
160                .unwrap_or(false),
161            Self::NotEqual(left, right) => context
162                .get(left)
163                .map(|value| value != right)
164                .unwrap_or(true),
165            Self::Not(pred) => !pred.eval(contexts),
166            Self::Child(parent, child) => {
167                parent.eval(&contexts[..contexts.len() - 1]) && child.eval(contexts)
168            }
169            Self::And(left, right) => left.eval(contexts) && right.eval(contexts),
170            Self::Or(left, right) => left.eval(contexts) || right.eval(contexts),
171        }
172    }
173
174    fn parse_expr(mut source: &str, min_precedence: u32) -> anyhow::Result<(Self, &str)> {
175        type Op = fn(
176            KeyBindingContextPredicate,
177            KeyBindingContextPredicate,
178        ) -> Result<KeyBindingContextPredicate>;
179
180        let (mut predicate, rest) = Self::parse_primary(source)?;
181        source = rest;
182
183        'parse: loop {
184            for (operator, precedence, constructor) in [
185                (">", PRECEDENCE_CHILD, Self::new_child as Op),
186                ("&&", PRECEDENCE_AND, Self::new_and as Op),
187                ("||", PRECEDENCE_OR, Self::new_or as Op),
188                ("==", PRECEDENCE_EQ, Self::new_eq as Op),
189                ("!=", PRECEDENCE_EQ, Self::new_neq as Op),
190            ] {
191                if source.starts_with(operator) && precedence >= min_precedence {
192                    source = skip_whitespace(&source[operator.len()..]);
193                    let (right, rest) = Self::parse_expr(source, precedence + 1)?;
194                    predicate = constructor(predicate, right)?;
195                    source = rest;
196                    continue 'parse;
197                }
198            }
199            break;
200        }
201
202        Ok((predicate, source))
203    }
204
205    fn parse_primary(mut source: &str) -> anyhow::Result<(Self, &str)> {
206        let next = source
207            .chars()
208            .next()
209            .ok_or_else(|| anyhow!("unexpected eof"))?;
210        match next {
211            '(' => {
212                source = skip_whitespace(&source[1..]);
213                let (predicate, rest) = Self::parse_expr(source, 0)?;
214                if rest.starts_with(')') {
215                    source = skip_whitespace(&rest[1..]);
216                    Ok((predicate, source))
217                } else {
218                    Err(anyhow!("expected a ')'"))
219                }
220            }
221            '!' => {
222                let source = skip_whitespace(&source[1..]);
223                let (predicate, source) = Self::parse_expr(source, PRECEDENCE_NOT)?;
224                Ok((KeyBindingContextPredicate::Not(Box::new(predicate)), source))
225            }
226            _ if is_identifier_char(next) => {
227                let len = source
228                    .find(|c: char| !is_identifier_char(c))
229                    .unwrap_or(source.len());
230                let (identifier, rest) = source.split_at(len);
231                source = skip_whitespace(rest);
232                Ok((
233                    KeyBindingContextPredicate::Identifier(identifier.to_string().into()),
234                    source,
235                ))
236            }
237            _ => Err(anyhow!("unexpected character {next:?}")),
238        }
239    }
240
241    fn new_or(self, other: Self) -> Result<Self> {
242        Ok(Self::Or(Box::new(self), Box::new(other)))
243    }
244
245    fn new_and(self, other: Self) -> Result<Self> {
246        Ok(Self::And(Box::new(self), Box::new(other)))
247    }
248
249    fn new_child(self, other: Self) -> Result<Self> {
250        Ok(Self::Child(Box::new(self), Box::new(other)))
251    }
252
253    fn new_eq(self, other: Self) -> Result<Self> {
254        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
255            Ok(Self::Equal(left, right))
256        } else {
257            Err(anyhow!("operands must be identifiers"))
258        }
259    }
260
261    fn new_neq(self, other: Self) -> Result<Self> {
262        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
263            Ok(Self::NotEqual(left, right))
264        } else {
265            Err(anyhow!("operands must be identifiers"))
266        }
267    }
268}
269
270const PRECEDENCE_CHILD: u32 = 1;
271const PRECEDENCE_OR: u32 = 2;
272const PRECEDENCE_AND: u32 = 3;
273const PRECEDENCE_EQ: u32 = 4;
274const PRECEDENCE_NOT: u32 = 5;
275
276fn is_identifier_char(c: char) -> bool {
277    c.is_alphanumeric() || c == '_' || c == '-'
278}
279
280fn skip_whitespace(source: &str) -> &str {
281    let len = source
282        .find(|c: char| !c.is_whitespace())
283        .unwrap_or(source.len());
284    &source[len..]
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate as gpui;
291    use KeyBindingContextPredicate::*;
292
293    #[test]
294    fn test_actions_definition() {
295        {
296            actions!(test, [A, B, C, D, E, F, G]);
297        }
298
299        {
300            actions!(
301                test,
302                [
303                A,
304                B,
305                C,
306                D,
307                E,
308                F,
309                G, // Don't wrap, test the trailing comma
310            ]
311            );
312        }
313    }
314
315    #[test]
316    fn test_parse_context() {
317        let mut expected = KeyContext::default();
318        expected.add("baz");
319        expected.set("foo", "bar");
320        assert_eq!(KeyContext::parse("baz foo=bar").unwrap(), expected);
321        assert_eq!(KeyContext::parse("baz foo = bar").unwrap(), expected);
322        assert_eq!(
323            KeyContext::parse("  baz foo   =   bar baz").unwrap(),
324            expected
325        );
326        assert_eq!(KeyContext::parse(" baz foo = bar").unwrap(), expected);
327    }
328
329    #[test]
330    fn test_parse_identifiers() {
331        // Identifiers
332        assert_eq!(
333            KeyBindingContextPredicate::parse("abc12").unwrap(),
334            Identifier("abc12".into())
335        );
336        assert_eq!(
337            KeyBindingContextPredicate::parse("_1a").unwrap(),
338            Identifier("_1a".into())
339        );
340    }
341
342    #[test]
343    fn test_parse_negations() {
344        assert_eq!(
345            KeyBindingContextPredicate::parse("!abc").unwrap(),
346            Not(Box::new(Identifier("abc".into())))
347        );
348        assert_eq!(
349            KeyBindingContextPredicate::parse(" ! ! abc").unwrap(),
350            Not(Box::new(Not(Box::new(Identifier("abc".into())))))
351        );
352    }
353
354    #[test]
355    fn test_parse_equality_operators() {
356        assert_eq!(
357            KeyBindingContextPredicate::parse("a == b").unwrap(),
358            Equal("a".into(), "b".into())
359        );
360        assert_eq!(
361            KeyBindingContextPredicate::parse("c!=d").unwrap(),
362            NotEqual("c".into(), "d".into())
363        );
364        assert_eq!(
365            KeyBindingContextPredicate::parse("c == !d")
366                .unwrap_err()
367                .to_string(),
368            "operands must be identifiers"
369        );
370    }
371
372    #[test]
373    fn test_parse_boolean_operators() {
374        assert_eq!(
375            KeyBindingContextPredicate::parse("a || b").unwrap(),
376            Or(
377                Box::new(Identifier("a".into())),
378                Box::new(Identifier("b".into()))
379            )
380        );
381        assert_eq!(
382            KeyBindingContextPredicate::parse("a || !b && c").unwrap(),
383            Or(
384                Box::new(Identifier("a".into())),
385                Box::new(And(
386                    Box::new(Not(Box::new(Identifier("b".into())))),
387                    Box::new(Identifier("c".into()))
388                ))
389            )
390        );
391        assert_eq!(
392            KeyBindingContextPredicate::parse("a && b || c&&d").unwrap(),
393            Or(
394                Box::new(And(
395                    Box::new(Identifier("a".into())),
396                    Box::new(Identifier("b".into()))
397                )),
398                Box::new(And(
399                    Box::new(Identifier("c".into())),
400                    Box::new(Identifier("d".into()))
401                ))
402            )
403        );
404        assert_eq!(
405            KeyBindingContextPredicate::parse("a == b && c || d == e && f").unwrap(),
406            Or(
407                Box::new(And(
408                    Box::new(Equal("a".into(), "b".into())),
409                    Box::new(Identifier("c".into()))
410                )),
411                Box::new(And(
412                    Box::new(Equal("d".into(), "e".into())),
413                    Box::new(Identifier("f".into()))
414                ))
415            )
416        );
417        assert_eq!(
418            KeyBindingContextPredicate::parse("a && b && c && d").unwrap(),
419            And(
420                Box::new(And(
421                    Box::new(And(
422                        Box::new(Identifier("a".into())),
423                        Box::new(Identifier("b".into()))
424                    )),
425                    Box::new(Identifier("c".into())),
426                )),
427                Box::new(Identifier("d".into()))
428            ),
429        );
430    }
431
432    #[test]
433    fn test_parse_parenthesized_expressions() {
434        assert_eq!(
435            KeyBindingContextPredicate::parse("a && (b == c || d != e)").unwrap(),
436            And(
437                Box::new(Identifier("a".into())),
438                Box::new(Or(
439                    Box::new(Equal("b".into(), "c".into())),
440                    Box::new(NotEqual("d".into(), "e".into())),
441                )),
442            ),
443        );
444        assert_eq!(
445            KeyBindingContextPredicate::parse(" ( a || b ) ").unwrap(),
446            Or(
447                Box::new(Identifier("a".into())),
448                Box::new(Identifier("b".into())),
449            )
450        );
451    }
452}