SSH installation refactor (#19991)

Conrad Irwin and Mikayala created

This also cleans up logic for deciding how to do things.

Release Notes:

- Remoting: If downloading the binary on the remote fails, fall back to
uploading it.

---------

Co-authored-by: Mikayala <mikayla@zed.dev>

Change summary

Cargo.lock                                    |   1 
crates/recent_projects/src/ssh_connections.rs | 270 +++--------------
crates/remote/Cargo.toml                      |   1 
crates/remote/src/ssh_session.rs              | 322 ++++++++++++++++----
4 files changed, 306 insertions(+), 288 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -9536,6 +9536,7 @@ dependencies = [
  "log",
  "parking_lot",
  "prost",
+ "release_channel",
  "rpc",
  "serde",
  "serde_json",

crates/recent_projects/src/ssh_connections.rs 🔗

@@ -13,8 +13,7 @@ use gpui::{AppContext, Model};
 
 use language::CursorShape;
 use markdown::{Markdown, MarkdownStyle};
-use release_channel::{AppVersion, ReleaseChannel};
-use remote::ssh_session::{ServerBinary, ServerVersion};
+use release_channel::ReleaseChannel;
 use remote::{SshConnectionOptions, SshPlatform, SshRemoteClient};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
@@ -441,23 +440,66 @@ impl remote::SshClientDelegate for SshClientDelegate {
         self.update_status(status, cx)
     }
 
-    fn get_server_binary(
+    fn download_server_binary_locally(
         &self,
         platform: SshPlatform,
-        upload_binary_over_ssh: bool,
+        release_channel: ReleaseChannel,
+        version: Option<SemanticVersion>,
         cx: &mut AsyncAppContext,
-    ) -> oneshot::Receiver<Result<(ServerBinary, ServerVersion)>> {
-        let (tx, rx) = oneshot::channel();
-        let this = self.clone();
+    ) -> Task<anyhow::Result<PathBuf>> {
         cx.spawn(|mut cx| async move {
-            tx.send(
-                this.get_server_binary_impl(platform, upload_binary_over_ssh, &mut cx)
-                    .await,
+            let binary_path = AutoUpdater::download_remote_server_release(
+                platform.os,
+                platform.arch,
+                release_channel,
+                version,
+                &mut cx,
             )
-            .ok();
+            .await
+            .map_err(|e| {
+                anyhow!(
+                    "Failed to download remote server binary (version: {}, os: {}, arch: {}): {}",
+                    version
+                        .map(|v| format!("{}", v))
+                        .unwrap_or("unknown".to_string()),
+                    platform.os,
+                    platform.arch,
+                    e
+                )
+            })?;
+            Ok(binary_path)
         })
-        .detach();
-        rx
+    }
+
+    fn get_download_params(
+        &self,
+        platform: SshPlatform,
+        release_channel: ReleaseChannel,
+        version: Option<SemanticVersion>,
+        cx: &mut AsyncAppContext,
+    ) -> Task<Result<(String, String)>> {
+        cx.spawn(|mut cx| async move {
+                let (release, request_body) = AutoUpdater::get_remote_server_release_url(
+                            platform.os,
+                            platform.arch,
+                            release_channel,
+                            version,
+                            &mut cx,
+                        )
+                        .await
+                        .map_err(|e| {
+                            anyhow!(
+                                "Failed to get remote server binary download url (version: {}, os: {}, arch: {}): {}",
+                                version.map(|v| format!("{}", v)).unwrap_or("unknown".to_string()),
+                                platform.os,
+                                platform.arch,
+                                e
+                            )
+                        })?;
+
+                Ok((release.url, request_body))
+            }
+        )
     }
 
     fn remote_server_binary_path(
@@ -485,208 +527,6 @@ impl SshClientDelegate {
             })
             .ok();
     }
-
-    async fn get_server_binary_impl(
-        &self,
-        platform: SshPlatform,
-        upload_binary_via_ssh: bool,
-        cx: &mut AsyncAppContext,
-    ) -> Result<(ServerBinary, ServerVersion)> {
-        let (version, release_channel) = cx.update(|cx| {
-            let version = AppVersion::global(cx);
-            let channel = ReleaseChannel::global(cx);
-
-            (version, channel)
-        })?;
-
-        // In dev mode, build the remote server binary from source
-        #[cfg(debug_assertions)]
-        if release_channel == ReleaseChannel::Dev {
-            let result = self.build_local(cx, platform, version).await?;
-            // Fall through to a remote binary if we're not able to compile a local binary
-            if let Some((path, version)) = result {
-                return Ok((
-                    ServerBinary::LocalBinary(path),
-                    ServerVersion::Semantic(version),
-                ));
-            }
-        }
-
-        // For nightly channel, always get latest
-        let current_version = if release_channel == ReleaseChannel::Nightly {
-            None
-        } else {
-            Some(version)
-        };
-
-        self.update_status(
-            Some(&format!("Checking remote server release {}", version)),
-            cx,
-        );
-
-        if upload_binary_via_ssh {
-            let binary_path = AutoUpdater::download_remote_server_release(
-                platform.os,
-                platform.arch,
-                release_channel,
-                current_version,
-                cx,
-            )
-            .await
-            .map_err(|e| {
-                anyhow!(
-                    "Failed to download remote server binary (version: {}, os: {}, arch: {}): {}",
-                    version,
-                    platform.os,
-                    platform.arch,
-                    e
-                )
-            })?;
-
-            Ok((
-                ServerBinary::LocalBinary(binary_path),
-                ServerVersion::Semantic(version),
-            ))
-        } else {
-            let (release, request_body) = AutoUpdater::get_remote_server_release_url(
-                    platform.os,
-                    platform.arch,
-                    release_channel,
-                    current_version,
-                    cx,
-                )
-                .await
-                .map_err(|e| {
-                    anyhow!(
-                        "Failed to get remote server binary download url (version: {}, os: {}, arch: {}): {}",
-                        version,
-                        platform.os,
-                        platform.arch,
-                        e
-                    )
-                })?;
-
-            let version = release
-                .version
-                .parse::<SemanticVersion>()
-                .map(ServerVersion::Semantic)
-                .unwrap_or_else(|_| ServerVersion::Commit(release.version));
-            Ok((
-                ServerBinary::ReleaseUrl {
-                    url: release.url,
-                    body: request_body,
-                },
-                version,
-            ))
-        }
-    }
-
-    #[cfg(debug_assertions)]
-    async fn build_local(
-        &self,
-        cx: &mut AsyncAppContext,
-        platform: SshPlatform,
-        version: gpui::SemanticVersion,
-    ) -> Result<Option<(PathBuf, gpui::SemanticVersion)>> {
-        use smol::process::{Command, Stdio};
-
-        async fn run_cmd(command: &mut Command) -> Result<()> {
-            let output = command
-                .kill_on_drop(true)
-                .stderr(Stdio::inherit())
-                .output()
-                .await?;
-            if !output.status.success() {
-                Err(anyhow!("Failed to run command: {:?}", command))?;
-            }
-            Ok(())
-        }
-
-        if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
-            self.update_status(Some("Building remote server binary from source"), cx);
-            log::info!("building remote server binary from source");
-            run_cmd(Command::new("cargo").args([
-                "build",
-                "--package",
-                "remote_server",
-                "--features",
-                "debug-embed",
-                "--target-dir",
-                "target/remote_server",
-            ]))
-            .await?;
-
-            self.update_status(Some("Compressing binary"), cx);
-
-            run_cmd(Command::new("gzip").args([
-                "-9",
-                "-f",
-                "target/remote_server/debug/remote_server",
-            ]))
-            .await?;
-
-            let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
-            return Ok(Some((path, version)));
-        } else if let Some(triple) = platform.triple() {
-            smol::fs::create_dir_all("target/remote_server").await?;
-
-            self.update_status(Some("Installing cross.rs for cross-compilation"), cx);
-            log::info!("installing cross");
-            run_cmd(Command::new("cargo").args([
-                "install",
-                "cross",
-                "--git",
-                "https://github.com/cross-rs/cross",
-            ]))
-            .await?;
-
-            self.update_status(
-                Some(&format!(
-                    "Building remote server binary from source for {} with Docker",
-                    &triple
-                )),
-                cx,
-            );
-            log::info!("building remote server binary from source for {}", &triple);
-            run_cmd(
-                Command::new("cross")
-                    .args([
-                        "build",
-                        "--package",
-                        "remote_server",
-                        "--features",
-                        "debug-embed",
-                        "--target-dir",
-                        "target/remote_server",
-                        "--target",
-                        &triple,
-                    ])
-                    .env(
-                        "CROSS_CONTAINER_OPTS",
-                        "--mount type=bind,src=./target,dst=/app/target",
-                    ),
-            )
-            .await?;
-
-            self.update_status(Some("Compressing binary"), cx);
-
-            run_cmd(Command::new("gzip").args([
-                "-9",
-                "-f",
-                &format!("target/remote_server/{}/debug/remote_server", triple),
-            ]))
-            .await?;
-
-            let path = std::env::current_dir()?.join(format!(
-                "target/remote_server/{}/debug/remote_server.gz",
-                triple
-            ));
-
-            return Ok(Some((path, version)));
-        } else {
-            return Ok(None);
-        }
-    }
 }
 
 pub fn is_connecting_over_ssh(workspace: &Workspace, cx: &AppContext) -> bool {

crates/remote/Cargo.toml 🔗

@@ -35,6 +35,7 @@ smol.workspace = true
 tempfile.workspace = true
 thiserror.workspace = true
 util.workspace = true
+release_channel.workspace = true
 
 [dev-dependencies]
 gpui = { workspace = true, features = ["test-support"] }

crates/remote/src/ssh_session.rs 🔗

@@ -21,6 +21,7 @@ use gpui::{
     ModelContext, SemanticVersion, Task, WeakModel,
 };
 use parking_lot::Mutex;
+use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
 use rpc::{
     proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
     AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
@@ -227,10 +228,19 @@ pub enum ServerBinary {
     ReleaseUrl { url: String, body: String },
 }
 
+#[derive(Clone, Debug, PartialEq, Eq)]
 pub enum ServerVersion {
     Semantic(SemanticVersion),
     Commit(String),
 }
+impl ServerVersion {
+    pub fn semantic_version(&self) -> Option<SemanticVersion> {
+        match self {
+            Self::Semantic(version) => Some(*version),
+            _ => None,
+        }
+    }
+}
 
 impl std::fmt::Display for ServerVersion {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -252,12 +262,21 @@ pub trait SshClientDelegate: Send + Sync {
         platform: SshPlatform,
         cx: &mut AsyncAppContext,
     ) -> Result<PathBuf>;
-    fn get_server_binary(
+    fn get_download_params(
+        &self,
+        platform: SshPlatform,
+        release_channel: ReleaseChannel,
+        version: Option<SemanticVersion>,
+        cx: &mut AsyncAppContext,
+    ) -> Task<Result<(String, String)>>;
+
+    fn download_server_binary_locally(
         &self,
         platform: SshPlatform,
-        upload_binary_over_ssh: bool,
+        release_channel: ReleaseChannel,
+        version: Option<SemanticVersion>,
         cx: &mut AsyncAppContext,
-    ) -> oneshot::Receiver<Result<(ServerBinary, ServerVersion)>>;
+    ) -> Task<Result<PathBuf>>;
     fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 }
 
@@ -1727,86 +1746,123 @@ impl SshRemoteConnection {
         platform: SshPlatform,
         cx: &mut AsyncAppContext,
     ) -> Result<()> {
-        if std::env::var("ZED_USE_CACHED_REMOTE_SERVER").is_ok() {
-            if let Ok(installed_version) =
-                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
-            {
-                log::info!("using cached server binary version {}", installed_version);
+        let current_version = match run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
+        {
+            Ok(version_output) => {
+                if let Ok(version) = version_output.trim().parse::<SemanticVersion>() {
+                    Some(ServerVersion::Semantic(version))
+                } else {
+                    Some(ServerVersion::Commit(version_output.trim().to_string()))
+                }
+            }
+            Err(_) => None,
+        };
+        let (release_channel, wanted_version) = cx.update(|cx| {
+            let release_channel = ReleaseChannel::global(cx);
+            let wanted_version = match release_channel {
+                ReleaseChannel::Nightly => {
+                    AppCommitSha::try_global(cx).map(|sha| ServerVersion::Commit(sha.0))
+                }
+                ReleaseChannel::Dev => None,
+                _ => Some(ServerVersion::Semantic(AppVersion::global(cx))),
+            };
+            (release_channel, wanted_version)
+        })?;
+
+        match (&current_version, &wanted_version) {
+            (Some(current), Some(wanted)) if current == wanted => {
+                log::info!("remote development server present and matching client version");
                 return Ok(());
             }
+            (Some(ServerVersion::Semantic(current)), Some(ServerVersion::Semantic(wanted)))
+                if current > wanted =>
+            {
+                anyhow::bail!("The version of the remote server ({}) is newer than the Zed version ({}). Please update Zed.", current, wanted);
+            }
+            _ => {
+                log::info!("Installing remote development server");
+            }
         }
 
-        if cfg!(not(debug_assertions)) {
+        if self.is_binary_in_use(dst_path).await? {
             // When we're not in dev mode, we don't want to switch out the binary if it's
             // still open.
             // In dev mode, that's fine, since we often kill Zed processes with Ctrl-C and want
             // to still replace the binary.
-            if self.is_binary_in_use(dst_path).await? {
-                log::info!("server binary is opened by another process. not updating");
-                delegate.set_status(
-                    Some("Skipping update of remote development server, since it's still in use"),
-                    cx,
-                );
-                return Ok(());
+            if cfg!(not(debug_assertions)) {
+                anyhow::bail!("The remote server version ({:?}) does not match the wanted version ({:?}), but is in use by another Zed client so cannot be upgraded.", &current_version, &wanted_version)
+            } else {
+                log::info!("Binary is currently in use, ignoring because this is a dev build")
             }
         }
 
-        let upload_binary_over_ssh = self.socket.connection_options.upload_binary_over_ssh;
-        let (binary, new_server_version) = delegate
-            .get_server_binary(platform, upload_binary_over_ssh, cx)
-            .await??;
+        if wanted_version.is_none() {
+            if std::env::var("ZED_BUILD_REMOTE_SERVER").is_err() {
+                if let Some(current_version) = current_version {
+                    log::warn!(
+                        "In development, using cached remote server binary version ({})",
+                        current_version
+                    );
 
-        if cfg!(not(debug_assertions)) {
-            let installed_version = if let Ok(version_output) =
-                run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
-            {
-                if let Ok(version) = version_output.trim().parse::<SemanticVersion>() {
-                    Some(ServerVersion::Semantic(version))
+                    return Ok(());
                 } else {
-                    Some(ServerVersion::Commit(version_output.trim().to_string()))
+                    anyhow::bail!(
+                        "ZED_BUILD_REMOTE_SERVER is not set, but no remote server exists at ({:?})",
+                        dst_path
+                    )
                 }
-            } else {
-                None
-            };
+            }
 
-            if let Some(installed_version) = installed_version {
-                use ServerVersion::*;
-                match (installed_version, new_server_version) {
-                    (Semantic(installed), Semantic(new)) if installed == new => {
-                        log::info!("remote development server present and matching client version");
-                        return Ok(());
-                    }
-                    (Semantic(installed), Semantic(new)) if installed > new => {
-                        let error = anyhow!("The version of the remote server ({}) is newer than the Zed version ({}). Please update Zed.", installed, new);
-                        return Err(error);
-                    }
-                    (Commit(installed), Commit(new)) if installed == new => {
-                        log::info!(
-                            "remote development server present and matching client version {}",
-                            installed
-                        );
-                        return Ok(());
-                    }
-                    (installed, _) => {
-                        log::info!(
-                            "remote development server has version: {}. updating...",
-                            installed
-                        );
-                    }
-                }
+            #[cfg(debug_assertions)]
+            {
+                let src_path = self.build_local(platform, delegate, cx).await?;
+
+                return self
+                    .upload_local_server_binary(&src_path, dst_path, delegate, cx)
+                    .await;
             }
+
+            #[cfg(not(debug_assertions))]
+            anyhow::bail!("Running development build in release mode, cannot cross compile (unset ZED_BUILD_REMOTE_SERVER)")
         }
 
-        match binary {
-            ServerBinary::LocalBinary(src_path) => {
-                self.upload_local_server_binary(&src_path, dst_path, delegate, cx)
-                    .await
-            }
-            ServerBinary::ReleaseUrl { url, body } => {
-                self.download_binary_on_server(&url, &body, dst_path, delegate, cx)
-                    .await
+        let upload_binary_over_ssh = self.socket.connection_options.upload_binary_over_ssh;
+
+        if !upload_binary_over_ssh {
+            let (url, body) = delegate
+                .get_download_params(
+                    platform,
+                    release_channel,
+                    wanted_version.clone().and_then(|v| v.semantic_version()),
+                    cx,
+                )
+                .await?;
+
+            match self
+                .download_binary_on_server(&url, &body, dst_path, delegate, cx)
+                .await
+            {
+                Ok(_) => return Ok(()),
+                Err(e) => {
+                    log::error!(
+                        "Failed to download binary on server, attempting to upload server: {}",
+                        e
+                    )
+                }
             }
         }
+
+        let src_path = delegate
+            .download_server_binary_locally(
+                platform,
+                release_channel,
+                wanted_version.and_then(|v| v.semantic_version()),
+                cx,
+            )
+            .await?;
+
+        self.upload_local_server_binary(&src_path, dst_path, delegate, cx)
+            .await
     }
 
     async fn is_binary_in_use(&self, binary_path: &Path) -> Result<bool> {
@@ -1973,6 +2029,113 @@ impl SshRemoteConnection {
             ))
         }
     }
+
+    #[cfg(debug_assertions)]
+    async fn build_local(
+        &self,
+        platform: SshPlatform,
+        delegate: &Arc<dyn SshClientDelegate>,
+        cx: &mut AsyncAppContext,
+    ) -> Result<PathBuf> {
+        use smol::process::{Command, Stdio};
+
+        async fn run_cmd(command: &mut Command) -> Result<()> {
+            let output = command
+                .kill_on_drop(true)
+                .stderr(Stdio::inherit())
+                .output()
+                .await?;
+            if !output.status.success() {
+                Err(anyhow!("Failed to run command: {:?}", command))?;
+            }
+            Ok(())
+        }
+
+        if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
+            delegate.set_status(Some("Building remote server binary from source"), cx);
+            log::info!("building remote server binary from source");
+            run_cmd(Command::new("cargo").args([
+                "build",
+                "--package",
+                "remote_server",
+                "--features",
+                "debug-embed",
+                "--target-dir",
+                "target/remote_server",
+            ]))
+            .await?;
+
+            delegate.set_status(Some("Compressing binary"), cx);
+
+            run_cmd(Command::new("gzip").args([
+                "-9",
+                "-f",
+                "target/remote_server/debug/remote_server",
+            ]))
+            .await?;
+
+            let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
+            return Ok(path);
+        }
+        let Some(triple) = platform.triple() else {
+            anyhow::bail!("can't cross compile for: {:?}", platform);
+        };
+        smol::fs::create_dir_all("target/remote_server").await?;
+
+        delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
+        log::info!("installing cross");
+        run_cmd(Command::new("cargo").args([
+            "install",
+            "cross",
+            "--git",
+            "https://github.com/cross-rs/cross",
+        ]))
+        .await?;
+
+        delegate.set_status(
+            Some(&format!(
+                "Building remote server binary from source for {} with Docker",
+                &triple
+            )),
+            cx,
+        );
+        log::info!("building remote server binary from source for {}", &triple);
+        run_cmd(
+            Command::new("cross")
+                .args([
+                    "build",
+                    "--package",
+                    "remote_server",
+                    "--features",
+                    "debug-embed",
+                    "--target-dir",
+                    "target/remote_server",
+                    "--target",
+                    &triple,
+                ])
+                .env(
+                    "CROSS_CONTAINER_OPTS",
+                    "--mount type=bind,src=./target,dst=/app/target",
+                ),
+        )
+        .await?;
+
+        delegate.set_status(Some("Compressing binary"), cx);
+
+        run_cmd(Command::new("gzip").args([
+            "-9",
+            "-f",
+            &format!("target/remote_server/{}/debug/remote_server", triple),
+        ]))
+        .await?;
+
+        let path = std::env::current_dir()?.join(format!(
+            "target/remote_server/{}/debug/remote_server.gz",
+            triple
+        ));
+
+        return Ok(path);
+    }
 }
 
 type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
@@ -2294,12 +2457,12 @@ mod fake {
         },
         select_biased, FutureExt, SinkExt, StreamExt,
     };
-    use gpui::{AsyncAppContext, Task, TestAppContext};
+    use gpui::{AsyncAppContext, SemanticVersion, Task, TestAppContext};
+    use release_channel::ReleaseChannel;
     use rpc::proto::Envelope;
 
     use super::{
-        ChannelClient, RemoteConnection, ServerBinary, ServerVersion, SshClientDelegate,
-        SshConnectionOptions, SshPlatform,
+        ChannelClient, RemoteConnection, SshClientDelegate, SshConnectionOptions, SshPlatform,
     };
 
     pub(super) struct FakeRemoteConnection {
@@ -2411,23 +2574,36 @@ mod fake {
         ) -> oneshot::Receiver<Result<String>> {
             unreachable!()
         }
-        fn remote_server_binary_path(
+
+        fn download_server_binary_locally(
             &self,
             _: SshPlatform,
+            _: ReleaseChannel,
+            _: Option<SemanticVersion>,
             _: &mut AsyncAppContext,
-        ) -> Result<PathBuf> {
+        ) -> Task<Result<PathBuf>> {
             unreachable!()
         }
-        fn get_server_binary(
+
+        fn get_download_params(
             &self,
-            _: SshPlatform,
-            _: bool,
-            _: &mut AsyncAppContext,
-        ) -> oneshot::Receiver<Result<(ServerBinary, ServerVersion)>> {
+            _platform: SshPlatform,
+            _release_channel: ReleaseChannel,
+            _version: Option<SemanticVersion>,
+            _cx: &mut AsyncAppContext,
+        ) -> Task<Result<(String, String)>> {
             unreachable!()
         }
 
         fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {}
+
+        fn remote_server_binary_path(
+            &self,
+            _platform: SshPlatform,
+            _cx: &mut AsyncAppContext,
+        ) -> Result<PathBuf> {
+            unreachable!()
+        }
     }
 }