transport.rs

  1use std::io::Write;
  2
  3use crate::{
  4    RemoteArch, RemoteOs, RemotePlatform,
  5    json_log::LogRecord,
  6    protocol::{MESSAGE_LEN_SIZE, message_len_from_buffer, read_message_with_len, write_message},
  7};
  8use anyhow::{Context as _, Result};
  9use futures::{
 10    AsyncReadExt as _, FutureExt as _, StreamExt as _,
 11    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
 12};
 13use gpui::{AppContext as _, AsyncApp, Task};
 14use rpc::proto::Envelope;
 15use util::command::Child;
 16
 17pub mod docker;
 18#[cfg(any(test, feature = "test-support"))]
 19pub mod mock;
 20pub mod ssh;
 21pub mod wsl;
 22
 23/// Parses the output of `uname -sm` to determine the remote platform.
 24/// Takes the last line to skip possible shell initialization output.
 25fn parse_platform(output: &str) -> Result<RemotePlatform> {
 26    let output = output.trim();
 27    let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last);
 28    let Some((os, arch)) = uname.split_once(" ") else {
 29        anyhow::bail!("unknown uname: {uname:?}")
 30    };
 31
 32    let os = match os {
 33        "Darwin" => RemoteOs::MacOs,
 34        "Linux" => RemoteOs::Linux,
 35        _ => anyhow::bail!(
 36            "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 37        ),
 38    };
 39
 40    // exclude armv5,6,7 as they are 32-bit.
 41    let arch = if arch.starts_with("armv8")
 42        || arch.starts_with("armv9")
 43        || arch.starts_with("arm64")
 44        || arch.starts_with("aarch64")
 45    {
 46        RemoteArch::Aarch64
 47    } else if arch.starts_with("x86") {
 48        RemoteArch::X86_64
 49    } else {
 50        anyhow::bail!(
 51            "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 52        )
 53    };
 54
 55    Ok(RemotePlatform { os, arch })
 56}
 57
 58/// Parses the output of `echo $SHELL` to determine the remote shell.
 59/// Takes the last line to skip possible shell initialization output.
 60fn parse_shell(output: &str, fallback_shell: &str) -> String {
 61    let output = output.trim();
 62    let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last);
 63    if shell.is_empty() {
 64        log::error!("$SHELL is not set, falling back to {fallback_shell}");
 65        fallback_shell.to_owned()
 66    } else {
 67        shell.to_owned()
 68    }
 69}
 70
 71fn handle_rpc_messages_over_child_process_stdio(
 72    mut remote_proxy_process: Child,
 73    incoming_tx: UnboundedSender<Envelope>,
 74    mut outgoing_rx: UnboundedReceiver<Envelope>,
 75    mut connection_activity_tx: Sender<()>,
 76    cx: &AsyncApp,
 77) -> Task<Result<i32>> {
 78    let mut child_stderr = remote_proxy_process.stderr.take().unwrap();
 79    let mut child_stdout = remote_proxy_process.stdout.take().unwrap();
 80    let mut child_stdin = remote_proxy_process.stdin.take().unwrap();
 81
 82    let mut stdin_buffer = Vec::new();
 83    let mut stdout_buffer = Vec::new();
 84    let mut stderr_buffer = Vec::new();
 85    let mut stderr_offset = 0;
 86
 87    let stdin_task = cx.background_spawn(async move {
 88        while let Some(outgoing) = outgoing_rx.next().await {
 89            write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 90        }
 91        anyhow::Ok(())
 92    });
 93
 94    let stdout_task = cx.background_spawn({
 95        let mut connection_activity_tx = connection_activity_tx.clone();
 96        async move {
 97            loop {
 98                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 99                let len = child_stdout.read(&mut stdout_buffer).await?;
100
101                if len == 0 {
102                    return anyhow::Ok(());
103                }
104
105                if len < MESSAGE_LEN_SIZE {
106                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
107                }
108
109                let message_len = message_len_from_buffer(&stdout_buffer);
110                let envelope =
111                    read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
112                        .await?;
113                connection_activity_tx.try_send(()).ok();
114                incoming_tx.unbounded_send(envelope).ok();
115            }
116        }
117    });
118
119    let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
120        loop {
121            stderr_buffer.resize(stderr_offset + 1024, 0);
122
123            let len = child_stderr
124                .read(&mut stderr_buffer[stderr_offset..])
125                .await?;
126            if len == 0 {
127                return anyhow::Ok(());
128            }
129
130            stderr_offset += len;
131            let mut start_ix = 0;
132            while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
133                .iter()
134                .position(|b| b == &b'\n')
135            {
136                let line_ix = start_ix + ix;
137                let content = &stderr_buffer[start_ix..line_ix];
138                start_ix = line_ix + 1;
139                if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
140                    record.log(log::logger())
141                } else {
142                    std::io::stderr()
143                        .write_fmt(format_args!(
144                            "(remote) {}\n",
145                            String::from_utf8_lossy(content)
146                        ))
147                        .ok();
148                }
149            }
150            stderr_buffer.drain(0..start_ix);
151            stderr_offset -= start_ix;
152
153            connection_activity_tx.try_send(()).ok();
154        }
155    });
156
157    cx.background_spawn(async move {
158        let result = futures::select! {
159            result = stdin_task.fuse() => {
160                result.context("stdin")
161            }
162            result = stdout_task.fuse() => {
163                result.context("stdout")
164            }
165            result = stderr_task.fuse() => {
166                result.context("stderr")
167            }
168        };
169        let exit_status = remote_proxy_process.status().await?;
170        let status = exit_status.code().unwrap_or_else(|| {
171            #[cfg(unix)]
172            let status = std::os::unix::process::ExitStatusExt::signal(&exit_status).unwrap_or(1);
173            #[cfg(not(unix))]
174            let status = 1;
175            status
176        });
177        match result {
178            Ok(_) => Ok(status),
179            Err(error) => Err(error),
180        }
181    })
182}
183
184#[cfg(any(debug_assertions, feature = "build-remote-server-binary"))]
185async fn build_remote_server_from_source(
186    platform: &crate::RemotePlatform,
187    delegate: &dyn crate::RemoteClientDelegate,
188    binary_exists_on_server: bool,
189    cx: &mut AsyncApp,
190) -> Result<Option<std::path::PathBuf>> {
191    use std::env::VarError;
192    use std::path::Path;
193    use util::command::{Command, Stdio, new_command};
194
195    if let Ok(path) = std::env::var("ZED_COPY_REMOTE_SERVER") {
196        let path = std::path::PathBuf::from(path);
197        if path.exists() {
198            return Ok(Some(path));
199        } else {
200            log::warn!(
201                "ZED_COPY_REMOTE_SERVER path does not exist, falling back to ZED_BUILD_REMOTE_SERVER: {}",
202                path.display()
203            );
204        }
205    }
206
207    // By default, we make building remote server from source opt-out and we do not force artifact compression
208    // for quicker builds.
209    let build_remote_server =
210        std::env::var("ZED_BUILD_REMOTE_SERVER").unwrap_or("nocompress".into());
211
212    if let "never" = &*build_remote_server {
213        return Ok(None);
214    } else if let "false" | "no" | "off" | "0" = &*build_remote_server {
215        if binary_exists_on_server {
216            return Ok(None);
217        }
218        log::warn!("ZED_BUILD_REMOTE_SERVER is disabled, but no server binary exists on the server")
219    }
220
221    async fn run_cmd(command: &mut Command) -> Result<()> {
222        let output = command
223            .kill_on_drop(true)
224            .stdout(Stdio::inherit())
225            .output()
226            .await?;
227        anyhow::ensure!(
228            output.status.success(),
229            "Failed to run command: {command:?}: output: {}",
230            String::from_utf8_lossy(&output.stderr)
231        );
232        Ok(())
233    }
234
235    let use_musl = !build_remote_server.contains("nomusl");
236    let triple = format!(
237        "{}-{}",
238        platform.arch,
239        match platform.os {
240            RemoteOs::Linux =>
241                if use_musl {
242                    "unknown-linux-musl"
243                } else {
244                    "unknown-linux-gnu"
245                },
246            RemoteOs::MacOs => "apple-darwin",
247            RemoteOs::Windows if cfg!(windows) => "pc-windows-msvc",
248            RemoteOs::Windows => "pc-windows-gnu",
249        }
250    );
251    let mut rust_flags = match std::env::var("RUSTFLAGS") {
252        Ok(val) => val,
253        Err(VarError::NotPresent) => String::new(),
254        Err(e) => {
255            log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
256            String::new()
257        }
258    };
259    if platform.os == RemoteOs::Linux && use_musl {
260        rust_flags.push_str(" -C target-feature=+crt-static");
261
262        if let Ok(path) = std::env::var("ZED_ZSTD_MUSL_LIB") {
263            rust_flags.push_str(&format!(" -C link-arg=-L{path}"));
264        }
265    }
266    if build_remote_server.contains("mold") {
267        rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
268    }
269
270    if platform.arch.as_str() == std::env::consts::ARCH
271        && platform.os.as_str() == std::env::consts::OS
272    {
273        delegate.set_status(Some("Building remote server binary from source"), cx);
274        log::info!("building remote server binary from source");
275        run_cmd(
276            new_command("cargo")
277                .current_dir(concat!(env!("CARGO_MANIFEST_DIR"), "/../.."))
278                .args([
279                    "build",
280                    "--package",
281                    "remote_server",
282                    "--features",
283                    "debug-embed",
284                    "--target-dir",
285                    "target/remote_server",
286                    "--target",
287                    &triple,
288                ])
289                .env("RUSTFLAGS", &rust_flags),
290        )
291        .await?;
292    } else {
293        if which("zig", cx).await?.is_none() {
294            anyhow::bail!(if cfg!(not(windows)) {
295                "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup)"
296            } else {
297                "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup)"
298            });
299        }
300
301        let rustup = which("rustup", cx)
302            .await?
303            .context("rustup not found on $PATH, install rustup (see https://rustup.rs/)")?;
304        delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
305        log::info!("adding rustup target");
306        run_cmd(new_command(rustup).args(["target", "add"]).arg(&triple)).await?;
307
308        if which("cargo-zigbuild", cx).await?.is_none() {
309            delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
310            log::info!("installing cargo-zigbuild");
311            run_cmd(new_command("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
312        }
313
314        delegate.set_status(
315            Some(&format!(
316                "Building remote binary from source for {triple} with Zig"
317            )),
318            cx,
319        );
320        log::info!("building remote binary from source for {triple} with Zig");
321        run_cmd(
322            new_command("cargo")
323                .args([
324                    "zigbuild",
325                    "--package",
326                    "remote_server",
327                    "--features",
328                    "debug-embed",
329                    "--target-dir",
330                    "target/remote_server",
331                    "--target",
332                    &triple,
333                ])
334                .env("RUSTFLAGS", &rust_flags),
335        )
336        .await?;
337    };
338    let bin_path = Path::new("target")
339        .join("remote_server")
340        .join(&triple)
341        .join("debug")
342        .join("remote_server")
343        .with_extension(if platform.os.is_windows() { "exe" } else { "" });
344
345    let path = if !build_remote_server.contains("nocompress") {
346        delegate.set_status(Some("Compressing binary"), cx);
347
348        #[cfg(not(target_os = "windows"))]
349        let archive_path = {
350            run_cmd(new_command("gzip").arg("-f").arg(&bin_path)).await?;
351            bin_path.with_extension("gz")
352        };
353
354        #[cfg(target_os = "windows")]
355        let archive_path = {
356            let zip_path = bin_path.with_extension("zip");
357            if smol::fs::metadata(&zip_path).await.is_ok() {
358                smol::fs::remove_file(&zip_path).await?;
359            }
360            let compress_command = format!(
361                "Compress-Archive -Path '{}' -DestinationPath '{}' -Force",
362                bin_path.display(),
363                zip_path.display(),
364            );
365            run_cmd(new_command("powershell.exe").args([
366                "-NoProfile",
367                "-Command",
368                &compress_command,
369            ]))
370            .await?;
371            zip_path
372        };
373
374        std::env::current_dir()?.join(archive_path)
375    } else {
376        bin_path
377    };
378
379    Ok(Some(path))
380}
381
382#[cfg(any(debug_assertions, feature = "build-remote-server-binary"))]
383async fn which(
384    binary_name: impl AsRef<str>,
385    cx: &mut AsyncApp,
386) -> Result<Option<std::path::PathBuf>> {
387    let binary_name = binary_name.as_ref().to_string();
388    let binary_name_cloned = binary_name.clone();
389    let res = cx
390        .background_spawn(async move { which::which(binary_name_cloned) })
391        .await;
392    match res {
393        Ok(path) => Ok(Some(path)),
394        Err(which::Error::CannotFindBinaryPath) => Ok(None),
395        Err(err) => Err(anyhow::anyhow!(
396            "Failed to run 'which' to find the binary '{binary_name}': {err}"
397        )),
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_parse_platform() {
407        let result = parse_platform("Linux x86_64\n").unwrap();
408        assert_eq!(result.os, RemoteOs::Linux);
409        assert_eq!(result.arch, RemoteArch::X86_64);
410
411        let result = parse_platform("Darwin arm64\n").unwrap();
412        assert_eq!(result.os, RemoteOs::MacOs);
413        assert_eq!(result.arch, RemoteArch::Aarch64);
414
415        let result = parse_platform("Linux x86_64").unwrap();
416        assert_eq!(result.os, RemoteOs::Linux);
417        assert_eq!(result.arch, RemoteArch::X86_64);
418
419        let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap();
420        assert_eq!(result.os, RemoteOs::Linux);
421        assert_eq!(result.arch, RemoteArch::Aarch64);
422
423        let result = parse_platform("some shell init output\nLinux aarch64").unwrap();
424        assert_eq!(result.os, RemoteOs::Linux);
425        assert_eq!(result.arch, RemoteArch::Aarch64);
426
427        assert_eq!(
428            parse_platform("Linux armv8l\n").unwrap().arch,
429            RemoteArch::Aarch64
430        );
431        assert_eq!(
432            parse_platform("Linux aarch64\n").unwrap().arch,
433            RemoteArch::Aarch64
434        );
435        assert_eq!(
436            parse_platform("Linux x86_64\n").unwrap().arch,
437            RemoteArch::X86_64
438        );
439
440        let result = parse_platform(
441            r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#,
442        )
443        .unwrap();
444        assert_eq!(result.os, RemoteOs::Linux);
445        assert_eq!(result.arch, RemoteArch::X86_64);
446
447        assert!(parse_platform("Windows x86_64\n").is_err());
448        assert!(parse_platform("Linux armv7l\n").is_err());
449    }
450
451    #[test]
452    fn test_parse_shell() {
453        assert_eq!(parse_shell("/bin/bash\n", "sh"), "/bin/bash");
454        assert_eq!(parse_shell("/bin/zsh\n", "sh"), "/bin/zsh");
455
456        assert_eq!(parse_shell("/bin/bash", "sh"), "/bin/bash");
457        assert_eq!(
458            parse_shell("some shell init output\n/bin/bash\n", "sh"),
459            "/bin/bash"
460        );
461        assert_eq!(
462            parse_shell("some shell init output\n/bin/bash", "sh"),
463            "/bin/bash"
464        );
465        assert_eq!(parse_shell("", "sh"), "sh");
466        assert_eq!(parse_shell("\n", "sh"), "sh");
467    }
468}