ssh remoting: Add setting to download binary on host (#19606)

Thorsten Ball created

This adds the following optional setting:

```json
{
  "remote_server": {
    "download_on_host": false
  }
}
```
Right now, it's **off by default** because I haven't tested it enough.

Release Notes:

- N/A

Change summary

crates/auto_update/src/auto_update.rs         |  51 ++++++++-
crates/recent_projects/src/ssh_connections.rs |  88 ++++++++++++----
crates/remote/src/ssh_session.rs              | 109 ++++++++++++++++++--
3 files changed, 207 insertions(+), 41 deletions(-)

Detailed changes

crates/auto_update/src/auto_update.rs 🔗

@@ -474,6 +474,39 @@ impl AutoUpdater {
         Ok(version_path)
     }
 
+    pub async fn get_latest_remote_server_release_url(
+        os: &str,
+        arch: &str,
+        mut release_channel: ReleaseChannel,
+        cx: &mut AsyncAppContext,
+    ) -> Result<(String, String)> {
+        let this = cx.update(|cx| {
+            cx.default_global::<GlobalAutoUpdate>()
+                .0
+                .clone()
+                .ok_or_else(|| anyhow!("auto-update not initialized"))
+        })??;
+
+        if release_channel == ReleaseChannel::Dev {
+            release_channel = ReleaseChannel::Nightly;
+        }
+
+        let release = Self::get_latest_release(
+            &this,
+            "zed-remote-server",
+            os,
+            arch,
+            Some(release_channel),
+            cx,
+        )
+        .await?;
+
+        let update_request_body = build_remote_server_update_request_body(cx)?;
+        let body = serde_json::to_string(&update_request_body)?;
+
+        Ok((release.url, body))
+    }
+
     async fn get_latest_release(
         this: &Model<Self>,
         asset: &str,
@@ -629,6 +662,15 @@ async fn download_remote_server_binary(
     cx: &AsyncAppContext,
 ) -> Result<()> {
     let mut target_file = File::create(&target_path).await?;
+    let update_request_body = build_remote_server_update_request_body(cx)?;
+    let request_body = AsyncBody::from(serde_json::to_string(&update_request_body)?);
+
+    let mut response = client.get(&release.url, request_body, true).await?;
+    smol::io::copy(response.body_mut(), &mut target_file).await?;
+    Ok(())
+}
+
+fn build_remote_server_update_request_body(cx: &AsyncAppContext) -> Result<UpdateRequestBody> {
     let (installation_id, release_channel, telemetry_enabled, is_staff) = cx.update(|cx| {
         let telemetry = Client::global(cx).telemetry().clone();
         let is_staff = telemetry.is_staff();
@@ -644,17 +686,14 @@ async fn download_remote_server_binary(
             is_staff,
         )
     })?;
-    let request_body = AsyncBody::from(serde_json::to_string(&UpdateRequestBody {
+
+    Ok(UpdateRequestBody {
         installation_id,
         release_channel,
         telemetry: telemetry_enabled,
         is_staff,
         destination: "remote",
-    })?);
-
-    let mut response = client.get(&release.url, request_body, true).await?;
-    smol::io::copy(response.body_mut(), &mut target_file).await?;
-    Ok(())
+    })
 }
 
 async fn download_release(

crates/recent_projects/src/ssh_connections.rs 🔗

@@ -14,6 +14,7 @@ use gpui::{AppContext, Model};
 use language::CursorShape;
 use markdown::{Markdown, MarkdownStyle};
 use release_channel::{AppVersion, ReleaseChannel};
+use remote::ssh_session::ServerBinary;
 use remote::{SshConnectionOptions, SshPlatform, SshRemoteClient};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
@@ -25,9 +26,15 @@ use ui::{
 };
 use workspace::{AppState, ModalView, Workspace};
 
+#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
+pub struct RemoteServerSettings {
+    pub download_on_host: Option<bool>,
+}
+
 #[derive(Deserialize)]
 pub struct SshSettings {
     pub ssh_connections: Option<Vec<SshConnection>>,
+    pub remote_server: Option<RemoteServerSettings>,
 }
 
 impl SshSettings {
@@ -107,6 +114,7 @@ pub struct SshProject {
 #[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
 pub struct RemoteSettingsContent {
     pub ssh_connections: Option<Vec<SshConnection>>,
+    pub remote_server: Option<RemoteServerSettings>,
 }
 
 impl Settings for SshSettings {
@@ -435,7 +443,7 @@ impl remote::SshClientDelegate for SshClientDelegate {
         &self,
         platform: SshPlatform,
         cx: &mut AsyncAppContext,
-    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
+    ) -> oneshot::Receiver<Result<(ServerBinary, SemanticVersion)>> {
         let (tx, rx) = oneshot::channel();
         let this = self.clone();
         cx.spawn(|mut cx| async move {
@@ -476,10 +484,18 @@ impl SshClientDelegate {
         &self,
         platform: SshPlatform,
         cx: &mut AsyncAppContext,
-    ) -> Result<(PathBuf, SemanticVersion)> {
-        let (version, release_channel) = cx.update(|cx| {
-            let global = AppVersion::global(cx);
-            (global, ReleaseChannel::global(cx))
+    ) -> Result<(ServerBinary, SemanticVersion)> {
+        let (version, release_channel, download_binary_on_host) = cx.update(|cx| {
+            let version = AppVersion::global(cx);
+            let channel = ReleaseChannel::global(cx);
+
+            let ssh_settings = SshSettings::get_global(cx);
+            let download_binary_on_host = ssh_settings
+                .remote_server
+                .as_ref()
+                .and_then(|server| server.download_on_host)
+                .unwrap_or(false);
+            (version, channel, download_binary_on_host)
         })?;
 
         // In dev mode, build the remote server binary from source
@@ -487,29 +503,55 @@ impl SshClientDelegate {
         if release_channel == ReleaseChannel::Dev {
             let result = self.build_local(cx, platform, version).await?;
             // Fall through to a remote binary if we're not able to compile a local binary
-            if let Some(result) = result {
-                return Ok(result);
+            if let Some((path, version)) = result {
+                return Ok((ServerBinary::LocalBinary(path), version));
             }
         }
 
-        self.update_status(Some("checking for latest version of remote server"), cx);
-        let binary_path = AutoUpdater::get_latest_remote_server_release(
-            platform.os,
-            platform.arch,
-            release_channel,
-            cx,
-        )
-        .await
-        .map_err(|e| {
-            anyhow!(
-                "failed to download remote server binary (os: {}, arch: {}): {}",
+        if download_binary_on_host {
+            let (request_url, request_body) = AutoUpdater::get_latest_remote_server_release_url(
                 platform.os,
                 platform.arch,
-                e
+                release_channel,
+                cx,
             )
-        })?;
+            .await
+            .map_err(|e| {
+                anyhow!(
+                    "failed to get remote server binary download url (os: {}, arch: {}): {}",
+                    platform.os,
+                    platform.arch,
+                    e
+                )
+            })?;
+
+            Ok((
+                ServerBinary::ReleaseUrl {
+                    url: request_url,
+                    body: request_body,
+                },
+                version,
+            ))
+        } else {
+            self.update_status(Some("checking for latest version of remote server"), cx);
+            let binary_path = AutoUpdater::get_latest_remote_server_release(
+                platform.os,
+                platform.arch,
+                release_channel,
+                cx,
+            )
+            .await
+            .map_err(|e| {
+                anyhow!(
+                    "failed to download remote server binary (os: {}, arch: {}): {}",
+                    platform.os,
+                    platform.arch,
+                    e
+                )
+            })?;
 
-        Ok((binary_path, version))
+            Ok((ServerBinary::LocalBinary(binary_path), version))
+        }
     }
 
     #[cfg(debug_assertions)]
@@ -517,8 +559,8 @@ impl SshClientDelegate {
         &self,
         cx: &mut AsyncAppContext,
         platform: SshPlatform,
-        version: SemanticVersion,
-    ) -> Result<Option<(PathBuf, SemanticVersion)>> {
+        version: gpui::SemanticVersion,
+    ) -> Result<Option<(PathBuf, gpui::SemanticVersion)>> {
         use smol::process::{Command, Stdio};
 
         async fn run_cmd(command: &mut Command) -> Result<()> {

crates/remote/src/ssh_session.rs 🔗

@@ -216,6 +216,11 @@ impl SshPlatform {
     }
 }
 
+pub enum ServerBinary {
+    LocalBinary(PathBuf),
+    ReleaseUrl { url: String, body: String },
+}
+
 pub trait SshClientDelegate: Send + Sync {
     fn ask_password(
         &self,
@@ -231,7 +236,7 @@ pub trait SshClientDelegate: Send + Sync {
         &self,
         platform: SshPlatform,
         cx: &mut AsyncAppContext,
-    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
+    ) -> oneshot::Receiver<Result<(ServerBinary, SemanticVersion)>>;
     fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
 }
 
@@ -1479,14 +1484,7 @@ impl SshRemoteConnection {
             }
         }
 
-        let mut dst_path_gz = dst_path.to_path_buf();
-        dst_path_gz.set_extension("gz");
-
-        if let Some(parent) = dst_path.parent() {
-            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
-        }
-
-        let (src_path, version) = delegate.get_server_binary(platform, cx).await??;
+        let (binary, version) = delegate.get_server_binary(platform, cx).await??;
 
         let mut server_binary_exists = false;
         if !server_binary_exists && cfg!(not(debug_assertions)) {
@@ -1504,9 +1502,82 @@ impl SshRemoteConnection {
             return Ok(());
         }
 
+        match binary {
+            ServerBinary::LocalBinary(src_path) => {
+                self.upload_local_server_binary(&src_path, dst_path, delegate, cx)
+                    .await
+            }
+            ServerBinary::ReleaseUrl { url, body } => {
+                self.download_binary_on_server(&url, &body, dst_path, delegate, cx)
+                    .await
+            }
+        }
+    }
+
+    async fn download_binary_on_server(
+        &self,
+        url: &str,
+        body: &str,
+        dst_path: &Path,
+        delegate: &Arc<dyn SshClientDelegate>,
+        cx: &mut AsyncAppContext,
+    ) -> Result<()> {
+        let mut dst_path_gz = dst_path.to_path_buf();
+        dst_path_gz.set_extension("gz");
+
+        if let Some(parent) = dst_path.parent() {
+            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
+        }
+
+        delegate.set_status(Some("Downloading remote development server on host..."), cx);
+
+        let script = format!(
+            r#"
+            if command -v wget >/dev/null 2>&1; then
+                wget --max-redirect=5 --method=GET --header="Content-Type: application/json" --body-data='{}' '{}' -O '{}' && echo "wget"
+            elif command -v curl >/dev/null 2>&1; then
+                curl -L -X GET -H "Content-Type: application/json" -d '{}' '{}' -o '{}' && echo "curl"
+            else
+                echo "Neither curl nor wget is available" >&2
+                exit 1
+            fi
+            "#,
+            body.replace("'", r#"\'"#),
+            url,
+            dst_path_gz.display(),
+            body.replace("'", r#"\'"#),
+            url,
+            dst_path_gz.display(),
+        );
+
+        let output = run_cmd(self.socket.ssh_command("bash").arg("-c").arg(script))
+            .await
+            .context("Failed to download server binary")?;
+
+        if !output.contains("curl") && !output.contains("wget") {
+            return Err(anyhow!("Failed to download server binary: {}", output));
+        }
+
+        self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
+            .await
+    }
+
+    async fn upload_local_server_binary(
+        &self,
+        src_path: &Path,
+        dst_path: &Path,
+        delegate: &Arc<dyn SshClientDelegate>,
+        cx: &mut AsyncAppContext,
+    ) -> Result<()> {
+        let mut dst_path_gz = dst_path.to_path_buf();
+        dst_path_gz.set_extension("gz");
+
+        if let Some(parent) = dst_path.parent() {
+            run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
+        }
+
         let src_stat = fs::metadata(&src_path).await?;
         let size = src_stat.len();
-        let server_mode = 0o755;
 
         let t0 = Instant::now();
         delegate.set_status(Some("Uploading remote development server"), cx);
@@ -1516,6 +1587,17 @@ impl SshRemoteConnection {
             .context("failed to upload server binary")?;
         log::info!("uploaded remote development server in {:?}", t0.elapsed());
 
+        self.extract_server_binary(dst_path, &dst_path_gz, delegate, cx)
+            .await
+    }
+
+    async fn extract_server_binary(
+        &self,
+        dst_path: &Path,
+        dst_path_gz: &Path,
+        delegate: &Arc<dyn SshClientDelegate>,
+        cx: &mut AsyncAppContext,
+    ) -> Result<()> {
         delegate.set_status(Some("Extracting remote development server"), cx);
         run_cmd(
             self.socket
@@ -1525,6 +1607,7 @@ impl SshRemoteConnection {
         )
         .await?;
 
+        let server_mode = 0o755;
         delegate.set_status(Some("Marking remote development server executable"), cx);
         run_cmd(
             self.socket
@@ -1894,7 +1977,8 @@ mod fake {
     use rpc::proto::Envelope;
 
     use super::{
-        ChannelClient, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteProcess,
+        ChannelClient, ServerBinary, SshClientDelegate, SshConnectionOptions, SshPlatform,
+        SshRemoteProcess,
     };
 
     pub(super) struct SshRemoteConnection {
@@ -2010,9 +2094,10 @@ mod fake {
             &self,
             _: SshPlatform,
             _: &mut AsyncAppContext,
-        ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
+        ) -> oneshot::Receiver<Result<(ServerBinary, SemanticVersion)>> {
             unreachable!()
         }
+
         fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {
             unreachable!()
         }