transport.rs

  1use crate::{
  2    RemotePlatform,
  3    json_log::LogRecord,
  4    protocol::{MESSAGE_LEN_SIZE, message_len_from_buffer, read_message_with_len, write_message},
  5};
  6use anyhow::{Context as _, Result};
  7use futures::{
  8    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  9    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
 10};
 11use gpui::{AppContext as _, AsyncApp, Task};
 12use rpc::proto::Envelope;
 13use smol::process::Child;
 14
 15pub mod ssh;
 16pub mod wsl;
 17
 18/// Parses the output of `uname -sm` to determine the remote platform.
 19/// Takes the last line to skip possible shell initialization output.
 20fn parse_platform(output: &str) -> Result<RemotePlatform> {
 21    let output = output.trim();
 22    let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last);
 23    let Some((os, arch)) = uname.split_once(" ") else {
 24        anyhow::bail!("unknown uname: {uname:?}")
 25    };
 26
 27    let os = match os {
 28        "Darwin" => "macos",
 29        "Linux" => "linux",
 30        _ => anyhow::bail!(
 31            "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
 32        ),
 33    };
 34
 35    // exclude armv5,6,7 as they are 32-bit.
 36    let arch = if arch.starts_with("armv8")
 37        || arch.starts_with("armv9")
 38        || arch.starts_with("arm64")
 39        || arch.starts_with("aarch64")
 40    {
 41        "aarch64"
 42    } else if arch.starts_with("x86") {
 43        "x86_64"
 44    } else {
 45        anyhow::bail!(
 46            "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
 47        )
 48    };
 49
 50    Ok(RemotePlatform { os, arch })
 51}
 52
 53/// Parses the output of `echo $SHELL` to determine the remote shell.
 54/// Takes the last line to skip possible shell initialization output.
 55fn parse_shell(output: &str, fallback_shell: &str) -> String {
 56    let output = output.trim();
 57    let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last);
 58    if shell.is_empty() {
 59        log::error!("$SHELL is not set, falling back to {fallback_shell}");
 60        fallback_shell.to_owned()
 61    } else {
 62        shell.to_owned()
 63    }
 64}
 65
 66fn handle_rpc_messages_over_child_process_stdio(
 67    mut ssh_proxy_process: Child,
 68    incoming_tx: UnboundedSender<Envelope>,
 69    mut outgoing_rx: UnboundedReceiver<Envelope>,
 70    mut connection_activity_tx: Sender<()>,
 71    cx: &AsyncApp,
 72) -> Task<Result<i32>> {
 73    let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 74    let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 75    let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 76
 77    let mut stdin_buffer = Vec::new();
 78    let mut stdout_buffer = Vec::new();
 79    let mut stderr_buffer = Vec::new();
 80    let mut stderr_offset = 0;
 81
 82    let stdin_task = cx.background_spawn(async move {
 83        while let Some(outgoing) = outgoing_rx.next().await {
 84            write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 85        }
 86        anyhow::Ok(())
 87    });
 88
 89    let stdout_task = cx.background_spawn({
 90        let mut connection_activity_tx = connection_activity_tx.clone();
 91        async move {
 92            loop {
 93                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 94                let len = child_stdout.read(&mut stdout_buffer).await?;
 95
 96                if len == 0 {
 97                    return anyhow::Ok(());
 98                }
 99
100                if len < MESSAGE_LEN_SIZE {
101                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
102                }
103
104                let message_len = message_len_from_buffer(&stdout_buffer);
105                let envelope =
106                    read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
107                        .await?;
108                connection_activity_tx.try_send(()).ok();
109                incoming_tx.unbounded_send(envelope).ok();
110            }
111        }
112    });
113
114    let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
115        loop {
116            stderr_buffer.resize(stderr_offset + 1024, 0);
117
118            let len = child_stderr
119                .read(&mut stderr_buffer[stderr_offset..])
120                .await?;
121            if len == 0 {
122                return anyhow::Ok(());
123            }
124
125            stderr_offset += len;
126            let mut start_ix = 0;
127            while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
128                .iter()
129                .position(|b| b == &b'\n')
130            {
131                let line_ix = start_ix + ix;
132                let content = &stderr_buffer[start_ix..line_ix];
133                start_ix = line_ix + 1;
134                if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
135                    record.log(log::logger())
136                } else {
137                    eprintln!("(remote) {}", String::from_utf8_lossy(content));
138                }
139            }
140            stderr_buffer.drain(0..start_ix);
141            stderr_offset -= start_ix;
142
143            connection_activity_tx.try_send(()).ok();
144        }
145    });
146
147    cx.background_spawn(async move {
148        let result = futures::select! {
149            result = stdin_task.fuse() => {
150                result.context("stdin")
151            }
152            result = stdout_task.fuse() => {
153                result.context("stdout")
154            }
155            result = stderr_task.fuse() => {
156                result.context("stderr")
157            }
158        };
159        let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
160        match result {
161            Ok(_) => Ok(status),
162            Err(error) => Err(error),
163        }
164    })
165}
166
167#[cfg(debug_assertions)]
168async fn build_remote_server_from_source(
169    platform: &crate::RemotePlatform,
170    delegate: &dyn crate::RemoteClientDelegate,
171    cx: &mut AsyncApp,
172) -> Result<Option<std::path::PathBuf>> {
173    use smol::process::{Command, Stdio};
174    use std::env::VarError;
175    use std::path::Path;
176    use util::command::new_smol_command;
177
178    // By default, we make building remote server from source opt-out and we do not force artifact compression
179    // for quicker builds.
180    let build_remote_server =
181        std::env::var("ZED_BUILD_REMOTE_SERVER").unwrap_or("nocompress".into());
182
183    if let "false" | "no" | "off" | "0" = &*build_remote_server {
184        return Ok(None);
185    }
186
187    async fn run_cmd(command: &mut Command) -> Result<()> {
188        let output = command
189            .kill_on_drop(true)
190            .stderr(Stdio::inherit())
191            .output()
192            .await?;
193        anyhow::ensure!(
194            output.status.success(),
195            "Failed to run command: {command:?}"
196        );
197        Ok(())
198    }
199
200    let use_musl = !build_remote_server.contains("nomusl");
201    let triple = format!(
202        "{}-{}",
203        platform.arch,
204        match platform.os {
205            "linux" =>
206                if use_musl {
207                    "unknown-linux-musl"
208                } else {
209                    "unknown-linux-gnu"
210                },
211            "macos" => "apple-darwin",
212            _ => anyhow::bail!("can't cross compile for: {:?}", platform),
213        }
214    );
215    let mut rust_flags = match std::env::var("RUSTFLAGS") {
216        Ok(val) => val,
217        Err(VarError::NotPresent) => String::new(),
218        Err(e) => {
219            log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
220            String::new()
221        }
222    };
223    if platform.os == "linux" && use_musl {
224        rust_flags.push_str(" -C target-feature=+crt-static");
225
226        if let Ok(path) = std::env::var("ZED_ZSTD_MUSL_LIB") {
227            rust_flags.push_str(&format!(" -C link-arg=-L{path}"));
228        }
229    }
230    if build_remote_server.contains("mold") {
231        rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
232    }
233
234    if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
235        delegate.set_status(Some("Building remote server binary from source"), cx);
236        log::info!("building remote server binary from source");
237        run_cmd(
238            new_smol_command("cargo")
239                .current_dir(concat!(env!("CARGO_MANIFEST_DIR"), "/../.."))
240                .args([
241                    "build",
242                    "--package",
243                    "remote_server",
244                    "--features",
245                    "debug-embed",
246                    "--target-dir",
247                    "target/remote_server",
248                    "--target",
249                    &triple,
250                ])
251                .env("RUSTFLAGS", &rust_flags),
252        )
253        .await?;
254    } else {
255        if which("zig", cx).await?.is_none() {
256            anyhow::bail!(if cfg!(not(windows)) {
257                "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup)"
258            } else {
259                "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup)"
260            });
261        }
262
263        let rustup = which("rustup", cx)
264            .await?
265            .context("rustup not found on $PATH, install rustup (see https://rustup.rs/)")?;
266        delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
267        log::info!("adding rustup target");
268        run_cmd(
269            new_smol_command(rustup)
270                .args(["target", "add"])
271                .arg(&triple),
272        )
273        .await?;
274
275        if which("cargo-zigbuild", cx).await?.is_none() {
276            delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
277            log::info!("installing cargo-zigbuild");
278            run_cmd(new_smol_command("cargo").args(["install", "--locked", "cargo-zigbuild"]))
279                .await?;
280        }
281
282        delegate.set_status(
283            Some(&format!(
284                "Building remote binary from source for {triple} with Zig"
285            )),
286            cx,
287        );
288        log::info!("building remote binary from source for {triple} with Zig");
289        run_cmd(
290            new_smol_command("cargo")
291                .args([
292                    "zigbuild",
293                    "--package",
294                    "remote_server",
295                    "--features",
296                    "debug-embed",
297                    "--target-dir",
298                    "target/remote_server",
299                    "--target",
300                    &triple,
301                ])
302                .env("RUSTFLAGS", &rust_flags),
303        )
304        .await?;
305    };
306    let bin_path = Path::new("target")
307        .join("remote_server")
308        .join(&triple)
309        .join("debug")
310        .join("remote_server");
311
312    let path = if !build_remote_server.contains("nocompress") {
313        delegate.set_status(Some("Compressing binary"), cx);
314
315        #[cfg(not(target_os = "windows"))]
316        {
317            run_cmd(new_smol_command("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
318        }
319
320        #[cfg(target_os = "windows")]
321        {
322            // On Windows, we use 7z to compress the binary
323
324            let seven_zip = which("7z.exe",cx)
325                .await?
326                .context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
327            let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
328            if smol::fs::metadata(&gz_path).await.is_ok() {
329                smol::fs::remove_file(&gz_path).await?;
330            }
331            run_cmd(new_smol_command(seven_zip).args([
332                "a",
333                "-tgzip",
334                &gz_path,
335                &bin_path.to_string_lossy(),
336            ]))
337            .await?;
338        }
339
340        let mut archive_path = bin_path;
341        archive_path.set_extension("gz");
342        std::env::current_dir()?.join(archive_path)
343    } else {
344        bin_path
345    };
346
347    Ok(Some(path))
348}
349
350#[cfg(debug_assertions)]
351async fn which(
352    binary_name: impl AsRef<str>,
353    cx: &mut AsyncApp,
354) -> Result<Option<std::path::PathBuf>> {
355    let binary_name = binary_name.as_ref().to_string();
356    let binary_name_cloned = binary_name.clone();
357    let res = cx
358        .background_spawn(async move { which::which(binary_name_cloned) })
359        .await;
360    match res {
361        Ok(path) => Ok(Some(path)),
362        Err(which::Error::CannotFindBinaryPath) => Ok(None),
363        Err(err) => Err(anyhow::anyhow!(
364            "Failed to run 'which' to find the binary '{binary_name}': {err}"
365        )),
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_parse_platform() {
375        let result = parse_platform("Linux x86_64\n").unwrap();
376        assert_eq!(result.os, "linux");
377        assert_eq!(result.arch, "x86_64");
378
379        let result = parse_platform("Darwin arm64\n").unwrap();
380        assert_eq!(result.os, "macos");
381        assert_eq!(result.arch, "aarch64");
382
383        let result = parse_platform("Linux x86_64").unwrap();
384        assert_eq!(result.os, "linux");
385        assert_eq!(result.arch, "x86_64");
386
387        let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap();
388        assert_eq!(result.os, "linux");
389        assert_eq!(result.arch, "aarch64");
390
391        let result = parse_platform("some shell init output\nLinux aarch64").unwrap();
392        assert_eq!(result.os, "linux");
393        assert_eq!(result.arch, "aarch64");
394
395        assert_eq!(parse_platform("Linux armv8l\n").unwrap().arch, "aarch64");
396        assert_eq!(parse_platform("Linux aarch64\n").unwrap().arch, "aarch64");
397        assert_eq!(parse_platform("Linux x86_64\n").unwrap().arch, "x86_64");
398
399        let result = parse_platform(
400            r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#,
401        )
402        .unwrap();
403        assert_eq!(result.os, "linux");
404        assert_eq!(result.arch, "x86_64");
405
406        assert!(parse_platform("Windows x86_64\n").is_err());
407        assert!(parse_platform("Linux armv7l\n").is_err());
408    }
409
410    #[test]
411    fn test_parse_shell() {
412        assert_eq!(parse_shell("/bin/bash\n", "sh"), "/bin/bash");
413        assert_eq!(parse_shell("/bin/zsh\n", "sh"), "/bin/zsh");
414
415        assert_eq!(parse_shell("/bin/bash", "sh"), "/bin/bash");
416        assert_eq!(
417            parse_shell("some shell init output\n/bin/bash\n", "sh"),
418            "/bin/bash"
419        );
420        assert_eq!(
421            parse_shell("some shell init output\n/bin/bash", "sh"),
422            "/bin/bash"
423        );
424        assert_eq!(parse_shell("", "sh"), "sh");
425        assert_eq!(parse_shell("\n", "sh"), "sh");
426    }
427}