Merge Zed task context providing logic (#10544)

Kirill Bulatov created

Before, `tasks_ui` set most of the context with `SymbolContextProvider`
providing the symbol data part of the context. Now, there's a
`BasicContextProvider` that forms all standard Zed context and it
automatically serves as a base, with no need for other providers like
`RustContextProvider` to call it as before.

Also, stop adding `SelectedText` task variable into the context for
blank text selection.

Release Notes:

- N/A

Change summary

crates/language/src/language.rs     |   2 
crates/language/src/task_context.rs |  92 +++++++++++--
crates/languages/src/lib.rs         |   4 
crates/languages/src/rust.rs        |  25 ++-
crates/tasks_ui/src/lib.rs          | 196 ++++++++++++++----------------
5 files changed, 181 insertions(+), 138 deletions(-)

Detailed changes

crates/language/src/language.rs 🔗

@@ -56,7 +56,7 @@ use std::{
     },
 };
 use syntax_map::SyntaxSnapshot;
-pub use task_context::{ContextProvider, ContextProviderWithTasks, SymbolContextProvider};
+pub use task_context::{BasicContextProvider, ContextProvider, ContextProviderWithTasks};
 use theme::SyntaxTheme;
 use tree_sitter::{self, wasmtime, Query, WasmStore};
 use util::http::HttpClient;

crates/language/src/task_context.rs 🔗

@@ -1,34 +1,56 @@
+use std::path::Path;
+
 use crate::Location;
 
 use anyhow::Result;
 use gpui::AppContext;
 use task::{static_source::TaskDefinitions, TaskVariables, VariableName};
+use text::{Point, ToPoint};
 
