diff --git a/crates/remote/src/transport/ssh.rs b/crates/remote/src/transport/ssh.rs index 20cd0c5ff4b427d3a37882603ce2962db9e4e1e0..56f29be092b5ed6ab4993664eb256056837047f5 100644 --- a/crates/remote/src/transport/ssh.rs +++ b/crates/remote/src/transport/ssh.rs @@ -1055,57 +1055,74 @@ impl SshSocket { } async fn platform(&self, shell: ShellKind) -> Result { - let uname = self.run_command(shell, "uname", &["-sm"], false).await?; - let Some((os, arch)) = uname.split_once(" ") else { - anyhow::bail!("unknown uname: {uname:?}") - }; - - let os = match os.trim() { - "Darwin" => "macos", - "Linux" => "linux", - _ => anyhow::bail!( - "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development" - ), - }; - // exclude armv5,6,7 as they are 32-bit. - let arch = if arch.starts_with("armv8") - || arch.starts_with("armv9") - || arch.starts_with("arm64") - || arch.starts_with("aarch64") - { - "aarch64" - } else if arch.starts_with("x86") { - "x86_64" - } else { - anyhow::bail!( - "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development" - ) - }; - - Ok(RemotePlatform { os, arch }) + let output = self.run_command(shell, "uname", &["-sm"], false).await?; + parse_platform(&output) } async fn shell(&self) -> String { - let default_shell = "sh"; match self .run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false) .await { - Ok(shell) => match shell.trim() { - "" => { - log::error!("$SHELL is not set, falling back to {default_shell}"); - default_shell.to_owned() - } - shell => shell.to_owned(), - }, + Ok(output) => parse_shell(&output), Err(e) => { log::error!("Failed to get shell: {e}"); - default_shell.to_owned() + DEFAULT_SHELL.to_owned() } } } } +const DEFAULT_SHELL: &str = "sh"; + +/// Parses the output of `uname -sm` to determine the remote platform. +/// Takes the last line to skip possible shell initialization output. +fn parse_platform(output: &str) -> Result { + let output = output.trim(); + let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last); + let Some((os, arch)) = uname.split_once(" ") else { + anyhow::bail!("unknown uname: {uname:?}") + }; + + let os = match os { + "Darwin" => "macos", + "Linux" => "linux", + _ => anyhow::bail!( + "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development" + ), + }; + + // exclude armv5,6,7 as they are 32-bit. + let arch = if arch.starts_with("armv8") + || arch.starts_with("armv9") + || arch.starts_with("arm64") + || arch.starts_with("aarch64") + { + "aarch64" + } else if arch.starts_with("x86") { + "x86_64" + } else { + anyhow::bail!( + "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development" + ) + }; + + Ok(RemotePlatform { os, arch }) +} + +/// Parses the output of `echo $SHELL` to determine the remote shell. +/// Takes the last line to skip possible shell initialization output. +fn parse_shell(output: &str) -> String { + let output = output.trim(); + let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last); + if shell.is_empty() { + log::error!("$SHELL is not set, falling back to {DEFAULT_SHELL}"); + DEFAULT_SHELL.to_owned() + } else { + shell.to_owned() + } +} + fn parse_port_number(port_str: &str) -> Result { port_str .parse() @@ -1502,12 +1519,63 @@ mod tests { "-p".to_string(), "2222".to_string(), "-o".to_string(), - "StrictHostKeyChecking=no".to_string() + "StrictHostKeyChecking=no".to_string(), ] ); - assert!( - scp_args.iter().all(|arg| !arg.starts_with("-L")), - "scp args should not contain port forward flags: {scp_args:?}" + } + + #[test] + fn test_parse_platform() { + let result = parse_platform("Linux x86_64\n").unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "x86_64"); + + let result = parse_platform("Darwin arm64\n").unwrap(); + assert_eq!(result.os, "macos"); + assert_eq!(result.arch, "aarch64"); + + let result = parse_platform("Linux x86_64").unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "x86_64"); + + let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "aarch64"); + + let result = parse_platform("some shell init output\nLinux aarch64").unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "aarch64"); + + assert_eq!(parse_platform("Linux armv8l\n").unwrap().arch, "aarch64"); + assert_eq!(parse_platform("Linux aarch64\n").unwrap().arch, "aarch64"); + assert_eq!(parse_platform("Linux x86_64\n").unwrap().arch, "x86_64"); + + let result = parse_platform( + r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#, + ) + .unwrap(); + assert_eq!(result.os, "linux"); + assert_eq!(result.arch, "x86_64"); + + assert!(parse_platform("Windows x86_64\n").is_err()); + assert!(parse_platform("Linux armv7l\n").is_err()); + } + + #[test] + fn test_parse_shell() { + assert_eq!(parse_shell("/bin/bash\n"), "/bin/bash"); + assert_eq!(parse_shell("/bin/zsh\n"), "/bin/zsh"); + + assert_eq!(parse_shell("/bin/bash"), "/bin/bash"); + assert_eq!( + parse_shell("some shell init output\n/bin/bash\n"), + "/bin/bash" + ); + assert_eq!( + parse_shell("some shell init output\n/bin/bash"), + "/bin/bash" ); + assert_eq!(parse_shell(""), "sh"); + assert_eq!(parse_shell("\n"), "sh"); } }