transport.rs

  1use crate::{
  2    json_log::LogRecord,
  3    protocol::{MESSAGE_LEN_SIZE, message_len_from_buffer, read_message_with_len, write_message},
  4};
  5use anyhow::{Context as _, Result};
  6use futures::{
  7    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  8    channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
  9};
 10use gpui::{AppContext as _, AsyncApp, Task};
 11use rpc::proto::Envelope;
 12use smol::process::Child;
 13
 14pub mod ssh;
 15pub mod wsl;
 16
 17fn handle_rpc_messages_over_child_process_stdio(
 18    mut ssh_proxy_process: Child,
 19    incoming_tx: UnboundedSender<Envelope>,
 20    mut outgoing_rx: UnboundedReceiver<Envelope>,
 21    mut connection_activity_tx: Sender<()>,
 22    cx: &AsyncApp,
 23) -> Task<Result<i32>> {
 24    let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
 25    let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
 26    let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
 27
 28    let mut stdin_buffer = Vec::new();
 29    let mut stdout_buffer = Vec::new();
 30    let mut stderr_buffer = Vec::new();
 31    let mut stderr_offset = 0;
 32
 33    let stdin_task = cx.background_spawn(async move {
 34        while let Some(outgoing) = outgoing_rx.next().await {
 35            write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
 36        }
 37        anyhow::Ok(())
 38    });
 39
 40    let stdout_task = cx.background_spawn({
 41        let mut connection_activity_tx = connection_activity_tx.clone();
 42        async move {
 43            loop {
 44                stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
 45                let len = child_stdout.read(&mut stdout_buffer).await?;
 46
 47                if len == 0 {
 48                    return anyhow::Ok(());
 49                }
 50
 51                if len < MESSAGE_LEN_SIZE {
 52                    child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
 53                }
 54
 55                let message_len = message_len_from_buffer(&stdout_buffer);
 56                let envelope =
 57                    read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
 58                        .await?;
 59                connection_activity_tx.try_send(()).ok();
 60                incoming_tx.unbounded_send(envelope).ok();
 61            }
 62        }
 63    });
 64
 65    let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
 66        loop {
 67            stderr_buffer.resize(stderr_offset + 1024, 0);
 68
 69            let len = child_stderr
 70                .read(&mut stderr_buffer[stderr_offset..])
 71                .await?;
 72            if len == 0 {
 73                return anyhow::Ok(());
 74            }
 75
 76            stderr_offset += len;
 77            let mut start_ix = 0;
 78            while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
 79                .iter()
 80                .position(|b| b == &b'\n')
 81            {
 82                let line_ix = start_ix + ix;
 83                let content = &stderr_buffer[start_ix..line_ix];
 84                start_ix = line_ix + 1;
 85                if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
 86                    record.log(log::logger())
 87                } else {
 88                    eprintln!("(remote) {}", String::from_utf8_lossy(content));
 89                }
 90            }
 91            stderr_buffer.drain(0..start_ix);
 92            stderr_offset -= start_ix;
 93
 94            connection_activity_tx.try_send(()).ok();
 95        }
 96    });
 97
 98    cx.background_spawn(async move {
 99        let result = futures::select! {
100            result = stdin_task.fuse() => {
101                result.context("stdin")
102            }
103            result = stdout_task.fuse() => {
104                result.context("stdout")
105            }
106            result = stderr_task.fuse() => {
107                result.context("stderr")
108            }
109        };
110
111        let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
112        match result {
113            Ok(_) => Ok(status),
114            Err(error) => Err(error),
115        }
116    })
117}
118
119#[cfg(debug_assertions)]
120async fn build_remote_server_from_source(
121    platform: &crate::RemotePlatform,
122    delegate: &dyn crate::RemoteClientDelegate,
123    cx: &mut AsyncApp,
124) -> Result<Option<std::path::PathBuf>> {
125    use std::path::Path;
126
127    let Some(build_remote_server) = std::env::var("ZED_BUILD_REMOTE_SERVER").ok() else {
128        return Ok(None);
129    };
130
131    use smol::process::{Command, Stdio};
132    use std::env::VarError;
133
134    async fn run_cmd(command: &mut Command) -> Result<()> {
135        let output = command
136            .kill_on_drop(true)
137            .stderr(Stdio::inherit())
138            .output()
139            .await?;
140        anyhow::ensure!(
141            output.status.success(),
142            "Failed to run command: {command:?}"
143        );
144        Ok(())
145    }
146
147    let use_musl = !build_remote_server.contains("nomusl");
148    let triple = format!(
149        "{}-{}",
150        platform.arch,
151        match platform.os {
152            "linux" =>
153                if use_musl {
154                    "unknown-linux-musl"
155                } else {
156                    "unknown-linux-gnu"
157                },
158            "macos" => "apple-darwin",
159            _ => anyhow::bail!("can't cross compile for: {:?}", platform),
160        }
161    );
162    let mut rust_flags = match std::env::var("RUSTFLAGS") {
163        Ok(val) => val,
164        Err(VarError::NotPresent) => String::new(),
165        Err(e) => {
166            log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
167            String::new()
168        }
169    };
170    if platform.os == "linux" && use_musl {
171        rust_flags.push_str(" -C target-feature=+crt-static");
172    }
173    if build_remote_server.contains("mold") {
174        rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
175    }
176
177    if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
178        delegate.set_status(Some("Building remote server binary from source"), cx);
179        log::info!("building remote server binary from source");
180        run_cmd(
181            Command::new("cargo")
182                .args([
183                    "build",
184                    "--package",
185                    "remote_server",
186                    "--features",
187                    "debug-embed",
188                    "--target-dir",
189                    "target/remote_server",
190                    "--target",
191                    &triple,
192                ])
193                .env("RUSTFLAGS", &rust_flags),
194        )
195        .await?;
196    } else if build_remote_server.contains("cross") {
197        #[cfg(target_os = "windows")]
198        use util::paths::SanitizedPath;
199
200        delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
201        log::info!("installing cross");
202        run_cmd(Command::new("cargo").args([
203            "install",
204            "cross",
205            "--git",
206            "https://github.com/cross-rs/cross",
207        ]))
208        .await?;
209
210        delegate.set_status(
211            Some(&format!(
212                "Building remote server binary from source for {} with Docker",
213                &triple
214            )),
215            cx,
216        );
217        log::info!("building remote server binary from source for {}", &triple);
218
219        // On Windows, the binding needs to be set to the canonical path
220        #[cfg(target_os = "windows")]
221        let src = SanitizedPath::new(&smol::fs::canonicalize("./target").await?).to_glob_string();
222        #[cfg(not(target_os = "windows"))]
223        let src = "./target";
224
225        run_cmd(
226            Command::new("cross")
227                .args([
228                    "build",
229                    "--package",
230                    "remote_server",
231                    "--features",
232                    "debug-embed",
233                    "--target-dir",
234                    "target/remote_server",
235                    "--target",
236                    &triple,
237                ])
238                .env(
239                    "CROSS_CONTAINER_OPTS",
240                    format!("--mount type=bind,src={src},dst=/app/target"),
241                )
242                .env("RUSTFLAGS", &rust_flags),
243        )
244        .await?;
245    } else {
246        let which = cx
247            .background_spawn(async move { which::which("zig") })
248            .await;
249
250        if which.is_err() {
251            #[cfg(not(target_os = "windows"))]
252            {
253                anyhow::bail!(
254                    "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
255                )
256            }
257            #[cfg(target_os = "windows")]
258            {
259                anyhow::bail!(
260                    "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
261                )
262            }
263        }
264
265        delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
266        log::info!("adding rustup target");
267        run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
268
269        delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
270        log::info!("installing cargo-zigbuild");
271        run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
272
273        delegate.set_status(
274            Some(&format!(
275                "Building remote binary from source for {triple} with Zig"
276            )),
277            cx,
278        );
279        log::info!("building remote binary from source for {triple} with Zig");
280        run_cmd(
281            Command::new("cargo")
282                .args([
283                    "zigbuild",
284                    "--package",
285                    "remote_server",
286                    "--features",
287                    "debug-embed",
288                    "--target-dir",
289                    "target/remote_server",
290                    "--target",
291                    &triple,
292                ])
293                .env("RUSTFLAGS", &rust_flags),
294        )
295        .await?;
296    };
297    let bin_path = Path::new("target")
298        .join("remote_server")
299        .join(&triple)
300        .join("debug")
301        .join("remote_server");
302
303    let path = if !build_remote_server.contains("nocompress") {
304        delegate.set_status(Some("Compressing binary"), cx);
305
306        #[cfg(not(target_os = "windows"))]
307        {
308            run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
309        }
310
311        #[cfg(target_os = "windows")]
312        {
313            // On Windows, we use 7z to compress the binary
314            let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
315            let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
316            if smol::fs::metadata(&gz_path).await.is_ok() {
317                smol::fs::remove_file(&gz_path).await?;
318            }
319            run_cmd(Command::new(seven_zip).args([
320                "a",
321                "-tgzip",
322                &gz_path,
323                &bin_path.to_string_lossy(),
324            ]))
325            .await?;
326        }
327
328        let mut archive_path = bin_path;
329        archive_path.set_extension("gz");
330        std::env::current_dir()?.join(archive_path)
331    } else {
332        bin_path
333    };
334
335    Ok(Some(path))
336}