Add runnable for rust main function (#13087)

Panghu created

Release Notes:

- N/A



https://github.com/zed-industries/zed/assets/21101490/7a57805c-1d31-48b2-bc2c-3a6f0b730d72

Change summary

crates/languages/src/rust.rs            | 156 +++++++++++++++++++++++++-
crates/languages/src/rust/runnables.scm |  12 ++
2 files changed, 157 insertions(+), 11 deletions(-)

Detailed changes

crates/languages/src/rust.rs 🔗

@@ -346,10 +346,17 @@ pub(crate) struct RustContextProvider;
 const RUST_PACKAGE_TASK_VARIABLE: VariableName =
     VariableName::Custom(Cow::Borrowed("RUST_PACKAGE"));
 
+/// The bin name corresponding to the current file in Cargo.toml
+const RUST_BIN_NAME_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("RUST_BIN_NAME"));
+
+const RUST_MAIN_FUNCTION_TASK_VARIABLE: VariableName =
+    VariableName::Custom(Cow::Borrowed("_rust_main_function_end"));
+
 impl ContextProvider for RustContextProvider {
     fn build_context(
         &self,
-        _: &TaskVariables,
+        task_variables: &TaskVariables,
         location: &Location,
         cx: &mut gpui::AppContext,
     ) -> Result<TaskVariables> {
@@ -358,17 +365,35 @@ impl ContextProvider for RustContextProvider {
             .read(cx)
             .file()
             .and_then(|file| Some(file.as_local()?.abs_path(cx)));
-        Ok(
-            if let Some(package_name) = local_abs_path
-                .as_deref()
-                .and_then(|local_abs_path| local_abs_path.parent())
-                .and_then(human_readable_package_name)
+
+        let local_abs_path = local_abs_path.as_deref();
+
+        let is_main_function = task_variables
+            .get(&RUST_MAIN_FUNCTION_TASK_VARIABLE)
+            .is_some();
+
+        if is_main_function {
+            if let Some((package_name, bin_name)) = local_abs_path
+                .and_then(|local_abs_path| package_name_and_bin_name_from_abs_path(local_abs_path))
             {
-                TaskVariables::from_iter(Some((RUST_PACKAGE_TASK_VARIABLE.clone(), package_name)))
-            } else {
-                TaskVariables::default()
-            },
-        )
+                return Ok(TaskVariables::from_iter([
+                    (RUST_PACKAGE_TASK_VARIABLE.clone(), package_name),
+                    (RUST_BIN_NAME_TASK_VARIABLE.clone(), bin_name),
+                ]));
+            }
+        }
+
+        if let Some(package_name) = local_abs_path
+            .and_then(|local_abs_path| local_abs_path.parent())
+            .and_then(human_readable_package_name)
+        {
+            return Ok(TaskVariables::from_iter([(
+                RUST_PACKAGE_TASK_VARIABLE.clone(),
+                package_name,
+            )]));
+        }
+
+        Ok(TaskVariables::default())
     }
 
     fn associated_tasks(&self) -> Option<TaskTemplates> {
@@ -426,6 +451,23 @@ impl ContextProvider for RustContextProvider {
                 tags: vec!["rust-mod-test".to_owned()],
                 ..TaskTemplate::default()
             },
+            TaskTemplate {
+                label: format!(
+                    "cargo run -p {} --bin {}",
+                    RUST_PACKAGE_TASK_VARIABLE.template_value(),
+                    RUST_BIN_NAME_TASK_VARIABLE.template_value(),
+                ),
+                command: "cargo".into(),
+                args: vec![
+                    "run".into(),
+                    "-p".into(),
+                    RUST_PACKAGE_TASK_VARIABLE.template_value(),
+                    "--bin".into(),
+                    RUST_BIN_NAME_TASK_VARIABLE.template_value(),
+                ],
+                tags: vec!["rust-main".to_owned()],
+                ..TaskTemplate::default()
+            },
             TaskTemplate {
                 label: format!(
                     "cargo test -p {}",
@@ -455,6 +497,65 @@ impl ContextProvider for RustContextProvider {
     }
 }
 
+/// Part of the data structure of Cargo metadata
+#[derive(serde::Deserialize)]
+struct CargoMetadata {
+    packages: Vec<CargoPackage>,
+}
+
+#[derive(serde::Deserialize)]
+struct CargoPackage {
+    id: String,
+    targets: Vec<CargoTarget>,
+}
+
+#[derive(serde::Deserialize)]
+struct CargoTarget {
+    name: String,
+    kind: Vec<String>,
+    src_path: String,
+}
+
+fn package_name_and_bin_name_from_abs_path(abs_path: &Path) -> Option<(String, String)> {
+    let output = std::process::Command::new("cargo")
+        .current_dir(abs_path.parent()?)
+        .arg("metadata")
+        .arg("--no-deps")
+        .arg("--format-version")
+        .arg("1")
+        .output()
+        .log_err()?
+        .stdout;
+
+    let metadata: CargoMetadata = serde_json::from_slice(&output).log_err()?;
+
+    retrieve_package_id_and_bin_name_from_metadata(metadata, abs_path).and_then(
+        |(package_id, bin_name)| {
+            let package_name = package_name_from_pkgid(&package_id);
+
+            package_name.map(|package_name| (package_name.to_owned(), bin_name))
+        },
+    )
+}
+
+fn retrieve_package_id_and_bin_name_from_metadata(
+    metadata: CargoMetadata,
+    abs_path: &Path,
+) -> Option<(String, String)> {
+    let abs_path = abs_path.to_str()?;
+
+    for package in metadata.packages {
+        for target in package.targets {
+            let is_bin = target.kind.iter().any(|kind| kind == "bin");
+            if target.src_path == abs_path && is_bin {
+                return Some((package.id, target.name));
+            }
+        }
+    }
+
+    None
+}
+
 fn human_readable_package_name(package_directory: &Path) -> Option<String> {
     let pkgid = String::from_utf8(
         std::process::Command::new("cargo")
@@ -815,4 +916,37 @@ mod tests {
             assert_eq!(package_name_from_pkgid(input), Some(expected));
         }
     }
+
+    #[test]
+    fn test_retrieve_package_id_and_bin_name_from_metadata() {
+        for (input, absolute_path, expected) in [
+            (
+                r#"{"packages":[{"id":"path+file:///path/to/zed/crates/zed#0.131.0","targets":[{"name":"zed","kind":["bin"],"src_path":"/path/to/zed/src/main.rs"}]}]}"#,
+                "/path/to/zed/src/main.rs",
+                Some(("path+file:///path/to/zed/crates/zed#0.131.0", "zed")),
+            ),
+            (
+                r#"{"packages":[{"id":"path+file:///path/to/custom-package#my-custom-package@0.1.0","targets":[{"name":"my-custom-bin","kind":["bin"],"src_path":"/path/to/custom-package/src/main.rs"}]}]}"#,
+                "/path/to/custom-package/src/main.rs",
+                Some((
+                    "path+file:///path/to/custom-package#my-custom-package@0.1.0",
+                    "my-custom-bin",
+                )),
+            ),
+            (
+                r#"{"packages":[{"id":"path+file:///path/to/custom-package#my-custom-package@0.1.0","targets":[{"name":"my-custom-package","kind":["lib"],"src_path":"/path/to/custom-package/src/main.rs"}]}]}"#,
+                "/path/to/custom-package/src/main.rs",
+                None,
+            ),
+        ] {
+            let metadata: CargoMetadata = serde_json::from_str(input).unwrap();
+
+            let absolute_path = Path::new(absolute_path);
+
+            assert_eq!(
+                retrieve_package_id_and_bin_name_from_metadata(metadata, absolute_path),
+                expected.map(|(pkgid, bin)| (pkgid.to_owned(), bin.to_owned()))
+            );
+        }
+    }
 }

crates/languages/src/rust/runnables.scm 🔗

@@ -25,3 +25,15 @@
     )
     (#set! tag rust-test)
 )
+
+; Rust main function
+(
+    (
+        (function_item
+            name: (_) @run
+            body: _
+        ) @_rust_main_function_end
+        (#eq? @run "main")
+    )
+    (#set! tag rust-main)
+)