-/// Language Contexts are used by Zed tasks to extract information about source file.
+/// Language Contexts are used by Zed tasks to extract information about the source file where the tasks are supposed to be scheduled from.
+/// Multiple context providers may be used together: by default, Zed provides a base [`BasicContextProvider`] context that fills all non-custom [`VariableName`] variants.
+///
+/// The context will be used to fill data for the tasks, and filter out the ones that do not have the variables required.
 pub trait ContextProvider: Send + Sync {
-    fn build_context(&self, _: Location, _: &mut AppContext) -> Result<TaskVariables> {
+    /// Builds a specific context to be placed on top of the basic one (replacing all conflicting entries) and to be used for task resolving later.
+    fn build_context(
+        &self,
+        _: Option<&Path>,
+        _: &Location,
+        _: &mut AppContext,
+    ) -> Result<TaskVariables> {
         Ok(TaskVariables::default())
     }
 
+    /// Provides all tasks, associated with the current language.
     fn associated_tasks(&self) -> Option<TaskDefinitions> {
         None
     }
+
+    // Determines whether the [`BasicContextProvider`] variables should be filled too (if `false`), or omitted (if `true`).
+    fn is_basic(&self) -> bool {
+        false
+    }
 }
 
-/// A context provider that finds out what symbol is currently focused in the buffer.
-pub struct SymbolContextProvider;
+/// A context provided that tries to provide values for all non-custom [`VariableName`] variants for a currently opened file.
+/// Applied as a base for every custom [`ContextProvider`] unless explicitly oped out.
+pub struct BasicContextProvider;
+
+impl ContextProvider for BasicContextProvider {
+    fn is_basic(&self) -> bool {
+        true
+    }
 
-impl ContextProvider for SymbolContextProvider {
     fn build_context(
         &self,
-        location: Location,
+        worktree_abs_path: Option<&Path>,
+        location: &Location,
         cx: &mut AppContext,
-    ) -> gpui::Result<TaskVariables> {
-        let symbols = location
-            .buffer
-            .read(cx)
-            .snapshot()
-            .symbols_containing(location.range.start, None);
+    ) -> Result<TaskVariables> {
+        let buffer = location.buffer.read(cx);
+        let buffer_snapshot = buffer.snapshot();
+        let symbols = buffer_snapshot.symbols_containing(location.range.start, None);
         let symbol = symbols.unwrap_or_default().last().map(|symbol| {
             let range = symbol
                 .name_ranges
@@ -37,9 +59,40 @@ impl ContextProvider for SymbolContextProvider {
                 .unwrap_or(0..symbol.text.len());
             symbol.text[range].to_string()
         });
-        Ok(TaskVariables::from_iter(
-            Some(VariableName::Symbol).zip(symbol),
-        ))
+
+        let current_file = buffer
+            .file()
+            .and_then(|file| file.as_local())
+            .map(|file| file.abs_path(cx).to_string_lossy().to_string());
+        let Point { row, column } = location.range.start.to_point(&buffer_snapshot);
+        let row = row + 1;
+        let column = column + 1;
+        let selected_text = buffer
+            .chars_for_range(location.range.clone())
+            .collect::<String>();
+
+        let mut task_variables = TaskVariables::from_iter([
+            (VariableName::Row, row.to_string()),
+            (VariableName::Column, column.to_string()),
+        ]);
+
+        if let Some(symbol) = symbol {
+            task_variables.insert(VariableName::Symbol, symbol);
+        }
+        if !selected_text.trim().is_empty() {
+            task_variables.insert(VariableName::SelectedText, selected_text);
+        }
+        if let Some(path) = current_file {
+            task_variables.insert(VariableName::File, path);
+        }
+        if let Some(worktree_path) = worktree_abs_path {
+            task_variables.insert(
+                VariableName::WorktreeRoot,
+                worktree_path.to_string_lossy().to_string(),
+            );
+        }
+
+        Ok(task_variables)
     }
 }
 
@@ -59,7 +112,12 @@ impl ContextProvider for ContextProviderWithTasks {
         Some(self.definitions.clone())
     }
 
-    fn build_context(&self, location: Location, cx: &mut AppContext) -> Result<TaskVariables> {
-        SymbolContextProvider.build_context(location, cx)
+    fn build_context(
+        &self,
+        worktree_abs_path: Option<&Path>,
+        location: &Location,
+        cx: &mut AppContext,
+    ) -> Result<TaskVariables> {
+        BasicContextProvider.build_context(worktree_abs_path, location, cx)
     }
 }

crates/languages/src/lib.rs 🔗

@@ -105,7 +105,7 @@ pub fn init(
                     Ok((
                         config.clone(),
                         load_queries($name),
-                        Some(Arc::new(language::SymbolContextProvider)),
+                        Some(Arc::new(language::BasicContextProvider)),
                     ))
                 },
             );
@@ -125,7 +125,7 @@ pub fn init(
                     Ok((
                         config.clone(),
                         load_queries($name),
-                        Some(Arc::new(language::SymbolContextProvider)),
+                        Some(Arc::new(language::BasicContextProvider)),
                     ))
                 },
             );

crates/languages/src/rust.rs 🔗

@@ -334,25 +334,26 @@ const RUST_PACKAGE_TASK_VARIABLE: VariableName =
 impl ContextProvider for RustContextProvider {
     fn build_context(
         &self,
-        location: Location,
+        _: Option<&Path>,
+        location: &Location,
         cx: &mut gpui::AppContext,
     ) -> Result<TaskVariables> {
-        let mut context = SymbolContextProvider.build_context(location.clone(), cx)?;
-
         let local_abs_path = location
             .buffer
             .read(cx)
             .file()
             .and_then(|file| Some(file.as_local()?.abs_path(cx)));
-        if let Some(package_name) = local_abs_path
-            .as_deref()
-            .and_then(|local_abs_path| local_abs_path.parent())
-            .and_then(human_readable_package_name)
-        {
-            context.insert(RUST_PACKAGE_TASK_VARIABLE.clone(), package_name);
-        }
-
-        Ok(context)
+        Ok(
+            if let Some(package_name) = local_abs_path
+                .as_deref()
+                .and_then(|local_abs_path| local_abs_path.parent())
+                .and_then(human_readable_package_name)
+            {
+                TaskVariables::from_iter(Some((RUST_PACKAGE_TASK_VARIABLE.clone(), package_name)))
+            } else {
+                TaskVariables::default()
+            },
+        )
     }
 
     fn associated_tasks(&self) -> Option<TaskDefinitions> {

crates/tasks_ui/src/lib.rs 🔗

@@ -1,12 +1,16 @@
-use std::{path::PathBuf, sync::Arc};
+use std::{
+    path::{Path, PathBuf},
+    sync::Arc,
+};
 
 use ::settings::Settings;
+use anyhow::Context;
 use editor::Editor;
 use gpui::{AppContext, ViewContext, WindowContext};
-use language::{Language, Point};
+use language::{BasicContextProvider, ContextProvider, Language};
 use modal::{Spawn, TasksModal};
 use project::{Location, WorktreeId};
-use task::{Task, TaskContext, TaskVariables, VariableName};
+use task::{Task, TaskContext, TaskVariables};
 use util::ResultExt;
 use workspace::Workspace;
 
@@ -29,8 +33,7 @@ pub fn init(cx: &mut AppContext) {
                         })
                     {
                         let task_context = if action.reevaluate_context {
-                            let cwd = task_cwd(workspace, cx).log_err().flatten();
-                            task_context(workspace, cwd, cx)
+                            task_context(workspace, cx)
                         } else {
                             old_context
                         };
@@ -48,8 +51,7 @@ fn spawn_task_or_modal(workspace: &mut Workspace, action: &Spawn, cx: &mut ViewC
         None => {
             let inventory = workspace.project().read(cx).task_inventory().clone();
             let workspace_handle = workspace.weak_handle();
-            let cwd = task_cwd(workspace, cx).log_err().flatten();
-            let task_context = task_context(workspace, cwd, cx);
+            let task_context = task_context(workspace, cx);
             workspace.toggle_modal(cx, |cx| {
                 TasksModal::new(inventory, task_context, workspace_handle, cx)
             })
@@ -68,8 +70,7 @@ fn spawn_task_with_name(name: String, cx: &mut ViewContext<Workspace>) {
                     })
                 });
                 let (_, target_task) = tasks.into_iter().find(|(_, task)| task.name() == name)?;
-                let cwd = task_cwd(workspace, cx).log_err().flatten();
-                let task_context = task_context(workspace, cwd, cx);
+                let task_context = task_context(workspace, cx);
                 schedule_task(workspace, &target_task, task_context, false, cx);
                 Some(())
             })
