Convert query capture indices to style ids

Max Brunsfeld created

* Introduce a Theme struct as a new part of the app's settings
* Store on each Language a ThemeMap, which converts the capture ids
  from that language's highlight query into StyleIds, which identify
  styles in the current Theme.
* Update `highlighted_chunks` methods to provide StyleIds instead of
  capture ids.

Change summary

Cargo.lock                             |  12 +
zed/Cargo.toml                         |   1 
zed/assets/themes/light.toml           |  13 +
zed/languages/rust/highlights.scm      |  45 ++++
zed/src/editor/buffer/mod.rs           |  19 +
zed/src/editor/buffer_view.rs          |  12 
zed/src/editor/display_map/fold_map.rs |   7 
zed/src/editor/display_map/mod.rs      |  20 +
zed/src/language.rs                    |  22 ++
zed/src/main.rs                        |   1 
zed/src/settings.rs                    | 263 +++++++++++++++++++++++++++
11 files changed, 389 insertions(+), 26 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -347,7 +347,7 @@ dependencies = [
  "tar",
  "target_build_utils",
  "term",
- "toml",
+ "toml 0.4.10",
  "uuid",
  "walkdir",
 ]
@@ -2702,6 +2702,15 @@ dependencies = [
  "serde 1.0.125",
 ]
 
+[[package]]
+name = "toml"
+version = "0.5.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa"
+dependencies = [
+ "serde 1.0.125",
+]
+
 [[package]]
 name = "tree-sitter"
 version = "0.19.5"
@@ -2996,6 +3005,7 @@ dependencies = [
  "smallvec",
  "smol",
  "tempdir",
+ "toml 0.5.8",
  "tree-sitter",
  "tree-sitter-rust",
  "unindent",

zed/Cargo.toml 🔗

@@ -38,6 +38,7 @@ similar = "1.3"
 simplelog = "0.9"
 smallvec = {version = "1.6", features = ["union"]}
 smol = "1.2.5"
+toml = "0.5"
 tree-sitter = "0.19.5"
 tree-sitter-rust = "0.19.0"
 

zed/assets/themes/light.toml 🔗

@@ -0,0 +1,13 @@
+[ui]
+background = 0xffffff
+line_numbers = 0x237791
+text = 0x0d0d0d
+
+[syntax]
+keyword = 0xaf00db
+function = 0x795e26
+string = 0xa31515
+type = 0x267599
+number = 0x0d885b
+comment = 0x048204
+property = 0x001080

zed/languages/rust/highlights.scm 🔗

@@ -1,6 +1,49 @@
+(type_identifier) @type
+
+(call_expression
+  function: [
+    (identifier) @function
+    (scoped_identifier
+      name: (identifier) @function)
+    (field_expression
+      field: (field_identifier) @function.method)
+  ])
+
+(field_identifier) @property
+
+(function_item
+  name: (identifier) @function.definition)
+
 [
+  "async"
+  "break"
+  "const"
+  "continue"
+  "dyn"
   "else"
+  "enum"
+  "for"
   "fn"
   "if"
+  "impl"
+  "let"
+  "loop"
+  "match"
+  "mod"
+  "move"
+  "pub"
+  "return"
+  "struct"
+  "trait"
+  "type"
+  "use"
+  "where"
   "while"
-] @keyword
+] @keyword
+
+(string_literal) @string
+
+[
+  (line_comment)
+  (block_comment)
+] @comment

zed/src/editor/buffer/mod.rs 🔗

