remote(wsl): Make shell and platform discovery more resilient to shell scripts (#44363)

Lukas Wirth created

Companion PR to https://github.com/zed-industries/zed/pull/44165

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/remote/src/transport.rs     | 109 +++++++++++++++++++++++++++++++
crates/remote/src/transport/ssh.rs | 111 +------------------------------
crates/remote/src/transport/wsl.rs |  24 +++---
3 files changed, 125 insertions(+), 119 deletions(-)

Detailed changes

crates/remote/src/transport.rs 🔗

@@ -1,4 +1,5 @@
 use crate::{
+    RemotePlatform,
     json_log::LogRecord,
     protocol::{MESSAGE_LEN_SIZE, message_len_from_buffer, read_message_with_len, write_message},
 };
@@ -14,6 +15,54 @@ use smol::process::Child;
 pub mod ssh;
 pub mod wsl;
 
+/// Parses the output of `uname -sm` to determine the remote platform.
+/// Takes the last line to skip possible shell initialization output.
+fn parse_platform(output: &str) -> Result<RemotePlatform> {
+    let output = output.trim();
+    let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last);
+    let Some((os, arch)) = uname.split_once(" ") else {
+        anyhow::bail!("unknown uname: {uname:?}")
+    };
+
+    let os = match os {
+        "Darwin" => "macos",
+        "Linux" => "linux",
+        _ => anyhow::bail!(
+            "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
+        ),
+    };
+
+    // exclude armv5,6,7 as they are 32-bit.
+    let arch = if arch.starts_with("armv8")
+        || arch.starts_with("armv9")
+        || arch.starts_with("arm64")
+        || arch.starts_with("aarch64")
+    {
+        "aarch64"
+    } else if arch.starts_with("x86") {
+        "x86_64"
+    } else {
+        anyhow::bail!(
+            "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
+        )
+    };
+
+    Ok(RemotePlatform { os, arch })
+}
+
+/// Parses the output of `echo $SHELL` to determine the remote shell.
+/// Takes the last line to skip possible shell initialization output.
+fn parse_shell(output: &str, fallback_shell: &str) -> String {
+    let output = output.trim();
+    let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last);
+    if shell.is_empty() {
+        log::error!("$SHELL is not set, falling back to {fallback_shell}");
+        fallback_shell.to_owned()
+    } else {
+        shell.to_owned()
+    }
+}
+
 fn handle_rpc_messages_over_child_process_stdio(
     mut ssh_proxy_process: Child,
     incoming_tx: UnboundedSender<Envelope>,
@@ -316,3 +365,63 @@ async fn which(
         )),
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_parse_platform() {
+        let result = parse_platform("Linux x86_64\n").unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "x86_64");
+
+        let result = parse_platform("Darwin arm64\n").unwrap();
+        assert_eq!(result.os, "macos");
+        assert_eq!(result.arch, "aarch64");
+
+        let result = parse_platform("Linux x86_64").unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "x86_64");
+
+        let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "aarch64");
+
+        let result = parse_platform("some shell init output\nLinux aarch64").unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "aarch64");
+
+        assert_eq!(parse_platform("Linux armv8l\n").unwrap().arch, "aarch64");
+        assert_eq!(parse_platform("Linux aarch64\n").unwrap().arch, "aarch64");
+        assert_eq!(parse_platform("Linux x86_64\n").unwrap().arch, "x86_64");
+
+        let result = parse_platform(
+            r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#,
+        )
+        .unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "x86_64");
+
+        assert!(parse_platform("Windows x86_64\n").is_err());
+        assert!(parse_platform("Linux armv7l\n").is_err());
+    }
+
+    #[test]
+    fn test_parse_shell() {
+        assert_eq!(parse_shell("/bin/bash\n", "sh"), "/bin/bash");
+        assert_eq!(parse_shell("/bin/zsh\n", "sh"), "/bin/zsh");
+
+        assert_eq!(parse_shell("/bin/bash", "sh"), "/bin/bash");
+        assert_eq!(
+            parse_shell("some shell init output\n/bin/bash\n", "sh"),
+            "/bin/bash"
+        );
+        assert_eq!(
+            parse_shell("some shell init output\n/bin/bash", "sh"),
+            "/bin/bash"
+        );
+        assert_eq!(parse_shell("", "sh"), "sh");
+        assert_eq!(parse_shell("\n", "sh"), "sh");
+    }
+}

crates/remote/src/transport/ssh.rs 🔗

@@ -1,6 +1,7 @@
 use crate::{
     RemoteClientDelegate, RemotePlatform,
     remote_client::{CommandTemplate, RemoteConnection, RemoteConnectionOptions},
+    transport::{parse_platform, parse_shell},
 };
 use anyhow::{Context as _, Result, anyhow};
 use async_trait::async_trait;
@@ -1072,69 +1073,20 @@ impl SshSocket {
     }
 
     async fn shell(&self) -> String {
+        const DEFAULT_SHELL: &str = "sh";
         match self
             .run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false)
             .await
         {
-            Ok(output) => parse_shell(&output),
+            Ok(output) => parse_shell(&output, DEFAULT_SHELL),
             Err(e) => {
-                log::error!("Failed to get shell: {e}");
+                log::error!("Failed to detect remote shell: {e}");
                 DEFAULT_SHELL.to_owned()
             }
         }
     }
 }
 
