1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6};
7use anyhow::{anyhow, Context as _, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 future::BoxFuture,
12 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
13};
14use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, Task};
15use parking_lot::Mutex;
16use rpc::{
17 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
18 EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
19};
20use smol::{
21 fs,
22 process::{self, Stdio},
23};
24use std::{
25 any::TypeId,
26 ffi::OsStr,
27 path::{Path, PathBuf},
28 sync::{
29 atomic::{AtomicU32, Ordering::SeqCst},
30 Arc,
31 },
32 time::Instant,
33};
34use tempfile::TempDir;
35
36#[derive(
37 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
38)]
39pub struct SshProjectId(pub u64);
40
41#[derive(Clone)]
42pub struct SshSocket {
43 connection_options: SshConnectionOptions,
44 socket_path: PathBuf,
45}
46
47pub struct SshSession {
48 next_message_id: AtomicU32,
49 response_channels: ResponseChannels, // Lock
50 outgoing_tx: mpsc::UnboundedSender<Envelope>,
51 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
52 client_socket: Option<SshSocket>,
53 state: Mutex<ProtoMessageHandlerSet>, // Lock
54 _io_task: Option<Task<Result<()>>>,
55}
56
57struct SshClientState {
58 socket: SshSocket,
59 master_process: process::Child,
60 _temp_dir: TempDir,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct SshConnectionOptions {
65 pub host: String,
66 pub username: Option<String>,
67 pub port: Option<u16>,
68 pub password: Option<String>,
69}
70
71impl SshConnectionOptions {
72 pub fn ssh_url(&self) -> String {
73 let mut result = String::from("ssh://");
74 if let Some(username) = &self.username {
75 result.push_str(username);
76 result.push('@');
77 }
78 result.push_str(&self.host);
79 if let Some(port) = self.port {
80 result.push(':');
81 result.push_str(&port.to_string());
82 }
83 result
84 }
85
86 fn scp_url(&self) -> String {
87 if let Some(username) = &self.username {
88 format!("{}@{}", username, self.host)
89 } else {
90 self.host.clone()
91 }
92 }
93
94 pub fn connection_string(&self) -> String {
95 let host = if let Some(username) = &self.username {
96 format!("{}@{}", username, self.host)
97 } else {
98 self.host.clone()
99 };
100 if let Some(port) = &self.port {
101 format!("{}:{}", host, port)
102 } else {
103 host
104 }
105 }
106}
107
108struct SpawnRequest {
109 command: String,
110 process_tx: oneshot::Sender<process::Child>,
111}
112
113#[derive(Copy, Clone, Debug)]
114pub struct SshPlatform {
115 pub os: &'static str,
116 pub arch: &'static str,
117}
118
119pub trait SshClientDelegate {
120 fn ask_password(
121 &self,
122 prompt: String,
123 cx: &mut AsyncAppContext,
124 ) -> oneshot::Receiver<Result<String>>;
125 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
126 fn get_server_binary(
127 &self,
128 platform: SshPlatform,
129 cx: &mut AsyncAppContext,
130 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
131 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
132 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
133}
134
135type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
136
137impl SshSession {
138 pub async fn client(
139 connection_options: SshConnectionOptions,
140 delegate: Arc<dyn SshClientDelegate>,
141 cx: &mut AsyncAppContext,
142 ) -> Result<Arc<Self>> {
143 let client_state = SshClientState::new(connection_options, delegate.clone(), cx).await?;
144
145 let platform = client_state.query_platform().await?;
146 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
147 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
148 client_state
149 .ensure_server_binary(
150 &delegate,
151 &local_binary_path,
152 &remote_binary_path,
153 version,
154 cx,
155 )
156 .await?;
157
158 let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
159 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
160 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
161
162 let socket = client_state.socket.clone();
163 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
164
165 let mut remote_server_child = socket
166 .ssh_command(format!(
167 "RUST_LOG={} RUST_BACKTRACE={} {:?} run",
168 std::env::var("RUST_LOG").unwrap_or_default(),
169 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
170 remote_binary_path,
171 ))
172 .spawn()
173 .context("failed to spawn remote server")?;
174 let mut child_stderr = remote_server_child.stderr.take().unwrap();
175 let mut child_stdout = remote_server_child.stdout.take().unwrap();
176 let mut child_stdin = remote_server_child.stdin.take().unwrap();
177
178 let io_task = cx.background_executor().spawn(async move {
179 let mut stdin_buffer = Vec::new();
180 let mut stdout_buffer = Vec::new();
181 let mut stderr_buffer = Vec::new();
182 let mut stderr_offset = 0;
183
184 loop {
185 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
186 stderr_buffer.resize(stderr_offset + 1024, 0);
187
188 select_biased! {
189 outgoing = outgoing_rx.next().fuse() => {
190 let Some(outgoing) = outgoing else {
191 return anyhow::Ok(());
192 };
193
194 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
195 }
196
197 request = spawn_process_rx.next().fuse() => {
198 let Some(request) = request else {
199 return Ok(());
200 };
201
202 log::info!("spawn process: {:?}", request.command);
203 let child = client_state.socket
204 .ssh_command(&request.command)
205 .spawn()
206 .context("failed to create channel")?;
207 request.process_tx.send(child).ok();
208 }
209
210 result = child_stdout.read(&mut stdout_buffer).fuse() => {
211 match result {
212 Ok(0) => {
213 child_stdin.close().await?;
214 outgoing_rx.close();
215 let status = remote_server_child.status().await?;
216 if !status.success() {
217 log::error!("channel exited with status: {status:?}");
218 }
219 return Ok(());
220 }
221 Ok(len) => {
222 if len < stdout_buffer.len() {
223 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
224 }
225
226 let message_len = message_len_from_buffer(&stdout_buffer);
227 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
228 Ok(envelope) => {
229 incoming_tx.unbounded_send(envelope).ok();
230 }
231 Err(error) => {
232 log::error!("error decoding message {error:?}");
233 }
234 }
235 }
236 Err(error) => {
237 Err(anyhow!("error reading stdout: {error:?}"))?;
238 }
239 }
240 }
241
242 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
243 match result {
244 Ok(len) => {
245 stderr_offset += len;
246 let mut start_ix = 0;
247 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
248 let line_ix = start_ix + ix;
249 let content = &stderr_buffer[start_ix..line_ix];
250 start_ix = line_ix + 1;
251 if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
252 record.message = format!("(remote) {}", record.message);
253 record.log(log::logger())
254 } else {
255 eprintln!("(remote) {}", String::from_utf8_lossy(content));
256 }
257 }
258 stderr_buffer.drain(0..start_ix);
259 stderr_offset -= start_ix;
260 }
261 Err(error) => {
262 Err(anyhow!("error reading stderr: {error:?}"))?;
263 }
264 }
265 }
266 }
267 }
268 });
269
270 cx.update(|cx| {
271 Self::new(
272 incoming_rx,
273 outgoing_tx,
274 spawn_process_tx,
275 Some(socket),
276 Some(io_task),
277 cx,
278 )
279 })
280 }
281
282 pub fn server(
283 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
284 outgoing_tx: mpsc::UnboundedSender<Envelope>,
285 cx: &AppContext,
286 ) -> Arc<SshSession> {
287 let (tx, _rx) = mpsc::unbounded();
288 Self::new(incoming_rx, outgoing_tx, tx, None, None, cx)
289 }
290
291 #[cfg(any(test, feature = "test-support"))]
292 pub fn fake(
293 client_cx: &mut gpui::TestAppContext,
294 server_cx: &mut gpui::TestAppContext,
295 ) -> (Arc<Self>, Arc<Self>) {
296 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
297 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
298 let (tx, _rx) = mpsc::unbounded();
299 (
300 client_cx.update(|cx| {
301 Self::new(
302 server_to_client_rx,
303 client_to_server_tx,
304 tx.clone(),
305 None, // todo()
306 None,
307 cx,
308 )
309 }),
310 server_cx.update(|cx| {
311 Self::new(
312 client_to_server_rx,
313 server_to_client_tx,
314 tx.clone(),
315 None,
316 None,
317 cx,
318 )
319 }),
320 )
321 }
322
323 fn new(
324 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
325 outgoing_tx: mpsc::UnboundedSender<Envelope>,
326 spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
327 client_socket: Option<SshSocket>,
328 io_task: Option<Task<Result<()>>>,
329 cx: &AppContext,
330 ) -> Arc<SshSession> {
331 let this = Arc::new(Self {
332 next_message_id: AtomicU32::new(0),
333 response_channels: ResponseChannels::default(),
334 outgoing_tx,
335 spawn_process_tx,
336 client_socket,
337 state: Default::default(),
338 _io_task: io_task,
339 });
340
341 cx.spawn(|cx| {
342 let this = Arc::downgrade(&this);
343 async move {
344 let peer_id = PeerId { owner_id: 0, id: 0 };
345 while let Some(incoming) = incoming_rx.next().await {
346 let Some(this) = this.upgrade() else {
347 return anyhow::Ok(());
348 };
349
350 if let Some(request_id) = incoming.responding_to {
351 let request_id = MessageId(request_id);
352 let sender = this.response_channels.lock().remove(&request_id);
353 if let Some(sender) = sender {
354 let (tx, rx) = oneshot::channel();
355 if incoming.payload.is_some() {
356 sender.send((incoming, tx)).ok();
357 }
358 rx.await.ok();
359 }
360 } else if let Some(envelope) =
361 build_typed_envelope(peer_id, Instant::now(), incoming)
362 {
363 let type_name = envelope.payload_type_name();
364 if let Some(future) = ProtoMessageHandlerSet::handle_message(
365 &this.state,
366 envelope,
367 this.clone().into(),
368 cx.clone(),
369 ) {
370 log::debug!("ssh message received. name:{type_name}");
371 match future.await {
372 Ok(_) => {
373 log::debug!("ssh message handled. name:{type_name}");
374 }
375 Err(error) => {
376 log::error!(
377 "error handling message. type:{type_name}, error:{error}",
378 );
379 }
380 }
381 } else {
382 log::error!("unhandled ssh message name:{type_name}");
383 }
384 }
385 }
386 anyhow::Ok(())
387 }
388 })
389 .detach();
390
391 this
392 }
393
394 pub fn request<T: RequestMessage>(
395 &self,
396 payload: T,
397 ) -> impl 'static + Future<Output = Result<T::Response>> {
398 log::debug!("ssh request start. name:{}", T::NAME);
399 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
400 async move {
401 let response = response.await?;
402 log::debug!("ssh request finish. name:{}", T::NAME);
403 T::Response::from_envelope(response)
404 .ok_or_else(|| anyhow!("received a response of the wrong type"))
405 }
406 }
407
408 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
409 log::debug!("ssh send name:{}", T::NAME);
410 self.send_dynamic(payload.into_envelope(0, None, None))
411 }
412
413 pub fn request_dynamic(
414 &self,
415 mut envelope: proto::Envelope,
416 type_name: &'static str,
417 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
418 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
419 let (tx, rx) = oneshot::channel();
420 let mut response_channels_lock = self.response_channels.lock();
421 response_channels_lock.insert(MessageId(envelope.id), tx);
422 drop(response_channels_lock);
423 let result = self.outgoing_tx.unbounded_send(envelope);
424 async move {
425 if let Err(error) = &result {
426 log::error!("failed to send message: {}", error);
427 return Err(anyhow!("failed to send message: {}", error));
428 }
429
430 let response = rx.await.context("connection lost")?.0;
431 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
432 return Err(RpcError::from_proto(error, type_name));
433 }
434 Ok(response)
435 }
436 }
437
438 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
439 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
440 self.outgoing_tx.unbounded_send(envelope)?;
441 Ok(())
442 }
443
444 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
445 let id = (TypeId::of::<E>(), remote_id);
446
447 let mut state = self.state.lock();
448 if state.entities_by_type_and_remote_id.contains_key(&id) {
449 panic!("already subscribed to entity");
450 }
451
452 state.entities_by_type_and_remote_id.insert(
453 id,
454 EntityMessageSubscriber::Entity {
455 handle: entity.downgrade().into(),
456 },
457 );
458 }
459
460 pub async fn spawn_process(&self, command: String) -> process::Child {
461 let (process_tx, process_rx) = oneshot::channel();
462 self.spawn_process_tx
463 .unbounded_send(SpawnRequest {
464 command,
465 process_tx,
466 })
467 .ok();
468 process_rx.await.unwrap()
469 }
470
471 pub fn ssh_args(&self) -> Vec<String> {
472 self.client_socket.as_ref().unwrap().ssh_args()
473 }
474}
475
476impl ProtoClient for SshSession {
477 fn request(
478 &self,
479 envelope: proto::Envelope,
480 request_type: &'static str,
481 ) -> BoxFuture<'static, Result<proto::Envelope>> {
482 self.request_dynamic(envelope, request_type).boxed()
483 }
484
485 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
486 self.send_dynamic(envelope)
487 }
488
489 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
490 self.send_dynamic(envelope)
491 }
492
493 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
494 &self.state
495 }
496
497 fn is_via_collab(&self) -> bool {
498 false
499 }
500}
501
502impl SshClientState {
503 #[cfg(not(unix))]
504 async fn new(
505 _connection_options: SshConnectionOptions,
506 _delegate: Arc<dyn SshClientDelegate>,
507 _cx: &mut AsyncAppContext,
508 ) -> Result<Self> {
509 Err(anyhow!("ssh is not supported on this platform"))
510 }
511
512 #[cfg(unix)]
513 async fn new(
514 connection_options: SshConnectionOptions,
515 delegate: Arc<dyn SshClientDelegate>,
516 cx: &mut AsyncAppContext,
517 ) -> Result<Self> {
518 use futures::{io::BufReader, AsyncBufReadExt as _};
519 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
520 use util::ResultExt as _;
521
522 delegate.set_status(Some("connecting"), cx);
523
524 let url = connection_options.ssh_url();
525 let temp_dir = tempfile::Builder::new()
526 .prefix("zed-ssh-session")
527 .tempdir()?;
528
529 // Create a domain socket listener to handle requests from the askpass program.
530 let askpass_socket = temp_dir.path().join("askpass.sock");
531 let listener =
532 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
533
534 let askpass_task = cx.spawn({
535 let delegate = delegate.clone();
536 |mut cx| async move {
537 while let Ok((mut stream, _)) = listener.accept().await {
538 let mut buffer = Vec::new();
539 let mut reader = BufReader::new(&mut stream);
540 if reader.read_until(b'\0', &mut buffer).await.is_err() {
541 buffer.clear();
542 }
543 let password_prompt = String::from_utf8_lossy(&buffer);
544 if let Some(password) = delegate
545 .ask_password(password_prompt.to_string(), &mut cx)
546 .await
547 .context("failed to get ssh password")
548 .and_then(|p| p)
549 .log_err()
550 {
551 stream.write_all(password.as_bytes()).await.log_err();
552 }
553 }
554 }
555 });
556
557 // Create an askpass script that communicates back to this process.
558 let askpass_script = format!(
559 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
560 askpass_socket = askpass_socket.display(),
561 print_args = "printf '%s\\0' \"$@\"",
562 shebang = "#!/bin/sh",
563 );
564 let askpass_script_path = temp_dir.path().join("askpass.sh");
565 fs::write(&askpass_script_path, askpass_script).await?;
566 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
567
568 // Start the master SSH process, which does not do anything except for establish
569 // the connection and keep it open, allowing other ssh commands to reuse it
570 // via a control socket.
571 let socket_path = temp_dir.path().join("ssh.sock");
572 let mut master_process = process::Command::new("ssh")
573 .stdin(Stdio::null())
574 .stdout(Stdio::piped())
575 .stderr(Stdio::piped())
576 .env("SSH_ASKPASS_REQUIRE", "force")
577 .env("SSH_ASKPASS", &askpass_script_path)
578 .args(["-N", "-o", "ControlMaster=yes", "-o"])
579 .arg(format!("ControlPath={}", socket_path.display()))
580 .arg(&url)
581 .spawn()?;
582
583 // Wait for this ssh process to close its stdout, indicating that authentication
584 // has completed.
585 let stdout = master_process.stdout.as_mut().unwrap();
586 let mut output = Vec::new();
587 let connection_timeout = std::time::Duration::from_secs(10);
588 let result = read_with_timeout(stdout, connection_timeout, &mut output).await;
589 if let Err(e) = result {
590 let error_message = if e.kind() == std::io::ErrorKind::TimedOut {
591 format!(
592 "Failed to connect to host. Timed out after {:?}.",
593 connection_timeout
594 )
595 } else {
596 format!("Failed to connect to host: {}.", e)
597 };
598
599 delegate.set_error(error_message, cx);
600 return Err(e.into());
601 }
602
603 drop(askpass_task);
604
605 if master_process.try_status()?.is_some() {
606 output.clear();
607 let mut stderr = master_process.stderr.take().unwrap();
608 stderr.read_to_end(&mut output).await?;
609 Err(anyhow!(
610 "failed to connect: {}",
611 String::from_utf8_lossy(&output)
612 ))?;
613 }
614
615 Ok(Self {
616 socket: SshSocket {
617 connection_options,
618 socket_path,
619 },
620 master_process,
621 _temp_dir: temp_dir,
622 })
623 }
624
625 async fn ensure_server_binary(
626 &self,
627 delegate: &Arc<dyn SshClientDelegate>,
628 src_path: &Path,
629 dst_path: &Path,
630 version: SemanticVersion,
631 cx: &mut AsyncAppContext,
632 ) -> Result<()> {
633 let mut dst_path_gz = dst_path.to_path_buf();
634 dst_path_gz.set_extension("gz");
635
636 if let Some(parent) = dst_path.parent() {
637 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
638 }
639
640 let mut server_binary_exists = false;
641 if cfg!(not(debug_assertions)) {
642 if let Ok(installed_version) =
643 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
644 {
645 if installed_version.trim() == version.to_string() {
646 server_binary_exists = true;
647 }
648 }
649 }
650
651 if server_binary_exists {
652 log::info!("remote development server already present",);
653 return Ok(());
654 }
655
656 let src_stat = fs::metadata(src_path).await?;
657 let size = src_stat.len();
658 let server_mode = 0o755;
659
660 let t0 = Instant::now();
661 delegate.set_status(Some("uploading remote development server"), cx);
662 log::info!("uploading remote development server ({}kb)", size / 1024);
663 self.upload_file(src_path, &dst_path_gz)
664 .await
665 .context("failed to upload server binary")?;
666 log::info!("uploaded remote development server in {:?}", t0.elapsed());
667
668 delegate.set_status(Some("extracting remote development server"), cx);
669 run_cmd(
670 self.socket
671 .ssh_command("gunzip")
672 .arg("--force")
673 .arg(&dst_path_gz),
674 )
675 .await?;
676
677 delegate.set_status(Some("unzipping remote development server"), cx);
678 run_cmd(
679 self.socket
680 .ssh_command("chmod")
681 .arg(format!("{:o}", server_mode))
682 .arg(dst_path),
683 )
684 .await?;
685
686 Ok(())
687 }
688
689 async fn query_platform(&self) -> Result<SshPlatform> {
690 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
691 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
692
693 let os = match os.trim() {
694 "Darwin" => "macos",
695 "Linux" => "linux",
696 _ => Err(anyhow!("unknown uname os {os:?}"))?,
697 };
698 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
699 "aarch64"
700 } else if arch.starts_with("x86") || arch.starts_with("i686") {
701 "x86_64"
702 } else {
703 Err(anyhow!("unknown uname architecture {arch:?}"))?
704 };
705
706 Ok(SshPlatform { os, arch })
707 }
708
709 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
710 let mut command = process::Command::new("scp");
711 let output = self
712 .socket
713 .ssh_options(&mut command)
714 .args(
715 self.socket
716 .connection_options
717 .port
718 .map(|port| vec!["-P".to_string(), port.to_string()])
719 .unwrap_or_default(),
720 )
721 .arg(src_path)
722 .arg(format!(
723 "{}:{}",
724 self.socket.connection_options.scp_url(),
725 dest_path.display()
726 ))
727 .output()
728 .await?;
729
730 if output.status.success() {
731 Ok(())
732 } else {
733 Err(anyhow!(
734 "failed to upload file {} -> {}: {}",
735 src_path.display(),
736 dest_path.display(),
737 String::from_utf8_lossy(&output.stderr)
738 ))
739 }
740 }
741}
742
743#[cfg(unix)]
744async fn read_with_timeout(
745 stdout: &mut process::ChildStdout,
746 timeout: std::time::Duration,
747 output: &mut Vec<u8>,
748) -> Result<(), std::io::Error> {
749 smol::future::or(
750 async {
751 stdout.read_to_end(output).await?;
752 Ok::<_, std::io::Error>(())
753 },
754 async {
755 smol::Timer::after(timeout).await;
756
757 Err(std::io::Error::new(
758 std::io::ErrorKind::TimedOut,
759 "Read operation timed out",
760 ))
761 },
762 )
763 .await
764}
765
766impl Drop for SshClientState {
767 fn drop(&mut self) {
768 if let Err(error) = self.master_process.kill() {
769 log::error!("failed to kill SSH master process: {}", error);
770 }
771 }
772}
773
774impl SshSocket {
775 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
776 let mut command = process::Command::new("ssh");
777 self.ssh_options(&mut command)
778 .arg(self.connection_options.ssh_url())
779 .arg(program);
780 command
781 }
782
783 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
784 command
785 .stdin(Stdio::piped())
786 .stdout(Stdio::piped())
787 .stderr(Stdio::piped())
788 .args(["-o", "ControlMaster=no", "-o"])
789 .arg(format!("ControlPath={}", self.socket_path.display()))
790 }
791
792 fn ssh_args(&self) -> Vec<String> {
793 vec![
794 "-o".to_string(),
795 "ControlMaster=no".to_string(),
796 "-o".to_string(),
797 format!("ControlPath={}", self.socket_path.display()),
798 self.connection_options.ssh_url(),
799 ]
800 }
801}
802
803async fn run_cmd(command: &mut process::Command) -> Result<String> {
804 let output = command.output().await?;
805 if output.status.success() {
806 Ok(String::from_utf8_lossy(&output.stdout).to_string())
807 } else {
808 Err(anyhow!(
809 "failed to run command: {}",
810 String::from_utf8_lossy(&output.stderr)
811 ))
812 }
813}