Improve TypeScript task detection (#31711)

Kirill Bulatov created

Parses project's package.json to better detect Jasmine, Jest, Vitest and
Mocha and `test`, `build` scripts presence.
Also tries to detect `pnpm` and `npx` as test runners, falls back to
`npm`.


https://github.com/user-attachments/assets/112d3d8b-8daa-4ba5-8cb5-2f483036bd98

Release Notes:

- Improved TypeScript task detection

Change summary

Cargo.lock                           |   2 
crates/language/src/language.rs      |   2 
crates/language/src/task_context.rs  |  13 
crates/languages/Cargo.toml          |   2 
crates/languages/src/go.rs           |   3 
crates/languages/src/lib.rs          |   2 
crates/languages/src/python.rs       |  20 
crates/languages/src/rust.rs         |   3 
crates/languages/src/typescript.rs   | 397 ++++++++++++++++++++++++++++-
crates/project/src/task_inventory.rs |   5 
crates/project/src/task_store.rs     |  54 +++
crates/project/src/worktree_store.rs |   7 
12 files changed, 468 insertions(+), 42 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8934,6 +8934,7 @@ dependencies = [
  "async-compression",
  "async-tar",
  "async-trait",
+ "chrono",
  "collections",
  "dap",
  "futures 0.3.31",
@@ -8987,6 +8988,7 @@ dependencies = [
  "tree-sitter-yaml",
  "unindent",
  "util",
+ "which 6.0.3",
  "workspace",
  "workspace-hack",
 ]

crates/language/src/language.rs 🔗

@@ -64,7 +64,7 @@ use std::{
 use std::{num::NonZeroU32, sync::OnceLock};
 use syntax_map::{QueryCursorHandle, SyntaxSnapshot};
 use task::RunnableTag;
-pub use task_context::{ContextProvider, RunnableRange};
+pub use task_context::{ContextLocation, ContextProvider, RunnableRange};
 pub use text_diff::{
     DiffOptions, apply_diff_patch, line_diff, text_diff, text_diff_with_options, unified_diff,
 };

crates/language/src/task_context.rs 🔗

@@ -1,9 +1,10 @@
-use std::{ops::Range, sync::Arc};
+use std::{ops::Range, path::PathBuf, sync::Arc};
 
 use crate::{LanguageToolchainStore, Location, Runnable};
 
 use anyhow::Result;
 use collections::HashMap;
+use fs::Fs;
 use gpui::{App, Task};
 use lsp::LanguageServerName;
 use task::{TaskTemplates, TaskVariables};
@@ -26,11 +27,12 @@ pub trait ContextProvider: Send + Sync {
     fn build_context(
         &self,
         _variables: &TaskVariables,
-        _location: &Location,
+        _location: ContextLocation<'_>,
         _project_env: Option<HashMap<String, String>>,
         _toolchains: Arc<dyn LanguageToolchainStore>,
         _cx: &mut App,
     ) -> Task<Result<TaskVariables>> {
+        let _ = _location;
         Task::ready(Ok(TaskVariables::default()))
     }
 
@@ -48,3 +50,10 @@ pub trait ContextProvider: Send + Sync {
         None
     }
 }
+
+/// Metadata about the place in the project we gather the context for.
+pub struct ContextLocation<'a> {
+    pub fs: Option<Arc<dyn Fs>>,
+    pub worktree_root: Option<PathBuf>,
+    pub file_location: &'a Location,
+}

crates/languages/Cargo.toml 🔗

@@ -38,6 +38,7 @@ anyhow.workspace = true
 async-compression.workspace = true
 async-tar.workspace = true
 async-trait.workspace = true
+chrono.workspace = true
 collections.workspace = true
 dap.workspace = true
 futures.workspace = true
@@ -87,6 +88,7 @@ tree-sitter-rust = { workspace = true, optional = true }
 tree-sitter-typescript = { workspace = true, optional = true }
 tree-sitter-yaml = { workspace = true, optional = true }
 util.workspace = true
+which.workspace = true
 workspace-hack.workspace = true
 
 [dev-dependencies]

crates/languages/src/go.rs 🔗

@@ -444,12 +444,13 @@ impl ContextProvider for GoContextProvider {
     fn build_context(
         &self,
         variables: &TaskVariables,
-        location: &Location,
+        location: ContextLocation<'_>,
         _: Option<HashMap<String, String>>,
         _: Arc<dyn LanguageToolchainStore>,
         cx: &mut gpui::App,
     ) -> Task<Result<TaskVariables>> {
         let local_abs_path = location
+            .file_location
             .buffer
             .read(cx)
             .file()

crates/languages/src/lib.rs 🔗

@@ -88,7 +88,7 @@ pub fn init(languages: Arc<LanguageRegistry>, node: NodeRuntime, cx: &mut App) {
     let rust_context_provider = Arc::new(rust::RustContextProvider);
     let rust_lsp_adapter = Arc::new(rust::RustLspAdapter);
     let tailwind_adapter = Arc::new(tailwind::TailwindLspAdapter::new(node.clone()));
-    let typescript_context = Arc::new(typescript::typescript_task_context());
+    let typescript_context = Arc::new(typescript::TypeScriptContextProvider::new());
     let typescript_lsp_adapter = Arc::new(typescript::TypeScriptLspAdapter::new(node.clone()));
     let vtsls_adapter = Arc::new(vtsls::VtslsLspAdapter::new(node.clone()));
     let yaml_lsp_adapter = Arc::new(yaml::YamlLspAdapter::new(node.clone()));

crates/languages/src/python.rs 🔗

@@ -4,11 +4,11 @@ use async_trait::async_trait;
 use collections::HashMap;
 use gpui::{App, Task};
 use gpui::{AsyncApp, SharedString};
-use language::LanguageToolchainStore;
 use language::Toolchain;
 use language::ToolchainList;
 use language::ToolchainLister;
 use language::language_settings::language_settings;
+use language::{ContextLocation, LanguageToolchainStore};
 use language::{ContextProvider, LspAdapter, LspAdapterDelegate};
 use language::{LanguageName, ManifestName, ManifestProvider, ManifestQuery};
 use lsp::LanguageServerBinary;
@@ -367,18 +367,24 @@ impl ContextProvider for PythonContextProvider {
     fn build_context(
         &self,
         variables: &task::TaskVariables,
-        location: &project::Location,
+        location: ContextLocation<'_>,
         _: Option<HashMap<String, String>>,
         toolchains: Arc<dyn LanguageToolchainStore>,
         cx: &mut gpui::App,
     ) -> Task<Result<task::TaskVariables>> {
-        let test_target = match selected_test_runner(location.buffer.read(cx).file(), cx) {
-            TestRunner::UNITTEST => self.build_unittest_target(variables),
-            TestRunner::PYTEST => self.build_pytest_target(variables),
-        };
+        let test_target =
+            match selected_test_runner(location.file_location.buffer.read(cx).file(), cx) {
+                TestRunner::UNITTEST => self.build_unittest_target(variables),
+                TestRunner::PYTEST => self.build_pytest_target(variables),
+            };
 
         let module_target = self.build_module_target(variables);
-        let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
+        let worktree_id = location
+            .file_location
+            .buffer
+            .read(cx)
+            .file()
+            .map(|f| f.worktree_id(cx));
 
         cx.spawn(async move |cx| {
             let raw_toolchain = if let Some(worktree_id) = worktree_id {

crates/languages/src/rust.rs 🔗

@@ -557,12 +557,13 @@ impl ContextProvider for RustContextProvider {
     fn build_context(
         &self,
         task_variables: &TaskVariables,
-        location: &Location,
+        location: ContextLocation<'_>,
         project_env: Option<HashMap<String, String>>,
         _: Arc<dyn LanguageToolchainStore>,
         cx: &mut gpui::App,
     ) -> Task<Result<TaskVariables>> {
         let local_abs_path = location
+            .file_location
             .buffer
             .read(cx)
             .file()

crates/languages/src/typescript.rs 🔗

@@ -2,56 +2,407 @@ use anyhow::{Context as _, Result};
 use async_compression::futures::bufread::GzipDecoder;
 use async_tar::Archive;
 use async_trait::async_trait;
+use chrono::{DateTime, Local};
 use collections::HashMap;
-use gpui::AsyncApp;
+use gpui::{App, AppContext, AsyncApp, Task};
 use http_client::github::{AssetKind, GitHubLspBinaryVersion, build_asset_url};
-use language::{LanguageToolchainStore, LspAdapter, LspAdapterDelegate};
+use language::{
+    ContextLocation, ContextProvider, File, LanguageToolchainStore, LspAdapter, LspAdapterDelegate,
+};
 use lsp::{CodeActionKind, LanguageServerBinary, LanguageServerName};
 use node_runtime::NodeRuntime;
-use project::ContextProviderWithTasks;
 use project::{Fs, lsp_store::language_server_settings};
 use serde_json::{Value, json};
-use smol::{fs, io::BufReader, stream::StreamExt};
+use smol::{fs, io::BufReader, lock::RwLock, stream::StreamExt};
 use std::{
     any::Any,
+    borrow::Cow,
     ffi::OsString,
     path::{Path, PathBuf},
     sync::Arc,
 };
-use task::{TaskTemplate, TaskTemplates, VariableName};
+use task::{TaskTemplate, TaskTemplates, TaskVariables, VariableName};
 use util::archive::extract_zip;
 use util::merge_json_value_into;
 use util::{ResultExt, fs::remove_matching, maybe};
 
-pub(super) fn typescript_task_context() -> ContextProviderWithTasks {
-    ContextProviderWithTasks::new(TaskTemplates(vec![
-        TaskTemplate {
-            label: "jest file test".to_owned(),
-            command: "npx jest".to_owned(),
-            args: vec![VariableName::File.template_value()],
+pub(crate) struct TypeScriptContextProvider {
+    last_package_json: PackageJsonContents,
+}
+
+const TYPESCRIPT_RUNNER_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("TYPESCRIPT_RUNNER"));
+const TYPESCRIPT_JEST_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("TYPESCRIPT_JEST"));
+const TYPESCRIPT_MOCHA_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("TYPESCRIPT_MOCHA"));
+
+const TYPESCRIPT_VITEST_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("TYPESCRIPT_VITEST"));
+const TYPESCRIPT_JASMINE_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("TYPESCRIPT_JASMINE"));
+const TYPESCRIPT_BUILD_SCRIPT_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("TYPESCRIPT_BUILD_SCRIPT"));
+const TYPESCRIPT_TEST_SCRIPT_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("TYPESCRIPT_TEST_SCRIPT"));
+
+#[derive(Clone, Default)]
+struct PackageJsonContents(Arc<RwLock<HashMap<PathBuf, PackageJson>>>);
+
+struct PackageJson {
+    mtime: DateTime<Local>,
+    data: PackageJsonData,
+}
+
+#[derive(Clone, Copy, Default)]
+struct PackageJsonData {
+    jest: bool,
+    mocha: bool,
+    vitest: bool,
+    jasmine: bool,
+    build_script: bool,
+    test_script: bool,
+    runner: Runner,
+}
+
+#[derive(Clone, Copy, Default)]
+enum Runner {
+    #[default]
+    Npm,
+    Npx,
+    Pnpm,
+}
+
+impl PackageJsonData {
+    fn new(package_json: HashMap<String, Value>) -> Self {
+        let mut build_script = false;
+        let mut test_script = false;
+        if let Some(serde_json::Value::Object(scripts)) = package_json.get("scripts") {
+            build_script |= scripts.contains_key("build");
+            test_script |= scripts.contains_key("test");
+        }
+
+        let mut jest = false;
+        let mut mocha = false;
+        let mut vitest = false;
+        let mut jasmine = false;
+        if let Some(serde_json::Value::Object(dependencies)) = package_json.get("devDependencies") {
+            jest |= dependencies.contains_key("jest");
+            mocha |= dependencies.contains_key("mocha");
+            vitest |= dependencies.contains_key("vitest");
+            jasmine |= dependencies.contains_key("jasmine");
+        }
+        if let Some(serde_json::Value::Object(dev_dependencies)) = package_json.get("dependencies")
+        {
+            jest |= dev_dependencies.contains_key("jest");
+            mocha |= dev_dependencies.contains_key("mocha");
+            vitest |= dev_dependencies.contains_key("vitest");
+            jasmine |= dev_dependencies.contains_key("jasmine");
+        }
+
+        let mut runner = Runner::Npm;
+        if which::which("pnpm").is_ok() {
+            runner = Runner::Pnpm;
+        } else if which::which("npx").is_ok() {
+            runner = Runner::Npx;
+        }
+
+        Self {
+            jest,
+            mocha,
+            vitest,
+            jasmine,
+            build_script,
+            test_script,
+            runner,
+        }
+    }
+
+    fn fill_variables(&self, variables: &mut TaskVariables) {
+        let runner = match self.runner {
+            Runner::Npm => "npm",
+            Runner::Npx => "npx",
+            Runner::Pnpm => "pnpm",
+        };
+        variables.insert(TYPESCRIPT_RUNNER_VARIABLE, runner.to_owned());
+
+        if self.jest {
+            variables.insert(TYPESCRIPT_JEST_TASK_VARIABLE, "jest".to_owned());
+        }
+        if self.mocha {
+            variables.insert(TYPESCRIPT_MOCHA_TASK_VARIABLE, "mocha".to_owned());
+        }
+        if self.vitest {
+            variables.insert(TYPESCRIPT_VITEST_TASK_VARIABLE, "vitest".to_owned());
+        }
+        if self.jasmine {
+            variables.insert(TYPESCRIPT_JASMINE_TASK_VARIABLE, "jasmine".to_owned());
+        }
+        if self.build_script {
+            variables.insert(TYPESCRIPT_BUILD_SCRIPT_TASK_VARIABLE, "build".to_owned());
+        }
+        if self.test_script {
+            variables.insert(TYPESCRIPT_TEST_SCRIPT_TASK_VARIABLE, "test".to_owned());
+        }
+    }
+}
+
+impl TypeScriptContextProvider {
+    pub fn new() -> Self {
+        TypeScriptContextProvider {
+            last_package_json: PackageJsonContents::default(),
+        }
+    }
+}
+
+impl ContextProvider for TypeScriptContextProvider {
+    fn associated_tasks(&self, _: Option<Arc<dyn File>>, _: &App) -> Option<TaskTemplates> {
+        let mut task_templates = TaskTemplates(Vec::new());
+
+        // Jest tasks
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "{} file test",
+                TYPESCRIPT_JEST_TASK_VARIABLE.template_value()
+            ),
+            command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
+            args: vec![
+                TYPESCRIPT_JEST_TASK_VARIABLE.template_value(),
+                VariableName::File.template_value(),
+            ],
             ..TaskTemplate::default()
-        },
-        TaskTemplate {
-            label: "jest test $ZED_SYMBOL".to_owned(),
-            command: "npx jest".to_owned(),
+        });
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "{} test {}",
+                TYPESCRIPT_JEST_TASK_VARIABLE.template_value(),
+                VariableName::Symbol.template_value(),
+            ),
+            command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
             args: vec![
-                "--testNamePattern".into(),
+                TYPESCRIPT_JEST_TASK_VARIABLE.template_value(),
+                "--testNamePattern".to_owned(),
                 format!("\"{}\"", VariableName::Symbol.template_value()),
                 VariableName::File.template_value(),
             ],
-            tags: vec!["ts-test".into(), "js-test".into(), "tsx-test".into()],
+            tags: vec![
+                "ts-test".to_owned(),
+                "js-test".to_owned(),
+                "tsx-test".to_owned(),
+            ],
+            ..TaskTemplate::default()
+        });
+
+        // Vitest tasks
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "{} file test",
+                TYPESCRIPT_VITEST_TASK_VARIABLE.template_value()
+            ),
+            command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
+            args: vec![
+                TYPESCRIPT_VITEST_TASK_VARIABLE.template_value(),
+                "run".to_owned(),
+                VariableName::File.template_value(),
+            ],
             ..TaskTemplate::default()
-        },
-        TaskTemplate {
-            label: "execute selection $ZED_SELECTED_TEXT".to_owned(),
+        });
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "{} test {}",
+                TYPESCRIPT_VITEST_TASK_VARIABLE.template_value(),
+                VariableName::Symbol.template_value(),
+            ),
+            command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
+            args: vec![
+                TYPESCRIPT_VITEST_TASK_VARIABLE.template_value(),
+                "run".to_owned(),
+                "--testNamePattern".to_owned(),
+                format!("\"{}\"", VariableName::Symbol.template_value()),
+                VariableName::File.template_value(),
+            ],
+            tags: vec![
+                "ts-test".to_owned(),
+                "js-test".to_owned(),
+                "tsx-test".to_owned(),
+            ],
+            ..TaskTemplate::default()
+        });
+
+        // Mocha tasks
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "{} file test",
+                TYPESCRIPT_MOCHA_TASK_VARIABLE.template_value()
+            ),
+            command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
+            args: vec![
+                TYPESCRIPT_MOCHA_TASK_VARIABLE.template_value(),
+                VariableName::File.template_value(),
+            ],
+            ..TaskTemplate::default()
+        });
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "{} test {}",
+                TYPESCRIPT_MOCHA_TASK_VARIABLE.template_value(),
+                VariableName::Symbol.template_value(),
+            ),
+            command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
+            args: vec![
+                TYPESCRIPT_MOCHA_TASK_VARIABLE.template_value(),
+                "--grep".to_owned(),
+                format!("\"{}\"", VariableName::Symbol.template_value()),
+                VariableName::File.template_value(),
+            ],
+            tags: vec![
+                "ts-test".to_owned(),
+                "js-test".to_owned(),
+                "tsx-test".to_owned(),
+            ],
+            ..TaskTemplate::default()
+        });
+
+        // Jasmine tasks
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "{} file test",
+                TYPESCRIPT_JASMINE_TASK_VARIABLE.template_value()
+            ),
+            command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
+            args: vec![
+                TYPESCRIPT_JASMINE_TASK_VARIABLE.template_value(),
+                VariableName::File.template_value(),
+            ],
+            ..TaskTemplate::default()
+        });
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "{} test {}",
+                TYPESCRIPT_JASMINE_TASK_VARIABLE.template_value(),
+                VariableName::Symbol.template_value(),
+            ),
+            command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
+            args: vec![
+                TYPESCRIPT_JASMINE_TASK_VARIABLE.template_value(),
+                format!("--filter={}", VariableName::Symbol.template_value()),
+                VariableName::File.template_value(),
+            ],
+            tags: vec![
+                "ts-test".to_owned(),
+                "js-test".to_owned(),
+                "tsx-test".to_owned(),
+            ],
+            ..TaskTemplate::default()
+        });
+
+        for package_json_script in [
+            TYPESCRIPT_TEST_SCRIPT_TASK_VARIABLE,
+            TYPESCRIPT_BUILD_SCRIPT_TASK_VARIABLE,
+        ] {
+            task_templates.0.push(TaskTemplate {
+                label: format!(
+                    "package.json script {}",
+                    package_json_script.template_value()
+                ),
+                command: TYPESCRIPT_RUNNER_VARIABLE.template_value(),
+                args: vec![
+                    "--prefix".to_owned(),
+                    VariableName::WorktreeRoot.template_value(),
+                    "run".to_owned(),
+                    package_json_script.template_value(),
+                ],
+                tags: vec!["package-script".into()],
+                ..TaskTemplate::default()
+            });
+        }
+
+        task_templates.0.push(TaskTemplate {
+            label: format!(
+                "execute selection {}",
+                VariableName::SelectedText.template_value()
+            ),
             command: "node".to_owned(),
             args: vec![
-                "-e".into(),
+                "-e".to_owned(),
                 format!("\"{}\"", VariableName::SelectedText.template_value()),
             ],
             ..TaskTemplate::default()
-        },
-    ]))
+        });
+
+        Some(task_templates)
+    }
+
+    fn build_context(
+        &self,
+        _variables: &task::TaskVariables,
+        location: ContextLocation<'_>,
+        _project_env: Option<HashMap<String, String>>,
+        _toolchains: Arc<dyn LanguageToolchainStore>,
+        cx: &mut App,
+    ) -> Task<Result<task::TaskVariables>> {
+        let Some((fs, worktree_root)) = location.fs.zip(location.worktree_root) else {
+            return Task::ready(Ok(task::TaskVariables::default()));
+        };
+
+        let package_json_contents = self.last_package_json.clone();
+        cx.background_spawn(async move {
+            let variables = package_json_variables(fs, worktree_root, package_json_contents)
+                .await
+                .context("package.json context retrieval")
+                .log_err()
+                .unwrap_or_else(task::TaskVariables::default);
+            Ok(variables)
+        })
+    }
+}
+
+async fn package_json_variables(
+    fs: Arc<dyn Fs>,
+    worktree_root: PathBuf,
+    package_json_contents: PackageJsonContents,
+) -> anyhow::Result<task::TaskVariables> {
+    let package_json_path = worktree_root.join("package.json");
+    let metadata = fs
+        .metadata(&package_json_path)
+        .await
+        .with_context(|| format!("getting metadata for {package_json_path:?}"))?
+        .with_context(|| format!("missing FS metadata for {package_json_path:?}"))?;
+    let mtime = DateTime::<Local>::from(metadata.mtime.timestamp_for_user());
+    let existing_data = {
+        let contents = package_json_contents.0.read().await;
+        contents
+            .get(&package_json_path)
+            .filter(|package_json| package_json.mtime == mtime)
+            .map(|package_json| package_json.data)
+    };
+
+    let mut variables = TaskVariables::default();
+    if let Some(existing_data) = existing_data {
+        existing_data.fill_variables(&mut variables);
+    } else {
+        let package_json_string = fs
+            .load(&package_json_path)
+            .await
+            .with_context(|| format!("loading package.json from {package_json_path:?}"))?;
+        let package_json: HashMap<String, serde_json::Value> =
+            serde_json::from_str(&package_json_string)
+                .with_context(|| format!("parsing package.json from {package_json_path:?}"))?;
+        let new_data = PackageJsonData::new(package_json);
+        new_data.fill_variables(&mut variables);
+        {
+            let mut contents = package_json_contents.0.write().await;
+            contents.insert(
+                package_json_path,
+                PackageJson {
+                    mtime,
+                    data: new_data,
+                },
+            );
+        }
+    }
+
+    Ok(variables)
 }
 
 fn typescript_server_binary_arguments(server_path: &Path) -> Vec<OsString> {

crates/project/src/task_inventory.rs 🔗

@@ -14,7 +14,7 @@ use dap::DapRegistry;
 use gpui::{App, AppContext as _, Entity, SharedString, Task};
 use itertools::Itertools;
 use language::{
-    Buffer, ContextProvider, File, Language, LanguageToolchainStore, Location,
+    Buffer, ContextLocation, ContextProvider, File, Language, LanguageToolchainStore, Location,
     language_settings::language_settings,
 };
 use lsp::{LanguageServerId, LanguageServerName};
@@ -791,11 +791,12 @@ impl ContextProvider for BasicContextProvider {
     fn build_context(
         &self,
         _: &TaskVariables,
-        location: &Location,
+        location: ContextLocation<'_>,
         _: Option<HashMap<String, String>>,
         _: Arc<dyn LanguageToolchainStore>,
         cx: &mut App,
     ) -> Task<Result<TaskVariables>> {
+        let location = location.file_location;
         let buffer = location.buffer.read(cx);
         let buffer_snapshot = buffer.snapshot();
         let symbols = buffer_snapshot.symbols_containing(location.range.start, None);

crates/project/src/task_store.rs 🔗

@@ -5,9 +5,10 @@ use std::{
 
 use anyhow::Context as _;
 use collections::HashMap;
+use fs::Fs;
 use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
 use language::{
-    ContextProvider as _, LanguageToolchainStore, Location,
+    ContextLocation, ContextProvider as _, LanguageToolchainStore, Location,
     proto::{deserialize_anchor, serialize_anchor},
 };
 use rpc::{AnyProtoClient, TypedEnvelope, proto};
@@ -311,6 +312,7 @@ fn local_task_context_for_location(
     let worktree_abs_path = worktree_id
         .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
         .and_then(|worktree| worktree.read(cx).root_dir());
+    let fs = worktree_store.read(cx).fs();
 
     cx.spawn(async move |cx| {
         let project_env = environment
@@ -324,6 +326,8 @@ fn local_task_context_for_location(
             .update(|cx| {
                 combine_task_variables(
                     captured_variables,
+                    fs,
+                    worktree_store.clone(),
                     location,
                     project_env.clone(),
                     BasicContextProvider::new(worktree_store),
@@ -358,9 +362,15 @@ fn remote_task_context_for_location(
         // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
         let mut remote_context = cx
             .update(|cx| {
+                let worktree_root = worktree_root(&worktree_store, &location, cx);
+
                 BasicContextProvider::new(worktree_store).build_context(
                     &TaskVariables::default(),
-                    &location,
+                    ContextLocation {
+                        fs: None,
+                        worktree_root,
+                        file_location: &location,
+                    },
                     None,
                     toolchain_store,
                     cx,
@@ -408,8 +418,34 @@ fn remote_task_context_for_location(
     })
 }
 
+fn worktree_root(
+    worktree_store: &Entity<WorktreeStore>,
+    location: &Location,
+    cx: &mut App,
+) -> Option<PathBuf> {
+    location
+        .buffer
+        .read(cx)
+        .file()
+        .map(|f| f.worktree_id(cx))
+        .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
+        .and_then(|worktree| {
+            let worktree = worktree.read(cx);
+            if !worktree.is_visible() {
+                return None;
+            }
+            let root_entry = worktree.root_entry()?;
+            if !root_entry.is_dir() {
+                return None;
+            }
+            worktree.absolutize(&root_entry.path).ok()
+        })
+}
+
 fn combine_task_variables(
     mut captured_variables: TaskVariables,
+    fs: Option<Arc<dyn Fs>>,
+    worktree_store: Entity<WorktreeStore>,
     location: Location,
     project_env: Option<HashMap<String, String>>,
     baseline: BasicContextProvider,
@@ -424,9 +460,14 @@ fn combine_task_variables(
     cx.spawn(async move |cx| {
         let baseline = cx
             .update(|cx| {
+                let worktree_root = worktree_root(&worktree_store, &location, cx);
                 baseline.build_context(
                     &captured_variables,
-                    &location,
+                    ContextLocation {
+                        fs: fs.clone(),
+                        worktree_root,
+                        file_location: &location,
+                    },
                     project_env.clone(),
                     toolchain_store.clone(),
                     cx,
@@ -438,9 +479,14 @@ fn combine_task_variables(
         if let Some(provider) = language_context_provider {
             captured_variables.extend(
                 cx.update(|cx| {
+                    let worktree_root = worktree_root(&worktree_store, &location, cx);
                     provider.build_context(
                         &captured_variables,
-                        &location,
+                        ContextLocation {
+                            fs,
+                            worktree_root,
+                            file_location: &location,
+                        },
                         project_env,
                         toolchain_store,
                         cx,

crates/project/src/worktree_store.rs 🔗

@@ -967,6 +967,13 @@ impl WorktreeStore {
             .context("invalid request")?;
         Worktree::handle_expand_all_for_entry(worktree, envelope.payload, cx).await
     }
+
+    pub fn fs(&self) -> Option<Arc<dyn Fs>> {
+        match &self.state {
+            WorktreeStoreState::Local { fs } => Some(fs.clone()),
+            WorktreeStoreState::Remote { .. } => None,
+        }
+    }
 }
 
 #[derive(Clone, Debug)]