task: Allow obtaining custom task variables from tree-sitter queries (#11624)

Piotr Osiewicz and Remco created

From now on, only top-level captures are treated as runnable tags and
the rest is appended to task context as custom environmental variables
(unless the name is prefixed with _, in which case the capture is
ignored). This is most likely gonna help with Pest-like test runners.



Release Notes:

- N/A

---------

Co-authored-by: Remco <djsmits12@gmail.com>

Change summary

crates/editor/src/editor.rs             | 69 ++++++++++++++++----------
crates/editor/src/tasks.rs              | 18 ++++++
crates/language/src/buffer.rs           | 65 ++++++++++++++++---------
crates/language/src/language.rs         |  4 +
crates/language/src/task_context.rs     | 13 +++-
crates/languages/src/rust/runnables.scm |  4 
crates/multi_buffer/src/multi_buffer.rs | 20 +++---
7 files changed, 126 insertions(+), 67 deletions(-)

Detailed changes

crates/editor/src/editor.rs 🔗

@@ -79,7 +79,6 @@ use inlay_hint_cache::{InlayHintCache, InlaySplice, InvalidationStrategy};
 pub use inline_completion_provider::*;
 pub use items::MAX_TAB_TITLE_LEN;
 use itertools::Itertools;
-use language::Runnable;
 use language::{
     char_kind,
     language_settings::{self, all_language_settings, InlayHintSettings},
@@ -87,7 +86,8 @@ use language::{
     CursorShape, Diagnostic, Documentation, IndentKind, IndentSize, Language, OffsetRangeExt,
     Point, Selection, SelectionGoal, TransactionId,
 };
-use task::{ResolvedTask, TaskTemplate};
+use language::{Runnable, RunnableRange};
+use task::{ResolvedTask, TaskTemplate, TaskVariables};
 
 use hover_links::{HoverLink, HoveredLinkState, InlayHighlight};
 use lsp::{DiagnosticSeverity, LanguageServerId};
@@ -404,6 +404,7 @@ struct RunnableTasks {
     templates: Vec<(TaskSourceKind, TaskTemplate)>,
     // We need the column at which the task context evaluation should take place.
     column: u32,
+    extra_variables: HashMap<String, String>,
 }
 
 #[derive(Clone)]
@@ -3909,23 +3910,33 @@ impl Editor {
                                 .flatten()
                         },
                     );
-                    let tasks = tasks
-                        .zip(task_context.as_ref())
-                        .map(|(tasks, task_context)| {
-                            Arc::new(ResolvedTasks {
-                                templates: tasks
-                                    .1
-                                    .templates
-                                    .iter()
-                                    .filter_map(|(kind, template)| {
-                                        template
-                                            .resolve_task(&kind.to_id_base(), &task_context)
-                                            .map(|task| (kind.clone(), task))
-                                    })
-                                    .collect(),
-                                position: Point::new(buffer_row, tasks.1.column),
-                            })
-                        });
+                    let tasks = tasks.zip(task_context).map(|(tasks, mut task_context)| {
+                        // Fill in the environmental variables from the tree-sitter captures
+                        let mut additional_task_variables = TaskVariables::default();
+                        for (capture_name, value) in tasks.1.extra_variables.clone() {
+                            additional_task_variables.insert(
+                                task::VariableName::Custom(capture_name.into()),
+                                value.clone(),
+                            );
+                        }
+                        task_context
+                            .task_variables
+                            .extend(additional_task_variables);
+
+                        Arc::new(ResolvedTasks {
+                            templates: tasks
+                                .1
+                                .templates
+                                .iter()
+                                .filter_map(|(kind, template)| {
+                                    template
+                                        .resolve_task(&kind.to_id_base(), &task_context)
+                                        .map(|task| (kind.clone(), task))
+                                })
+                                .collect(),
+                            position: Point::new(buffer_row, tasks.1.column),
+                        })
+                    });
                     let spawn_straight_away = tasks
                         .as_ref()
                         .map_or(false, |tasks| tasks.templates.len() == 1)
@@ -7745,39 +7756,45 @@ impl Editor {
     fn fetch_runnable_ranges(
         snapshot: &DisplaySnapshot,
         range: Range<Anchor>,
-    ) -> Vec<(BufferId, Range<usize>, Runnable)> {
+    ) -> Vec<language::RunnableRange> {
         snapshot.buffer_snapshot.runnable_ranges(range).collect()
     }
 
     fn runnable_rows(
         project: Model<Project>,
         snapshot: DisplaySnapshot,
-        runnable_ranges: Vec<(BufferId, Range<usize>, Runnable)>,
+        runnable_ranges: Vec<RunnableRange>,
         mut cx: AsyncWindowContext,
     ) -> Vec<((BufferId, u32), (usize, RunnableTasks))> {
         runnable_ranges
             .into_iter()
-            .filter_map(|(buffer_id, multi_buffer_range, mut runnable)| {
+            .filter_map(|mut runnable| {
                 let (tasks, _) = cx
-                    .update(|cx| Self::resolve_runnable(project.clone(), &mut runnable, cx))
+                    .update(|cx| {
+                        Self::resolve_runnable(project.clone(), &mut runnable.runnable, cx)
+                    })
                     .ok()?;
                 if tasks.is_empty() {
                     return None;
                 }
-                let point = multi_buffer_range.start.to_point(&snapshot.buffer_snapshot);
+
+                let point = runnable.run_range.start.to_point(&snapshot.buffer_snapshot);
+
                 let row = snapshot
                     .buffer_snapshot
                     .buffer_line_for_row(point.row)?
                     .1
                     .start
                     .row;
+
                 Some((
-                    (buffer_id, row),
+                    (runnable.buffer_id, row),
                     (
-                        multi_buffer_range.start,
+                        runnable.run_range.start,
                         RunnableTasks {
                             templates: tasks,
                             column: point.column,
+                            extra_variables: runnable.extra_captures,
                         },
                     ),
                 ))

crates/editor/src/tasks.rs 🔗

@@ -6,7 +6,7 @@ use anyhow::Context;
 use gpui::WindowContext;
 use language::{BasicContextProvider, ContextProvider};
 use project::{Location, WorktreeId};
-use task::{TaskContext, TaskVariables};
+use task::{TaskContext, TaskVariables, VariableName};
 use util::ResultExt;
 use workspace::Workspace;
 
@@ -79,7 +79,21 @@ pub(crate) fn task_context_with_editor(
         buffer,
         range: start..end,
     };
-    task_context_for_location(workspace, location, cx)
+    task_context_for_location(workspace, location.clone(), cx).map(|mut task_context| {
+        for range in location
+            .buffer
+            .read(cx)
+            .snapshot()
+            .runnable_ranges(location.range)
+        {
+            for (capture_name, value) in range.extra_captures {
+                task_context
+                    .task_variables
+                    .insert(VariableName::Custom(capture_name.into()), value);
+            }
+        }
+        task_context
+    })
 }
 
 pub fn task_context(workspace: &Workspace, cx: &mut WindowContext<'_>) -> TaskContext {

crates/language/src/buffer.rs 🔗

@@ -13,6 +13,7 @@ use crate::{
         SyntaxLayer, SyntaxMap, SyntaxMapCapture, SyntaxMapCaptures, SyntaxMapMatches,
         SyntaxSnapshot, ToTreeSitterPoint,
     },
+    task_context::RunnableRange,
     LanguageScope, Outline, RunnableTag,
 };
 use anyhow::{anyhow, Context, Result};
@@ -2993,7 +2994,7 @@ impl BufferSnapshot {
     pub fn runnable_ranges(
         &self,
         range: Range<Anchor>,
-    ) -> impl Iterator<Item = (Range<usize>, Runnable)> + '_ {
+    ) -> impl Iterator<Item = RunnableRange> + '_ {
         let offset_range = range.start.to_offset(self)..range.end.to_offset(self);
 
         let mut syntax_matches = self.syntax.matches(offset_range, self, |grammar| {
@@ -3007,31 +3008,49 @@ impl BufferSnapshot {
             .collect::<Vec<_>>();
 
         iter::from_fn(move || {
-            let test_range = syntax_matches
-                .peek()
-                .and_then(|mat| {
-                    test_configs[mat.grammar_index].and_then(|test_configs| {
-                        let tags = SmallVec::from_iter(mat.captures.iter().filter_map(|capture| {
-                            test_configs.runnable_tags.get(&capture.index).cloned()
+            let test_range = syntax_matches.peek().and_then(|mat| {
+                test_configs[mat.grammar_index].and_then(|test_configs| {
+                    let mut tags: SmallVec<[(Range<usize>, RunnableTag); 1]> =
+                        SmallVec::from_iter(mat.captures.iter().filter_map(|capture| {
+                            test_configs
+                                .runnable_tags
+                                .get(&capture.index)
+                                .cloned()
+                                .map(|tag_name| (capture.node.byte_range(), tag_name))
                         }));
-
-                        if tags.is_empty() {
-                            return None;
-                        }
-
-                        Some((
-                            mat.captures
-                                .iter()
-                                .find(|capture| capture.index == test_configs.run_capture_ix)?,
-                            Runnable {
-                                tags,
-                                language: mat.language,
-                                buffer: self.remote_id(),
-                            },
-                        ))
+                    let maximum_range = tags
+                        .iter()
+                        .max_by_key(|(byte_range, _)| byte_range.len())
+                        .map(|(range, _)| range)?
+                        .clone();
+                    tags.sort_by_key(|(range, _)| range == &maximum_range);
+                    let split_point = tags.partition_point(|(range, _)| range != &maximum_range);
+                    let (extra_captures, tags) = tags.split_at(split_point);
+                    let extra_captures = extra_captures
+                        .into_iter()
+                        .map(|(range, name)| {
+                            (
+                                name.0.to_string(),
+                                self.text_for_range(range.clone()).collect::<String>(),
+                            )
+                        })
+                        .collect();
+                    Some(RunnableRange {
+                        run_range: mat
+                            .captures
+                            .iter()
+                            .find(|capture| capture.index == test_configs.run_capture_ix)
+                            .map(|mat| mat.node.byte_range())?,
+                        runnable: Runnable {
+                            tags: tags.into_iter().cloned().map(|(_, tag)| tag).collect(),
+                            language: mat.language,
+                            buffer: self.remote_id(),
+                        },
+                        extra_captures,
+                        buffer_id: self.remote_id(),
                     })
                 })
-                .map(|(mat, test_tags)| (mat.node.byte_range(), test_tags));
+            });
             syntax_matches.advance();
             test_range
         })

crates/language/src/language.rs 🔗

@@ -57,7 +57,9 @@ use std::{
 };
 use syntax_map::{QueryCursorHandle, SyntaxSnapshot};
 use task::RunnableTag;
-pub use task_context::{BasicContextProvider, ContextProvider, ContextProviderWithTasks};
+pub use task_context::{
+    BasicContextProvider, ContextProvider, ContextProviderWithTasks, RunnableRange,
+};
 use theme::SyntaxTheme;
 use tree_sitter::{self, wasmtime, Query, QueryCursor, WasmStore};
 use util::http::HttpClient;

crates/language/src/task_context.rs 🔗

@@ -1,12 +1,19 @@
-use std::path::Path;
+use std::{ops::Range, path::Path};
 
-use crate::Location;
+use crate::{Location, Runnable};
 
 use anyhow::Result;
+use collections::HashMap;
 use gpui::AppContext;
 use task::{TaskTemplates, TaskVariables, VariableName};
-use text::{Point, ToPoint};
+use text::{BufferId, Point, ToPoint};
 
+pub struct RunnableRange {
+    pub buffer_id: BufferId,
+    pub run_range: Range<usize>,
+    pub runnable: Runnable,
+    pub extra_captures: HashMap<String, String>,
+}
 /// Language Contexts are used by Zed tasks to extract information about the source file where the tasks are supposed to be scheduled from.
 /// Multiple context providers may be used together: by default, Zed provides a base [`BasicContextProvider`] context that fills all non-custom [`VariableName`] variants.
 ///

crates/languages/src/rust/runnables.scm 🔗

@@ -1,6 +1,6 @@
 (
-    (attribute_item (attribute) @_attribute
-        (#match? @_attribute ".*test"))
+    (attribute_item (attribute) @attribute
+        (#match? @attribute ".*test"))
     .
     (function_item
         name: (_) @run)

crates/multi_buffer/src/multi_buffer.rs 🔗

@@ -13,7 +13,7 @@ use language::{
     language_settings::{language_settings, LanguageSettings},
     AutoindentMode, Buffer, BufferChunks, BufferSnapshot, Capability, CharKind, Chunk, CursorShape,
     DiagnosticEntry, File, IndentSize, Language, LanguageScope, OffsetRangeExt, OffsetUtf16,
-    Outline, OutlineItem, Point, PointUtf16, Runnable, Selection, TextDimension, ToOffset as _,
+    Outline, OutlineItem, Point, PointUtf16, Selection, TextDimension, ToOffset as _,
     ToOffsetUtf16 as _, ToPoint as _, ToPointUtf16 as _, TransactionId, Unclipped,
 };
 use smallvec::SmallVec;
@@ -3168,7 +3168,7 @@ impl MultiBufferSnapshot {
     pub fn runnable_ranges(
         &self,
         range: Range<Anchor>,
-    ) -> impl Iterator<Item = (BufferId, Range<usize>, Runnable)> + '_ {
+    ) -> impl Iterator<Item = language::RunnableRange> + '_ {
         let range = range.start.to_offset(self)..range.end.to_offset(self);
         self.excerpts_for_range(range.clone())
             .flat_map(move |(excerpt, excerpt_offset)| {
@@ -3177,16 +3177,16 @@ impl MultiBufferSnapshot {
                 excerpt
                     .buffer
                     .runnable_ranges(excerpt.range.context.clone())
-                    .map(move |(mut match_range, runnable)| {
+                    .map(move |mut runnable| {
                         // Re-base onto the excerpts coordinates in the multibuffer
-                        match_range.start =
-                            excerpt_offset + (match_range.start - excerpt_buffer_start);
-                        match_range.end = excerpt_offset + (match_range.end - excerpt_buffer_start);
-
-                        (excerpt.buffer_id, match_range, runnable)
+                        runnable.run_range.start =
+                            excerpt_offset + (runnable.run_range.start - excerpt_buffer_start);
+                        runnable.run_range.end =
+                            excerpt_offset + (runnable.run_range.end - excerpt_buffer_start);
+                        runnable
                     })
-                    .skip_while(move |(_, match_range, _)| match_range.end < range.start)
-                    .take_while(move |(_, match_range, _)| match_range.start < range.end)
+                    .skip_while(move |runnable| runnable.run_range.end < range.start)
+                    .take_while(move |runnable| runnable.run_range.start < range.end)
             })
     }