tasks: Add ability to query active toolchains for languages (#20667)

Piotr Osiewicz created

Closes #18649

Release Notes:

- Python tasks now use active toolchain to run.

Change summary

crates/language/src/task_context.rs          |  11 +
crates/languages/src/go.rs                   |   9 
crates/languages/src/python.rs               |  38 ++++--
crates/languages/src/rust.rs                 |  25 ++--
crates/project/src/project.rs                |  34 +++--
crates/project/src/task_inventory.rs         |  12 +-
crates/project/src/task_store.rs             | 122 ++++++++++++++-------
crates/project/src/toolchain_store.rs        |   2 
crates/remote_server/src/headless_project.rs |  20 +-
9 files changed, 170 insertions(+), 103 deletions(-)

Detailed changes

crates/language/src/task_context.rs 🔗

@@ -1,10 +1,10 @@
 use std::{ops::Range, sync::Arc};
 
-use crate::{Location, Runnable};
+use crate::{LanguageToolchainStore, Location, Runnable};
 
 use anyhow::Result;
 use collections::HashMap;
-use gpui::AppContext;
+use gpui::{AppContext, Task};
 use task::{TaskTemplates, TaskVariables};
 use text::BufferId;
 
@@ -25,10 +25,11 @@ pub trait ContextProvider: Send + Sync {
         &self,
         _variables: &TaskVariables,
         _location: &Location,
-        _project_env: Option<&HashMap<String, String>>,
+        _project_env: Option<HashMap<String, String>>,
+        _toolchains: Arc<dyn LanguageToolchainStore>,
         _cx: &mut AppContext,
-    ) -> Result<TaskVariables> {
-        Ok(TaskVariables::default())
+    ) -> Task<Result<TaskVariables>> {
+        Task::ready(Ok(TaskVariables::default()))
     }
 
     /// Provides all tasks, associated with the current language.

crates/languages/src/go.rs 🔗

@@ -418,9 +418,10 @@ impl ContextProvider for GoContextProvider {
         &self,
         variables: &TaskVariables,
         location: &Location,
-        _: Option<&HashMap<String, String>>,
+        _: Option<HashMap<String, String>>,
+        _: Arc<dyn LanguageToolchainStore>,
         cx: &mut gpui::AppContext,
-    ) -> Result<TaskVariables> {
+    ) -> Task<Result<TaskVariables>> {
         let local_abs_path = location
             .buffer
             .read(cx)
@@ -468,7 +469,7 @@ impl ContextProvider for GoContextProvider {
         let go_subtest_variable = extract_subtest_name(_subtest_name.unwrap_or(""))
             .map(|subtest_name| (GO_SUBTEST_NAME_TASK_VARIABLE.clone(), subtest_name));
 
-        Ok(TaskVariables::from_iter(
+        Task::ready(Ok(TaskVariables::from_iter(
             [
                 go_package_variable,
                 go_subtest_variable,
@@ -476,7 +477,7 @@ impl ContextProvider for GoContextProvider {
             ]
             .into_iter()
             .flatten(),
-        ))
+        )))
     }
 
     fn associated_tasks(

crates/languages/src/python.rs 🔗

@@ -2,8 +2,8 @@ use anyhow::ensure;
 use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use collections::HashMap;
-use gpui::AppContext;
 use gpui::AsyncAppContext;
+use gpui::{AppContext, Task};
 use language::LanguageName;
 use language::LanguageToolchainStore;
 use language::Toolchain;
@@ -267,14 +267,17 @@ pub(crate) struct PythonContextProvider;
 const PYTHON_UNITTEST_TARGET_TASK_VARIABLE: VariableName =
     VariableName::Custom(Cow::Borrowed("PYTHON_UNITTEST_TARGET"));
 
+const PYTHON_ACTIVE_TOOLCHAIN_PATH: VariableName =
+    VariableName::Custom(Cow::Borrowed("PYTHON_ACTIVE_ZED_TOOLCHAIN"));
 impl ContextProvider for PythonContextProvider {
     fn build_context(
         &self,
         variables: &task::TaskVariables,
-        _location: &project::Location,
-        _: Option<&HashMap<String, String>>,
-        _cx: &mut gpui::AppContext,
-    ) -> Result<task::TaskVariables> {
+        location: &project::Location,
+        _: Option<HashMap<String, String>>,
+        toolchains: Arc<dyn LanguageToolchainStore>,
+        cx: &mut gpui::AppContext,
+    ) -> Task<Result<task::TaskVariables>> {
         let python_module_name = python_module_name_from_relative_path(
             variables.get(&VariableName::RelativeFile).unwrap_or(""),
         );
@@ -290,15 +293,26 @@ impl ContextProvider for PythonContextProvider {
             }
             (Some(class_name), None) => format!("{}.{}", python_module_name, class_name),
             (None, None) => python_module_name,
-            (None, Some(_)) => return Ok(task::TaskVariables::default()), // should never happen, a TestCase class is the unit of testing
+            (None, Some(_)) => return Task::ready(Ok(task::TaskVariables::default())), // should never happen, a TestCase class is the unit of testing
         };
 
         let unittest_target = (
             PYTHON_UNITTEST_TARGET_TASK_VARIABLE.clone(),
             unittest_target_str,
         );
-
-        Ok(task::TaskVariables::from_iter([unittest_target]))
+        let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
+        cx.spawn(move |mut cx| async move {
+            let active_toolchain = if let Some(worktree_id) = worktree_id {
+                toolchains
+                    .active_toolchain(worktree_id, "Python".into(), &mut cx)
+                    .await
+                    .map_or_else(|| "python3".to_owned(), |toolchain| toolchain.path.into())
+            } else {
+                String::from("python3")
+            };
+            let toolchain = (PYTHON_ACTIVE_TOOLCHAIN_PATH, active_toolchain);
+            Ok(task::TaskVariables::from_iter([unittest_target, toolchain]))
+        })
     }
 
     fn associated_tasks(
@@ -309,19 +323,19 @@ impl ContextProvider for PythonContextProvider {
         Some(TaskTemplates(vec![
             TaskTemplate {
                 label: "execute selection".to_owned(),
-                command: "python3".to_owned(),
+                command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(),
                 args: vec!["-c".to_owned(), VariableName::SelectedText.template_value()],
                 ..TaskTemplate::default()
             },
             TaskTemplate {
                 label: format!("run '{}'", VariableName::File.template_value()),
-                command: "python3".to_owned(),
+                command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(),
                 args: vec![VariableName::File.template_value()],
                 ..TaskTemplate::default()
             },
             TaskTemplate {
                 label: format!("unittest '{}'", VariableName::File.template_value()),
-                command: "python3".to_owned(),
+                command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(),
                 args: vec![
                     "-m".to_owned(),
                     "unittest".to_owned(),
@@ -331,7 +345,7 @@ impl ContextProvider for PythonContextProvider {
             },
             TaskTemplate {
                 label: "unittest $ZED_CUSTOM_PYTHON_UNITTEST_TARGET".to_owned(),
-                command: "python3".to_owned(),
+                command: PYTHON_ACTIVE_TOOLCHAIN_PATH.template_value(),
                 args: vec![
                     "-m".to_owned(),
                     "unittest".to_owned(),

crates/languages/src/rust.rs 🔗

@@ -3,7 +3,7 @@ use async_compression::futures::bufread::GzipDecoder;
 use async_trait::async_trait;
 use collections::HashMap;
 use futures::{io::BufReader, StreamExt};
-use gpui::{AppContext, AsyncAppContext};
+use gpui::{AppContext, AsyncAppContext, Task};
 use http_client::github::AssetKind;
 use http_client::github::{latest_github_release, GitHubLspBinaryVersion};
 pub use language::*;
@@ -424,9 +424,10 @@ impl ContextProvider for RustContextProvider {
         &self,
         task_variables: &TaskVariables,
         location: &Location,
-        project_env: Option<&HashMap<String, String>>,
+        project_env: Option<HashMap<String, String>>,
+        _: Arc<dyn LanguageToolchainStore>,
         cx: &mut gpui::AppContext,
-    ) -> Result<TaskVariables> {
+    ) -> Task<Result<TaskVariables>> {
         let local_abs_path = location
             .buffer
             .read(cx)
@@ -440,27 +441,27 @@ impl ContextProvider for RustContextProvider {
             .is_some();
 
         if is_main_function {
-            if let Some((package_name, bin_name)) = local_abs_path
-                .and_then(|path| package_name_and_bin_name_from_abs_path(path, project_env))
-            {
-                return Ok(TaskVariables::from_iter([
+            if let Some((package_name, bin_name)) = local_abs_path.and_then(|path| {
+                package_name_and_bin_name_from_abs_path(path, project_env.as_ref())
+            }) {
+                return Task::ready(Ok(TaskVariables::from_iter([
                     (RUST_PACKAGE_TASK_VARIABLE.clone(), package_name),
                     (RUST_BIN_NAME_TASK_VARIABLE.clone(), bin_name),
-                ]));
+                ])));
             }
         }
 
         if let Some(package_name) = local_abs_path
             .and_then(|local_abs_path| local_abs_path.parent())
-            .and_then(|path| human_readable_package_name(path, project_env))
+            .and_then(|path| human_readable_package_name(path, project_env.as_ref()))
         {
-            return Ok(TaskVariables::from_iter([(
+            return Task::ready(Ok(TaskVariables::from_iter([(
                 RUST_PACKAGE_TASK_VARIABLE.clone(),
                 package_name,
-            )]));
+            )])));
         }
 
-        Ok(TaskVariables::default())
+        Task::ready(Ok(TaskVariables::default()))
     }
 
     fn associated_tasks(

crates/project/src/project.rs 🔗

@@ -82,6 +82,7 @@ use std::{
 use task_store::TaskStore;
 use terminals::Terminals;
 use text::{Anchor, BufferId};
+use toolchain_store::EmptyToolchainStore;
 use util::{paths::compare_paths, ResultExt as _};
 use worktree::{CreatedEntry, Snapshot, Traversal};
 use worktree_store::{WorktreeStore, WorktreeStoreEvent};
@@ -626,12 +627,20 @@ impl Project {
             });
 
             let environment = ProjectEnvironment::new(&worktree_store, env, cx);
-
+            let toolchain_store = cx.new_model(|cx| {
+                ToolchainStore::local(
+                    languages.clone(),
+                    worktree_store.clone(),
+                    environment.clone(),
+                    cx,
+                )
+            });
             let task_store = cx.new_model(|cx| {
                 TaskStore::local(
                     fs.clone(),
                     buffer_store.downgrade(),
                     worktree_store.clone(),
+                    toolchain_store.read(cx).as_language_toolchain_store(),
                     environment.clone(),
                     cx,
                 )
@@ -647,14 +656,7 @@ impl Project {
             });
             cx.subscribe(&settings_observer, Self::on_settings_observer_event)
                 .detach();
-            let toolchain_store = cx.new_model(|cx| {
-                ToolchainStore::local(
-                    languages.clone(),
-                    worktree_store.clone(),
-                    environment.clone(),
-                    cx,
-                )
-            });
+
             let lsp_store = cx.new_model(|cx| {
                 LspStore::new_local(
                     buffer_store.clone(),
@@ -749,12 +751,15 @@ impl Project {
             });
             cx.subscribe(&buffer_store, Self::on_buffer_store_event)
                 .detach();
-
+            let toolchain_store = cx.new_model(|cx| {
+                ToolchainStore::remote(SSH_PROJECT_ID, ssh.read(cx).proto_client(), cx)
+            });
             let task_store = cx.new_model(|cx| {
                 TaskStore::remote(
                     fs.clone(),
                     buffer_store.downgrade(),
                     worktree_store.clone(),
+                    toolchain_store.read(cx).as_language_toolchain_store(),
                     ssh.read(cx).proto_client(),
                     SSH_PROJECT_ID,
                     cx,
@@ -768,14 +773,12 @@ impl Project {
                 .detach();
 
             let environment = ProjectEnvironment::new(&worktree_store, None, cx);
-            let toolchain_store = Some(cx.new_model(|cx| {
-                ToolchainStore::remote(SSH_PROJECT_ID, ssh.read(cx).proto_client(), cx)
-            }));
+
             let lsp_store = cx.new_model(|cx| {
                 LspStore::new_remote(
                     buffer_store.clone(),
                     worktree_store.clone(),
-                    toolchain_store.clone(),
+                    Some(toolchain_store.clone()),
                     languages.clone(),
                     ssh_proto.clone(),
                     SSH_PROJECT_ID,
@@ -835,7 +838,7 @@ impl Project {
                 search_included_history: Self::new_search_history(),
                 search_excluded_history: Self::new_search_history(),
 
-                toolchain_store,
+                toolchain_store: Some(toolchain_store),
             };
 
             let ssh = ssh.read(cx);
@@ -963,6 +966,7 @@ impl Project {
                     fs.clone(),
                     buffer_store.downgrade(),
                     worktree_store.clone(),
+                    Arc::new(EmptyToolchainStore),
                     client.clone().into(),
                     remote_id,
                     cx,

crates/project/src/task_inventory.rs 🔗

@@ -10,9 +10,9 @@ use std::{
 
 use anyhow::{Context, Result};
 use collections::{HashMap, HashSet, VecDeque};
-use gpui::{AppContext, Context as _, Model};
+use gpui::{AppContext, Context as _, Model, Task};
 use itertools::Itertools;
-use language::{ContextProvider, File, Language, Location};
+use language::{ContextProvider, File, Language, LanguageToolchainStore, Location};
 use settings::{parse_json_with_comments, SettingsLocation};
 use task::{
     ResolvedTask, TaskContext, TaskId, TaskTemplate, TaskTemplates, TaskVariables, VariableName,
@@ -431,15 +431,15 @@ impl BasicContextProvider {
         Self { worktree_store }
     }
 }
-
 impl ContextProvider for BasicContextProvider {
     fn build_context(
         &self,
         _: &TaskVariables,
         location: &Location,
-        _: Option<&HashMap<String, String>>,
+        _: Option<HashMap<String, String>>,
+        _: Arc<dyn LanguageToolchainStore>,
         cx: &mut AppContext,
-    ) -> Result<TaskVariables> {
+    ) -> Task<Result<TaskVariables>> {
         let buffer = location.buffer.read(cx);
         let buffer_snapshot = buffer.snapshot();
         let symbols = buffer_snapshot.symbols_containing(location.range.start, None);
@@ -517,7 +517,7 @@ impl ContextProvider for BasicContextProvider {
             task_variables.insert(VariableName::File, path_as_string);
         }
 
-        Ok(task_variables)
+        Task::ready(Ok(task_variables))
     }
 }
 

crates/project/src/task_store.rs 🔗

@@ -7,7 +7,7 @@ use futures::StreamExt as _;
 use gpui::{AppContext, AsyncAppContext, EventEmitter, Model, ModelContext, Task, WeakModel};
 use language::{
     proto::{deserialize_anchor, serialize_anchor},
-    ContextProvider as _, Location,
+    ContextProvider as _, LanguageToolchainStore, Location,
 };
 use rpc::{proto, AnyProtoClient, TypedEnvelope};
 use settings::{watch_config_file, SettingsLocation};
@@ -20,6 +20,7 @@ use crate::{
     ProjectEnvironment,
 };
 
+#[expect(clippy::large_enum_variant)]
 pub enum TaskStore {
     Functional(StoreState),
     Noop,
@@ -30,6 +31,7 @@ pub struct StoreState {
     task_inventory: Model<Inventory>,
     buffer_store: WeakModel<BufferStore>,
     worktree_store: Model<WorktreeStore>,
+    toolchain_store: Arc<dyn LanguageToolchainStore>,
     _global_task_config_watcher: Task<()>,
 }
 
@@ -155,6 +157,7 @@ impl TaskStore {
         fs: Arc<dyn Fs>,
         buffer_store: WeakModel<BufferStore>,
         worktree_store: Model<WorktreeStore>,
+        toolchain_store: Arc<dyn LanguageToolchainStore>,
         environment: Model<ProjectEnvironment>,
         cx: &mut ModelContext<'_, Self>,
     ) -> Self {
@@ -165,6 +168,7 @@ impl TaskStore {
             },
             task_inventory: Inventory::new(cx),
             buffer_store,
+            toolchain_store,
             worktree_store,
             _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
         })
@@ -174,6 +178,7 @@ impl TaskStore {
         fs: Arc<dyn Fs>,
         buffer_store: WeakModel<BufferStore>,
         worktree_store: Model<WorktreeStore>,
+        toolchain_store: Arc<dyn LanguageToolchainStore>,
         upstream_client: AnyProtoClient,
         project_id: u64,
         cx: &mut ModelContext<'_, Self>,
@@ -185,6 +190,7 @@ impl TaskStore {
             },
             task_inventory: Inventory::new(cx),
             buffer_store,
+            toolchain_store,
             worktree_store,
             _global_task_config_watcher: Self::subscribe_to_global_task_file_changes(fs, cx),
         })
@@ -200,6 +206,7 @@ impl TaskStore {
             TaskStore::Functional(state) => match &state.mode {
                 StoreMode::Local { environment, .. } => local_task_context_for_location(
                     state.worktree_store.clone(),
+                    state.toolchain_store.clone(),
                     environment.clone(),
                     captured_variables,
                     location,
@@ -210,10 +217,11 @@ impl TaskStore {
                     project_id,
                 } => remote_task_context_for_location(
                     *project_id,
-                    upstream_client,
+                    upstream_client.clone(),
                     state.worktree_store.clone(),
                     captured_variables,
                     location,
+                    state.toolchain_store.clone(),
                     cx,
                 ),
             },
@@ -314,6 +322,7 @@ impl TaskStore {
 
 fn local_task_context_for_location(
     worktree_store: Model<WorktreeStore>,
+    toolchain_store: Arc<dyn LanguageToolchainStore>,
     environment: Model<ProjectEnvironment>,
     captured_variables: TaskVariables,
     location: Location,
@@ -338,14 +347,15 @@ fn local_task_context_for_location(
                 combine_task_variables(
                     captured_variables,
                     location,
-                    project_env.as_ref(),
+                    project_env.clone(),
                     BasicContextProvider::new(worktree_store),
+                    toolchain_store,
                     cx,
                 )
-                .log_err()
             })
-            .ok()
-            .flatten()?;
+            .ok()?
+            .await
+            .log_err()?;
         // Remove all custom entries starting with _, as they're not intended for use by the end user.
         task_variables.sweep();
 
@@ -359,32 +369,46 @@ fn local_task_context_for_location(
 
 fn remote_task_context_for_location(
     project_id: u64,
-    upstream_client: &AnyProtoClient,
+    upstream_client: AnyProtoClient,
     worktree_store: Model<WorktreeStore>,
     captured_variables: TaskVariables,
     location: Location,
+    toolchain_store: Arc<dyn LanguageToolchainStore>,
     cx: &mut AppContext,
 ) -> Task<Option<TaskContext>> {
-    // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
-    let mut remote_context = BasicContextProvider::new(worktree_store)
-        .build_context(&TaskVariables::default(), &location, None, cx)
-        .log_err()
-        .unwrap_or_default();
-    remote_context.extend(captured_variables);
+    cx.spawn(|cx| async move {
+        // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
+        let mut remote_context = cx
+            .update(|cx| {
+                BasicContextProvider::new(worktree_store).build_context(
+                    &TaskVariables::default(),
+                    &location,
+                    None,
+                    toolchain_store,
+                    cx,
+                )
+            })
+            .ok()?
+            .await
+            .log_err()
+            .unwrap_or_default();
+        remote_context.extend(captured_variables);
 
-    let context_task = upstream_client.request(proto::TaskContextForLocation {
-        project_id,
-        location: Some(proto::Location {
-            buffer_id: location.buffer.read(cx).remote_id().into(),
-            start: Some(serialize_anchor(&location.range.start)),
-            end: Some(serialize_anchor(&location.range.end)),
-        }),
-        task_variables: remote_context
-            .into_iter()
-            .map(|(k, v)| (k.to_string(), v))
-            .collect(),
-    });
-    cx.spawn(|_| async move {
+        let buffer_id = cx
+            .update(|cx| location.buffer.read(cx).remote_id().to_proto())
+            .ok()?;
+        let context_task = upstream_client.request(proto::TaskContextForLocation {
+            project_id,
+            location: Some(proto::Location {
+                buffer_id,
+                start: Some(serialize_anchor(&location.range.start)),
+                end: Some(serialize_anchor(&location.range.end)),
+            }),
+            task_variables: remote_context
+                .into_iter()
+                .map(|(k, v)| (k.to_string(), v))
+                .collect(),
+        });
         let task_context = context_task.await.log_err()?;
         Some(TaskContext {
             cwd: task_context.cwd.map(PathBuf::from),
@@ -409,25 +433,45 @@ fn remote_task_context_for_location(
 fn combine_task_variables(
     mut captured_variables: TaskVariables,
     location: Location,
-    project_env: Option<&HashMap<String, String>>,
+    project_env: Option<HashMap<String, String>>,
     baseline: BasicContextProvider,
+    toolchain_store: Arc<dyn LanguageToolchainStore>,
     cx: &mut AppContext,
-) -> anyhow::Result<TaskVariables> {
+) -> Task<anyhow::Result<TaskVariables>> {
     let language_context_provider = location
         .buffer
         .read(cx)
         .language()
         .and_then(|language| language.context_provider());
-    let baseline = baseline
-        .build_context(&captured_variables, &location, project_env, cx)
-        .context("building basic default context")?;
-    captured_variables.extend(baseline);
-    if let Some(provider) = language_context_provider {
-        captured_variables.extend(
-            provider
-                .build_context(&captured_variables, &location, project_env, cx)
+    cx.spawn(move |cx| async move {
+        let baseline = cx
+            .update(|cx| {
+                baseline.build_context(
+                    &captured_variables,
+                    &location,
+                    project_env.clone(),
+                    toolchain_store.clone(),
+                    cx,
+                )
+            })?
+            .await
+            .context("building basic default context")?;
+        captured_variables.extend(baseline);
+        if let Some(provider) = language_context_provider {
+            captured_variables.extend(
+                cx.update(|cx| {
+                    provider.build_context(
+                        &captured_variables,
+                        &location,
+                        project_env,
+                        toolchain_store,
+                        cx,
+                    )
+                })?
+                .await
                 .context("building provider context")?,
-        );
-    }
-    Ok(captured_variables)
+            );
+        }
+        Ok(captured_variables)
+    })
 }

crates/project/src/toolchain_store.rs 🔗

@@ -194,7 +194,7 @@ impl ToolchainStore {
             groups,
         })
     }
-    pub(crate) fn as_language_toolchain_store(&self) -> Arc<dyn LanguageToolchainStore> {
+    pub fn as_language_toolchain_store(&self) -> Arc<dyn LanguageToolchainStore> {
         match &self.0 {
             ToolchainStoreInner::Local(local, _) => Arc::new(LocalStore(local.downgrade())),
             ToolchainStoreInner::Remote(remote) => Arc::new(RemoteStore(remote.downgrade())),

crates/remote_server/src/headless_project.rs 🔗

@@ -85,13 +85,22 @@ impl HeadlessProject {
                 cx,
             )
         });
-
         let environment = project::ProjectEnvironment::new(&worktree_store, None, cx);
+        let toolchain_store = cx.new_model(|cx| {
+            ToolchainStore::local(
+                languages.clone(),
+                worktree_store.clone(),
+                environment.clone(),
+                cx,
+            )
+        });
+
         let task_store = cx.new_model(|cx| {
             let mut task_store = TaskStore::local(
                 fs.clone(),
                 buffer_store.downgrade(),
                 worktree_store.clone(),
+                toolchain_store.read(cx).as_language_toolchain_store(),
                 environment.clone(),
                 cx,
             );
@@ -108,14 +117,7 @@ impl HeadlessProject {
             observer.shared(SSH_PROJECT_ID, session.clone().into(), cx);
             observer
         });
-        let toolchain_store = cx.new_model(|cx| {
-            ToolchainStore::local(
-                languages.clone(),
-                worktree_store.clone(),
-                environment.clone(),
-                cx,
-            )
-        });
+
         let lsp_store = cx.new_model(|cx| {
             let mut lsp_store = LspStore::new_local(
                 buffer_store.clone(),