remote: Use last line of `uname` and shell output (#44165)

Agus Zubiaga created

We have seen cases (see
https://github.com/zed-industries/zed/issues/43694) where the user's
shell initialization script includes text that ends up in the output of
the commands we use to detect the platform and shell of the remote. This
solution isn't perfect, but it should address the issue in most
situations since both commands should only output one line.

Release Notes:

- remote: Improve resiliency when initialization scripts output text

Change summary

crates/remote/src/transport/ssh.rs | 150 +++++++++++++++++++++++--------
1 file changed, 109 insertions(+), 41 deletions(-)

Detailed changes

crates/remote/src/transport/ssh.rs 🔗

@@ -1055,57 +1055,74 @@ impl SshSocket {
     }
 
     async fn platform(&self, shell: ShellKind) -> Result<RemotePlatform> {
-        let uname = self.run_command(shell, "uname", &["-sm"], false).await?;
-        let Some((os, arch)) = uname.split_once(" ") else {
-            anyhow::bail!("unknown uname: {uname:?}")
-        };
-
-        let os = match os.trim() {
-            "Darwin" => "macos",
-            "Linux" => "linux",
-            _ => anyhow::bail!(
-                "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
-            ),
-        };
-        // exclude armv5,6,7 as they are 32-bit.
-        let arch = if arch.starts_with("armv8")
-            || arch.starts_with("armv9")
-            || arch.starts_with("arm64")
-            || arch.starts_with("aarch64")
-        {
-            "aarch64"
-        } else if arch.starts_with("x86") {
-            "x86_64"
-        } else {
-            anyhow::bail!(
-                "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
-            )
-        };
-
-        Ok(RemotePlatform { os, arch })
+        let output = self.run_command(shell, "uname", &["-sm"], false).await?;
+        parse_platform(&output)
     }
 
     async fn shell(&self) -> String {
-        let default_shell = "sh";
         match self
             .run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false)
             .await
         {
-            Ok(shell) => match shell.trim() {
-                "" => {
-                    log::error!("$SHELL is not set, falling back to {default_shell}");
-                    default_shell.to_owned()
-                }
-                shell => shell.to_owned(),
-            },
+            Ok(output) => parse_shell(&output),
             Err(e) => {
                 log::error!("Failed to get shell: {e}");
-                default_shell.to_owned()
+                DEFAULT_SHELL.to_owned()
             }
         }
     }
 }
 
+const DEFAULT_SHELL: &str = "sh";
+
+/// Parses the output of `uname -sm` to determine the remote platform.
+/// Takes the last line to skip possible shell initialization output.
+fn parse_platform(output: &str) -> Result<RemotePlatform> {
+    let output = output.trim();
+    let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last);
+    let Some((os, arch)) = uname.split_once(" ") else {
+        anyhow::bail!("unknown uname: {uname:?}")
+    };
+
+    let os = match os {
+        "Darwin" => "macos",
+        "Linux" => "linux",
+        _ => anyhow::bail!(
+            "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
+        ),
+    };
+
+    // exclude armv5,6,7 as they are 32-bit.
+    let arch = if arch.starts_with("armv8")
+        || arch.starts_with("armv9")
+        || arch.starts_with("arm64")
+        || arch.starts_with("aarch64")
+    {
+        "aarch64"
+    } else if arch.starts_with("x86") {
+        "x86_64"
+    } else {
+        anyhow::bail!(
+            "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
+        )
+    };
+
+    Ok(RemotePlatform { os, arch })
+}
+
+/// Parses the output of `echo $SHELL` to determine the remote shell.
+/// Takes the last line to skip possible shell initialization output.
+fn parse_shell(output: &str) -> String {
+    let output = output.trim();
+    let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last);
+    if shell.is_empty() {
+        log::error!("$SHELL is not set, falling back to {DEFAULT_SHELL}");
+        DEFAULT_SHELL.to_owned()
+    } else {
+        shell.to_owned()
+    }
+}
+
 fn parse_port_number(port_str: &str) -> Result<u16> {
     port_str
         .parse()
@@ -1502,12 +1519,63 @@ mod tests {
                 "-p".to_string(),
                 "2222".to_string(),
                 "-o".to_string(),
-                "StrictHostKeyChecking=no".to_string()
+                "StrictHostKeyChecking=no".to_string(),
             ]
         );
-        assert!(
-            scp_args.iter().all(|arg| !arg.starts_with("-L")),
-            "scp args should not contain port forward flags: {scp_args:?}"
+    }
+
+    #[test]
+    fn test_parse_platform() {
+        let result = parse_platform("Linux x86_64\n").unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "x86_64");
+
+        let result = parse_platform("Darwin arm64\n").unwrap();
+        assert_eq!(result.os, "macos");
+        assert_eq!(result.arch, "aarch64");
+
+        let result = parse_platform("Linux x86_64").unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "x86_64");
+
+        let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "aarch64");
+
+        let result = parse_platform("some shell init output\nLinux aarch64").unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "aarch64");
+
+        assert_eq!(parse_platform("Linux armv8l\n").unwrap().arch, "aarch64");
+        assert_eq!(parse_platform("Linux aarch64\n").unwrap().arch, "aarch64");
+        assert_eq!(parse_platform("Linux x86_64\n").unwrap().arch, "x86_64");
+
+        let result = parse_platform(
+            r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#,
+        )
+        .unwrap();
+        assert_eq!(result.os, "linux");
+        assert_eq!(result.arch, "x86_64");
+
+        assert!(parse_platform("Windows x86_64\n").is_err());
+        assert!(parse_platform("Linux armv7l\n").is_err());
+    }
+
+    #[test]
+    fn test_parse_shell() {
+        assert_eq!(parse_shell("/bin/bash\n"), "/bin/bash");
+        assert_eq!(parse_shell("/bin/zsh\n"), "/bin/zsh");
+
+        assert_eq!(parse_shell("/bin/bash"), "/bin/bash");
+        assert_eq!(
+            parse_shell("some shell init output\n/bin/bash\n"),
+            "/bin/bash"
+        );
+        assert_eq!(
+            parse_shell("some shell init output\n/bin/bash"),
+            "/bin/bash"
         );
+        assert_eq!(parse_shell(""), "sh");
+        assert_eq!(parse_shell("\n"), "sh");
     }
 }