context.rs

  1use crate::SharedString;
  2use anyhow::{anyhow, Result};
  3use smallvec::SmallVec;
  4use std::fmt;
  5
  6#[derive(Clone, Default, Eq, PartialEq, Hash)]
  7pub struct KeyContext(SmallVec<[ContextEntry; 8]>);
  8
  9#[derive(Clone, Debug, Eq, PartialEq, Hash)]
 10struct ContextEntry {
 11    key: SharedString,
 12    value: Option<SharedString>,
 13}
 14
 15impl<'a> TryFrom<&'a str> for KeyContext {
 16    type Error = anyhow::Error;
 17
 18    fn try_from(value: &'a str) -> Result<Self> {
 19        Self::parse(value)
 20    }
 21}
 22
 23impl KeyContext {
 24    pub fn parse(source: &str) -> Result<Self> {
 25        let mut context = Self::default();
 26        let source = skip_whitespace(source);
 27        Self::parse_expr(&source, &mut context)?;
 28        Ok(context)
 29    }
 30
 31    fn parse_expr(mut source: &str, context: &mut Self) -> Result<()> {
 32        if source.is_empty() {
 33            return Ok(());
 34        }
 35
 36        let key = source
 37            .chars()
 38            .take_while(|c| is_identifier_char(*c))
 39            .collect::<String>();
 40        source = skip_whitespace(&source[key.len()..]);
 41        if let Some(suffix) = source.strip_prefix('=') {
 42            source = skip_whitespace(suffix);
 43            let value = source
 44                .chars()
 45                .take_while(|c| is_identifier_char(*c))
 46                .collect::<String>();
 47            source = skip_whitespace(&source[value.len()..]);
 48            context.set(key, value);
 49        } else {
 50            context.add(key);
 51        }
 52
 53        Self::parse_expr(source, context)
 54    }
 55
 56    pub fn is_empty(&self) -> bool {
 57        self.0.is_empty()
 58    }
 59
 60    pub fn clear(&mut self) {
 61        self.0.clear();
 62    }
 63
 64    pub fn extend(&mut self, other: &Self) {
 65        for entry in &other.0 {
 66            if !self.contains(&entry.key) {
 67                self.0.push(entry.clone());
 68            }
 69        }
 70    }
 71
 72    pub fn add<I: Into<SharedString>>(&mut self, identifier: I) {
 73        let key = identifier.into();
 74
 75        if !self.contains(&key) {
 76            self.0.push(ContextEntry { key, value: None })
 77        }
 78    }
 79
 80    pub fn set<S1: Into<SharedString>, S2: Into<SharedString>>(&mut self, key: S1, value: S2) {
 81        let key = key.into();
 82        if !self.contains(&key) {
 83            self.0.push(ContextEntry {
 84                key,
 85                value: Some(value.into()),
 86            })
 87        }
 88    }
 89
 90    pub fn contains(&self, key: &str) -> bool {
 91        self.0.iter().any(|entry| entry.key.as_ref() == key)
 92    }
 93
 94    pub fn get(&self, key: &str) -> Option<&SharedString> {
 95        self.0
 96            .iter()
 97            .find(|entry| entry.key.as_ref() == key)?
 98            .value
 99            .as_ref()
100    }
101}
102
103impl fmt::Debug for KeyContext {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        let mut entries = self.0.iter().peekable();
106        while let Some(entry) = entries.next() {
107            if let Some(ref value) = entry.value {
108                write!(f, "{}={}", entry.key, value)?;
109            } else {
110                write!(f, "{}", entry.key)?;
111            }
112            if entries.peek().is_some() {
113                write!(f, " ")?;
114            }
115        }
116        Ok(())
117    }
118}
119
120#[derive(Clone, Debug, Eq, PartialEq, Hash)]
121pub enum KeyBindingContextPredicate {
122    Identifier(SharedString),
123    Equal(SharedString, SharedString),
124    NotEqual(SharedString, SharedString),
125    Child(
126        Box<KeyBindingContextPredicate>,
127        Box<KeyBindingContextPredicate>,
128    ),
129    Not(Box<KeyBindingContextPredicate>),
130    And(
131        Box<KeyBindingContextPredicate>,
132        Box<KeyBindingContextPredicate>,
133    ),
134    Or(
135        Box<KeyBindingContextPredicate>,
136        Box<KeyBindingContextPredicate>,
137    ),
138}
139
140impl KeyBindingContextPredicate {
141    pub fn parse(source: &str) -> Result<Self> {
142        let source = skip_whitespace(source);
143        let (predicate, rest) = Self::parse_expr(source, 0)?;
144        if let Some(next) = rest.chars().next() {
145            Err(anyhow!("unexpected character {next:?}"))
146        } else {
147            Ok(predicate)
148        }
149    }
150
151    pub fn eval(&self, contexts: &[KeyContext]) -> bool {
152        let Some(context) = contexts.last() else {
153            return false;
154        };
155        match self {
156            Self::Identifier(name) => context.contains(name),
157            Self::Equal(left, right) => context
158                .get(left)
159                .map(|value| value == right)
160                .unwrap_or(false),
161            Self::NotEqual(left, right) => context
162                .get(left)
163                .map(|value| value != right)
164                .unwrap_or(true),
165            Self::Not(pred) => !pred.eval(contexts),
166            Self::Child(parent, child) => {
167                parent.eval(&contexts[..contexts.len() - 1]) && child.eval(contexts)
168            }
169            Self::And(left, right) => left.eval(contexts) && right.eval(contexts),
170            Self::Or(left, right) => left.eval(contexts) || right.eval(contexts),
171        }
172    }
173
174    fn parse_expr(mut source: &str, min_precedence: u32) -> anyhow::Result<(Self, &str)> {
175        type Op = fn(
176            KeyBindingContextPredicate,
177            KeyBindingContextPredicate,
178        ) -> Result<KeyBindingContextPredicate>;
179
180        let (mut predicate, rest) = Self::parse_primary(source)?;
181        source = rest;
182
183        'parse: loop {
184            for (operator, precedence, constructor) in [
185                (">", PRECEDENCE_CHILD, Self::new_child as Op),
186                ("&&", PRECEDENCE_AND, Self::new_and as Op),
187                ("||", PRECEDENCE_OR, Self::new_or as Op),
188                ("==", PRECEDENCE_EQ, Self::new_eq as Op),
189                ("!=", PRECEDENCE_EQ, Self::new_neq as Op),
190            ] {
191                if source.starts_with(operator) && precedence >= min_precedence {
192                    source = skip_whitespace(&source[operator.len()..]);
193                    let (right, rest) = Self::parse_expr(source, precedence + 1)?;
194                    predicate = constructor(predicate, right)?;
195                    source = rest;
196                    continue 'parse;
197                }
198            }
199            break;
200        }
201
202        Ok((predicate, source))
203    }
204
205    fn parse_primary(mut source: &str) -> anyhow::Result<(Self, &str)> {
206        let next = source
207            .chars()
208            .next()
209            .ok_or_else(|| anyhow!("unexpected eof"))?;
210        match next {
211            '(' => {
212                source = skip_whitespace(&source[1..]);
213                let (predicate, rest) = Self::parse_expr(source, 0)?;
214                if rest.starts_with(')') {
215                    source = skip_whitespace(&rest[1..]);
216                    Ok((predicate, source))
217                } else {
218                    Err(anyhow!("expected a ')'"))
219                }
220            }
221            '!' => {
222                let source = skip_whitespace(&source[1..]);
223                let (predicate, source) = Self::parse_expr(&source, PRECEDENCE_NOT)?;
224                Ok((KeyBindingContextPredicate::Not(Box::new(predicate)), source))
225            }
226            _ if is_identifier_char(next) => {
227                let len = source
228                    .find(|c: char| !is_identifier_char(c))
229                    .unwrap_or(source.len());
230                let (identifier, rest) = source.split_at(len);
231                source = skip_whitespace(rest);
232                Ok((
233                    KeyBindingContextPredicate::Identifier(identifier.to_string().into()),
234                    source,
235                ))
236            }
237            _ => Err(anyhow!("unexpected character {next:?}")),
238        }
239    }
240
241    fn new_or(self, other: Self) -> Result<Self> {
242        Ok(Self::Or(Box::new(self), Box::new(other)))
243    }
244
245    fn new_and(self, other: Self) -> Result<Self> {
246        Ok(Self::And(Box::new(self), Box::new(other)))
247    }
248
249    fn new_child(self, other: Self) -> Result<Self> {
250        Ok(Self::Child(Box::new(self), Box::new(other)))
251    }
252
253    fn new_eq(self, other: Self) -> Result<Self> {
254        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
255            Ok(Self::Equal(left, right))
256        } else {
257            Err(anyhow!("operands must be identifiers"))
258        }
259    }
260
261    fn new_neq(self, other: Self) -> Result<Self> {
262        if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) {
263            Ok(Self::NotEqual(left, right))
264        } else {
265            Err(anyhow!("operands must be identifiers"))
266        }
267    }
268}
269
270const PRECEDENCE_CHILD: u32 = 1;
271const PRECEDENCE_OR: u32 = 2;
272const PRECEDENCE_AND: u32 = 3;
273const PRECEDENCE_EQ: u32 = 4;
274const PRECEDENCE_NOT: u32 = 5;
275
276fn is_identifier_char(c: char) -> bool {
277    c.is_alphanumeric() || c == '_' || c == '-'
278}
279
280fn skip_whitespace(source: &str) -> &str {
281    let len = source
282        .find(|c: char| !c.is_whitespace())
283        .unwrap_or(source.len());
284    &source[len..]
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate as gpui;
291    use KeyBindingContextPredicate::*;
292
293    #[test]
294    fn test_actions_definition() {
295        {
296            actions!(A, B, C, D, E, F, G);
297        }
298
299        {
300            actions!(
301                A,
302                B,
303                C,
304                D,
305                E,
306                F,
307                G, // Don't wrap, test the trailing comma
308            );
309        }
310    }
311
312    #[test]
313    fn test_parse_context() {
314        let mut expected = KeyContext::default();
315        expected.add("baz");
316        expected.set("foo", "bar");
317        assert_eq!(KeyContext::parse("baz foo=bar").unwrap(), expected);
318        assert_eq!(KeyContext::parse("baz foo = bar").unwrap(), expected);
319        assert_eq!(
320            KeyContext::parse("  baz foo   =   bar baz").unwrap(),
321            expected
322        );
323        assert_eq!(KeyContext::parse(" baz foo = bar").unwrap(), expected);
324    }
325
326    #[test]
327    fn test_parse_identifiers() {
328        // Identifiers
329        assert_eq!(
330            KeyBindingContextPredicate::parse("abc12").unwrap(),
331            Identifier("abc12".into())
332        );
333        assert_eq!(
334            KeyBindingContextPredicate::parse("_1a").unwrap(),
335            Identifier("_1a".into())
336        );
337    }
338
339    #[test]
340    fn test_parse_negations() {
341        assert_eq!(
342            KeyBindingContextPredicate::parse("!abc").unwrap(),
343            Not(Box::new(Identifier("abc".into())))
344        );
345        assert_eq!(
346            KeyBindingContextPredicate::parse(" ! ! abc").unwrap(),
347            Not(Box::new(Not(Box::new(Identifier("abc".into())))))
348        );
349    }
350
351    #[test]
352    fn test_parse_equality_operators() {
353        assert_eq!(
354            KeyBindingContextPredicate::parse("a == b").unwrap(),
355            Equal("a".into(), "b".into())
356        );
357        assert_eq!(
358            KeyBindingContextPredicate::parse("c!=d").unwrap(),
359            NotEqual("c".into(), "d".into())
360        );
361        assert_eq!(
362            KeyBindingContextPredicate::parse("c == !d")
363                .unwrap_err()
364                .to_string(),
365            "operands must be identifiers"
366        );
367    }
368
369    #[test]
370    fn test_parse_boolean_operators() {
371        assert_eq!(
372            KeyBindingContextPredicate::parse("a || b").unwrap(),
373            Or(
374                Box::new(Identifier("a".into())),
375                Box::new(Identifier("b".into()))
376            )
377        );
378        assert_eq!(
379            KeyBindingContextPredicate::parse("a || !b && c").unwrap(),
380            Or(
381                Box::new(Identifier("a".into())),
382                Box::new(And(
383                    Box::new(Not(Box::new(Identifier("b".into())))),
384                    Box::new(Identifier("c".into()))
385                ))
386            )
387        );
388        assert_eq!(
389            KeyBindingContextPredicate::parse("a && b || c&&d").unwrap(),
390            Or(
391                Box::new(And(
392                    Box::new(Identifier("a".into())),
393                    Box::new(Identifier("b".into()))
394                )),
395                Box::new(And(
396                    Box::new(Identifier("c".into())),
397                    Box::new(Identifier("d".into()))
398                ))
399            )
400        );
401        assert_eq!(
402            KeyBindingContextPredicate::parse("a == b && c || d == e && f").unwrap(),
403            Or(
404                Box::new(And(
405                    Box::new(Equal("a".into(), "b".into())),
406                    Box::new(Identifier("c".into()))
407                )),
408                Box::new(And(
409                    Box::new(Equal("d".into(), "e".into())),
410                    Box::new(Identifier("f".into()))
411                ))
412            )
413        );
414        assert_eq!(
415            KeyBindingContextPredicate::parse("a && b && c && d").unwrap(),
416            And(
417                Box::new(And(
418                    Box::new(And(
419                        Box::new(Identifier("a".into())),
420                        Box::new(Identifier("b".into()))
421                    )),
422                    Box::new(Identifier("c".into())),
423                )),
424                Box::new(Identifier("d".into()))
425            ),
426        );
427    }
428
429    #[test]
430    fn test_parse_parenthesized_expressions() {
431        assert_eq!(
432            KeyBindingContextPredicate::parse("a && (b == c || d != e)").unwrap(),
433            And(
434                Box::new(Identifier("a".into())),
435                Box::new(Or(
436                    Box::new(Equal("b".into(), "c".into())),
437                    Box::new(NotEqual("d".into(), "e".into())),
438                )),
439            ),
440        );
441        assert_eq!(
442            KeyBindingContextPredicate::parse(" ( a || b ) ").unwrap(),
443            Or(
444                Box::new(Identifier("a".into())),
445                Box::new(Identifier("b".into())),
446            )
447        );
448    }
449}