Fix `cargo test` task for tests module in `lib.rs`, `main.rs`, `mod.rs` (#25092)

Cole Miller created

Closes #19161

Release Notes:
- Fixed not being able to spawn the `cargo test` task for a `tests`
module in `lib.rs`, `main.rs`, or `mod.rs`

Change summary

crates/languages/src/rust.rs | 102 ++++++++++++++++++++++++++++---------
1 file changed, 77 insertions(+), 25 deletions(-)

Detailed changes

crates/languages/src/rust.rs 🔗

@@ -469,8 +469,8 @@ const RUST_BIN_NAME_TASK_VARIABLE: VariableName =
 const RUST_BIN_KIND_TASK_VARIABLE: VariableName =
     VariableName::Custom(Cow::Borrowed("RUST_BIN_KIND"));
 
-const RUST_MAIN_FUNCTION_TASK_VARIABLE: VariableName =
-    VariableName::Custom(Cow::Borrowed("_rust_main_function_end"));
+const RUST_TEST_FRAGMENT_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("RUST_TEST_FRAGMENT"));
 
 impl ContextProvider for RustContextProvider {
     fn build_context(
@@ -489,36 +489,35 @@ impl ContextProvider for RustContextProvider {
 
         let local_abs_path = local_abs_path.as_deref();
 
-        let is_main_function = task_variables
-            .get(&RUST_MAIN_FUNCTION_TASK_VARIABLE)
-            .is_some();
-
-        if is_main_function {
-            if let Some(target) = local_abs_path.and_then(|path| {
-                package_name_and_bin_name_from_abs_path(path, project_env.as_ref())
-            }) {
-                return Task::ready(Ok(TaskVariables::from_iter([
-                    (RUST_PACKAGE_TASK_VARIABLE.clone(), target.package_name),
-                    (RUST_BIN_NAME_TASK_VARIABLE.clone(), target.target_name),
-                    (
-                        RUST_BIN_KIND_TASK_VARIABLE.clone(),
-                        target.target_kind.to_string(),
-                    ),
-                ])));
-            }
+        let mut variables = TaskVariables::default();
+
+        if let Some(target) = local_abs_path
+            .and_then(|path| package_name_and_bin_name_from_abs_path(path, project_env.as_ref()))
+        {
+            variables.extend(TaskVariables::from_iter([
+                (RUST_PACKAGE_TASK_VARIABLE.clone(), target.package_name),
+                (RUST_BIN_NAME_TASK_VARIABLE.clone(), target.target_name),
+                (
+                    RUST_BIN_KIND_TASK_VARIABLE.clone(),
+                    target.target_kind.to_string(),
+                ),
+            ]));
         }
 
         if let Some(package_name) = local_abs_path
             .and_then(|local_abs_path| local_abs_path.parent())
             .and_then(|path| human_readable_package_name(path, project_env.as_ref()))
         {
-            return Task::ready(Ok(TaskVariables::from_iter([(
-                RUST_PACKAGE_TASK_VARIABLE.clone(),
-                package_name,
-            )])));
+            variables.insert(RUST_PACKAGE_TASK_VARIABLE.clone(), package_name);
         }
 
-        Task::ready(Ok(TaskVariables::default()))
+        if let (Some(path), Some(stem)) = (local_abs_path, task_variables.get(&VariableName::Stem))
+        {
+            let fragment = test_fragment(&variables, path, stem);
+            variables.insert(RUST_TEST_FRAGMENT_TASK_VARIABLE, fragment);
+        };
+
+        Task::ready(Ok(variables))
     }
 
     fn associated_tasks(
@@ -589,7 +588,7 @@ impl ContextProvider for RustContextProvider {
                     "test".into(),
                     "-p".into(),
                     RUST_PACKAGE_TASK_VARIABLE.template_value(),
-                    VariableName::Stem.template_value(),
+                    RUST_TEST_FRAGMENT_TASK_VARIABLE.template_value(),
                 ],
                 tags: vec!["rust-mod-test".to_owned()],
                 cwd: Some("$ZED_DIRNAME".to_owned()),
@@ -824,6 +823,29 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option<LanguageServ
     .log_err()
 }
 
+fn test_fragment(variables: &TaskVariables, path: &Path, stem: &str) -> String {
+    let fragment = if stem == "lib" {
+        // This isn't quite right---it runs the tests for the entire library, rather than
+        // just for the top-level `mod tests`. But we don't really have the means here to
+        // filter out just that module.
+        Some("--lib".to_owned())
+    } else if stem == "mod" {
+        maybe!({ Some(path.parent()?.file_name()?.to_string_lossy().to_string()) })
+    } else if stem == "main" {
+        if let (Some(bin_name), Some(bin_kind)) = (
+            variables.get(&RUST_BIN_NAME_TASK_VARIABLE),
+            variables.get(&RUST_BIN_KIND_TASK_VARIABLE),
+        ) {
+            Some(format!("--{bin_kind}={bin_name}"))
+        } else {
+            None
+        }
+    } else {
+        Some(stem.to_owned())
+    };
+    fragment.unwrap_or_else(|| "--".to_owned())
+}
+
 #[cfg(test)]
 mod tests {
     use std::num::NonZeroU32;
@@ -1179,4 +1201,34 @@ mod tests {
             );
         }
     }
+
+    #[test]
+    fn test_rust_test_fragment() {
+        #[track_caller]
+        fn check(
+            variables: impl IntoIterator<Item = (VariableName, &'static str)>,
+            path: &str,
+            expected: &str,
+        ) {
+            let path = Path::new(path);
+            let found = test_fragment(
+                &TaskVariables::from_iter(variables.into_iter().map(|(k, v)| (k, v.to_owned()))),
+                path,
+                &path.file_stem().unwrap().to_str().unwrap(),
+            );
+            assert_eq!(expected, found);
+        }
+
+        check([], "/project/src/lib.rs", "--lib");
+        check([], "/project/src/foo/mod.rs", "foo");
+        check(
+            [
+                (RUST_BIN_KIND_TASK_VARIABLE.clone(), "bin"),
+                (RUST_BIN_NAME_TASK_VARIABLE, "x"),
+            ],
+            "/project/src/main.rs",
+            "--bin=x",
+        );
+        check([], "/project/src/main.rs", "--");
+    }
 }