@@ -111,104 +112,89 @@ fn active_item_selection_properties(
     (worktree_id, language)
 }
 
-fn task_context(
-    workspace: &Workspace,
-    cwd: Option<PathBuf>,
-    cx: &mut WindowContext<'_>,
-) -> TaskContext {
-    let current_editor = workspace
-        .active_item(cx)
-        .and_then(|item| item.act_as::<Editor>(cx));
-    if let Some(current_editor) = current_editor {
-        (|| {
-            let editor = current_editor.read(cx);
+fn task_context(workspace: &Workspace, cx: &mut WindowContext<'_>) -> TaskContext {
+    fn task_context_impl(workspace: &Workspace, cx: &mut WindowContext<'_>) -> Option<TaskContext> {
+        let cwd = task_cwd(workspace, cx).log_err().flatten();
+        let editor = workspace
+            .active_item(cx)
+            .and_then(|item| item.act_as::<Editor>(cx))?;
+
+        let (selection, buffer, editor_snapshot) = editor.update(cx, |editor, cx| {
             let selection = editor.selections.newest::<usize>(cx);
             let (buffer, _, _) = editor
                 .buffer()
                 .read(cx)
                 .point_to_buffer_offset(selection.start, cx)?;
+            let snapshot = editor.snapshot(cx);
+            Some((selection, buffer, snapshot))
+        })?;
+        let language_context_provider = buffer
+            .read(cx)
+            .language()
+            .and_then(|language| language.context_provider())?;
 
-            current_editor.update(cx, |editor, cx| {
-                let snapshot = editor.snapshot(cx);
-                let selection_range = selection.range();
-                let start = snapshot
-                    .display_snapshot
-                    .buffer_snapshot
-                    .anchor_after(selection_range.start)
-                    .text_anchor;
-                let end = snapshot
-                    .display_snapshot
-                    .buffer_snapshot
-                    .anchor_after(selection_range.end)
-                    .text_anchor;
-                let Point { row, column } = snapshot
-                    .display_snapshot
-                    .buffer_snapshot
-                    .offset_to_point(selection_range.start);
-                let row = row + 1;
-                let column = column + 1;
-                let location = Location {
-                    buffer: buffer.clone(),
-                    range: start..end,
-                };
-
-                let current_file = location
-                    .buffer
+        let selection_range = selection.range();
+        let start = editor_snapshot
+            .display_snapshot
+            .buffer_snapshot
+            .anchor_after(selection_range.start)
+            .text_anchor;
+        let end = editor_snapshot
+            .display_snapshot
+            .buffer_snapshot
+            .anchor_after(selection_range.end)
+            .text_anchor;
+        let worktree_abs_path = buffer
+            .read(cx)
+            .file()
+            .map(|file| WorktreeId::from_usize(file.worktree_id()))
+            .and_then(|worktree_id| {
+                workspace
+                    .project()
                     .read(cx)
-                    .file()
-                    .and_then(|file| file.as_local())
-                    .map(|file| file.abs_path(cx).to_string_lossy().to_string());
-                let worktree_id = location
-                    .buffer
-                    .read(cx)
-                    .file()
-                    .map(|file| WorktreeId::from_usize(file.worktree_id()));
-                let context = buffer
-                    .read(cx)
-                    .language()
-                    .and_then(|language| language.context_provider())
-                    .and_then(|provider| provider.build_context(location, cx).ok());
-
-                let worktree_path = worktree_id.and_then(|worktree_id| {
-                    workspace
-                        .project()
-                        .read(cx)
-                        .worktree_for_id(worktree_id, cx)
-                        .map(|worktree| worktree.read(cx).abs_path().to_string_lossy().to_string())
-                });
-
-                let selected_text = buffer.read(cx).chars_for_range(selection_range).collect();
-
-                let mut task_variables = TaskVariables::from_iter([
-                    (VariableName::Row, row.to_string()),
-                    (VariableName::Column, column.to_string()),
-                    (VariableName::SelectedText, selected_text),
-                ]);
-                if let Some(path) = current_file {
-                    task_variables.insert(VariableName::File, path);
-                }
-                if let Some(worktree_path) = worktree_path {
-                    task_variables.insert(VariableName::WorktreeRoot, worktree_path);
-                }
-                if let Some(language_context) = context {
-                    task_variables.extend(language_context);
-                }
-
-                Some(TaskContext {
-                    cwd: cwd.clone(),
-                    task_variables,
-                })
-            })
-        })()
-        .unwrap_or_else(|| TaskContext {
+                    .worktree_for_id(worktree_id, cx)
+                    .map(|worktree| worktree.read(cx).abs_path())
+            });
+        let location = Location {
+            buffer,
+            range: start..end,
+        };
+        let task_variables = combine_task_variables(
+            worktree_abs_path.as_deref(),
+            location,
+            language_context_provider.as_ref(),
+            cx,
+        )
+        .log_err()?;
+        Some(TaskContext {
             cwd,
-            task_variables: Default::default(),
+            task_variables,
         })
+    }
+
+    task_context_impl(workspace, cx).unwrap_or_default()
+}
+
+fn combine_task_variables(
+    worktree_abs_path: Option<&Path>,
+    location: Location,
+    context_provider: &dyn ContextProvider,
+    cx: &mut WindowContext<'_>,
+) -> anyhow::Result<TaskVariables> {
+    if context_provider.is_basic() {
+        context_provider
+            .build_context(worktree_abs_path, &location, cx)
+            .context("building basic provider context")
     } else {
-        TaskContext {
-            cwd,
-            task_variables: Default::default(),
-        }
+        let mut basic_context = BasicContextProvider
+            .build_context(worktree_abs_path, &location, cx)
+            .context("building basic default context")?;
+        basic_context.extend(
+            context_provider
+                .build_context(worktree_abs_path, &location, cx)
+                .context("building provider context ")?,
+        );
+        Ok(basic_context)
     }
 }
 
@@ -273,14 +259,14 @@ mod tests {
 
     use editor::Editor;
     use gpui::{Entity, TestAppContext};
-    use language::{Language, LanguageConfig, SymbolContextProvider};
+    use language::{BasicContextProvider, Language, LanguageConfig};
     use project::{FakeFs, Project, TaskSourceKind};
     use serde_json::json;
     use task::{oneshot_source::OneshotSource, TaskContext, TaskVariables, VariableName};
     use ui::VisualContext;
     use workspace::{AppState, Workspace};
 
-    use crate::{task_context, task_cwd};
+    use crate::task_context;
 
     #[gpui::test]
     async fn test_default_language_context(cx: &mut TestAppContext) {
@@ -323,7 +309,7 @@ mod tests {
             name: (_) @name) @item"#,
             )
             .unwrap()
-            .with_context_provider(Some(Arc::new(SymbolContextProvider))),
+            .with_context_provider(Some(Arc::new(BasicContextProvider))),
         );
 
         let typescript_language = Arc::new(
@@ -341,7 +327,7 @@ mod tests {
                       ")" @context)) @item"#,
             )
             .unwrap()
-            .with_context_provider(Some(Arc::new(SymbolContextProvider))),
+            .with_context_provider(Some(Arc::new(BasicContextProvider))),
         );
         let project = Project::test(fs, ["/dir".as_ref()], cx).await;
         project.update(cx, |project, cx| {
@@ -380,7 +366,7 @@ mod tests {
             this.add_item_to_center(Box::new(editor2.clone()), cx);
             assert_eq!(this.active_item(cx).unwrap().item_id(), editor2.entity_id());
             assert_eq!(
-                task_context(this, task_cwd(this, cx).unwrap(), cx),
+                task_context(this, cx),
                 TaskContext {
                     cwd: Some("/dir".into()),
                     task_variables: TaskVariables::from_iter([
@@ -388,7 +374,6 @@ mod tests {
                         (VariableName::WorktreeRoot, "/dir".into()),
                         (VariableName::Row, "1".into()),
                         (VariableName::Column, "1".into()),
-                        (VariableName::SelectedText, "".into())
                     ])
                 }
             );
@@ -397,7 +382,7 @@ mod tests {
                 this.change_selections(None, cx, |selections| selections.select_ranges([14..18]))
             });
             assert_eq!(
-                task_context(this, task_cwd(this, cx).unwrap(), cx),
+                task_context(this, cx),
                 TaskContext {
                     cwd: Some("/dir".into()),
                     task_variables: TaskVariables::from_iter([
@@ -414,7 +399,7 @@ mod tests {
             // Now, let's switch the active item to .ts file.
             this.activate_item(&editor1, cx);
             assert_eq!(
-                task_context(this, task_cwd(this, cx).unwrap(), cx),
+                task_context(this, cx),
                 TaskContext {
                     cwd: Some("/dir".into()),
                     task_variables: TaskVariables::from_iter([
@@ -422,7 +407,6 @@ mod tests {
                         (VariableName::WorktreeRoot, "/dir".into()),
                         (VariableName::Row, "1".into()),
                         (VariableName::Column, "1".into()),
-                        (VariableName::SelectedText, "".into()),
                         (VariableName::Symbol, "this_is_a_test".into()),
                     ])
                 }