1use crate::{
2 RemoteClientDelegate, RemotePlatform,
3 json_log::LogRecord,
4 protocol::{MESSAGE_LEN_SIZE, message_len_from_buffer, read_message_with_len, write_message},
5 remote_client::{CommandTemplate, RemoteConnection},
6};
7use anyhow::{Context as _, Result, anyhow};
8use async_trait::async_trait;
9use collections::HashMap;
10use futures::{
11 AsyncReadExt as _, FutureExt as _, StreamExt as _,
12 channel::mpsc::{Sender, UnboundedReceiver, UnboundedSender},
13 select_biased,
14};
15use gpui::{App, AppContext as _, AsyncApp, SemanticVersion, Task};
16use itertools::Itertools;
17use parking_lot::Mutex;
18use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
19use rpc::proto::Envelope;
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22use smol::{
23 fs,
24 process::{self, Child, Stdio},
25};
26use std::{
27 iter,
28 path::{Path, PathBuf},
29 sync::Arc,
30 time::Instant,
31};
32use tempfile::TempDir;
33use util::{
34 get_default_system_shell,
35 paths::{PathStyle, RemotePathBuf},
36};
37
38pub(crate) struct SshRemoteConnection {
39 socket: SshSocket,
40 master_process: Mutex<Option<Child>>,
41 remote_binary_path: Option<RemotePathBuf>,
42 ssh_platform: RemotePlatform,
43 ssh_path_style: PathStyle,
44 ssh_shell: String,
45 _temp_dir: TempDir,
46}
47
48#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
49pub struct SshConnectionOptions {
50 pub host: String,
51 pub username: Option<String>,
52 pub port: Option<u16>,
53 pub password: Option<String>,
54 pub args: Option<Vec<String>>,
55 pub port_forwards: Option<Vec<SshPortForwardOption>>,
56
57 pub nickname: Option<String>,
58 pub upload_binary_over_ssh: bool,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, JsonSchema)]
62pub struct SshPortForwardOption {
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub local_host: Option<String>,
65 pub local_port: u16,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub remote_host: Option<String>,
68 pub remote_port: u16,
69}
70
71#[derive(Clone)]
72struct SshSocket {
73 connection_options: SshConnectionOptions,
74 #[cfg(not(target_os = "windows"))]
75 socket_path: PathBuf,
76 envs: HashMap<String, String>,
77}
78
79macro_rules! shell_script {
80 ($fmt:expr, $($name:ident = $arg:expr),+ $(,)?) => {{
81 format!(
82 $fmt,
83 $(
84 $name = shlex::try_quote($arg).unwrap()
85 ),+
86 )
87 }};
88}
89
90#[async_trait(?Send)]
91impl RemoteConnection for SshRemoteConnection {
92 async fn kill(&self) -> Result<()> {
93 let Some(mut process) = self.master_process.lock().take() else {
94 return Ok(());
95 };
96 process.kill().ok();
97 process.status().await?;
98 Ok(())
99 }
100
101 fn has_been_killed(&self) -> bool {
102 self.master_process.lock().is_none()
103 }
104
105 fn connection_options(&self) -> SshConnectionOptions {
106 self.socket.connection_options.clone()
107 }
108
109 fn shell(&self) -> String {
110 self.ssh_shell.clone()
111 }
112
113 fn build_command(
114 &self,
115 input_program: Option<String>,
116 input_args: &[String],
117 input_env: &HashMap<String, String>,
118 working_dir: Option<String>,
119 activation_script: Option<String>,
120 port_forward: Option<(u16, String, u16)>,
121 ) -> Result<CommandTemplate> {
122 use std::fmt::Write as _;
123
124 let mut script = String::new();
125 if let Some(working_dir) = working_dir {
126 let working_dir =
127 RemotePathBuf::new(working_dir.into(), self.ssh_path_style).to_string();
128
129 // shlex will wrap the command in single quotes (''), disabling ~ expansion,
130 // replace ith with something that works
131 const TILDE_PREFIX: &'static str = "~/";
132 if working_dir.starts_with(TILDE_PREFIX) {
133 let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/");
134 write!(&mut script, "cd \"$HOME/{working_dir}\"; ").unwrap();
135 } else {
136 write!(&mut script, "cd \"{working_dir}\"; ").unwrap();
137 }
138 } else {
139 write!(&mut script, "cd; ").unwrap();
140 };
141 if let Some(activation_script) = activation_script {
142 write!(&mut script, " {activation_script};").unwrap();
143 }
144
145 for (k, v) in input_env.iter() {
146 if let Some((k, v)) = shlex::try_quote(k).ok().zip(shlex::try_quote(v).ok()) {
147 write!(&mut script, "{}={} ", k, v).unwrap();
148 }
149 }
150
151 let shell = &self.ssh_shell;
152
153 if let Some(input_program) = input_program {
154 let command = shlex::try_quote(&input_program)?;
155 script.push_str(&command);
156 for arg in input_args {
157 let arg = shlex::try_quote(&arg)?;
158 script.push_str(" ");
159 script.push_str(&arg);
160 }
161 } else {
162 write!(&mut script, "exec {shell} -l").unwrap();
163 };
164
165 let sys_shell = get_default_system_shell();
166 let shell_invocation = format!("{sys_shell} -c {}", shlex::try_quote(&script).unwrap());
167
168 let mut args = Vec::new();
169 args.extend(self.socket.ssh_args());
170
171 if let Some((local_port, host, remote_port)) = port_forward {
172 args.push("-L".into());
173 args.push(format!("{local_port}:{host}:{remote_port}"));
174 }
175
176 args.push("-t".into());
177 args.push(shell_invocation);
178 Ok(CommandTemplate {
179 program: "ssh".into(),
180 args,
181 env: self.socket.envs.clone(),
182 })
183 }
184
185 fn upload_directory(
186 &self,
187 src_path: PathBuf,
188 dest_path: RemotePathBuf,
189 cx: &App,
190 ) -> Task<Result<()>> {
191 let mut command = util::command::new_smol_command("scp");
192 let output = self
193 .socket
194 .ssh_options(&mut command)
195 .args(
196 self.socket
197 .connection_options
198 .port
199 .map(|port| vec!["-P".to_string(), port.to_string()])
200 .unwrap_or_default(),
201 )
202 .arg("-C")
203 .arg("-r")
204 .arg(&src_path)
205 .arg(format!(
206 "{}:{}",
207 self.socket.connection_options.scp_url(),
208 dest_path
209 ))
210 .output();
211
212 cx.background_spawn(async move {
213 let output = output.await?;
214
215 anyhow::ensure!(
216 output.status.success(),
217 "failed to upload directory {} -> {}: {}",
218 src_path.display(),
219 dest_path.to_string(),
220 String::from_utf8_lossy(&output.stderr)
221 );
222
223 Ok(())
224 })
225 }
226
227 fn start_proxy(
228 &self,
229 unique_identifier: String,
230 reconnect: bool,
231 incoming_tx: UnboundedSender<Envelope>,
232 outgoing_rx: UnboundedReceiver<Envelope>,
233 connection_activity_tx: Sender<()>,
234 delegate: Arc<dyn RemoteClientDelegate>,
235 cx: &mut AsyncApp,
236 ) -> Task<Result<i32>> {
237 delegate.set_status(Some("Starting proxy"), cx);
238
239 let Some(remote_binary_path) = self.remote_binary_path.clone() else {
240 return Task::ready(Err(anyhow!("Remote binary path not set")));
241 };
242
243 let mut start_proxy_command = shell_script!(
244 "exec {binary_path} proxy --identifier {identifier}",
245 binary_path = &remote_binary_path.to_string(),
246 identifier = &unique_identifier,
247 );
248
249 for env_var in ["RUST_LOG", "RUST_BACKTRACE", "ZED_GENERATE_MINIDUMPS"] {
250 if let Some(value) = std::env::var(env_var).ok() {
251 start_proxy_command = format!(
252 "{}={} {} ",
253 env_var,
254 shlex::try_quote(&value).unwrap(),
255 start_proxy_command,
256 );
257 }
258 }
259
260 if reconnect {
261 start_proxy_command.push_str(" --reconnect");
262 }
263
264 let ssh_proxy_process = match self
265 .socket
266 .ssh_command("sh", &["-lc", &start_proxy_command])
267 // IMPORTANT: we kill this process when we drop the task that uses it.
268 .kill_on_drop(true)
269 .spawn()
270 {
271 Ok(process) => process,
272 Err(error) => {
273 return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)));
274 }
275 };
276
277 Self::multiplex(
278 ssh_proxy_process,
279 incoming_tx,
280 outgoing_rx,
281 connection_activity_tx,
282 cx,
283 )
284 }
285
286 fn path_style(&self) -> PathStyle {
287 self.ssh_path_style
288 }
289}
290
291impl SshRemoteConnection {
292 pub(crate) async fn new(
293 connection_options: SshConnectionOptions,
294 delegate: Arc<dyn RemoteClientDelegate>,
295 cx: &mut AsyncApp,
296 ) -> Result<Self> {
297 use askpass::AskPassResult;
298
299 delegate.set_status(Some("Connecting"), cx);
300
301 let url = connection_options.ssh_url();
302
303 let temp_dir = tempfile::Builder::new()
304 .prefix("zed-ssh-session")
305 .tempdir()?;
306 let askpass_delegate = askpass::AskPassDelegate::new(cx, {
307 let delegate = delegate.clone();
308 move |prompt, tx, cx| delegate.ask_password(prompt, tx, cx)
309 });
310
311 let mut askpass =
312 askpass::AskPassSession::new(cx.background_executor(), askpass_delegate).await?;
313
314 // Start the master SSH process, which does not do anything except for establish
315 // the connection and keep it open, allowing other ssh commands to reuse it
316 // via a control socket.
317 #[cfg(not(target_os = "windows"))]
318 let socket_path = temp_dir.path().join("ssh.sock");
319
320 let mut master_process = {
321 #[cfg(not(target_os = "windows"))]
322 let args = [
323 "-N",
324 "-o",
325 "ControlPersist=no",
326 "-o",
327 "ControlMaster=yes",
328 "-o",
329 ];
330 // On Windows, `ControlMaster` and `ControlPath` are not supported:
331 // https://github.com/PowerShell/Win32-OpenSSH/issues/405
332 // https://github.com/PowerShell/Win32-OpenSSH/wiki/Project-Scope
333 #[cfg(target_os = "windows")]
334 let args = ["-N"];
335 let mut master_process = util::command::new_smol_command("ssh");
336 master_process
337 .kill_on_drop(true)
338 .stdin(Stdio::null())
339 .stdout(Stdio::piped())
340 .stderr(Stdio::piped())
341 .env("SSH_ASKPASS_REQUIRE", "force")
342 .env("SSH_ASKPASS", askpass.script_path())
343 .args(connection_options.additional_args())
344 .args(args);
345 #[cfg(not(target_os = "windows"))]
346 master_process.arg(format!("ControlPath={}", socket_path.display()));
347 master_process.arg(&url).spawn()?
348 };
349 // Wait for this ssh process to close its stdout, indicating that authentication
350 // has completed.
351 let mut stdout = master_process.stdout.take().unwrap();
352 let mut output = Vec::new();
353
354 let result = select_biased! {
355 result = askpass.run().fuse() => {
356 match result {
357 AskPassResult::CancelledByUser => {
358 master_process.kill().ok();
359 anyhow::bail!("SSH connection canceled")
360 }
361 AskPassResult::Timedout => {
362 anyhow::bail!("connecting to host timed out")
363 }
364 }
365 }
366 _ = stdout.read_to_end(&mut output).fuse() => {
367 anyhow::Ok(())
368 }
369 };
370
371 if let Err(e) = result {
372 return Err(e.context("Failed to connect to host"));
373 }
374
375 if master_process.try_status()?.is_some() {
376 output.clear();
377 let mut stderr = master_process.stderr.take().unwrap();
378 stderr.read_to_end(&mut output).await?;
379
380 let error_message = format!(
381 "failed to connect: {}",
382 String::from_utf8_lossy(&output).trim()
383 );
384 anyhow::bail!(error_message);
385 }
386
387 #[cfg(not(target_os = "windows"))]
388 let socket = SshSocket::new(connection_options, socket_path)?;
389 #[cfg(target_os = "windows")]
390 let socket = SshSocket::new(connection_options, &temp_dir, askpass.get_password())?;
391 drop(askpass);
392
393 let ssh_platform = socket.platform().await?;
394 let ssh_path_style = match ssh_platform.os {
395 "windows" => PathStyle::Windows,
396 _ => PathStyle::Posix,
397 };
398 let ssh_shell = socket.shell().await;
399
400 let mut this = Self {
401 socket,
402 master_process: Mutex::new(Some(master_process)),
403 _temp_dir: temp_dir,
404 remote_binary_path: None,
405 ssh_path_style,
406 ssh_platform,
407 ssh_shell,
408 };
409
410 let (release_channel, version, commit) = cx.update(|cx| {
411 (
412 ReleaseChannel::global(cx),
413 AppVersion::global(cx),
414 AppCommitSha::try_global(cx),
415 )
416 })?;
417 this.remote_binary_path = Some(
418 this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
419 .await?,
420 );
421
422 Ok(this)
423 }
424
425 fn multiplex(
426 mut ssh_proxy_process: Child,
427 incoming_tx: UnboundedSender<Envelope>,
428 mut outgoing_rx: UnboundedReceiver<Envelope>,
429 mut connection_activity_tx: Sender<()>,
430 cx: &AsyncApp,
431 ) -> Task<Result<i32>> {
432 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
433 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
434 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
435
436 let mut stdin_buffer = Vec::new();
437 let mut stdout_buffer = Vec::new();
438 let mut stderr_buffer = Vec::new();
439 let mut stderr_offset = 0;
440
441 let stdin_task = cx.background_spawn(async move {
442 while let Some(outgoing) = outgoing_rx.next().await {
443 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
444 }
445 anyhow::Ok(())
446 });
447
448 let stdout_task = cx.background_spawn({
449 let mut connection_activity_tx = connection_activity_tx.clone();
450 async move {
451 loop {
452 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
453 let len = child_stdout.read(&mut stdout_buffer).await?;
454
455 if len == 0 {
456 return anyhow::Ok(());
457 }
458
459 if len < MESSAGE_LEN_SIZE {
460 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
461 }
462
463 let message_len = message_len_from_buffer(&stdout_buffer);
464 let envelope =
465 read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
466 .await?;
467 connection_activity_tx.try_send(()).ok();
468 incoming_tx.unbounded_send(envelope).ok();
469 }
470 }
471 });
472
473 let stderr_task: Task<anyhow::Result<()>> = cx.background_spawn(async move {
474 loop {
475 stderr_buffer.resize(stderr_offset + 1024, 0);
476
477 let len = child_stderr
478 .read(&mut stderr_buffer[stderr_offset..])
479 .await?;
480 if len == 0 {
481 return anyhow::Ok(());
482 }
483
484 stderr_offset += len;
485 let mut start_ix = 0;
486 while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
487 .iter()
488 .position(|b| b == &b'\n')
489 {
490 let line_ix = start_ix + ix;
491 let content = &stderr_buffer[start_ix..line_ix];
492 start_ix = line_ix + 1;
493 if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
494 record.log(log::logger())
495 } else {
496 eprintln!("(remote) {}", String::from_utf8_lossy(content));
497 }
498 }
499 stderr_buffer.drain(0..start_ix);
500 stderr_offset -= start_ix;
501
502 connection_activity_tx.try_send(()).ok();
503 }
504 });
505
506 cx.background_spawn(async move {
507 let result = futures::select! {
508 result = stdin_task.fuse() => {
509 result.context("stdin")
510 }
511 result = stdout_task.fuse() => {
512 result.context("stdout")
513 }
514 result = stderr_task.fuse() => {
515 result.context("stderr")
516 }
517 };
518
519 let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
520 match result {
521 Ok(_) => Ok(status),
522 Err(error) => Err(error),
523 }
524 })
525 }
526
527 #[allow(unused)]
528 async fn ensure_server_binary(
529 &self,
530 delegate: &Arc<dyn RemoteClientDelegate>,
531 release_channel: ReleaseChannel,
532 version: SemanticVersion,
533 commit: Option<AppCommitSha>,
534 cx: &mut AsyncApp,
535 ) -> Result<RemotePathBuf> {
536 let version_str = match release_channel {
537 ReleaseChannel::Nightly => {
538 let commit = commit.map(|s| s.full()).unwrap_or_default();
539 format!("{}-{}", version, commit)
540 }
541 ReleaseChannel::Dev => "build".to_string(),
542 _ => version.to_string(),
543 };
544 let binary_name = format!(
545 "zed-remote-server-{}-{}",
546 release_channel.dev_name(),
547 version_str
548 );
549 let dst_path = RemotePathBuf::new(
550 paths::remote_server_dir_relative().join(binary_name),
551 self.ssh_path_style,
552 );
553
554 let build_remote_server = std::env::var("ZED_BUILD_REMOTE_SERVER").ok();
555 #[cfg(debug_assertions)]
556 if let Some(build_remote_server) = build_remote_server {
557 let src_path = self.build_local(build_remote_server, delegate, cx).await?;
558 let tmp_path = RemotePathBuf::new(
559 paths::remote_server_dir_relative().join(format!(
560 "download-{}-{}",
561 std::process::id(),
562 src_path.file_name().unwrap().to_string_lossy()
563 )),
564 self.ssh_path_style,
565 );
566 self.upload_local_server_binary(&src_path, &tmp_path, delegate, cx)
567 .await?;
568 self.extract_server_binary(&dst_path, &tmp_path, delegate, cx)
569 .await?;
570 return Ok(dst_path);
571 }
572
573 if self
574 .socket
575 .run_command(&dst_path.to_string(), &["version"])
576 .await
577 .is_ok()
578 {
579 return Ok(dst_path);
580 }
581
582 let wanted_version = cx.update(|cx| match release_channel {
583 ReleaseChannel::Nightly => Ok(None),
584 ReleaseChannel::Dev => {
585 anyhow::bail!(
586 "ZED_BUILD_REMOTE_SERVER is not set and no remote server exists at ({:?})",
587 dst_path
588 )
589 }
590 _ => Ok(Some(AppVersion::global(cx))),
591 })??;
592
593 let tmp_path_gz = RemotePathBuf::new(
594 PathBuf::from(format!("{}-download-{}.gz", dst_path, std::process::id())),
595 self.ssh_path_style,
596 );
597 if !self.socket.connection_options.upload_binary_over_ssh
598 && let Some((url, body)) = delegate
599 .get_download_params(self.ssh_platform, release_channel, wanted_version, cx)
600 .await?
601 {
602 match self
603 .download_binary_on_server(&url, &body, &tmp_path_gz, delegate, cx)
604 .await
605 {
606 Ok(_) => {
607 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
608 .await?;
609 return Ok(dst_path);
610 }
611 Err(e) => {
612 log::error!(
613 "Failed to download binary on server, attempting to upload server: {}",
614 e
615 )
616 }
617 }
618 }
619
620 let src_path = delegate
621 .download_server_binary_locally(self.ssh_platform, release_channel, wanted_version, cx)
622 .await?;
623 self.upload_local_server_binary(&src_path, &tmp_path_gz, delegate, cx)
624 .await?;
625 self.extract_server_binary(&dst_path, &tmp_path_gz, delegate, cx)
626 .await?;
627 Ok(dst_path)
628 }
629
630 async fn download_binary_on_server(
631 &self,
632 url: &str,
633 body: &str,
634 tmp_path_gz: &RemotePathBuf,
635 delegate: &Arc<dyn RemoteClientDelegate>,
636 cx: &mut AsyncApp,
637 ) -> Result<()> {
638 if let Some(parent) = tmp_path_gz.parent() {
639 self.socket
640 .run_command(
641 "sh",
642 &[
643 "-lc",
644 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
645 ],
646 )
647 .await?;
648 }
649
650 delegate.set_status(Some("Downloading remote development server on host"), cx);
651
652 match self
653 .socket
654 .run_command(
655 "curl",
656 &[
657 "-f",
658 "-L",
659 "-X",
660 "GET",
661 "-H",
662 "Content-Type: application/json",
663 "-d",
664 body,
665 url,
666 "-o",
667 &tmp_path_gz.to_string(),
668 ],
669 )
670 .await
671 {
672 Ok(_) => {}
673 Err(e) => {
674 if self.socket.run_command("which", &["curl"]).await.is_ok() {
675 return Err(e);
676 }
677
678 match self
679 .socket
680 .run_command(
681 "wget",
682 &[
683 "--method=GET",
684 "--header=Content-Type: application/json",
685 "--body-data",
686 body,
687 url,
688 "-O",
689 &tmp_path_gz.to_string(),
690 ],
691 )
692 .await
693 {
694 Ok(_) => {}
695 Err(e) => {
696 if self.socket.run_command("which", &["wget"]).await.is_ok() {
697 return Err(e);
698 } else {
699 anyhow::bail!("Neither curl nor wget is available");
700 }
701 }
702 }
703 }
704 }
705
706 Ok(())
707 }
708
709 async fn upload_local_server_binary(
710 &self,
711 src_path: &Path,
712 tmp_path_gz: &RemotePathBuf,
713 delegate: &Arc<dyn RemoteClientDelegate>,
714 cx: &mut AsyncApp,
715 ) -> Result<()> {
716 if let Some(parent) = tmp_path_gz.parent() {
717 self.socket
718 .run_command(
719 "sh",
720 &[
721 "-lc",
722 &shell_script!("mkdir -p {parent}", parent = parent.to_string().as_ref()),
723 ],
724 )
725 .await?;
726 }
727
728 let src_stat = fs::metadata(&src_path).await?;
729 let size = src_stat.len();
730
731 let t0 = Instant::now();
732 delegate.set_status(Some("Uploading remote development server"), cx);
733 log::info!(
734 "uploading remote development server to {:?} ({}kb)",
735 tmp_path_gz,
736 size / 1024
737 );
738 self.upload_file(src_path, tmp_path_gz)
739 .await
740 .context("failed to upload server binary")?;
741 log::info!("uploaded remote development server in {:?}", t0.elapsed());
742 Ok(())
743 }
744
745 async fn extract_server_binary(
746 &self,
747 dst_path: &RemotePathBuf,
748 tmp_path: &RemotePathBuf,
749 delegate: &Arc<dyn RemoteClientDelegate>,
750 cx: &mut AsyncApp,
751 ) -> Result<()> {
752 delegate.set_status(Some("Extracting remote development server"), cx);
753 let server_mode = 0o755;
754
755 let orig_tmp_path = tmp_path.to_string();
756 let script = if let Some(tmp_path) = orig_tmp_path.strip_suffix(".gz") {
757 shell_script!(
758 "gunzip -f {orig_tmp_path} && chmod {server_mode} {tmp_path} && mv {tmp_path} {dst_path}",
759 server_mode = &format!("{:o}", server_mode),
760 dst_path = &dst_path.to_string(),
761 )
762 } else {
763 shell_script!(
764 "chmod {server_mode} {orig_tmp_path} && mv {orig_tmp_path} {dst_path}",
765 server_mode = &format!("{:o}", server_mode),
766 dst_path = &dst_path.to_string()
767 )
768 };
769 self.socket.run_command("sh", &["-lc", &script]).await?;
770 Ok(())
771 }
772
773 async fn upload_file(&self, src_path: &Path, dest_path: &RemotePathBuf) -> Result<()> {
774 log::debug!("uploading file {:?} to {:?}", src_path, dest_path);
775 let mut command = util::command::new_smol_command("scp");
776 let output = self
777 .socket
778 .ssh_options(&mut command)
779 .args(
780 self.socket
781 .connection_options
782 .port
783 .map(|port| vec!["-P".to_string(), port.to_string()])
784 .unwrap_or_default(),
785 )
786 .arg(src_path)
787 .arg(format!(
788 "{}:{}",
789 self.socket.connection_options.scp_url(),
790 dest_path
791 ))
792 .output()
793 .await?;
794
795 anyhow::ensure!(
796 output.status.success(),
797 "failed to upload file {} -> {}: {}",
798 src_path.display(),
799 dest_path.to_string(),
800 String::from_utf8_lossy(&output.stderr)
801 );
802 Ok(())
803 }
804
805 #[cfg(debug_assertions)]
806 async fn build_local(
807 &self,
808 build_remote_server: String,
809 delegate: &Arc<dyn RemoteClientDelegate>,
810 cx: &mut AsyncApp,
811 ) -> Result<PathBuf> {
812 use smol::process::{Command, Stdio};
813 use std::env::VarError;
814
815 async fn run_cmd(command: &mut Command) -> Result<()> {
816 let output = command
817 .kill_on_drop(true)
818 .stderr(Stdio::inherit())
819 .output()
820 .await?;
821 anyhow::ensure!(
822 output.status.success(),
823 "Failed to run command: {command:?}"
824 );
825 Ok(())
826 }
827
828 let use_musl = !build_remote_server.contains("nomusl");
829 let triple = format!(
830 "{}-{}",
831 self.ssh_platform.arch,
832 match self.ssh_platform.os {
833 "linux" =>
834 if use_musl {
835 "unknown-linux-musl"
836 } else {
837 "unknown-linux-gnu"
838 },
839 "macos" => "apple-darwin",
840 _ => anyhow::bail!("can't cross compile for: {:?}", self.ssh_platform),
841 }
842 );
843 let mut rust_flags = match std::env::var("RUSTFLAGS") {
844 Ok(val) => val,
845 Err(VarError::NotPresent) => String::new(),
846 Err(e) => {
847 log::error!("Failed to get env var `RUSTFLAGS` value: {e}");
848 String::new()
849 }
850 };
851 if self.ssh_platform.os == "linux" && use_musl {
852 rust_flags.push_str(" -C target-feature=+crt-static");
853 }
854 if build_remote_server.contains("mold") {
855 rust_flags.push_str(" -C link-arg=-fuse-ld=mold");
856 }
857
858 if self.ssh_platform.arch == std::env::consts::ARCH
859 && self.ssh_platform.os == std::env::consts::OS
860 {
861 delegate.set_status(Some("Building remote server binary from source"), cx);
862 log::info!("building remote server binary from source");
863 run_cmd(
864 Command::new("cargo")
865 .args([
866 "build",
867 "--package",
868 "remote_server",
869 "--features",
870 "debug-embed",
871 "--target-dir",
872 "target/remote_server",
873 "--target",
874 &triple,
875 ])
876 .env("RUSTFLAGS", &rust_flags),
877 )
878 .await?;
879 } else if build_remote_server.contains("cross") {
880 #[cfg(target_os = "windows")]
881 use util::paths::SanitizedPath;
882
883 delegate.set_status(Some("Installing cross.rs for cross-compilation"), cx);
884 log::info!("installing cross");
885 run_cmd(Command::new("cargo").args([
886 "install",
887 "cross",
888 "--git",
889 "https://github.com/cross-rs/cross",
890 ]))
891 .await?;
892
893 delegate.set_status(
894 Some(&format!(
895 "Building remote server binary from source for {} with Docker",
896 &triple
897 )),
898 cx,
899 );
900 log::info!("building remote server binary from source for {}", &triple);
901
902 // On Windows, the binding needs to be set to the canonical path
903 #[cfg(target_os = "windows")]
904 let src =
905 SanitizedPath::new(&smol::fs::canonicalize("./target").await?).to_glob_string();
906 #[cfg(not(target_os = "windows"))]
907 let src = "./target";
908 run_cmd(
909 Command::new("cross")
910 .args([
911 "build",
912 "--package",
913 "remote_server",
914 "--features",
915 "debug-embed",
916 "--target-dir",
917 "target/remote_server",
918 "--target",
919 &triple,
920 ])
921 .env(
922 "CROSS_CONTAINER_OPTS",
923 format!("--mount type=bind,src={src},dst=/app/target"),
924 )
925 .env("RUSTFLAGS", &rust_flags),
926 )
927 .await?;
928 } else {
929 let which = cx
930 .background_spawn(async move { which::which("zig") })
931 .await;
932
933 if which.is_err() {
934 #[cfg(not(target_os = "windows"))]
935 {
936 anyhow::bail!(
937 "zig not found on $PATH, install zig (see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
938 )
939 }
940 #[cfg(target_os = "windows")]
941 {
942 anyhow::bail!(
943 "zig not found on $PATH, install zig (use `winget install -e --id zig.zig` or see https://ziglang.org/learn/getting-started or use zigup) or pass ZED_BUILD_REMOTE_SERVER=cross to use cross"
944 )
945 }
946 }
947
948 delegate.set_status(Some("Adding rustup target for cross-compilation"), cx);
949 log::info!("adding rustup target");
950 run_cmd(Command::new("rustup").args(["target", "add"]).arg(&triple)).await?;
951
952 delegate.set_status(Some("Installing cargo-zigbuild for cross-compilation"), cx);
953 log::info!("installing cargo-zigbuild");
954 run_cmd(Command::new("cargo").args(["install", "--locked", "cargo-zigbuild"])).await?;
955
956 delegate.set_status(
957 Some(&format!(
958 "Building remote binary from source for {triple} with Zig"
959 )),
960 cx,
961 );
962 log::info!("building remote binary from source for {triple} with Zig");
963 run_cmd(
964 Command::new("cargo")
965 .args([
966 "zigbuild",
967 "--package",
968 "remote_server",
969 "--features",
970 "debug-embed",
971 "--target-dir",
972 "target/remote_server",
973 "--target",
974 &triple,
975 ])
976 .env("RUSTFLAGS", &rust_flags),
977 )
978 .await?;
979 };
980 let bin_path = Path::new("target")
981 .join("remote_server")
982 .join(&triple)
983 .join("debug")
984 .join("remote_server");
985
986 let path = if !build_remote_server.contains("nocompress") {
987 delegate.set_status(Some("Compressing binary"), cx);
988
989 #[cfg(not(target_os = "windows"))]
990 {
991 run_cmd(Command::new("gzip").args(["-f", &bin_path.to_string_lossy()])).await?;
992 }
993 #[cfg(target_os = "windows")]
994 {
995 // On Windows, we use 7z to compress the binary
996 let seven_zip = which::which("7z.exe").context("7z.exe not found on $PATH, install it (e.g. with `winget install -e --id 7zip.7zip`) or, if you don't want this behaviour, set $env:ZED_BUILD_REMOTE_SERVER=\"nocompress\"")?;
997 let gz_path = format!("target/remote_server/{}/debug/remote_server.gz", triple);
998 if smol::fs::metadata(&gz_path).await.is_ok() {
999 smol::fs::remove_file(&gz_path).await?;
1000 }
1001 run_cmd(Command::new(seven_zip).args([
1002 "a",
1003 "-tgzip",
1004 &gz_path,
1005 &bin_path.to_string_lossy(),
1006 ]))
1007 .await?;
1008 }
1009
1010 let mut archive_path = bin_path;
1011 archive_path.set_extension("gz");
1012 std::env::current_dir()?.join(archive_path)
1013 } else {
1014 bin_path
1015 };
1016
1017 Ok(path)
1018 }
1019}
1020
1021impl SshSocket {
1022 #[cfg(not(target_os = "windows"))]
1023 fn new(options: SshConnectionOptions, socket_path: PathBuf) -> Result<Self> {
1024 Ok(Self {
1025 connection_options: options,
1026 envs: HashMap::default(),
1027 socket_path,
1028 })
1029 }
1030
1031 #[cfg(target_os = "windows")]
1032 fn new(options: SshConnectionOptions, temp_dir: &TempDir, secret: String) -> Result<Self> {
1033 let askpass_script = temp_dir.path().join("askpass.bat");
1034 std::fs::write(&askpass_script, "@ECHO OFF\necho %ZED_SSH_ASKPASS%")?;
1035 let mut envs = HashMap::default();
1036 envs.insert("SSH_ASKPASS_REQUIRE".into(), "force".into());
1037 envs.insert("SSH_ASKPASS".into(), askpass_script.display().to_string());
1038 envs.insert("ZED_SSH_ASKPASS".into(), secret);
1039 Ok(Self {
1040 connection_options: options,
1041 envs,
1042 })
1043 }
1044
1045 // :WARNING: ssh unquotes arguments when executing on the remote :WARNING:
1046 // e.g. $ ssh host sh -c 'ls -l' is equivalent to $ ssh host sh -c ls -l
1047 // and passes -l as an argument to sh, not to ls.
1048 // Furthermore, some setups (e.g. Coder) will change directory when SSH'ing
1049 // into a machine. You must use `cd` to get back to $HOME.
1050 // You need to do it like this: $ ssh host "cd; sh -c 'ls -l /tmp'"
1051 fn ssh_command(&self, program: &str, args: &[&str]) -> process::Command {
1052 let mut command = util::command::new_smol_command("ssh");
1053 let to_run = iter::once(&program)
1054 .chain(args.iter())
1055 .map(|token| {
1056 // We're trying to work with: sh, bash, zsh, fish, tcsh, ...?
1057 debug_assert!(
1058 !token.contains('\n'),
1059 "multiline arguments do not work in all shells"
1060 );
1061 shlex::try_quote(token).unwrap()
1062 })
1063 .join(" ");
1064 let to_run = format!("cd; {to_run}");
1065 log::debug!("ssh {} {:?}", self.connection_options.ssh_url(), to_run);
1066 self.ssh_options(&mut command)
1067 .arg(self.connection_options.ssh_url())
1068 .arg(to_run);
1069 command
1070 }
1071
1072 async fn run_command(&self, program: &str, args: &[&str]) -> Result<String> {
1073 let output = self.ssh_command(program, args).output().await?;
1074 anyhow::ensure!(
1075 output.status.success(),
1076 "failed to run command: {}",
1077 String::from_utf8_lossy(&output.stderr)
1078 );
1079 Ok(String::from_utf8_lossy(&output.stdout).to_string())
1080 }
1081
1082 #[cfg(not(target_os = "windows"))]
1083 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
1084 command
1085 .stdin(Stdio::piped())
1086 .stdout(Stdio::piped())
1087 .stderr(Stdio::piped())
1088 .args(self.connection_options.additional_args())
1089 .args(["-o", "ControlMaster=no", "-o"])
1090 .arg(format!("ControlPath={}", self.socket_path.display()))
1091 }
1092
1093 #[cfg(target_os = "windows")]
1094 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
1095 command
1096 .stdin(Stdio::piped())
1097 .stdout(Stdio::piped())
1098 .stderr(Stdio::piped())
1099 .args(self.connection_options.additional_args())
1100 .envs(self.envs.clone())
1101 }
1102
1103 // On Windows, we need to use `SSH_ASKPASS` to provide the password to ssh.
1104 // On Linux, we use the `ControlPath` option to create a socket file that ssh can use to
1105 #[cfg(not(target_os = "windows"))]
1106 fn ssh_args(&self) -> Vec<String> {
1107 let mut arguments = self.connection_options.additional_args();
1108 arguments.extend(vec![
1109 "-o".to_string(),
1110 "ControlMaster=no".to_string(),
1111 "-o".to_string(),
1112 format!("ControlPath={}", self.socket_path.display()),
1113 self.connection_options.ssh_url(),
1114 ]);
1115 arguments
1116 }
1117
1118 #[cfg(target_os = "windows")]
1119 fn ssh_args(&self) -> Vec<String> {
1120 let mut arguments = self.connection_options.additional_args();
1121 arguments.push(self.connection_options.ssh_url());
1122 arguments
1123 }
1124
1125 async fn platform(&self) -> Result<RemotePlatform> {
1126 let uname = self.run_command("sh", &["-lc", "uname -sm"]).await?;
1127 let Some((os, arch)) = uname.split_once(" ") else {
1128 anyhow::bail!("unknown uname: {uname:?}")
1129 };
1130
1131 let os = match os.trim() {
1132 "Darwin" => "macos",
1133 "Linux" => "linux",
1134 _ => anyhow::bail!(
1135 "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
1136 ),
1137 };
1138 // exclude armv5,6,7 as they are 32-bit.
1139 let arch = if arch.starts_with("armv8")
1140 || arch.starts_with("armv9")
1141 || arch.starts_with("arm64")
1142 || arch.starts_with("aarch64")
1143 {
1144 "aarch64"
1145 } else if arch.starts_with("x86") {
1146 "x86_64"
1147 } else {
1148 anyhow::bail!(
1149 "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
1150 )
1151 };
1152
1153 Ok(RemotePlatform { os, arch })
1154 }
1155
1156 async fn shell(&self) -> String {
1157 match self.run_command("sh", &["-lc", "echo $SHELL"]).await {
1158 Ok(shell) => shell.trim().to_owned(),
1159 Err(e) => {
1160 log::error!("Failed to get shell: {e}");
1161 "sh".to_owned()
1162 }
1163 }
1164 }
1165}
1166
1167fn parse_port_number(port_str: &str) -> Result<u16> {
1168 port_str
1169 .parse()
1170 .with_context(|| format!("parsing port number: {port_str}"))
1171}
1172
1173fn parse_port_forward_spec(spec: &str) -> Result<SshPortForwardOption> {
1174 let parts: Vec<&str> = spec.split(':').collect();
1175
1176 match parts.len() {
1177 4 => {
1178 let local_port = parse_port_number(parts[1])?;
1179 let remote_port = parse_port_number(parts[3])?;
1180
1181 Ok(SshPortForwardOption {
1182 local_host: Some(parts[0].to_string()),
1183 local_port,
1184 remote_host: Some(parts[2].to_string()),
1185 remote_port,
1186 })
1187 }
1188 3 => {
1189 let local_port = parse_port_number(parts[0])?;
1190 let remote_port = parse_port_number(parts[2])?;
1191
1192 Ok(SshPortForwardOption {
1193 local_host: None,
1194 local_port,
1195 remote_host: Some(parts[1].to_string()),
1196 remote_port,
1197 })
1198 }
1199 _ => anyhow::bail!("Invalid port forward format"),
1200 }
1201}
1202
1203impl SshConnectionOptions {
1204 pub fn parse_command_line(input: &str) -> Result<Self> {
1205 let input = input.trim_start_matches("ssh ");
1206 let mut hostname: Option<String> = None;
1207 let mut username: Option<String> = None;
1208 let mut port: Option<u16> = None;
1209 let mut args = Vec::new();
1210 let mut port_forwards: Vec<SshPortForwardOption> = Vec::new();
1211
1212 // disallowed: -E, -e, -F, -f, -G, -g, -M, -N, -n, -O, -q, -S, -s, -T, -t, -V, -v, -W
1213 const ALLOWED_OPTS: &[&str] = &[
1214 "-4", "-6", "-A", "-a", "-C", "-K", "-k", "-X", "-x", "-Y", "-y",
1215 ];
1216 const ALLOWED_ARGS: &[&str] = &[
1217 "-B", "-b", "-c", "-D", "-F", "-I", "-i", "-J", "-l", "-m", "-o", "-P", "-p", "-R",
1218 "-w",
1219 ];
1220
1221 let mut tokens = shlex::split(input).context("invalid input")?.into_iter();
1222
1223 'outer: while let Some(arg) = tokens.next() {
1224 if ALLOWED_OPTS.contains(&(&arg as &str)) {
1225 args.push(arg.to_string());
1226 continue;
1227 }
1228 if arg == "-p" {
1229 port = tokens.next().and_then(|arg| arg.parse().ok());
1230 continue;
1231 } else if let Some(p) = arg.strip_prefix("-p") {
1232 port = p.parse().ok();
1233 continue;
1234 }
1235 if arg == "-l" {
1236 username = tokens.next();
1237 continue;
1238 } else if let Some(l) = arg.strip_prefix("-l") {
1239 username = Some(l.to_string());
1240 continue;
1241 }
1242 if arg == "-L" || arg.starts_with("-L") {
1243 let forward_spec = if arg == "-L" {
1244 tokens.next()
1245 } else {
1246 Some(arg.strip_prefix("-L").unwrap().to_string())
1247 };
1248
1249 if let Some(spec) = forward_spec {
1250 port_forwards.push(parse_port_forward_spec(&spec)?);
1251 } else {
1252 anyhow::bail!("Missing port forward format");
1253 }
1254 }
1255
1256 for a in ALLOWED_ARGS {
1257 if arg == *a {
1258 args.push(arg);
1259 if let Some(next) = tokens.next() {
1260 args.push(next);
1261 }
1262 continue 'outer;
1263 } else if arg.starts_with(a) {
1264 args.push(arg);
1265 continue 'outer;
1266 }
1267 }
1268 if arg.starts_with("-") || hostname.is_some() {
1269 anyhow::bail!("unsupported argument: {:?}", arg);
1270 }
1271 let mut input = &arg as &str;
1272 // Destination might be: username1@username2@ip2@ip1
1273 if let Some((u, rest)) = input.rsplit_once('@') {
1274 input = rest;
1275 username = Some(u.to_string());
1276 }
1277 if let Some((rest, p)) = input.split_once(':') {
1278 input = rest;
1279 port = p.parse().ok()
1280 }
1281 hostname = Some(input.to_string())
1282 }
1283
1284 let Some(hostname) = hostname else {
1285 anyhow::bail!("missing hostname");
1286 };
1287
1288 let port_forwards = match port_forwards.len() {
1289 0 => None,
1290 _ => Some(port_forwards),
1291 };
1292
1293 Ok(Self {
1294 host: hostname,
1295 username,
1296 port,
1297 port_forwards,
1298 args: Some(args),
1299 password: None,
1300 nickname: None,
1301 upload_binary_over_ssh: false,
1302 })
1303 }
1304
1305 pub fn ssh_url(&self) -> String {
1306 let mut result = String::from("ssh://");
1307 if let Some(username) = &self.username {
1308 // Username might be: username1@username2@ip2
1309 let username = urlencoding::encode(username);
1310 result.push_str(&username);
1311 result.push('@');
1312 }
1313 result.push_str(&self.host);
1314 if let Some(port) = self.port {
1315 result.push(':');
1316 result.push_str(&port.to_string());
1317 }
1318 result
1319 }
1320
1321 pub fn additional_args(&self) -> Vec<String> {
1322 let mut args = self.args.iter().flatten().cloned().collect::<Vec<String>>();
1323
1324 if let Some(forwards) = &self.port_forwards {
1325 args.extend(forwards.iter().map(|pf| {
1326 let local_host = match &pf.local_host {
1327 Some(host) => host,
1328 None => "localhost",
1329 };
1330 let remote_host = match &pf.remote_host {
1331 Some(host) => host,
1332 None => "localhost",
1333 };
1334
1335 format!(
1336 "-L{}:{}:{}:{}",
1337 local_host, pf.local_port, remote_host, pf.remote_port
1338 )
1339 }));
1340 }
1341
1342 args
1343 }
1344
1345 fn scp_url(&self) -> String {
1346 if let Some(username) = &self.username {
1347 format!("{}@{}", username, self.host)
1348 } else {
1349 self.host.clone()
1350 }
1351 }
1352
1353 pub fn connection_string(&self) -> String {
1354 let host = if let Some(username) = &self.username {
1355 format!("{}@{}", username, self.host)
1356 } else {
1357 self.host.clone()
1358 };
1359 if let Some(port) = &self.port {
1360 format!("{}:{}", host, port)
1361 } else {
1362 host
1363 }
1364 }
1365}