1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use collections::HashMap;
10use futures::{
11 channel::{
12 mpsc::{self, UnboundedReceiver, UnboundedSender},
13 oneshot,
14 },
15 future::BoxFuture,
16 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
17 StreamExt as _,
18};
19use gpui::{
20 AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel,
21};
22use parking_lot::Mutex;
23use rpc::{
24 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
25 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
26};
27use smol::{
28 fs,
29 process::{self, Child, Stdio},
30 Timer,
31};
32use std::{
33 any::TypeId,
34 ffi::OsStr,
35 fmt,
36 ops::ControlFlow,
37 path::{Path, PathBuf},
38 sync::{
39 atomic::{AtomicU32, Ordering::SeqCst},
40 Arc,
41 },
42 time::{Duration, Instant},
43};
44use tempfile::TempDir;
45use util::ResultExt;
46
47#[derive(
48 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
49)]
50pub struct SshProjectId(pub u64);
51
52#[derive(Clone)]
53pub struct SshSocket {
54 connection_options: SshConnectionOptions,
55 socket_path: PathBuf,
56}
57
58#[derive(Debug, Default, Clone, PartialEq, Eq)]
59pub struct SshConnectionOptions {
60 pub host: String,
61 pub username: Option<String>,
62 pub port: Option<u16>,
63 pub password: Option<String>,
64}
65
66impl SshConnectionOptions {
67 pub fn ssh_url(&self) -> String {
68 let mut result = String::from("ssh://");
69 if let Some(username) = &self.username {
70 result.push_str(username);
71 result.push('@');
72 }
73 result.push_str(&self.host);
74 if let Some(port) = self.port {
75 result.push(':');
76 result.push_str(&port.to_string());
77 }
78 result
79 }
80
81 fn scp_url(&self) -> String {
82 if let Some(username) = &self.username {
83 format!("{}@{}", username, self.host)
84 } else {
85 self.host.clone()
86 }
87 }
88
89 pub fn connection_string(&self) -> String {
90 let host = if let Some(username) = &self.username {
91 format!("{}@{}", username, self.host)
92 } else {
93 self.host.clone()
94 };
95 if let Some(port) = &self.port {
96 format!("{}:{}", host, port)
97 } else {
98 host
99 }
100 }
101
102 // Uniquely identifies dev server projects on a remote host. Needs to be
103 // stable for the same dev server project.
104 pub fn dev_server_identifier(&self) -> String {
105 let mut identifier = format!("dev-server-{:?}", self.host);
106 if let Some(username) = self.username.as_ref() {
107 identifier.push('-');
108 identifier.push_str(&username);
109 }
110 identifier
111 }
112}
113
114#[derive(Copy, Clone, Debug)]
115pub struct SshPlatform {
116 pub os: &'static str,
117 pub arch: &'static str,
118}
119
120pub trait SshClientDelegate: Send + Sync {
121 fn ask_password(
122 &self,
123 prompt: String,
124 cx: &mut AsyncAppContext,
125 ) -> oneshot::Receiver<Result<String>>;
126 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
127 fn get_server_binary(
128 &self,
129 platform: SshPlatform,
130 cx: &mut AsyncAppContext,
131 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
132 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
133 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
134}
135
136impl SshSocket {
137 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
138 let mut command = process::Command::new("ssh");
139 self.ssh_options(&mut command)
140 .arg(self.connection_options.ssh_url())
141 .arg(program);
142 command
143 }
144
145 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
146 command
147 .stdin(Stdio::piped())
148 .stdout(Stdio::piped())
149 .stderr(Stdio::piped())
150 .args(["-o", "ControlMaster=no", "-o"])
151 .arg(format!("ControlPath={}", self.socket_path.display()))
152 }
153
154 fn ssh_args(&self) -> Vec<String> {
155 vec![
156 "-o".to_string(),
157 "ControlMaster=no".to_string(),
158 "-o".to_string(),
159 format!("ControlPath={}", self.socket_path.display()),
160 self.connection_options.ssh_url(),
161 ]
162 }
163}
164
165async fn run_cmd(command: &mut process::Command) -> Result<String> {
166 let output = command.output().await?;
167 if output.status.success() {
168 Ok(String::from_utf8_lossy(&output.stdout).to_string())
169 } else {
170 Err(anyhow!(
171 "failed to run command: {}",
172 String::from_utf8_lossy(&output.stderr)
173 ))
174 }
175}
176
177struct ChannelForwarder {
178 quit_tx: UnboundedSender<()>,
179 forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
180}
181
182impl ChannelForwarder {
183 fn new(
184 mut incoming_tx: UnboundedSender<Envelope>,
185 mut outgoing_rx: UnboundedReceiver<Envelope>,
186 cx: &AsyncAppContext,
187 ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
188 let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
189
190 let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
191 let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
192
193 let forwarding_task = cx.background_executor().spawn(async move {
194 loop {
195 select_biased! {
196 _ = quit_rx.next().fuse() => {
197 break;
198 },
199 incoming_envelope = proxy_incoming_rx.next().fuse() => {
200 if let Some(envelope) = incoming_envelope {
201 if incoming_tx.send(envelope).await.is_err() {
202 break;
203 }
204 } else {
205 break;
206 }
207 }
208 outgoing_envelope = outgoing_rx.next().fuse() => {
209 if let Some(envelope) = outgoing_envelope {
210 if proxy_outgoing_tx.send(envelope).await.is_err() {
211 break;
212 }
213 } else {
214 break;
215 }
216 }
217 }
218 }
219
220 (incoming_tx, outgoing_rx)
221 });
222
223 (
224 Self {
225 forwarding_task,
226 quit_tx,
227 },
228 proxy_incoming_tx,
229 proxy_outgoing_rx,
230 )
231 }
232
233 async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
234 let _ = self.quit_tx.send(()).await;
235 self.forwarding_task.await
236 }
237}
238
239const MAX_MISSED_HEARTBEATS: usize = 5;
240const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
241const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
242
243const MAX_RECONNECT_ATTEMPTS: usize = 3;
244
245enum State {
246 Connecting,
247 Connected {
248 ssh_connection: SshRemoteConnection,
249 delegate: Arc<dyn SshClientDelegate>,
250 forwarder: ChannelForwarder,
251
252 multiplex_task: Task<Result<()>>,
253 heartbeat_task: Task<Result<()>>,
254 },
255 HeartbeatMissed {
256 missed_heartbeats: usize,
257
258 ssh_connection: SshRemoteConnection,
259 delegate: Arc<dyn SshClientDelegate>,
260 forwarder: ChannelForwarder,
261
262 multiplex_task: Task<Result<()>>,
263 heartbeat_task: Task<Result<()>>,
264 },
265 Reconnecting,
266 ReconnectFailed {
267 ssh_connection: SshRemoteConnection,
268 delegate: Arc<dyn SshClientDelegate>,
269 forwarder: ChannelForwarder,
270
271 error: anyhow::Error,
272 attempts: usize,
273 },
274 ReconnectExhausted,
275 ServerNotRunning,
276}
277
278impl fmt::Display for State {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 match self {
281 Self::Connecting => write!(f, "connecting"),
282 Self::Connected { .. } => write!(f, "connected"),
283 Self::Reconnecting => write!(f, "reconnecting"),
284 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
285 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
286 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
287 Self::ServerNotRunning { .. } => write!(f, "server not running"),
288 }
289 }
290}
291
292impl State {
293 fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
294 match self {
295 Self::Connected { ssh_connection, .. } => Some(ssh_connection),
296 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
297 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
298 _ => None,
299 }
300 }
301
302 fn can_reconnect(&self) -> bool {
303 match self {
304 Self::Connected { .. }
305 | Self::HeartbeatMissed { .. }
306 | Self::ReconnectFailed { .. } => true,
307 State::Connecting
308 | State::Reconnecting
309 | State::ReconnectExhausted
310 | State::ServerNotRunning => false,
311 }
312 }
313
314 fn is_reconnect_failed(&self) -> bool {
315 matches!(self, Self::ReconnectFailed { .. })
316 }
317
318 fn is_reconnecting(&self) -> bool {
319 matches!(self, Self::Reconnecting { .. })
320 }
321
322 fn heartbeat_recovered(self) -> Self {
323 match self {
324 Self::HeartbeatMissed {
325 ssh_connection,
326 delegate,
327 forwarder,
328 multiplex_task,
329 heartbeat_task,
330 ..
331 } => Self::Connected {
332 ssh_connection,
333 delegate,
334 forwarder,
335 multiplex_task,
336 heartbeat_task,
337 },
338 _ => self,
339 }
340 }
341
342 fn heartbeat_missed(self) -> Self {
343 match self {
344 Self::Connected {
345 ssh_connection,
346 delegate,
347 forwarder,
348 multiplex_task,
349 heartbeat_task,
350 } => Self::HeartbeatMissed {
351 missed_heartbeats: 1,
352 ssh_connection,
353 delegate,
354 forwarder,
355 multiplex_task,
356 heartbeat_task,
357 },
358 Self::HeartbeatMissed {
359 missed_heartbeats,
360 ssh_connection,
361 delegate,
362 forwarder,
363 multiplex_task,
364 heartbeat_task,
365 } => Self::HeartbeatMissed {
366 missed_heartbeats: missed_heartbeats + 1,
367 ssh_connection,
368 delegate,
369 forwarder,
370 multiplex_task,
371 heartbeat_task,
372 },
373 _ => self,
374 }
375 }
376}
377
378/// The state of the ssh connection.
379#[derive(Clone, Copy, Debug)]
380pub enum ConnectionState {
381 Connecting,
382 Connected,
383 HeartbeatMissed,
384 Reconnecting,
385 Disconnected,
386}
387
388impl From<&State> for ConnectionState {
389 fn from(value: &State) -> Self {
390 match value {
391 State::Connecting => Self::Connecting,
392 State::Connected { .. } => Self::Connected,
393 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
394 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
395 State::ReconnectExhausted => Self::Disconnected,
396 State::ServerNotRunning => Self::Disconnected,
397 }
398 }
399}
400
401pub struct SshRemoteClient {
402 client: Arc<ChannelClient>,
403 unique_identifier: String,
404 connection_options: SshConnectionOptions,
405 state: Arc<Mutex<Option<State>>>,
406}
407
408impl Drop for SshRemoteClient {
409 fn drop(&mut self) {
410 self.shutdown_processes();
411 }
412}
413
414impl SshRemoteClient {
415 pub fn new(
416 unique_identifier: String,
417 connection_options: SshConnectionOptions,
418 delegate: Arc<dyn SshClientDelegate>,
419 cx: &AppContext,
420 ) -> Task<Result<Model<Self>>> {
421 cx.spawn(|mut cx| async move {
422 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
423 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
424
425 let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
426 let this = cx.new_model(|cx| {
427 cx.on_app_quit(|this: &mut Self, _| {
428 this.shutdown_processes();
429 futures::future::ready(())
430 })
431 .detach();
432
433 Self {
434 client: client.clone(),
435 unique_identifier: unique_identifier.clone(),
436 connection_options: connection_options.clone(),
437 state: Arc::new(Mutex::new(Some(State::Connecting))),
438 }
439 })?;
440
441 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
442 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
443
444 let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
445 unique_identifier,
446 false,
447 connection_options,
448 delegate.clone(),
449 &mut cx,
450 )
451 .await?;
452
453 let multiplex_task = Self::multiplex(
454 this.downgrade(),
455 ssh_proxy_process,
456 proxy_incoming_tx,
457 proxy_outgoing_rx,
458 &mut cx,
459 );
460
461 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
462 log::error!("failed to establish connection: {}", error);
463 delegate.set_error(error.to_string(), &mut cx);
464 return Err(error);
465 }
466
467 let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
468
469 this.update(&mut cx, |this, _| {
470 *this.state.lock() = Some(State::Connected {
471 ssh_connection,
472 delegate,
473 forwarder: proxy,
474 multiplex_task,
475 heartbeat_task,
476 });
477 })?;
478
479 Ok(this)
480 })
481 }
482
483 fn shutdown_processes(&self) {
484 let Some(state) = self.state.lock().take() else {
485 return;
486 };
487 log::info!("shutting down ssh processes");
488
489 let State::Connected {
490 multiplex_task,
491 heartbeat_task,
492 ..
493 } = state
494 else {
495 return;
496 };
497 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
498 // child of master_process.
499 drop(multiplex_task);
500 // Now drop the rest of state, which kills master process.
501 drop(heartbeat_task);
502 }
503
504 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
505 let mut lock = self.state.lock();
506
507 let can_reconnect = lock
508 .as_ref()
509 .map(|state| state.can_reconnect())
510 .unwrap_or(false);
511 if !can_reconnect {
512 let error = if let Some(state) = lock.as_ref() {
513 format!("invalid state, cannot reconnect while in state {state}")
514 } else {
515 "no state set".to_string()
516 };
517 log::info!("aborting reconnect, because not in state that allows reconnecting");
518 return Err(anyhow!(error));
519 }
520
521 let state = lock.take().unwrap();
522 let (attempts, mut ssh_connection, delegate, forwarder) = match state {
523 State::Connected {
524 ssh_connection,
525 delegate,
526 forwarder,
527 multiplex_task,
528 heartbeat_task,
529 }
530 | State::HeartbeatMissed {
531 ssh_connection,
532 delegate,
533 forwarder,
534 multiplex_task,
535 heartbeat_task,
536 ..
537 } => {
538 drop(multiplex_task);
539 drop(heartbeat_task);
540 (0, ssh_connection, delegate, forwarder)
541 }
542 State::ReconnectFailed {
543 attempts,
544 ssh_connection,
545 delegate,
546 forwarder,
547 ..
548 } => (attempts, ssh_connection, delegate, forwarder),
549 State::Connecting
550 | State::Reconnecting
551 | State::ReconnectExhausted
552 | State::ServerNotRunning => unreachable!(),
553 };
554
555 let attempts = attempts + 1;
556 if attempts > MAX_RECONNECT_ATTEMPTS {
557 log::error!(
558 "Failed to reconnect to after {} attempts, giving up",
559 MAX_RECONNECT_ATTEMPTS
560 );
561 drop(lock);
562 self.set_state(State::ReconnectExhausted, cx);
563 return Ok(());
564 }
565 drop(lock);
566
567 self.set_state(State::Reconnecting, cx);
568
569 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
570
571 let identifier = self.unique_identifier.clone();
572 let client = self.client.clone();
573 let reconnect_task = cx.spawn(|this, mut cx| async move {
574 macro_rules! failed {
575 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
576 return State::ReconnectFailed {
577 error: anyhow!($error),
578 attempts: $attempts,
579 ssh_connection: $ssh_connection,
580 delegate: $delegate,
581 forwarder: $forwarder,
582 };
583 };
584 }
585
586 if let Err(error) = ssh_connection.master_process.kill() {
587 failed!(error, attempts, ssh_connection, delegate, forwarder);
588 };
589
590 if let Err(error) = ssh_connection
591 .master_process
592 .status()
593 .await
594 .context("Failed to kill ssh process")
595 {
596 failed!(error, attempts, ssh_connection, delegate, forwarder);
597 }
598
599 let connection_options = ssh_connection.socket.connection_options.clone();
600
601 let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
602 let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
603 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
604
605 let (ssh_connection, ssh_process) = match Self::establish_connection(
606 identifier,
607 true,
608 connection_options,
609 delegate.clone(),
610 &mut cx,
611 )
612 .await
613 {
614 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
615 Err(error) => {
616 failed!(error, attempts, ssh_connection, delegate, forwarder);
617 }
618 };
619
620 let multiplex_task = Self::multiplex(
621 this.clone(),
622 ssh_process,
623 proxy_incoming_tx,
624 proxy_outgoing_rx,
625 &mut cx,
626 );
627
628 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
629 failed!(error, attempts, ssh_connection, delegate, forwarder);
630 };
631
632 State::Connected {
633 ssh_connection,
634 delegate,
635 forwarder,
636 multiplex_task,
637 heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
638 }
639 });
640
641 cx.spawn(|this, mut cx| async move {
642 let new_state = reconnect_task.await;
643 this.update(&mut cx, |this, cx| {
644 this.try_set_state(cx, |old_state| {
645 if old_state.is_reconnecting() {
646 match &new_state {
647 State::Connecting
648 | State::Reconnecting { .. }
649 | State::HeartbeatMissed { .. }
650 | State::ServerNotRunning => {}
651 State::Connected { .. } => {
652 log::info!("Successfully reconnected");
653 }
654 State::ReconnectFailed {
655 error, attempts, ..
656 } => {
657 log::error!(
658 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
659 attempts,
660 error
661 );
662 }
663 State::ReconnectExhausted => {
664 log::error!("Reconnect attempt failed and all attempts exhausted");
665 }
666 }
667 Some(new_state)
668 } else {
669 None
670 }
671 });
672
673 if this.state_is(State::is_reconnect_failed) {
674 this.reconnect(cx)
675 } else {
676 log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
677 Ok(())
678 }
679 })
680 })
681 .detach_and_log_err(cx);
682
683 Ok(())
684 }
685
686 fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
687 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
688 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
689 };
690 cx.spawn(|mut cx| {
691 let this = this.clone();
692 async move {
693 let mut missed_heartbeats = 0;
694
695 let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
696 loop {
697 timer.next().await;
698
699 log::debug!("Sending heartbeat to server...");
700
701 let result = client.ping(HEARTBEAT_TIMEOUT).await;
702 if result.is_err() {
703 missed_heartbeats += 1;
704 log::warn!(
705 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
706 HEARTBEAT_TIMEOUT,
707 missed_heartbeats,
708 MAX_MISSED_HEARTBEATS
709 );
710 } else if missed_heartbeats != 0 {
711 missed_heartbeats = 0;
712 } else {
713 continue;
714 }
715
716 let result = this.update(&mut cx, |this, mut cx| {
717 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
718 })?;
719 if result.is_break() {
720 return Ok(());
721 }
722 }
723 }
724 })
725 }
726
727 fn handle_heartbeat_result(
728 &mut self,
729 missed_heartbeats: usize,
730 cx: &mut ModelContext<Self>,
731 ) -> ControlFlow<()> {
732 let state = self.state.lock().take().unwrap();
733 let next_state = if missed_heartbeats > 0 {
734 state.heartbeat_missed()
735 } else {
736 state.heartbeat_recovered()
737 };
738
739 self.set_state(next_state, cx);
740
741 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
742 log::error!(
743 "Missed last {} heartbeats. Reconnecting...",
744 missed_heartbeats
745 );
746
747 self.reconnect(cx)
748 .context("failed to start reconnect process after missing heartbeats")
749 .log_err();
750 ControlFlow::Break(())
751 } else {
752 ControlFlow::Continue(())
753 }
754 }
755
756 fn multiplex(
757 this: WeakModel<Self>,
758 mut ssh_proxy_process: Child,
759 incoming_tx: UnboundedSender<Envelope>,
760 mut outgoing_rx: UnboundedReceiver<Envelope>,
761 cx: &AsyncAppContext,
762 ) -> Task<Result<()>> {
763 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
764 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
765 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
766
767 let io_task = cx.background_executor().spawn(async move {
768 let mut stdin_buffer = Vec::new();
769 let mut stdout_buffer = Vec::new();
770 let mut stderr_buffer = Vec::new();
771 let mut stderr_offset = 0;
772
773 loop {
774 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
775 stderr_buffer.resize(stderr_offset + 1024, 0);
776
777 select_biased! {
778 outgoing = outgoing_rx.next().fuse() => {
779 let Some(outgoing) = outgoing else {
780 return anyhow::Ok(None);
781 };
782
783 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
784 }
785
786 result = child_stdout.read(&mut stdout_buffer).fuse() => {
787 match result {
788 Ok(0) => {
789 child_stdin.close().await?;
790 outgoing_rx.close();
791 let status = ssh_proxy_process.status().await?;
792 return Ok(status.code());
793 }
794 Ok(len) => {
795 if len < stdout_buffer.len() {
796 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
797 }
798
799 let message_len = message_len_from_buffer(&stdout_buffer);
800 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
801 Ok(envelope) => {
802 incoming_tx.unbounded_send(envelope).ok();
803 }
804 Err(error) => {
805 log::error!("error decoding message {error:?}");
806 }
807 }
808 }
809 Err(error) => {
810 Err(anyhow!("error reading stdout: {error:?}"))?;
811 }
812 }
813 }
814
815 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
816 match result {
817 Ok(len) => {
818 stderr_offset += len;
819 let mut start_ix = 0;
820 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
821 let line_ix = start_ix + ix;
822 let content = &stderr_buffer[start_ix..line_ix];
823 start_ix = line_ix + 1;
824 if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
825 record.message = format!("(remote) {}", record.message);
826 record.log(log::logger())
827 } else {
828 eprintln!("(remote) {}", String::from_utf8_lossy(content));
829 }
830 }
831 stderr_buffer.drain(0..start_ix);
832 stderr_offset -= start_ix;
833 }
834 Err(error) => {
835 Err(anyhow!("error reading stderr: {error:?}"))?;
836 }
837 }
838 }
839 }
840 }
841 });
842
843 cx.spawn(|mut cx| async move {
844 let result = io_task.await;
845
846 match result {
847 Ok(Some(exit_code)) => {
848 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
849 match error {
850 ProxyLaunchError::ServerNotRunning => {
851 log::error!("failed to reconnect because server is not running");
852 this.update(&mut cx, |this, cx| {
853 this.set_state(State::ServerNotRunning, cx);
854 })?;
855 }
856 }
857 } else if exit_code > 0 {
858 log::error!("proxy process terminated unexpectedly");
859 }
860 }
861 Ok(None) => {}
862 Err(error) => {
863 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
864 this.update(&mut cx, |this, cx| {
865 this.reconnect(cx).ok();
866 })?;
867 }
868 }
869 Ok(())
870 })
871 }
872
873 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
874 self.state.lock().as_ref().map_or(false, check)
875 }
876
877 fn try_set_state(
878 &self,
879 cx: &mut ModelContext<Self>,
880 map: impl FnOnce(&State) -> Option<State>,
881 ) {
882 let mut lock = self.state.lock();
883 let new_state = lock.as_ref().and_then(map);
884
885 if let Some(new_state) = new_state {
886 lock.replace(new_state);
887 cx.notify();
888 }
889 }
890
891 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
892 log::info!("setting state to '{}'", &state);
893 self.state.lock().replace(state);
894 cx.notify();
895 }
896
897 async fn establish_connection(
898 unique_identifier: String,
899 reconnect: bool,
900 connection_options: SshConnectionOptions,
901 delegate: Arc<dyn SshClientDelegate>,
902 cx: &mut AsyncAppContext,
903 ) -> Result<(SshRemoteConnection, Child)> {
904 let ssh_connection =
905 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
906
907 let platform = ssh_connection.query_platform().await?;
908 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
909 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
910 ssh_connection
911 .ensure_server_binary(
912 &delegate,
913 &local_binary_path,
914 &remote_binary_path,
915 version,
916 cx,
917 )
918 .await?;
919
920 let socket = ssh_connection.socket.clone();
921 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
922
923 delegate.set_status(Some("Starting proxy"), cx);
924
925 let mut start_proxy_command = format!(
926 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
927 std::env::var("RUST_LOG").unwrap_or_default(),
928 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
929 remote_binary_path,
930 unique_identifier,
931 );
932 if reconnect {
933 start_proxy_command.push_str(" --reconnect");
934 }
935
936 let ssh_proxy_process = socket
937 .ssh_command(start_proxy_command)
938 // IMPORTANT: we kill this process when we drop the task that uses it.
939 .kill_on_drop(true)
940 .spawn()
941 .context("failed to spawn remote server")?;
942
943 Ok((ssh_connection, ssh_proxy_process))
944 }
945
946 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
947 self.client.subscribe_to_entity(remote_id, entity);
948 }
949
950 pub fn ssh_args(&self) -> Option<Vec<String>> {
951 self.state
952 .lock()
953 .as_ref()
954 .and_then(|state| state.ssh_connection())
955 .map(|ssh_connection| ssh_connection.socket.ssh_args())
956 }
957
958 pub fn to_proto_client(&self) -> AnyProtoClient {
959 self.client.clone().into()
960 }
961
962 pub fn connection_string(&self) -> String {
963 self.connection_options.connection_string()
964 }
965
966 pub fn connection_state(&self) -> ConnectionState {
967 self.state
968 .lock()
969 .as_ref()
970 .map(ConnectionState::from)
971 .unwrap_or(ConnectionState::Disconnected)
972 }
973
974 #[cfg(any(test, feature = "test-support"))]
975 pub fn fake(
976 client_cx: &mut gpui::TestAppContext,
977 server_cx: &mut gpui::TestAppContext,
978 ) -> (Model<Self>, Arc<ChannelClient>) {
979 use gpui::Context;
980
981 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
982 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
983
984 (
985 client_cx.update(|cx| {
986 let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
987 cx.new_model(|_| Self {
988 client,
989 unique_identifier: "fake".to_string(),
990 connection_options: SshConnectionOptions::default(),
991 state: Arc::new(Mutex::new(None)),
992 })
993 }),
994 server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
995 )
996 }
997}
998
999impl From<SshRemoteClient> for AnyProtoClient {
1000 fn from(client: SshRemoteClient) -> Self {
1001 AnyProtoClient::new(client.client.clone())
1002 }
1003}
1004
1005struct SshRemoteConnection {
1006 socket: SshSocket,
1007 master_process: process::Child,
1008 _temp_dir: TempDir,
1009}
1010
1011impl Drop for SshRemoteConnection {
1012 fn drop(&mut self) {
1013 if let Err(error) = self.master_process.kill() {
1014 log::error!("failed to kill SSH master process: {}", error);
1015 }
1016 }
1017}
1018
1019impl SshRemoteConnection {
1020 #[cfg(not(unix))]
1021 async fn new(
1022 _connection_options: SshConnectionOptions,
1023 _delegate: Arc<dyn SshClientDelegate>,
1024 _cx: &mut AsyncAppContext,
1025 ) -> Result<Self> {
1026 Err(anyhow!("ssh is not supported on this platform"))
1027 }
1028
1029 #[cfg(unix)]
1030 async fn new(
1031 connection_options: SshConnectionOptions,
1032 delegate: Arc<dyn SshClientDelegate>,
1033 cx: &mut AsyncAppContext,
1034 ) -> Result<Self> {
1035 use futures::{io::BufReader, AsyncBufReadExt as _};
1036 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1037 use util::ResultExt as _;
1038
1039 delegate.set_status(Some("connecting"), cx);
1040
1041 let url = connection_options.ssh_url();
1042 let temp_dir = tempfile::Builder::new()
1043 .prefix("zed-ssh-session")
1044 .tempdir()?;
1045
1046 // Create a domain socket listener to handle requests from the askpass program.
1047 let askpass_socket = temp_dir.path().join("askpass.sock");
1048 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1049 let listener =
1050 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1051
1052 let askpass_task = cx.spawn({
1053 let delegate = delegate.clone();
1054 |mut cx| async move {
1055 let mut askpass_opened_tx = Some(askpass_opened_tx);
1056
1057 while let Ok((mut stream, _)) = listener.accept().await {
1058 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1059 askpass_opened_tx.send(()).ok();
1060 }
1061 let mut buffer = Vec::new();
1062 let mut reader = BufReader::new(&mut stream);
1063 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1064 buffer.clear();
1065 }
1066 let password_prompt = String::from_utf8_lossy(&buffer);
1067 if let Some(password) = delegate
1068 .ask_password(password_prompt.to_string(), &mut cx)
1069 .await
1070 .context("failed to get ssh password")
1071 .and_then(|p| p)
1072 .log_err()
1073 {
1074 stream.write_all(password.as_bytes()).await.log_err();
1075 }
1076 }
1077 }
1078 });
1079
1080 // Create an askpass script that communicates back to this process.
1081 let askpass_script = format!(
1082 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1083 askpass_socket = askpass_socket.display(),
1084 print_args = "printf '%s\\0' \"$@\"",
1085 shebang = "#!/bin/sh",
1086 );
1087 let askpass_script_path = temp_dir.path().join("askpass.sh");
1088 fs::write(&askpass_script_path, askpass_script).await?;
1089 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1090
1091 // Start the master SSH process, which does not do anything except for establish
1092 // the connection and keep it open, allowing other ssh commands to reuse it
1093 // via a control socket.
1094 let socket_path = temp_dir.path().join("ssh.sock");
1095 let mut master_process = process::Command::new("ssh")
1096 .stdin(Stdio::null())
1097 .stdout(Stdio::piped())
1098 .stderr(Stdio::piped())
1099 .env("SSH_ASKPASS_REQUIRE", "force")
1100 .env("SSH_ASKPASS", &askpass_script_path)
1101 .args(["-N", "-o", "ControlMaster=yes", "-o"])
1102 .arg(format!("ControlPath={}", socket_path.display()))
1103 .arg(&url)
1104 .spawn()?;
1105
1106 // Wait for this ssh process to close its stdout, indicating that authentication
1107 // has completed.
1108 let stdout = master_process.stdout.as_mut().unwrap();
1109 let mut output = Vec::new();
1110 let connection_timeout = Duration::from_secs(10);
1111
1112 let result = select_biased! {
1113 _ = askpass_opened_rx.fuse() => {
1114 // If the askpass script has opened, that means the user is typing
1115 // their password, in which case we don't want to timeout anymore,
1116 // since we know a connection has been established.
1117 stdout.read_to_end(&mut output).await?;
1118 Ok(())
1119 }
1120 result = stdout.read_to_end(&mut output).fuse() => {
1121 result?;
1122 Ok(())
1123 }
1124 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1125 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1126 }
1127 };
1128
1129 if let Err(e) = result {
1130 let error_message = format!("Failed to connect to host: {}.", e);
1131 delegate.set_error(error_message, cx);
1132 return Err(e);
1133 }
1134
1135 drop(askpass_task);
1136
1137 if master_process.try_status()?.is_some() {
1138 output.clear();
1139 let mut stderr = master_process.stderr.take().unwrap();
1140 stderr.read_to_end(&mut output).await?;
1141
1142 let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1143 delegate.set_error(error_message.clone(), cx);
1144 Err(anyhow!(error_message))?;
1145 }
1146
1147 Ok(Self {
1148 socket: SshSocket {
1149 connection_options,
1150 socket_path,
1151 },
1152 master_process,
1153 _temp_dir: temp_dir,
1154 })
1155 }
1156
1157 async fn ensure_server_binary(
1158 &self,
1159 delegate: &Arc<dyn SshClientDelegate>,
1160 src_path: &Path,
1161 dst_path: &Path,
1162 version: SemanticVersion,
1163 cx: &mut AsyncAppContext,
1164 ) -> Result<()> {
1165 let mut dst_path_gz = dst_path.to_path_buf();
1166 dst_path_gz.set_extension("gz");
1167
1168 if let Some(parent) = dst_path.parent() {
1169 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1170 }
1171
1172 let mut server_binary_exists = false;
1173 if cfg!(not(debug_assertions)) {
1174 if let Ok(installed_version) =
1175 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1176 {
1177 if installed_version.trim() == version.to_string() {
1178 server_binary_exists = true;
1179 }
1180 }
1181 }
1182
1183 if server_binary_exists {
1184 log::info!("remote development server already present",);
1185 return Ok(());
1186 }
1187
1188 let src_stat = fs::metadata(src_path).await?;
1189 let size = src_stat.len();
1190 let server_mode = 0o755;
1191
1192 let t0 = Instant::now();
1193 delegate.set_status(Some("uploading remote development server"), cx);
1194 log::info!("uploading remote development server ({}kb)", size / 1024);
1195 self.upload_file(src_path, &dst_path_gz)
1196 .await
1197 .context("failed to upload server binary")?;
1198 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1199
1200 delegate.set_status(Some("extracting remote development server"), cx);
1201 run_cmd(
1202 self.socket
1203 .ssh_command("gunzip")
1204 .arg("--force")
1205 .arg(&dst_path_gz),
1206 )
1207 .await?;
1208
1209 delegate.set_status(Some("unzipping remote development server"), cx);
1210 run_cmd(
1211 self.socket
1212 .ssh_command("chmod")
1213 .arg(format!("{:o}", server_mode))
1214 .arg(dst_path),
1215 )
1216 .await?;
1217
1218 Ok(())
1219 }
1220
1221 async fn query_platform(&self) -> Result<SshPlatform> {
1222 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1223 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1224
1225 let os = match os.trim() {
1226 "Darwin" => "macos",
1227 "Linux" => "linux",
1228 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1229 };
1230 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1231 "aarch64"
1232 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1233 "x86_64"
1234 } else {
1235 Err(anyhow!("unknown uname architecture {arch:?}"))?
1236 };
1237
1238 Ok(SshPlatform { os, arch })
1239 }
1240
1241 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1242 let mut command = process::Command::new("scp");
1243 let output = self
1244 .socket
1245 .ssh_options(&mut command)
1246 .args(
1247 self.socket
1248 .connection_options
1249 .port
1250 .map(|port| vec!["-P".to_string(), port.to_string()])
1251 .unwrap_or_default(),
1252 )
1253 .arg(src_path)
1254 .arg(format!(
1255 "{}:{}",
1256 self.socket.connection_options.scp_url(),
1257 dest_path.display()
1258 ))
1259 .output()
1260 .await?;
1261
1262 if output.status.success() {
1263 Ok(())
1264 } else {
1265 Err(anyhow!(
1266 "failed to upload file {} -> {}: {}",
1267 src_path.display(),
1268 dest_path.display(),
1269 String::from_utf8_lossy(&output.stderr)
1270 ))
1271 }
1272 }
1273}
1274
1275type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1276
1277pub struct ChannelClient {
1278 next_message_id: AtomicU32,
1279 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1280 response_channels: ResponseChannels, // Lock
1281 message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1282}
1283
1284impl ChannelClient {
1285 pub fn new(
1286 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1287 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1288 cx: &AppContext,
1289 ) -> Arc<Self> {
1290 let this = Arc::new(Self {
1291 outgoing_tx,
1292 next_message_id: AtomicU32::new(0),
1293 response_channels: ResponseChannels::default(),
1294 message_handlers: Default::default(),
1295 });
1296
1297 Self::start_handling_messages(this.clone(), incoming_rx, cx);
1298
1299 this
1300 }
1301
1302 fn start_handling_messages(
1303 this: Arc<Self>,
1304 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1305 cx: &AppContext,
1306 ) {
1307 cx.spawn(|cx| {
1308 let this = Arc::downgrade(&this);
1309 async move {
1310 let peer_id = PeerId { owner_id: 0, id: 0 };
1311 while let Some(incoming) = incoming_rx.next().await {
1312 let Some(this) = this.upgrade() else {
1313 return anyhow::Ok(());
1314 };
1315
1316 if let Some(request_id) = incoming.responding_to {
1317 let request_id = MessageId(request_id);
1318 let sender = this.response_channels.lock().remove(&request_id);
1319 if let Some(sender) = sender {
1320 let (tx, rx) = oneshot::channel();
1321 if incoming.payload.is_some() {
1322 sender.send((incoming, tx)).ok();
1323 }
1324 rx.await.ok();
1325 }
1326 } else if let Some(envelope) =
1327 build_typed_envelope(peer_id, Instant::now(), incoming)
1328 {
1329 let type_name = envelope.payload_type_name();
1330 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1331 &this.message_handlers,
1332 envelope,
1333 this.clone().into(),
1334 cx.clone(),
1335 ) {
1336 log::debug!("ssh message received. name:{type_name}");
1337 match future.await {
1338 Ok(_) => {
1339 log::debug!("ssh message handled. name:{type_name}");
1340 }
1341 Err(error) => {
1342 log::error!(
1343 "error handling message. type:{type_name}, error:{error}",
1344 );
1345 }
1346 }
1347 } else {
1348 log::error!("unhandled ssh message name:{type_name}");
1349 }
1350 }
1351 }
1352 anyhow::Ok(())
1353 }
1354 })
1355 .detach();
1356 }
1357
1358 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1359 let id = (TypeId::of::<E>(), remote_id);
1360
1361 let mut message_handlers = self.message_handlers.lock();
1362 if message_handlers
1363 .entities_by_type_and_remote_id
1364 .contains_key(&id)
1365 {
1366 panic!("already subscribed to entity");
1367 }
1368
1369 message_handlers.entities_by_type_and_remote_id.insert(
1370 id,
1371 EntityMessageSubscriber::Entity {
1372 handle: entity.downgrade().into(),
1373 },
1374 );
1375 }
1376
1377 pub fn request<T: RequestMessage>(
1378 &self,
1379 payload: T,
1380 ) -> impl 'static + Future<Output = Result<T::Response>> {
1381 log::debug!("ssh request start. name:{}", T::NAME);
1382 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1383 async move {
1384 let response = response.await?;
1385 log::debug!("ssh request finish. name:{}", T::NAME);
1386 T::Response::from_envelope(response)
1387 .ok_or_else(|| anyhow!("received a response of the wrong type"))
1388 }
1389 }
1390
1391 pub async fn ping(&self, timeout: Duration) -> Result<()> {
1392 smol::future::or(
1393 async {
1394 self.request(proto::Ping {}).await?;
1395 Ok(())
1396 },
1397 async {
1398 smol::Timer::after(timeout).await;
1399 Err(anyhow!("Timeout detected"))
1400 },
1401 )
1402 .await
1403 }
1404
1405 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1406 log::debug!("ssh send name:{}", T::NAME);
1407 self.send_dynamic(payload.into_envelope(0, None, None))
1408 }
1409
1410 pub fn request_dynamic(
1411 &self,
1412 mut envelope: proto::Envelope,
1413 type_name: &'static str,
1414 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1415 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1416 let (tx, rx) = oneshot::channel();
1417 let mut response_channels_lock = self.response_channels.lock();
1418 response_channels_lock.insert(MessageId(envelope.id), tx);
1419 drop(response_channels_lock);
1420 let result = self.outgoing_tx.unbounded_send(envelope);
1421 async move {
1422 if let Err(error) = &result {
1423 log::error!("failed to send message: {}", error);
1424 return Err(anyhow!("failed to send message: {}", error));
1425 }
1426
1427 let response = rx.await.context("connection lost")?.0;
1428 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1429 return Err(RpcError::from_proto(error, type_name));
1430 }
1431 Ok(response)
1432 }
1433 }
1434
1435 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1436 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1437 self.outgoing_tx.unbounded_send(envelope)?;
1438 Ok(())
1439 }
1440}
1441
1442impl ProtoClient for ChannelClient {
1443 fn request(
1444 &self,
1445 envelope: proto::Envelope,
1446 request_type: &'static str,
1447 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1448 self.request_dynamic(envelope, request_type).boxed()
1449 }
1450
1451 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1452 self.send_dynamic(envelope)
1453 }
1454
1455 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1456 self.send_dynamic(envelope)
1457 }
1458
1459 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1460 &self.message_handlers
1461 }
1462
1463 fn is_via_collab(&self) -> bool {
1464 false
1465 }
1466}