@@ -16,6 +16,7 @@ use crate::{
     editor::Bias,
     language::{Language, Tree},
     operation_queue::{self, OperationQueue},
+    settings::{StyleId, ThemeMap},
     sum_tree::{self, FilterCursor, SeekBias, SumTree},
     time::{self, ReplicaId},
     worktree::FileHandle,
@@ -617,6 +618,7 @@ impl Buffer {
                     handle.update(&mut ctx, |this, ctx| {
                         this.tree = Some((new_tree, new_version));
                         ctx.emit(Event::Reparsed);
+                        ctx.notify();
                     });
                 }
                 handle.update(&mut ctx, |this, _| this.is_parsing = false);
@@ -778,6 +780,7 @@ impl Buffer {
                 highlights: Some(Highlights {
                     captures: captures.peekable(),
                     stack: Default::default(),
+                    theme_mapping: language.theme_mapping(),
                     cursor,
                 }),
                 buffer: self,
@@ -2200,8 +2203,9 @@ impl<'a> tree_sitter::TextProvider<'a> for TextProvider<'a> {
 
 struct Highlights<'a> {
     captures: iter::Peekable<tree_sitter::QueryCaptures<'a, 'a, TextProvider<'a>>>,
-    stack: Vec<(usize, usize)>,
+    stack: Vec<(usize, StyleId)>,
     cursor: QueryCursor,
+    theme_mapping: ThemeMap,
 }
 
 pub struct HighlightedChunks<'a> {
@@ -2240,7 +2244,7 @@ impl<'a> HighlightedChunks<'a> {
 }
 
 impl<'a> Iterator for HighlightedChunks<'a> {
-    type Item = (&'a str, Option<usize>);
+    type Item = (&'a str, StyleId);
 
     fn next(&mut self) -> Option<Self::Item> {
         let mut next_capture_start = usize::MAX;
@@ -2260,9 +2264,8 @@ impl<'a> Iterator for HighlightedChunks<'a> {
                     next_capture_start = capture.node.start_byte();
                     break;
                 } else {
-                    highlights
-                        .stack
-                        .push((capture.node.end_byte(), capture.index as usize));
+                    let style_id = highlights.theme_mapping.get(capture.index);
+                    highlights.stack.push((capture.node.end_byte(), style_id));
                     highlights.captures.next().unwrap();
                 }
             }
@@ -2271,12 +2274,12 @@ impl<'a> Iterator for HighlightedChunks<'a> {
         if let Some(chunk) = self.chunks.peek() {
             let chunk_start = self.range.start;
             let mut chunk_end = (self.chunks.offset() + chunk.len()).min(next_capture_start);
-            let mut capture_ix = None;
-            if let Some((parent_capture_end, parent_capture_ix)) =
+            let mut capture_ix = StyleId::default();
+            if let Some((parent_capture_end, parent_style_id)) =
                 self.highlights.as_ref().and_then(|h| h.stack.last())
             {
                 chunk_end = chunk_end.min(*parent_capture_end);
-                capture_ix = Some(*parent_capture_ix);
+                capture_ix = *parent_style_id;
             }
 
             let slice =

zed/src/editor/buffer_view.rs 🔗

@@ -2,7 +2,12 @@ use super::{
     buffer, movement, Anchor, Bias, Buffer, BufferElement, DisplayMap, DisplayPoint, Point,
     Selection, SelectionGoal, SelectionSetId, ToOffset, ToPoint,
 };
-use crate::{settings::Settings, util::post_inc, workspace, worktree::FileHandle};
+use crate::{
+    settings::{Settings, StyleId},
+    util::post_inc,
+    workspace,
+    worktree::FileHandle,
+};
 use anyhow::Result;
 use gpui::{
     color::ColorU, fonts::Properties as FontProperties, geometry::vector::Vector2F,
@@ -2145,8 +2150,9 @@ impl BufferView {
         let mut row = rows.start;
         let snapshot = self.display_map.snapshot(ctx);
         let chunks = snapshot.highlighted_chunks_at(rows.start, ctx);
+        let theme = settings.theme.clone();
 
-        'outer: for (chunk, capture_ix) in chunks.chain(Some(("\n", None))) {
+        'outer: for (chunk, style_ix) in chunks.chain(Some(("\n", StyleId::default()))) {
             for (ix, line_chunk) in chunk.split('\n').enumerate() {
                 if ix > 0 {
                     layouts.push(layout_cache.layout_str(&line, font_size, &styles));
@@ -2160,7 +2166,7 @@ impl BufferView {
 
                 if !line_chunk.is_empty() {
                     line.push_str(line_chunk);
-                    styles.push((line_chunk.len(), font_id, ColorU::black()));
+                    styles.push((line_chunk.len(), font_id, theme.syntax_style(style_ix).0));
                 }
             }
         }

zed/src/editor/display_map/fold_map.rs 🔗

@@ -4,6 +4,7 @@ use super::{
 };
 use crate::{
     editor::buffer,
+    settings::StyleId,
     sum_tree::{self, Cursor, FilterCursor, SeekBias, SumTree},
     time,
 };
@@ -741,12 +742,12 @@ impl<'a> Iterator for Chunks<'a> {
 pub struct HighlightedChunks<'a> {
     transform_cursor: Cursor<'a, Transform, DisplayOffset, TransformSummary>,
     buffer_chunks: buffer::HighlightedChunks<'a>,
-    buffer_chunk: Option<(usize, &'a str, Option<usize>)>,
+    buffer_chunk: Option<(usize, &'a str, StyleId)>,
     buffer_offset: usize,
 }
 
 impl<'a> Iterator for HighlightedChunks<'a> {
-    type Item = (&'a str, Option<usize>);
+    type Item = (&'a str, StyleId);
 
     fn next(&mut self) -> Option<Self::Item> {
         let transform = if let Some(item) = self.transform_cursor.item() {
@@ -768,7 +769,7 @@ impl<'a> Iterator for HighlightedChunks<'a> {
                 self.transform_cursor.next();
             }
 
-            return Some((display_text, None));
+            return Some((display_text, StyleId::default()));
         }
 
         // Retrieve a chunk from the current location in the buffer.

zed/src/editor/display_map/mod.rs 🔗

@@ -1,5 +1,7 @@
 mod fold_map;
 
+use crate::settings::StyleId;
+
 use super::{buffer, Anchor, Bias, Buffer, Edit, Point, ToOffset, ToPoint};
 pub use fold_map::BufferRows;
 use fold_map::{FoldMap, FoldMapSnapshot};
@@ -163,7 +165,7 @@ impl DisplayMapSnapshot {
             column: 0,
             tab_size: self.tab_size,
             chunk: "",
-            capture_ix: None,
+            style_id: Default::default(),
         }
     }
 
@@ -355,19 +357,19 @@ impl<'a> Iterator for Chunks<'a> {
 pub struct HighlightedChunks<'a> {
     fold_chunks: fold_map::HighlightedChunks<'a>,
     chunk: &'a str,
-    capture_ix: Option<usize>,
+    style_id: StyleId,
     column: usize,
     tab_size: usize,
 }
 
 impl<'a> Iterator for HighlightedChunks<'a> {
-    type Item = (&'a str, Option<usize>);
+    type Item = (&'a str, StyleId);
 
     fn next(&mut self) -> Option<Self::Item> {
         if self.chunk.is_empty() {
-            if let Some((chunk, capture_ix)) = self.fold_chunks.next() {
+            if let Some((chunk, style_id)) = self.fold_chunks.next() {
                 self.chunk = chunk;
-                self.capture_ix = capture_ix;
+                self.style_id = style_id;
             } else {
                 return None;
             }
@@ -379,12 +381,12 @@ impl<'a> Iterator for HighlightedChunks<'a> {
                     if ix > 0 {
                         let (prefix, suffix) = self.chunk.split_at(ix);
                         self.chunk = suffix;
-                        return Some((prefix, self.capture_ix));
+                        return Some((prefix, self.style_id));
                     } else {
                         self.chunk = &self.chunk[1..];
                         let len = self.tab_size - self.column % self.tab_size;
                         self.column += len;
-                        return Some((&SPACES[0..len], self.capture_ix));
+                        return Some((&SPACES[0..len], self.style_id));
                     }
                 }
                 '\n' => self.column = 0,
@@ -392,7 +394,9 @@ impl<'a> Iterator for HighlightedChunks<'a> {
             }
         }
 
-        Some((mem::take(&mut self.chunk), self.capture_ix.take()))
+        let style_id = self.style_id;
+        self.style_id = StyleId::default();
+        Some((mem::take(&mut self.chunk), style_id))
     }
 }
 

zed/src/language.rs 🔗

@@ -1,3 +1,5 @@
+use crate::settings::{Theme, ThemeMap};
+use parking_lot::Mutex;
 use rust_embed::RustEmbed;
 use std::{path::Path, sync::Arc};
 use tree_sitter::{Language as Grammar, Query};
@@ -13,12 +15,23 @@ pub struct Language {
     pub grammar: Grammar,
     pub highlight_query: Query,
     path_suffixes: Vec<String>,
+    theme_mapping: Mutex<ThemeMap>,
 }
 
 pub struct LanguageRegistry {
     languages: Vec<Arc<Language>>,
 }
 
+impl Language {
+    pub fn theme_mapping(&self) -> ThemeMap {
+        self.theme_mapping.lock().clone()
+    }
+
+    fn set_theme(&self, theme: &Theme) {
+        *self.theme_mapping.lock() = ThemeMap::new(self.highlight_query.capture_names(), theme);
+    }
+}
+
 impl LanguageRegistry {
     pub fn new() -> Self {
         let grammar = tree_sitter_rust::language();
@@ -32,6 +45,7 @@ impl LanguageRegistry {
             )
             .unwrap(),
             path_suffixes: vec!["rs".to_string()],
+            theme_mapping: Mutex::new(ThemeMap::default()),
         };
 
         Self {
@@ -39,6 +53,12 @@ impl LanguageRegistry {
         }
     }
 
+    pub fn set_theme(&self, theme: &Theme) {
+        for language in &self.languages {
+            language.set_theme(theme);
+        }
+    }
+
     pub fn select_language(&self, path: impl AsRef<Path>) -> Option<&Arc<Language>> {
         let path = path.as_ref();
         let filename = path.file_name().and_then(|name| name.to_str());
@@ -67,12 +87,14 @@ mod tests {
                     grammar,
                     highlight_query: Query::new(grammar, "").unwrap(),
                     path_suffixes: vec!["rs".to_string()],
+                    theme_mapping: Default::default(),
                 }),
                 Arc::new(Language {
                     name: "Make".to_string(),
                     grammar,
                     highlight_query: Query::new(grammar, "").unwrap(),
                     path_suffixes: vec!["Makefile".to_string(), "mk".to_string()],
+                    theme_mapping: Default::default(),
                 }),
             ],
         };

zed/src/main.rs 🔗

@@ -18,6 +18,7 @@ fn main() {
 
     let (_, settings) = settings::channel(&app.font_cache()).unwrap();
     let language_registry = Arc::new(language::LanguageRegistry::new());
+    language_registry.set_theme(&settings.borrow().theme);
     let app_state = AppState {
         language_registry,
         settings,

zed/src/settings.rs 🔗

@@ -1,6 +1,15 @@
-use anyhow::Result;
-use gpui::font_cache::{FamilyId, FontCache};
+use super::assets::Assets;
+use anyhow::{anyhow, Context, Result};
+use gpui::{
+    color::ColorU,
+    font_cache::{FamilyId, FontCache},
+    fonts::Weight,
+};
 use postage::watch;
+use serde::Deserialize;
+use std::{collections::HashMap, sync::Arc};
+
+const DEFAULT_STYLE_ID: StyleId = StyleId(u32::MAX);
 
 #[derive(Clone)]
 pub struct Settings {
@@ -9,8 +18,23 @@ pub struct Settings {
     pub tab_size: usize,
     pub ui_font_family: FamilyId,
     pub ui_font_size: f32,
+    pub theme: Arc<Theme>,
+}
+
+#[derive(Clone, Default)]
+pub struct Theme {
+    pub background_color: ColorU,
+    pub line_number_color: ColorU,
+    pub default_text_color: ColorU,
+    syntax_styles: Vec<(String, ColorU, Weight)>,
 }
 
+#[derive(Clone, Debug)]
+pub struct ThemeMap(Arc<[StyleId]>);
+
+#[derive(Clone, Copy, Debug)]
+pub struct StyleId(u32);
+
 impl Settings {
     pub fn new(font_cache: &FontCache) -> Result<Self> {
         Ok(Self {
@@ -19,12 +43,247 @@ impl Settings {
             tab_size: 4,
             ui_font_family: font_cache.load_family(&["SF Pro", "Helvetica"])?,
             ui_font_size: 12.0,
+            theme: Arc::new(
+                Theme::parse(Assets::get("themes/light.toml").unwrap())
+                    .expect("Failed to parse built-in theme"),
+            ),
         })
     }
 }
 
+impl Theme {
+    pub fn parse(source: impl AsRef<[u8]>) -> Result<Self> {
+        #[derive(Deserialize)]
+        struct ThemeToml {
+            #[serde(default)]
+            syntax: HashMap<String, StyleToml>,
+            #[serde(default)]
+            ui: HashMap<String, u32>,
+        }
+
+        #[derive(Deserialize)]
+        #[serde(untagged)]
+        enum StyleToml {
+            Color(u32),
+            Full {
+                color: Option<u32>,
+                weight: Option<toml::Value>,
+            },
+        }
+
+        let theme_toml: ThemeToml =
+            toml::from_slice(source.as_ref()).context("failed to parse theme TOML")?;
+
+        let mut syntax_styles = Vec::<(String, ColorU, Weight)>::new();
+        for (key, style) in theme_toml.syntax {
+            let (color, weight) = match style {
+                StyleToml::Color(color) => (color, None),
+                StyleToml::Full { color, weight } => (color.unwrap_or(0), weight),
+            };
+            match syntax_styles.binary_search_by_key(&&key, |e| &e.0) {
+                Ok(i) | Err(i) => syntax_styles.insert(
+                    i,
+                    (key, deserialize_color(color), deserialize_weight(weight)?),
+                ),
+            }
+        }
+
+        let background_color = theme_toml
+            .ui
+            .get("background")
+            .copied()
+            .map_or(ColorU::from_u32(0xffffffff), deserialize_color);
+        let line_number_color = theme_toml
+            .ui
+            .get("line_numbers")
+            .copied()
+            .map_or(ColorU::black(), deserialize_color);
+        let default_text_color = theme_toml
+            .ui
+            .get("text")
+            .copied()
+            .map_or(ColorU::black(), deserialize_color);
+
+        Ok(Theme {
+            background_color,
+            line_number_color,
+            default_text_color,
+            syntax_styles,
+        })
+    }
+
+    pub fn syntax_style(&self, id: StyleId) -> (ColorU, Weight) {
+        self.syntax_styles
+            .get(id.0 as usize)
+            .map_or((self.default_text_color, Weight::NORMAL), |entry| {
+                (entry.1, entry.2)
+            })
+    }
+
+    #[cfg(test)]
+    pub fn syntax_style_name(&self, id: StyleId) -> Option<&str> {
+        self.syntax_styles.get(id.0 as usize).map(|e| e.0.as_str())
+    }
+}
+
+impl ThemeMap {
+    pub fn new(capture_names: &[String], theme: &Theme) -> Self {
+        // For each capture name in the highlight query, find the longest
+        // key in the theme's syntax styles that matches all of the
+        // dot-separated components of the capture name.
+        ThemeMap(
+            capture_names
+                .iter()
+                .map(|capture_name| {
+                    theme
+                        .syntax_styles
+                        .iter()
+                        .enumerate()
+                        .filter_map(|(i, (key, _, _))| {
+                            let mut len = 0;
+                            let capture_parts = capture_name.split('.');
+                            for key_part in key.split('.') {
+                                if capture_parts.clone().any(|part| part == key_part) {
+                                    len += 1;
+                                } else {
+                                    return None;
+                                }
+                            }
+                            Some((i, len))
+                        })
+                        .max_by_key(|(_, len)| *len)
+                        .map_or(DEFAULT_STYLE_ID, |(i, _)| StyleId(i as u32))
+                })
+                .collect(),
+        )
+    }
+
+    pub fn get(&self, capture_id: u32) -> StyleId {
+        self.0
+            .get(capture_id as usize)
+            .copied()
+            .unwrap_or(DEFAULT_STYLE_ID)
+    }
+}
+
+impl Default for ThemeMap {
+    fn default() -> Self {
+        Self(Arc::new([]))
+    }
+}
+
+impl Default for StyleId {
+    fn default() -> Self {
+        DEFAULT_STYLE_ID
+    }
+}
+
 pub fn channel(
     font_cache: &FontCache,
 ) -> Result<(watch::Sender<Settings>, watch::Receiver<Settings>)> {
     Ok(watch::channel_with(Settings::new(font_cache)?))
 }
+
+fn deserialize_color(color: u32) -> ColorU {
+    ColorU::from_u32((color << 8) + 0xFF)
+}
+
+fn deserialize_weight(weight: Option<toml::Value>) -> Result<Weight> {
+    match &weight {
+        None => return Ok(Weight::NORMAL),
+        Some(toml::Value::Integer(i)) => return Ok(Weight(*i as f32)),
+        Some(toml::Value::String(s)) => match s.as_str() {
+            "normal" => return Ok(Weight::NORMAL),
+            "bold" => return Ok(Weight::BOLD),
+            "light" => return Ok(Weight::LIGHT),
+            "semibold" => return Ok(Weight::SEMIBOLD),
+            _ => {}
+        },
+        _ => {}
+    }
+    Err(anyhow!("Invalid weight {}", weight.unwrap()))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_parse_theme() {
+        let theme = Theme::parse(
+            r#"
+            [ui]
+            background = 0x00ed00
+            line_numbers = 0xdddddd
+
+            [syntax]
+            "beta.two" = 0xAABBCC
+            "alpha.one" = {color = 0x112233, weight = "bold"}
+            "gamma.three" = {weight = "light"}
+            "#,
+        )
+        .unwrap();
+
+        assert_eq!(theme.background_color, ColorU::from_u32(0x00ED00FF));
+        assert_eq!(theme.line_number_color, ColorU::from_u32(0xddddddff));
+        assert_eq!(
+            theme.syntax_styles,
+            &[
+                (
+                    "alpha.one".to_string(),
+                    ColorU::from_u32(0x112233FF),
+                    Weight::BOLD
+                ),
+                (
+                    "beta.two".to_string(),
+                    ColorU::from_u32(0xAABBCCFF),
+                    Weight::NORMAL
+                ),
+                (
+                    "gamma.three".to_string(),
+                    ColorU::from_u32(0x000000FF),
+                    Weight::LIGHT,
+                ),
+            ]
+        );
+    }
+
+    #[test]
+    fn test_parse_empty_theme() {
+        Theme::parse("").unwrap();
+    }
+
+    #[test]
+    fn test_theme_map() {
+        let theme = Theme {
+            default_text_color: Default::default(),
+            background_color: ColorU::default(),
+            line_number_color: ColorU::default(),
+            syntax_styles: [
+                ("function", ColorU::from_u32(0x100000ff)),
+                ("function.method", ColorU::from_u32(0x200000ff)),
+                ("function.async", ColorU::from_u32(0x300000ff)),
+                ("variable.builtin.self.rust", ColorU::from_u32(0x400000ff)),
+                ("variable.builtin", ColorU::from_u32(0x500000ff)),
+                ("variable", ColorU::from_u32(0x600000ff)),
+            ]
+            .iter()
+            .map(|e| (e.0.to_string(), e.1, Weight::NORMAL))
+            .collect(),
+        };
+
+        let capture_names = &[
+            "function.special".to_string(),
+            "function.async.rust".to_string(),
+            "variable.builtin.self".to_string(),
+        ];
+
+        let map = ThemeMap::new(capture_names, &theme);
+        assert_eq!(theme.syntax_style_name(map.get(0)), Some("function"));
+        assert_eq!(theme.syntax_style_name(map.get(1)), Some("function.async"));
+        assert_eq!(
+            theme.syntax_style_name(map.get(2)),
+            Some("variable.builtin")
+        );
+    }
+}