transport.rs

  1use std::io::Write;
  2
  3use crate::{
  4    RemoteArch, RemoteOs, RemotePlatform,
  5    json_log::LogRecord,
  6    protocol::{MESSAGE_LEN_SIZE, message_len_from_buffer, read_message_with_len, write_message},
  7};
  8use anyhow::{Context as _, Result};
  9use futures::{
 10    AsyncReadExt as _, FutureExt as _, StreamExt as _,
 11    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
 12};
 13use gpui::{AppContext as _, AsyncApp, Task};
 14use rpc::proto::Envelope;
 15use util::command::Child;
 16
 17pub mod docker;
 18#[cfg(any(test, feature = "test-support"))]
 19pub mod mock;
 20pub mod ssh;
 21pub mod wsl;
 22
 23/// Parses the output of `uname -sm` to determine the remote platform.
 24/// Takes the last line to skip possible shell initialization output.
 25fn parse_platform(output: &str) -> Result<RemotePlatform> {
 26    let output = output.trim();
 27    let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last);
 28    let Some((os, arch)) = uname.split_once(" ") else {
 29        anyhow::bail!("unknown uname: {uname:?}")
 30    };
 31
 32    let os = match os {
 33        "Darwin" => RemoteOs::MacOs,
 34        "Linux" => RemoteOs::Linux,
 35        _ => anyhow::bail!(
 36            "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 37        ),
 38    };
 39
 40    // exclude armv5,6,7 as they are 32-bit.
 41    let arch = if arch.starts_with("armv8")
 42        || arch.starts_with("armv9")
 43        || arch.starts_with("arm64")
 44        || arch.starts_with("aarch64")
 45    {
 46        RemoteArch::Aarch64
 47    } else if arch.starts_with("x86") {
 48        RemoteArch::X86_64
 49    } else {
 50        anyhow::bail!(
 51            "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 52        )
 53    };
 54
 55    Ok(RemotePlatform { os, arch })
 56}
 57
 58/// Parses the output of `echo $SHELL` to determine the remote shell.
 59/// Takes the last line to skip possible shell initialization output.
 60fn parse_shell(output: &str, fallback_shell: &str) -> String {
 61    let output = output.trim();
 62    let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last);
 63    if shell.is_empty() {
 64        log::error!("$SHELL is not set, falling back to {fallback_shell}");
 65        fallback_shell.to_owned()
 66    } else {
 67        shell.to_owned()
 68    }
 69}
 70
 71fn handle_rpc_messages_over_child_process_stdio(
 72    mut remote_proxy_process: Child,
 73    incoming_tx: UnboundedSender<Envelope>,
 74    mut outgoing_rx: UnboundedReceiver<Envelope>,
 75    mut connection_activity_tx: Sender<()>,
 76    cx: &AsyncApp,
 77) -> Task<Result<i32>> {
 78    let mut child_stderr = remote_proxy_process.stderr.take().unwrap();
 79    let mut child_stdout = remote_proxy_process.stdout.take().unwrap();
 80    let mut child_stdin = remote_proxy_process.stdin.take().unwrap();
 81
 82    let mut stdin_buffer = Vec::new();
 83    let mut stdout_buffer = Vec::new();
 84    let mut stderr_buffer = Vec::new();
 85    let mut stderr_offset = 0;
 86
 87    let stdin_task = cx.background_spawn(async move {
 88        while let Some(outgoing) = outgoing_rx.next().await {
 89            write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 90        }
 91        anyhow::Ok(())
 92    });
 93
 94    let stdout_task = cx.background_spawn({
 95        let mut connection_activity_tx = connection_activity_tx.clone();
 96        async move {
 97            loop {
 98                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 99                let len = child_stdout.read(&mut stdout_buffer).await?;
100
101                if len == 0 {
102                    return anyhow::Ok(());
103                }
104
105                if len < MESSAGE_LEN_SIZE {
106                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
107                }
108
109                let message_len = message_len_from_buffer(&stdout_buffer);
110                let envelope =
111                    read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
112                        .await?;
113                connection_activity_tx.try_send(()).ok();
114                incoming_tx.unbounded_send(envelope).ok();
115            }
116        }
117    });
118
119    let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
120        loop {
121            stderr_buffer.resize(stderr_offset + 1024, 0);
122
123            let len = child_stderr
124                .read(&mut stderr_buffer[stderr_offset..])
125                .await?;
126            if len == 0 {
127                return anyhow::Ok(());
128            }
129
130            stderr_offset += len;
131            let mut start_ix = 0;
132            while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
133                .iter()
134                .position(|b| b == &b'\n')
135            {
136                let line_ix = start_ix + ix;
137                let content = &stderr_buffer[start_ix..line_ix];
138                start_ix = line_ix + 1;
139                if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
140                    record.log(log::logger())
141                } else {
142                    std::io::stderr()
143                        .write_fmt(format_args!(
144                            "(remote) {}\n",
145                            String::from_utf8_lossy(content)
146                        ))
147                        .ok();
148                }
149            }
150            stderr_buffer.drain(0..start_ix);
151            stderr_offset -= start_ix;
152
153            connection_activity_tx.try_send(()).ok();
154        }
155    });
156
157    cx.background_spawn(async move {
158        let result = futures::select! {
159            result = stdin_task.fuse() => {
160                result.context("stdin")
161            }
162            result = stdout_task.fuse() => {
163                result.context("stdout")
164            }
165            result = stderr_task.fuse() => {
166                result.context("stderr")
167            }
168        };
169        let exit_status = remote_proxy_process.status().await?;
170        let status = exit_status.code().unwrap_or_else(|| {
171            #[cfg(unix)]
172            let status = std::os::unix::process::ExitStatusExt::signal(&exit_status).unwrap_or(1);
173            #[cfg(not(unix))]
174            let status = 1;
175            status
176        });
177        match result {
178            Ok(_) => Ok(status),
179            Err(error) => Err(error),
180        }
181    })
182}
183
184#[cfg(any(debug_assertions, feature = "build-remote-server-binary"))]
185async fn build_remote_server_from_source(
186    platform: &crate::RemotePlatform,
187    delegate: &dyn crate::RemoteClientDelegate,
188    binary_exists_on_server: bool,
189    cx: &mut AsyncApp,
190) -> Result<Option<std::path::PathBuf>> {
191    use std::env::VarError;
192    use std::path::Path;
193    use util::command::{Command, Stdio, new_command};
194
195    if let Ok(path) = std::env::var("ZED_COPY_REMOTE_SERVER") {
196        let path = std::path::PathBuf::from(path);
197        if path.exists() {
198            return Ok(Some(path));
199        } else {
200            log::warn!(
201                "ZED_COPY_REMOTE_SERVER path does not exist, falling back to ZED_BUILD_REMOTE_SERVER: {}",
202                path.display()
203            );
204        }
205    }
206
207    // By default, we make building remote server from source opt-out and we do not force artifact compression
208    // for quicker builds.
209    let build_remote_server =
210        std::env::var("ZED_BUILD_REMOTE_SERVER").unwrap_or("nocompress".into());
211
212    if let "never" = &*build_remote_server {
213        return Ok(None);
214    } else if let "false" | "no" | "off" | "0" = &*build_remote_server {
215        if binary_exists_on_server {
216            return Ok(None);
217        }
218        log::warn!("ZED_BUILD_REMOTE_SERVER is disabled, but no server binary exists on the server")
219    }
220
221    async fn run_cmd(command: &mut Command) -> Result<()> {
222        let output = command
223            .kill_on_drop(true)
224            .stdout(Stdio::inherit())
225            .output()
226            .await?;
227        anyhow::ensure!(
228            output.status.success(),
229            "Failed to run command: {command:?}: output: {}",
230            String::from_utf8_lossy(&output.stderr)
231        );
232        Ok(())
233    }
234
235    let use_musl = !build_remote_server.contains("nomusl");
236    let triple = format!(
237        "{}-{}",
238        platform.arch,
239        match platform.os {
240            RemoteOs::Linux =>
241                if use_musl {
242                    "unknown-linux-musl"
243                } else {
244                    "unknown-linux-gnu"
245                },
246            RemoteOs::MacOs => "apple-darwin",
247            RemoteOs::Windows if cfg!(windows) => "pc-windows-msvc",
248            RemoteOs::Windows => "pc-windows-gnu",
249        }
250    );
251    let mut rust_flags = match std::env::var("RUSTFLAGS") {
252        Ok(val) => val,
253        Err(VarError::NotPresent) => String::new(),
254        Err(e) => {
255            log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
256            String::new()
257        }
258    };
259    if platform.os == RemoteOs::Linux && use_musl {
260        rust_flags.push_str(" -C target-feature=+crt-static");
261
262        if let Ok(path) = std::env::var("ZED_ZSTD_MUSL_LIB") {
263            rust_flags.push_str(&format!(" -C link-arg=-L{path}"));
264        }
265    }
266    if platform.arch.as_str() == std::env::consts::ARCH
267        && platform.os.as_str() == std::env::consts::OS
268    {
269        delegate.set_status(Some("Building remote server binary from source"), cx);
270        log::info!("building remote server binary from source");
271        run_cmd(
272            new_command("cargo")
273                .current_dir(concat!(env!("CARGO_MANIFEST_DIR"), "/../.."))
274                .args([
275                    "build",
276                    "--package",
277                    "remote_server",
278                    "--features",
279                    "debug-embed",
280                    "--target-dir",
281                    "target/remote_server",
282                    "--target",
283                    &triple,
284                ])
285                .env("RUSTFLAGS", &rust_flags),
286        )
287        .await?;
288    } else {
289        if which("zig", cx).await?.is_none() {
290            anyhow::bail!(if cfg!(not(windows)) {
291                "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup)"
292            } else {
293                "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup)"
294            });
295        }
296
297        let rustup = which("rustup", cx)
298            .await?
299            .context("rustup not found on $PATH, install rustup (see https://rustup.rs/)")?;
300        delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
301        log::info!("adding rustup target");
302        run_cmd(new_command(rustup).args(["target", "add"]).arg(&triple)).await?;
303
304        if which("cargo-zigbuild", cx).await?.is_none() {
305            delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
306            log::info!("installing cargo-zigbuild");
307            run_cmd(new_command("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
308        }
309
310        delegate.set_status(
311            Some(&format!(
312                "Building remote binary from source for {triple} with Zig"
313            )),
314            cx,
315        );
316        log::info!("building remote binary from source for {triple} with Zig");
317        run_cmd(
318            new_command("cargo")
319                .args([
320                    "zigbuild",
321                    "--package",
322                    "remote_server",
323                    "--features",
324                    "debug-embed",
325                    "--target-dir",
326                    "target/remote_server",
327                    "--target",
328                    &triple,
329                ])
330                .env("RUSTFLAGS", &rust_flags),
331        )
332        .await?;
333    };
334    let bin_path = Path::new("target")
335        .join("remote_server")
336        .join(&triple)
337        .join("debug")
338        .join("remote_server")
339        .with_extension(if platform.os.is_windows() { "exe" } else { "" });
340
341    let path = if !build_remote_server.contains("nocompress") {
342        delegate.set_status(Some("Compressing binary"), cx);
343
344        #[cfg(not(target_os = "windows"))]
345        let archive_path = {
346            run_cmd(new_command("gzip").arg("-f").arg(&bin_path)).await?;
347            bin_path.with_extension("gz")
348        };
349
350        #[cfg(target_os = "windows")]
351        let archive_path = {
352            let zip_path = bin_path.with_extension("zip");
353            if smol::fs::metadata(&zip_path).await.is_ok() {
354                smol::fs::remove_file(&zip_path).await?;
355            }
356            let compress_command = format!(
357                "Compress-Archive -Path '{}' -DestinationPath '{}' -Force",
358                bin_path.display(),
359                zip_path.display(),
360            );
361            run_cmd(new_command("powershell.exe").args([
362                "-NoProfile",
363                "-Command",
364                &compress_command,
365            ]))
366            .await?;
367            zip_path
368        };
369
370        std::env::current_dir()?.join(archive_path)
371    } else {
372        bin_path
373    };
374
375    Ok(Some(path))
376}
377
378#[cfg(any(debug_assertions, feature = "build-remote-server-binary"))]
379async fn which(
380    binary_name: impl AsRef<str>,
381    cx: &mut AsyncApp,
382) -> Result<Option<std::path::PathBuf>> {
383    let binary_name = binary_name.as_ref().to_string();
384    let binary_name_cloned = binary_name.clone();
385    let res = cx
386        .background_spawn(async move { which::which(binary_name_cloned) })
387        .await;
388    match res {
389        Ok(path) => Ok(Some(path)),
390        Err(which::Error::CannotFindBinaryPath) => Ok(None),
391        Err(err) => Err(anyhow::anyhow!(
392            "Failed to run 'which' to find the binary '{binary_name}': {err}"
393        )),
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_parse_platform() {
403        let result = parse_platform("Linux x86_64\n").unwrap();
404        assert_eq!(result.os, RemoteOs::Linux);
405        assert_eq!(result.arch, RemoteArch::X86_64);
406
407        let result = parse_platform("Darwin arm64\n").unwrap();
408        assert_eq!(result.os, RemoteOs::MacOs);
409        assert_eq!(result.arch, RemoteArch::Aarch64);
410
411        let result = parse_platform("Linux x86_64").unwrap();
412        assert_eq!(result.os, RemoteOs::Linux);
413        assert_eq!(result.arch, RemoteArch::X86_64);
414
415        let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap();
416        assert_eq!(result.os, RemoteOs::Linux);
417        assert_eq!(result.arch, RemoteArch::Aarch64);
418
419        let result = parse_platform("some shell init output\nLinux aarch64").unwrap();
420        assert_eq!(result.os, RemoteOs::Linux);
421        assert_eq!(result.arch, RemoteArch::Aarch64);
422
423        assert_eq!(
424            parse_platform("Linux armv8l\n").unwrap().arch,
425            RemoteArch::Aarch64
426        );
427        assert_eq!(
428            parse_platform("Linux aarch64\n").unwrap().arch,
429            RemoteArch::Aarch64
430        );
431        assert_eq!(
432            parse_platform("Linux x86_64\n").unwrap().arch,
433            RemoteArch::X86_64
434        );
435
436        let result = parse_platform(
437            r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#,
438        )
439        .unwrap();
440        assert_eq!(result.os, RemoteOs::Linux);
441        assert_eq!(result.arch, RemoteArch::X86_64);
442
443        assert!(parse_platform("Windows x86_64\n").is_err());
444        assert!(parse_platform("Linux armv7l\n").is_err());
445    }
446
447    #[test]
448    fn test_parse_shell() {
449        assert_eq!(parse_shell("/bin/bash\n", "sh"), "/bin/bash");
450        assert_eq!(parse_shell("/bin/zsh\n", "sh"), "/bin/zsh");
451
452        assert_eq!(parse_shell("/bin/bash", "sh"), "/bin/bash");
453        assert_eq!(
454            parse_shell("some shell init output\n/bin/bash\n", "sh"),
455            "/bin/bash"
456        );
457        assert_eq!(
458            parse_shell("some shell init output\n/bin/bash", "sh"),
459            "/bin/bash"
460        );
461        assert_eq!(parse_shell("", "sh"), "sh");
462        assert_eq!(parse_shell("\n", "sh"), "sh");
463    }
464}