-const DEFAULT_SHELL: &str = "sh";
-
-/// Parses the output of `uname -sm` to determine the remote platform.
-/// Takes the last line to skip possible shell initialization output.
-fn parse_platform(output: &str) -> Result<RemotePlatform> {
-    let output = output.trim();
-    let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last);
-    let Some((os, arch)) = uname.split_once(" ") else {
-        anyhow::bail!("unknown uname: {uname:?}")
-    };
-
-    let os = match os {
-        "Darwin" => "macos",
-        "Linux" => "linux",
-        _ => anyhow::bail!(
-            "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
-        ),
-    };
-
-    // exclude armv5,6,7 as they are 32-bit.
-    let arch = if arch.starts_with("armv8")
-        || arch.starts_with("armv9")
-        || arch.starts_with("arm64")
-        || arch.starts_with("aarch64")
-    {
-        "aarch64"
-    } else if arch.starts_with("x86") {
-        "x86_64"
-    } else {
-        anyhow::bail!(
-            "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
-        )
-    };
-
-    Ok(RemotePlatform { os, arch })
-}
-
-/// Parses the output of `echo $SHELL` to determine the remote shell.
-/// Takes the last line to skip possible shell initialization output.
-fn parse_shell(output: &str) -> String {
-    let output = output.trim();
-    let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last);
-    if shell.is_empty() {
-        log::error!("$SHELL is not set, falling back to {DEFAULT_SHELL}");
-        DEFAULT_SHELL.to_owned()
-    } else {
-        shell.to_owned()
-    }
-}
-
 fn parse_port_number(port_str: &str) -> Result<u16> {
     port_str
         .parse()
@@ -1535,59 +1487,4 @@ mod tests {
             ]
         );
     }
-
-    #[test]
-    fn test_parse_platform() {
-        let result = parse_platform("Linux x86_64\n").unwrap();
-        assert_eq!(result.os, "linux");
-        assert_eq!(result.arch, "x86_64");
-
-        let result = parse_platform("Darwin arm64\n").unwrap();
-        assert_eq!(result.os, "macos");
-        assert_eq!(result.arch, "aarch64");
-
-        let result = parse_platform("Linux x86_64").unwrap();
-        assert_eq!(result.os, "linux");
-        assert_eq!(result.arch, "x86_64");
-
-        let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap();
-        assert_eq!(result.os, "linux");
-        assert_eq!(result.arch, "aarch64");
-
-        let result = parse_platform("some shell init output\nLinux aarch64").unwrap();
-        assert_eq!(result.os, "linux");
-        assert_eq!(result.arch, "aarch64");
-
-        assert_eq!(parse_platform("Linux armv8l\n").unwrap().arch, "aarch64");
-        assert_eq!(parse_platform("Linux aarch64\n").unwrap().arch, "aarch64");
-        assert_eq!(parse_platform("Linux x86_64\n").unwrap().arch, "x86_64");
-
-        let result = parse_platform(
-            r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#,
-        )
-        .unwrap();
-        assert_eq!(result.os, "linux");
-        assert_eq!(result.arch, "x86_64");
-
-        assert!(parse_platform("Windows x86_64\n").is_err());
-        assert!(parse_platform("Linux armv7l\n").is_err());
-    }
-
-    #[test]
-    fn test_parse_shell() {
-        assert_eq!(parse_shell("/bin/bash\n"), "/bin/bash");
-        assert_eq!(parse_shell("/bin/zsh\n"), "/bin/zsh");
-
-        assert_eq!(parse_shell("/bin/bash"), "/bin/bash");
-        assert_eq!(
-            parse_shell("some shell init output\n/bin/bash\n"),
-            "/bin/bash"
-        );
-        assert_eq!(
-            parse_shell("some shell init output\n/bin/bash"),
-            "/bin/bash"
-        );
-        assert_eq!(parse_shell(""), "sh");
-        assert_eq!(parse_shell("\n"), "sh");
-    }
 }

crates/remote/src/transport/wsl.rs 🔗

@@ -1,6 +1,7 @@
 use crate::{
     RemoteClientDelegate, RemotePlatform,
     remote_client::{CommandTemplate, RemoteConnection, RemoteConnectionOptions},
+    transport::{parse_platform, parse_shell},
 };
 use anyhow::{Context, Result, anyhow, bail};
 use async_trait::async_trait;
@@ -107,23 +108,22 @@ impl WslRemoteConnection {
 
     async fn detect_platform(&self) -> Result<RemotePlatform> {
         let program = self.shell_kind.prepend_command_prefix("uname");
-        let arch_str = self.run_wsl_command_with_output(&program, &["-m"]).await?;
-        let arch_str = arch_str.trim().to_string();
-        let arch = match arch_str.as_str() {
-            "x86_64" => "x86_64",
-            "aarch64" | "arm64" => "aarch64",
-            _ => "x86_64",
-        };
-        Ok(RemotePlatform { os: "linux", arch })
+        let output = self.run_wsl_command_with_output(&program, &["-sm"]).await?;
+        parse_platform(&output)
     }
 
     async fn detect_shell(&self) -> Result<String> {
-        Ok(self
+        const DEFAULT_SHELL: &str = "sh";
+        match self
             .run_wsl_command_with_output("sh", &["-c", "echo $SHELL"])
             .await
-            .inspect_err(|err| log::error!("Failed to detect remote shell: {err}"))
-            .ok()
-            .unwrap_or_else(|| "/bin/sh".to_string()))
+        {
+            Ok(output) => Ok(parse_shell(&output, DEFAULT_SHELL)),
+            Err(e) => {
+                log::error!("Failed to detect remote shell: {e}");
+                Ok(DEFAULT_SHELL.to_owned())
+            }
+        }
     }
 
     async fn detect_has_wsl_interop(&self) -> Result<bool> {