1use crate::{
2 json_log::LogRecord,
3 protocol::{
4 message_len_from_buffer, read_message_with_len, write_message, MessageId, MESSAGE_LEN_SIZE,
5 },
6 proxy::ProxyLaunchError,
7};
8use anyhow::{anyhow, Context as _, Result};
9use collections::HashMap;
10use futures::{
11 channel::{
12 mpsc::{self, UnboundedReceiver, UnboundedSender},
13 oneshot,
14 },
15 future::BoxFuture,
16 select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
17 StreamExt as _,
18};
19use gpui::{
20 AppContext, AsyncAppContext, Context, Model, ModelContext, SemanticVersion, Task, WeakModel,
21};
22use parking_lot::Mutex;
23use rpc::{
24 proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
25 AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
26};
27use smol::{
28 fs,
29 process::{self, Child, Stdio},
30 Timer,
31};
32use std::{
33 any::TypeId,
34 ffi::OsStr,
35 fmt,
36 ops::ControlFlow,
37 path::{Path, PathBuf},
38 sync::{
39 atomic::{AtomicU32, Ordering::SeqCst},
40 Arc,
41 },
42 time::{Duration, Instant},
43};
44use tempfile::TempDir;
45use util::ResultExt;
46
47#[derive(
48 Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, serde::Serialize, serde::Deserialize,
49)]
50pub struct SshProjectId(pub u64);
51
52#[derive(Clone)]
53pub struct SshSocket {
54 connection_options: SshConnectionOptions,
55 socket_path: PathBuf,
56}
57
58#[derive(Debug, Default, Clone, PartialEq, Eq)]
59pub struct SshConnectionOptions {
60 pub host: String,
61 pub username: Option<String>,
62 pub port: Option<u16>,
63 pub password: Option<String>,
64}
65
66impl SshConnectionOptions {
67 pub fn ssh_url(&self) -> String {
68 let mut result = String::from("ssh://");
69 if let Some(username) = &self.username {
70 result.push_str(username);
71 result.push('@');
72 }
73 result.push_str(&self.host);
74 if let Some(port) = self.port {
75 result.push(':');
76 result.push_str(&port.to_string());
77 }
78 result
79 }
80
81 fn scp_url(&self) -> String {
82 if let Some(username) = &self.username {
83 format!("{}@{}", username, self.host)
84 } else {
85 self.host.clone()
86 }
87 }
88
89 pub fn connection_string(&self) -> String {
90 let host = if let Some(username) = &self.username {
91 format!("{}@{}", username, self.host)
92 } else {
93 self.host.clone()
94 };
95 if let Some(port) = &self.port {
96 format!("{}:{}", host, port)
97 } else {
98 host
99 }
100 }
101
102 // Uniquely identifies dev server projects on a remote host. Needs to be
103 // stable for the same dev server project.
104 pub fn dev_server_identifier(&self) -> String {
105 let mut identifier = format!("dev-server-{:?}", self.host);
106 if let Some(username) = self.username.as_ref() {
107 identifier.push('-');
108 identifier.push_str(&username);
109 }
110 identifier
111 }
112}
113
114#[derive(Copy, Clone, Debug)]
115pub struct SshPlatform {
116 pub os: &'static str,
117 pub arch: &'static str,
118}
119
120pub trait SshClientDelegate: Send + Sync {
121 fn ask_password(
122 &self,
123 prompt: String,
124 cx: &mut AsyncAppContext,
125 ) -> oneshot::Receiver<Result<String>>;
126 fn remote_server_binary_path(&self, cx: &mut AsyncAppContext) -> Result<PathBuf>;
127 fn get_server_binary(
128 &self,
129 platform: SshPlatform,
130 cx: &mut AsyncAppContext,
131 ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>>;
132 fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext);
133 fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
134}
135
136impl SshSocket {
137 fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
138 let mut command = process::Command::new("ssh");
139 self.ssh_options(&mut command)
140 .arg(self.connection_options.ssh_url())
141 .arg(program);
142 command
143 }
144
145 fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
146 command
147 .stdin(Stdio::piped())
148 .stdout(Stdio::piped())
149 .stderr(Stdio::piped())
150 .args(["-o", "ControlMaster=no", "-o"])
151 .arg(format!("ControlPath={}", self.socket_path.display()))
152 }
153
154 fn ssh_args(&self) -> Vec<String> {
155 vec![
156 "-o".to_string(),
157 "ControlMaster=no".to_string(),
158 "-o".to_string(),
159 format!("ControlPath={}", self.socket_path.display()),
160 self.connection_options.ssh_url(),
161 ]
162 }
163}
164
165async fn run_cmd(command: &mut process::Command) -> Result<String> {
166 let output = command.output().await?;
167 if output.status.success() {
168 Ok(String::from_utf8_lossy(&output.stdout).to_string())
169 } else {
170 Err(anyhow!(
171 "failed to run command: {}",
172 String::from_utf8_lossy(&output.stderr)
173 ))
174 }
175}
176
177struct ChannelForwarder {
178 quit_tx: UnboundedSender<()>,
179 forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
180}
181
182impl ChannelForwarder {
183 fn new(
184 mut incoming_tx: UnboundedSender<Envelope>,
185 mut outgoing_rx: UnboundedReceiver<Envelope>,
186 cx: &AsyncAppContext,
187 ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
188 let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
189
190 let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
191 let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
192
193 let forwarding_task = cx.background_executor().spawn(async move {
194 loop {
195 select_biased! {
196 _ = quit_rx.next().fuse() => {
197 break;
198 },
199 incoming_envelope = proxy_incoming_rx.next().fuse() => {
200 if let Some(envelope) = incoming_envelope {
201 if incoming_tx.send(envelope).await.is_err() {
202 break;
203 }
204 } else {
205 break;
206 }
207 }
208 outgoing_envelope = outgoing_rx.next().fuse() => {
209 if let Some(envelope) = outgoing_envelope {
210 if proxy_outgoing_tx.send(envelope).await.is_err() {
211 break;
212 }
213 } else {
214 break;
215 }
216 }
217 }
218 }
219
220 (incoming_tx, outgoing_rx)
221 });
222
223 (
224 Self {
225 forwarding_task,
226 quit_tx,
227 },
228 proxy_incoming_tx,
229 proxy_outgoing_rx,
230 )
231 }
232
233 async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
234 let _ = self.quit_tx.send(()).await;
235 self.forwarding_task.await
236 }
237}
238
239const MAX_MISSED_HEARTBEATS: usize = 5;
240const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
241const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
242
243const MAX_RECONNECT_ATTEMPTS: usize = 3;
244
245enum State {
246 Connecting,
247 Connected {
248 ssh_connection: SshRemoteConnection,
249 delegate: Arc<dyn SshClientDelegate>,
250 forwarder: ChannelForwarder,
251
252 multiplex_task: Task<Result<()>>,
253 heartbeat_task: Task<Result<()>>,
254 },
255 HeartbeatMissed {
256 missed_heartbeats: usize,
257
258 ssh_connection: SshRemoteConnection,
259 delegate: Arc<dyn SshClientDelegate>,
260 forwarder: ChannelForwarder,
261
262 multiplex_task: Task<Result<()>>,
263 heartbeat_task: Task<Result<()>>,
264 },
265 Reconnecting,
266 ReconnectFailed {
267 ssh_connection: SshRemoteConnection,
268 delegate: Arc<dyn SshClientDelegate>,
269 forwarder: ChannelForwarder,
270
271 error: anyhow::Error,
272 attempts: usize,
273 },
274 ReconnectExhausted,
275 ServerNotRunning,
276}
277
278impl fmt::Display for State {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 match self {
281 Self::Connecting => write!(f, "connecting"),
282 Self::Connected { .. } => write!(f, "connected"),
283 Self::Reconnecting => write!(f, "reconnecting"),
284 Self::ReconnectFailed { .. } => write!(f, "reconnect failed"),
285 Self::ReconnectExhausted => write!(f, "reconnect exhausted"),
286 Self::HeartbeatMissed { .. } => write!(f, "heartbeat missed"),
287 Self::ServerNotRunning { .. } => write!(f, "server not running"),
288 }
289 }
290}
291
292impl State {
293 fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
294 match self {
295 Self::Connected { ssh_connection, .. } => Some(ssh_connection),
296 Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
297 Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
298 _ => None,
299 }
300 }
301
302 fn can_reconnect(&self) -> bool {
303 match self {
304 Self::Connected { .. }
305 | Self::HeartbeatMissed { .. }
306 | Self::ReconnectFailed { .. } => true,
307 State::Connecting
308 | State::Reconnecting
309 | State::ReconnectExhausted
310 | State::ServerNotRunning => false,
311 }
312 }
313
314 fn is_reconnect_failed(&self) -> bool {
315 matches!(self, Self::ReconnectFailed { .. })
316 }
317
318 fn is_reconnecting(&self) -> bool {
319 matches!(self, Self::Reconnecting { .. })
320 }
321
322 fn heartbeat_recovered(self) -> Self {
323 match self {
324 Self::HeartbeatMissed {
325 ssh_connection,
326 delegate,
327 forwarder,
328 multiplex_task,
329 heartbeat_task,
330 ..
331 } => Self::Connected {
332 ssh_connection,
333 delegate,
334 forwarder,
335 multiplex_task,
336 heartbeat_task,
337 },
338 _ => self,
339 }
340 }
341
342 fn heartbeat_missed(self) -> Self {
343 match self {
344 Self::Connected {
345 ssh_connection,
346 delegate,
347 forwarder,
348 multiplex_task,
349 heartbeat_task,
350 } => Self::HeartbeatMissed {
351 missed_heartbeats: 1,
352 ssh_connection,
353 delegate,
354 forwarder,
355 multiplex_task,
356 heartbeat_task,
357 },
358 Self::HeartbeatMissed {
359 missed_heartbeats,
360 ssh_connection,
361 delegate,
362 forwarder,
363 multiplex_task,
364 heartbeat_task,
365 } => Self::HeartbeatMissed {
366 missed_heartbeats: missed_heartbeats + 1,
367 ssh_connection,
368 delegate,
369 forwarder,
370 multiplex_task,
371 heartbeat_task,
372 },
373 _ => self,
374 }
375 }
376}
377
378/// The state of the ssh connection.
379#[derive(Clone, Copy, Debug)]
380pub enum ConnectionState {
381 Connecting,
382 Connected,
383 HeartbeatMissed,
384 Reconnecting,
385 Disconnected,
386}
387
388impl From<&State> for ConnectionState {
389 fn from(value: &State) -> Self {
390 match value {
391 State::Connecting => Self::Connecting,
392 State::Connected { .. } => Self::Connected,
393 State::Reconnecting | State::ReconnectFailed { .. } => Self::Reconnecting,
394 State::HeartbeatMissed { .. } => Self::HeartbeatMissed,
395 State::ReconnectExhausted => Self::Disconnected,
396 State::ServerNotRunning => Self::Disconnected,
397 }
398 }
399}
400
401pub struct SshRemoteClient {
402 client: Arc<ChannelClient>,
403 unique_identifier: String,
404 connection_options: SshConnectionOptions,
405 state: Arc<Mutex<Option<State>>>,
406}
407
408impl Drop for SshRemoteClient {
409 fn drop(&mut self) {
410 self.shutdown_processes();
411 }
412}
413
414impl SshRemoteClient {
415 pub fn new(
416 unique_identifier: String,
417 connection_options: SshConnectionOptions,
418 delegate: Arc<dyn SshClientDelegate>,
419 cx: &AppContext,
420 ) -> Task<Result<Model<Self>>> {
421 cx.spawn(|mut cx| async move {
422 let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
423 let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
424
425 let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
426 let this = cx.new_model(|cx| {
427 cx.on_app_quit(|this: &mut Self, _| {
428 this.shutdown_processes();
429 futures::future::ready(())
430 })
431 .detach();
432
433 Self {
434 client: client.clone(),
435 unique_identifier: unique_identifier.clone(),
436 connection_options: connection_options.clone(),
437 state: Arc::new(Mutex::new(Some(State::Connecting))),
438 }
439 })?;
440
441 let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
442 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
443
444 let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
445 unique_identifier,
446 false,
447 connection_options,
448 delegate.clone(),
449 &mut cx,
450 )
451 .await?;
452
453 let multiplex_task = Self::multiplex(
454 this.downgrade(),
455 ssh_proxy_process,
456 proxy_incoming_tx,
457 proxy_outgoing_rx,
458 &mut cx,
459 );
460
461 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
462 log::error!("failed to establish connection: {}", error);
463 delegate.set_error(error.to_string(), &mut cx);
464 return Err(error);
465 }
466
467 let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx);
468
469 this.update(&mut cx, |this, _| {
470 *this.state.lock() = Some(State::Connected {
471 ssh_connection,
472 delegate,
473 forwarder: proxy,
474 multiplex_task,
475 heartbeat_task,
476 });
477 })?;
478
479 Ok(this)
480 })
481 }
482
483 fn shutdown_processes(&self) {
484 let Some(state) = self.state.lock().take() else {
485 return;
486 };
487 log::info!("shutting down ssh processes");
488
489 let State::Connected {
490 multiplex_task,
491 heartbeat_task,
492 ..
493 } = state
494 else {
495 return;
496 };
497 // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a
498 // child of master_process.
499 drop(multiplex_task);
500 // Now drop the rest of state, which kills master process.
501 drop(heartbeat_task);
502 }
503
504 fn reconnect(&mut self, cx: &mut ModelContext<Self>) -> Result<()> {
505 let mut lock = self.state.lock();
506
507 let can_reconnect = lock
508 .as_ref()
509 .map(|state| state.can_reconnect())
510 .unwrap_or(false);
511 if !can_reconnect {
512 let error = if let Some(state) = lock.as_ref() {
513 format!("invalid state, cannot reconnect while in state {state}")
514 } else {
515 "no state set".to_string()
516 };
517 log::info!("aborting reconnect, because not in state that allows reconnecting");
518 return Err(anyhow!(error));
519 }
520
521 let state = lock.take().unwrap();
522 let (attempts, mut ssh_connection, delegate, forwarder) = match state {
523 State::Connected {
524 ssh_connection,
525 delegate,
526 forwarder,
527 multiplex_task,
528 heartbeat_task,
529 }
530 | State::HeartbeatMissed {
531 ssh_connection,
532 delegate,
533 forwarder,
534 multiplex_task,
535 heartbeat_task,
536 ..
537 } => {
538 drop(multiplex_task);
539 drop(heartbeat_task);
540 (0, ssh_connection, delegate, forwarder)
541 }
542 State::ReconnectFailed {
543 attempts,
544 ssh_connection,
545 delegate,
546 forwarder,
547 ..
548 } => (attempts, ssh_connection, delegate, forwarder),
549 State::Connecting
550 | State::Reconnecting
551 | State::ReconnectExhausted
552 | State::ServerNotRunning => unreachable!(),
553 };
554
555 let attempts = attempts + 1;
556 if attempts > MAX_RECONNECT_ATTEMPTS {
557 log::error!(
558 "Failed to reconnect to after {} attempts, giving up",
559 MAX_RECONNECT_ATTEMPTS
560 );
561 drop(lock);
562 self.set_state(State::ReconnectExhausted, cx);
563 return Ok(());
564 }
565 drop(lock);
566 self.set_state(State::Reconnecting, cx);
567
568 log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
569
570 let identifier = self.unique_identifier.clone();
571 let client = self.client.clone();
572 let reconnect_task = cx.spawn(|this, mut cx| async move {
573 macro_rules! failed {
574 ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
575 return State::ReconnectFailed {
576 error: anyhow!($error),
577 attempts: $attempts,
578 ssh_connection: $ssh_connection,
579 delegate: $delegate,
580 forwarder: $forwarder,
581 };
582 };
583 }
584
585 if let Err(error) = ssh_connection.master_process.kill() {
586 failed!(error, attempts, ssh_connection, delegate, forwarder);
587 };
588
589 if let Err(error) = ssh_connection
590 .master_process
591 .status()
592 .await
593 .context("Failed to kill ssh process")
594 {
595 failed!(error, attempts, ssh_connection, delegate, forwarder);
596 }
597
598 let connection_options = ssh_connection.socket.connection_options.clone();
599
600 let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
601 let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
602 ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
603
604 let (ssh_connection, ssh_process) = match Self::establish_connection(
605 identifier,
606 true,
607 connection_options,
608 delegate.clone(),
609 &mut cx,
610 )
611 .await
612 {
613 Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
614 Err(error) => {
615 failed!(error, attempts, ssh_connection, delegate, forwarder);
616 }
617 };
618
619 let multiplex_task = Self::multiplex(
620 this.clone(),
621 ssh_process,
622 proxy_incoming_tx,
623 proxy_outgoing_rx,
624 &mut cx,
625 );
626
627 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
628 failed!(error, attempts, ssh_connection, delegate, forwarder);
629 };
630
631 State::Connected {
632 ssh_connection,
633 delegate,
634 forwarder,
635 multiplex_task,
636 heartbeat_task: Self::heartbeat(this.clone(), &mut cx),
637 }
638 });
639
640 cx.spawn(|this, mut cx| async move {
641 let new_state = reconnect_task.await;
642 this.update(&mut cx, |this, cx| {
643 this.try_set_state(cx, |old_state| {
644 if old_state.is_reconnecting() {
645 match &new_state {
646 State::Connecting
647 | State::Reconnecting { .. }
648 | State::HeartbeatMissed { .. }
649 | State::ServerNotRunning => {}
650 State::Connected { .. } => {
651 log::info!("Successfully reconnected");
652 }
653 State::ReconnectFailed {
654 error, attempts, ..
655 } => {
656 log::error!(
657 "Reconnect attempt {} failed: {:?}. Starting new attempt...",
658 attempts,
659 error
660 );
661 }
662 State::ReconnectExhausted => {
663 log::error!("Reconnect attempt failed and all attempts exhausted");
664 }
665 }
666 Some(new_state)
667 } else {
668 None
669 }
670 });
671
672 if this.state_is(State::is_reconnect_failed) {
673 this.reconnect(cx)
674 } else {
675 log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
676 Ok(())
677 }
678 })
679 })
680 .detach_and_log_err(cx);
681
682 Ok(())
683 }
684
685 fn heartbeat(this: WeakModel<Self>, cx: &mut AsyncAppContext) -> Task<Result<()>> {
686 let Ok(client) = this.update(cx, |this, _| this.client.clone()) else {
687 return Task::ready(Err(anyhow!("SshRemoteClient lost")));
688 };
689 cx.spawn(|mut cx| {
690 let this = this.clone();
691 async move {
692 let mut missed_heartbeats = 0;
693
694 let mut timer = Timer::interval(HEARTBEAT_INTERVAL);
695 loop {
696 timer.next().await;
697
698 log::debug!("Sending heartbeat to server...");
699
700 let result = client.ping(HEARTBEAT_TIMEOUT).await;
701 if result.is_err() {
702 missed_heartbeats += 1;
703 log::warn!(
704 "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.",
705 HEARTBEAT_TIMEOUT,
706 missed_heartbeats,
707 MAX_MISSED_HEARTBEATS
708 );
709 } else if missed_heartbeats != 0 {
710 missed_heartbeats = 0;
711 } else {
712 continue;
713 }
714
715 let result = this.update(&mut cx, |this, mut cx| {
716 this.handle_heartbeat_result(missed_heartbeats, &mut cx)
717 })?;
718 if result.is_break() {
719 return Ok(());
720 }
721 }
722 }
723 })
724 }
725
726 fn handle_heartbeat_result(
727 &mut self,
728 missed_heartbeats: usize,
729 cx: &mut ModelContext<Self>,
730 ) -> ControlFlow<()> {
731 let state = self.state.lock().take().unwrap();
732 let next_state = if missed_heartbeats > 0 {
733 state.heartbeat_missed()
734 } else {
735 state.heartbeat_recovered()
736 };
737 self.set_state(next_state, cx);
738
739 if missed_heartbeats >= MAX_MISSED_HEARTBEATS {
740 log::error!(
741 "Missed last {} heartbeats. Reconnecting...",
742 missed_heartbeats
743 );
744
745 self.reconnect(cx)
746 .context("failed to start reconnect process after missing heartbeats")
747 .log_err();
748 ControlFlow::Break(())
749 } else {
750 ControlFlow::Continue(())
751 }
752 }
753
754 fn multiplex(
755 this: WeakModel<Self>,
756 mut ssh_proxy_process: Child,
757 incoming_tx: UnboundedSender<Envelope>,
758 mut outgoing_rx: UnboundedReceiver<Envelope>,
759 cx: &AsyncAppContext,
760 ) -> Task<Result<()>> {
761 let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
762 let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
763 let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
764
765 let io_task = cx.background_executor().spawn(async move {
766 let mut stdin_buffer = Vec::new();
767 let mut stdout_buffer = Vec::new();
768 let mut stderr_buffer = Vec::new();
769 let mut stderr_offset = 0;
770
771 loop {
772 stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
773 stderr_buffer.resize(stderr_offset + 1024, 0);
774
775 select_biased! {
776 outgoing = outgoing_rx.next().fuse() => {
777 let Some(outgoing) = outgoing else {
778 return anyhow::Ok(None);
779 };
780
781 write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
782 }
783
784 result = child_stdout.read(&mut stdout_buffer).fuse() => {
785 match result {
786 Ok(0) => {
787 child_stdin.close().await?;
788 outgoing_rx.close();
789 let status = ssh_proxy_process.status().await?;
790 return Ok(status.code());
791 }
792 Ok(len) => {
793 if len < stdout_buffer.len() {
794 child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
795 }
796
797 let message_len = message_len_from_buffer(&stdout_buffer);
798 match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await {
799 Ok(envelope) => {
800 incoming_tx.unbounded_send(envelope).ok();
801 }
802 Err(error) => {
803 log::error!("error decoding message {error:?}");
804 }
805 }
806 }
807 Err(error) => {
808 Err(anyhow!("error reading stdout: {error:?}"))?;
809 }
810 }
811 }
812
813 result = child_stderr.read(&mut stderr_buffer[stderr_offset..]).fuse() => {
814 match result {
815 Ok(len) => {
816 stderr_offset += len;
817 let mut start_ix = 0;
818 while let Some(ix) = stderr_buffer[start_ix..stderr_offset].iter().position(|b| b == &b'\n') {
819 let line_ix = start_ix + ix;
820 let content = &stderr_buffer[start_ix..line_ix];
821 start_ix = line_ix + 1;
822 if let Ok(mut record) = serde_json::from_slice::<LogRecord>(content) {
823 record.message = format!("(remote) {}", record.message);
824 record.log(log::logger())
825 } else {
826 eprintln!("(remote) {}", String::from_utf8_lossy(content));
827 }
828 }
829 stderr_buffer.drain(0..start_ix);
830 stderr_offset -= start_ix;
831 }
832 Err(error) => {
833 Err(anyhow!("error reading stderr: {error:?}"))?;
834 }
835 }
836 }
837 }
838 }
839 });
840
841 cx.spawn(|mut cx| async move {
842 let result = io_task.await;
843
844 match result {
845 Ok(Some(exit_code)) => {
846 if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
847 match error {
848 ProxyLaunchError::ServerNotRunning => {
849 log::error!("failed to reconnect because server is not running");
850 this.update(&mut cx, |this, cx| {
851 this.set_state(State::ServerNotRunning, cx);
852 })?;
853 }
854 }
855 } else if exit_code > 0 {
856 log::error!("proxy process terminated unexpectedly");
857 }
858 }
859 Ok(None) => {}
860 Err(error) => {
861 log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
862 this.update(&mut cx, |this, cx| {
863 this.reconnect(cx).ok();
864 })?;
865 }
866 }
867 Ok(())
868 })
869 }
870
871 fn state_is(&self, check: impl FnOnce(&State) -> bool) -> bool {
872 self.state.lock().as_ref().map_or(false, check)
873 }
874
875 fn try_set_state(
876 &self,
877 cx: &mut ModelContext<Self>,
878 map: impl FnOnce(&State) -> Option<State>,
879 ) {
880 if let Some(new_state) = self.state.lock().as_ref().and_then(map) {
881 self.set_state(new_state, cx);
882 }
883 }
884
885 fn set_state(&self, state: State, cx: &mut ModelContext<Self>) {
886 log::info!("setting state to '{}'", &state);
887 self.state.lock().replace(state);
888 cx.notify();
889 }
890
891 async fn establish_connection(
892 unique_identifier: String,
893 reconnect: bool,
894 connection_options: SshConnectionOptions,
895 delegate: Arc<dyn SshClientDelegate>,
896 cx: &mut AsyncAppContext,
897 ) -> Result<(SshRemoteConnection, Child)> {
898 let ssh_connection =
899 SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
900
901 let platform = ssh_connection.query_platform().await?;
902 let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
903 let remote_binary_path = delegate.remote_server_binary_path(cx)?;
904 ssh_connection
905 .ensure_server_binary(
906 &delegate,
907 &local_binary_path,
908 &remote_binary_path,
909 version,
910 cx,
911 )
912 .await?;
913
914 let socket = ssh_connection.socket.clone();
915 run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
916
917 delegate.set_status(Some("Starting proxy"), cx);
918
919 let mut start_proxy_command = format!(
920 "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
921 std::env::var("RUST_LOG").unwrap_or_default(),
922 std::env::var("RUST_BACKTRACE").unwrap_or_default(),
923 remote_binary_path,
924 unique_identifier,
925 );
926 if reconnect {
927 start_proxy_command.push_str(" --reconnect");
928 }
929
930 let ssh_proxy_process = socket
931 .ssh_command(start_proxy_command)
932 // IMPORTANT: we kill this process when we drop the task that uses it.
933 .kill_on_drop(true)
934 .spawn()
935 .context("failed to spawn remote server")?;
936
937 Ok((ssh_connection, ssh_proxy_process))
938 }
939
940 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
941 self.client.subscribe_to_entity(remote_id, entity);
942 }
943
944 pub fn ssh_args(&self) -> Option<Vec<String>> {
945 self.state
946 .lock()
947 .as_ref()
948 .and_then(|state| state.ssh_connection())
949 .map(|ssh_connection| ssh_connection.socket.ssh_args())
950 }
951
952 pub fn to_proto_client(&self) -> AnyProtoClient {
953 self.client.clone().into()
954 }
955
956 pub fn connection_string(&self) -> String {
957 self.connection_options.connection_string()
958 }
959
960 pub fn connection_state(&self) -> ConnectionState {
961 self.state
962 .lock()
963 .as_ref()
964 .map(ConnectionState::from)
965 .unwrap_or(ConnectionState::Disconnected)
966 }
967
968 #[cfg(any(test, feature = "test-support"))]
969 pub fn fake(
970 client_cx: &mut gpui::TestAppContext,
971 server_cx: &mut gpui::TestAppContext,
972 ) -> (Model<Self>, Arc<ChannelClient>) {
973 use gpui::Context;
974
975 let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
976 let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
977
978 (
979 client_cx.update(|cx| {
980 let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
981 cx.new_model(|_| Self {
982 client,
983 unique_identifier: "fake".to_string(),
984 connection_options: SshConnectionOptions::default(),
985 state: Arc::new(Mutex::new(None)),
986 })
987 }),
988 server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
989 )
990 }
991}
992
993impl From<SshRemoteClient> for AnyProtoClient {
994 fn from(client: SshRemoteClient) -> Self {
995 AnyProtoClient::new(client.client.clone())
996 }
997}
998
999struct SshRemoteConnection {
1000 socket: SshSocket,
1001 master_process: process::Child,
1002 _temp_dir: TempDir,
1003}
1004
1005impl Drop for SshRemoteConnection {
1006 fn drop(&mut self) {
1007 if let Err(error) = self.master_process.kill() {
1008 log::error!("failed to kill SSH master process: {}", error);
1009 }
1010 }
1011}
1012
1013impl SshRemoteConnection {
1014 #[cfg(not(unix))]
1015 async fn new(
1016 _connection_options: SshConnectionOptions,
1017 _delegate: Arc<dyn SshClientDelegate>,
1018 _cx: &mut AsyncAppContext,
1019 ) -> Result<Self> {
1020 Err(anyhow!("ssh is not supported on this platform"))
1021 }
1022
1023 #[cfg(unix)]
1024 async fn new(
1025 connection_options: SshConnectionOptions,
1026 delegate: Arc<dyn SshClientDelegate>,
1027 cx: &mut AsyncAppContext,
1028 ) -> Result<Self> {
1029 use futures::{io::BufReader, AsyncBufReadExt as _};
1030 use smol::{fs::unix::PermissionsExt as _, net::unix::UnixListener};
1031 use util::ResultExt as _;
1032
1033 delegate.set_status(Some("connecting"), cx);
1034
1035 let url = connection_options.ssh_url();
1036 let temp_dir = tempfile::Builder::new()
1037 .prefix("zed-ssh-session")
1038 .tempdir()?;
1039
1040 // Create a domain socket listener to handle requests from the askpass program.
1041 let askpass_socket = temp_dir.path().join("askpass.sock");
1042 let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
1043 let listener =
1044 UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
1045
1046 let askpass_task = cx.spawn({
1047 let delegate = delegate.clone();
1048 |mut cx| async move {
1049 let mut askpass_opened_tx = Some(askpass_opened_tx);
1050
1051 while let Ok((mut stream, _)) = listener.accept().await {
1052 if let Some(askpass_opened_tx) = askpass_opened_tx.take() {
1053 askpass_opened_tx.send(()).ok();
1054 }
1055 let mut buffer = Vec::new();
1056 let mut reader = BufReader::new(&mut stream);
1057 if reader.read_until(b'\0', &mut buffer).await.is_err() {
1058 buffer.clear();
1059 }
1060 let password_prompt = String::from_utf8_lossy(&buffer);
1061 if let Some(password) = delegate
1062 .ask_password(password_prompt.to_string(), &mut cx)
1063 .await
1064 .context("failed to get ssh password")
1065 .and_then(|p| p)
1066 .log_err()
1067 {
1068 stream.write_all(password.as_bytes()).await.log_err();
1069 }
1070 }
1071 }
1072 });
1073
1074 // Create an askpass script that communicates back to this process.
1075 let askpass_script = format!(
1076 "{shebang}\n{print_args} | nc -U {askpass_socket} 2> /dev/null \n",
1077 askpass_socket = askpass_socket.display(),
1078 print_args = "printf '%s\\0' \"$@\"",
1079 shebang = "#!/bin/sh",
1080 );
1081 let askpass_script_path = temp_dir.path().join("askpass.sh");
1082 fs::write(&askpass_script_path, askpass_script).await?;
1083 fs::set_permissions(&askpass_script_path, std::fs::Permissions::from_mode(0o755)).await?;
1084
1085 // Start the master SSH process, which does not do anything except for establish
1086 // the connection and keep it open, allowing other ssh commands to reuse it
1087 // via a control socket.
1088 let socket_path = temp_dir.path().join("ssh.sock");
1089 let mut master_process = process::Command::new("ssh")
1090 .stdin(Stdio::null())
1091 .stdout(Stdio::piped())
1092 .stderr(Stdio::piped())
1093 .env("SSH_ASKPASS_REQUIRE", "force")
1094 .env("SSH_ASKPASS", &askpass_script_path)
1095 .args(["-N", "-o", "ControlMaster=yes", "-o"])
1096 .arg(format!("ControlPath={}", socket_path.display()))
1097 .arg(&url)
1098 .spawn()?;
1099
1100 // Wait for this ssh process to close its stdout, indicating that authentication
1101 // has completed.
1102 let stdout = master_process.stdout.as_mut().unwrap();
1103 let mut output = Vec::new();
1104 let connection_timeout = Duration::from_secs(10);
1105
1106 let result = select_biased! {
1107 _ = askpass_opened_rx.fuse() => {
1108 // If the askpass script has opened, that means the user is typing
1109 // their password, in which case we don't want to timeout anymore,
1110 // since we know a connection has been established.
1111 stdout.read_to_end(&mut output).await?;
1112 Ok(())
1113 }
1114 result = stdout.read_to_end(&mut output).fuse() => {
1115 result?;
1116 Ok(())
1117 }
1118 _ = futures::FutureExt::fuse(smol::Timer::after(connection_timeout)) => {
1119 Err(anyhow!("Exceeded {:?} timeout trying to connect to host", connection_timeout))
1120 }
1121 };
1122
1123 if let Err(e) = result {
1124 let error_message = format!("Failed to connect to host: {}.", e);
1125 delegate.set_error(error_message, cx);
1126 return Err(e);
1127 }
1128
1129 drop(askpass_task);
1130
1131 if master_process.try_status()?.is_some() {
1132 output.clear();
1133 let mut stderr = master_process.stderr.take().unwrap();
1134 stderr.read_to_end(&mut output).await?;
1135
1136 let error_message = format!("failed to connect: {}", String::from_utf8_lossy(&output));
1137 delegate.set_error(error_message.clone(), cx);
1138 Err(anyhow!(error_message))?;
1139 }
1140
1141 Ok(Self {
1142 socket: SshSocket {
1143 connection_options,
1144 socket_path,
1145 },
1146 master_process,
1147 _temp_dir: temp_dir,
1148 })
1149 }
1150
1151 async fn ensure_server_binary(
1152 &self,
1153 delegate: &Arc<dyn SshClientDelegate>,
1154 src_path: &Path,
1155 dst_path: &Path,
1156 version: SemanticVersion,
1157 cx: &mut AsyncAppContext,
1158 ) -> Result<()> {
1159 let mut dst_path_gz = dst_path.to_path_buf();
1160 dst_path_gz.set_extension("gz");
1161
1162 if let Some(parent) = dst_path.parent() {
1163 run_cmd(self.socket.ssh_command("mkdir").arg("-p").arg(parent)).await?;
1164 }
1165
1166 let mut server_binary_exists = false;
1167 if cfg!(not(debug_assertions)) {
1168 if let Ok(installed_version) =
1169 run_cmd(self.socket.ssh_command(dst_path).arg("version")).await
1170 {
1171 if installed_version.trim() == version.to_string() {
1172 server_binary_exists = true;
1173 }
1174 }
1175 }
1176
1177 if server_binary_exists {
1178 log::info!("remote development server already present",);
1179 return Ok(());
1180 }
1181
1182 let src_stat = fs::metadata(src_path).await?;
1183 let size = src_stat.len();
1184 let server_mode = 0o755;
1185
1186 let t0 = Instant::now();
1187 delegate.set_status(Some("uploading remote development server"), cx);
1188 log::info!("uploading remote development server ({}kb)", size / 1024);
1189 self.upload_file(src_path, &dst_path_gz)
1190 .await
1191 .context("failed to upload server binary")?;
1192 log::info!("uploaded remote development server in {:?}", t0.elapsed());
1193
1194 delegate.set_status(Some("extracting remote development server"), cx);
1195 run_cmd(
1196 self.socket
1197 .ssh_command("gunzip")
1198 .arg("--force")
1199 .arg(&dst_path_gz),
1200 )
1201 .await?;
1202
1203 delegate.set_status(Some("unzipping remote development server"), cx);
1204 run_cmd(
1205 self.socket
1206 .ssh_command("chmod")
1207 .arg(format!("{:o}", server_mode))
1208 .arg(dst_path),
1209 )
1210 .await?;
1211
1212 Ok(())
1213 }
1214
1215 async fn query_platform(&self) -> Result<SshPlatform> {
1216 let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
1217 let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
1218
1219 let os = match os.trim() {
1220 "Darwin" => "macos",
1221 "Linux" => "linux",
1222 _ => Err(anyhow!("unknown uname os {os:?}"))?,
1223 };
1224 let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
1225 "aarch64"
1226 } else if arch.starts_with("x86") || arch.starts_with("i686") {
1227 "x86_64"
1228 } else {
1229 Err(anyhow!("unknown uname architecture {arch:?}"))?
1230 };
1231
1232 Ok(SshPlatform { os, arch })
1233 }
1234
1235 async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
1236 let mut command = process::Command::new("scp");
1237 let output = self
1238 .socket
1239 .ssh_options(&mut command)
1240 .args(
1241 self.socket
1242 .connection_options
1243 .port
1244 .map(|port| vec!["-P".to_string(), port.to_string()])
1245 .unwrap_or_default(),
1246 )
1247 .arg(src_path)
1248 .arg(format!(
1249 "{}:{}",
1250 self.socket.connection_options.scp_url(),
1251 dest_path.display()
1252 ))
1253 .output()
1254 .await?;
1255
1256 if output.status.success() {
1257 Ok(())
1258 } else {
1259 Err(anyhow!(
1260 "failed to upload file {} -> {}: {}",
1261 src_path.display(),
1262 dest_path.display(),
1263 String::from_utf8_lossy(&output.stderr)
1264 ))
1265 }
1266 }
1267}
1268
1269type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
1270
1271pub struct ChannelClient {
1272 next_message_id: AtomicU32,
1273 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1274 response_channels: ResponseChannels, // Lock
1275 message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
1276}
1277
1278impl ChannelClient {
1279 pub fn new(
1280 incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1281 outgoing_tx: mpsc::UnboundedSender<Envelope>,
1282 cx: &AppContext,
1283 ) -> Arc<Self> {
1284 let this = Arc::new(Self {
1285 outgoing_tx,
1286 next_message_id: AtomicU32::new(0),
1287 response_channels: ResponseChannels::default(),
1288 message_handlers: Default::default(),
1289 });
1290
1291 Self::start_handling_messages(this.clone(), incoming_rx, cx);
1292
1293 this
1294 }
1295
1296 fn start_handling_messages(
1297 this: Arc<Self>,
1298 mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
1299 cx: &AppContext,
1300 ) {
1301 cx.spawn(|cx| {
1302 let this = Arc::downgrade(&this);
1303 async move {
1304 let peer_id = PeerId { owner_id: 0, id: 0 };
1305 while let Some(incoming) = incoming_rx.next().await {
1306 let Some(this) = this.upgrade() else {
1307 return anyhow::Ok(());
1308 };
1309
1310 if let Some(request_id) = incoming.responding_to {
1311 let request_id = MessageId(request_id);
1312 let sender = this.response_channels.lock().remove(&request_id);
1313 if let Some(sender) = sender {
1314 let (tx, rx) = oneshot::channel();
1315 if incoming.payload.is_some() {
1316 sender.send((incoming, tx)).ok();
1317 }
1318 rx.await.ok();
1319 }
1320 } else if let Some(envelope) =
1321 build_typed_envelope(peer_id, Instant::now(), incoming)
1322 {
1323 let type_name = envelope.payload_type_name();
1324 if let Some(future) = ProtoMessageHandlerSet::handle_message(
1325 &this.message_handlers,
1326 envelope,
1327 this.clone().into(),
1328 cx.clone(),
1329 ) {
1330 log::debug!("ssh message received. name:{type_name}");
1331 match future.await {
1332 Ok(_) => {
1333 log::debug!("ssh message handled. name:{type_name}");
1334 }
1335 Err(error) => {
1336 log::error!(
1337 "error handling message. type:{type_name}, error:{error}",
1338 );
1339 }
1340 }
1341 } else {
1342 log::error!("unhandled ssh message name:{type_name}");
1343 }
1344 }
1345 }
1346 anyhow::Ok(())
1347 }
1348 })
1349 .detach();
1350 }
1351
1352 pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
1353 let id = (TypeId::of::<E>(), remote_id);
1354
1355 let mut message_handlers = self.message_handlers.lock();
1356 if message_handlers
1357 .entities_by_type_and_remote_id
1358 .contains_key(&id)
1359 {
1360 panic!("already subscribed to entity");
1361 }
1362
1363 message_handlers.entities_by_type_and_remote_id.insert(
1364 id,
1365 EntityMessageSubscriber::Entity {
1366 handle: entity.downgrade().into(),
1367 },
1368 );
1369 }
1370
1371 pub fn request<T: RequestMessage>(
1372 &self,
1373 payload: T,
1374 ) -> impl 'static + Future<Output = Result<T::Response>> {
1375 log::debug!("ssh request start. name:{}", T::NAME);
1376 let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
1377 async move {
1378 let response = response.await?;
1379 log::debug!("ssh request finish. name:{}", T::NAME);
1380 T::Response::from_envelope(response)
1381 .ok_or_else(|| anyhow!("received a response of the wrong type"))
1382 }
1383 }
1384
1385 pub async fn ping(&self, timeout: Duration) -> Result<()> {
1386 smol::future::or(
1387 async {
1388 self.request(proto::Ping {}).await?;
1389 Ok(())
1390 },
1391 async {
1392 smol::Timer::after(timeout).await;
1393 Err(anyhow!("Timeout detected"))
1394 },
1395 )
1396 .await
1397 }
1398
1399 pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
1400 log::debug!("ssh send name:{}", T::NAME);
1401 self.send_dynamic(payload.into_envelope(0, None, None))
1402 }
1403
1404 pub fn request_dynamic(
1405 &self,
1406 mut envelope: proto::Envelope,
1407 type_name: &'static str,
1408 ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
1409 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1410 let (tx, rx) = oneshot::channel();
1411 let mut response_channels_lock = self.response_channels.lock();
1412 response_channels_lock.insert(MessageId(envelope.id), tx);
1413 drop(response_channels_lock);
1414 let result = self.outgoing_tx.unbounded_send(envelope);
1415 async move {
1416 if let Err(error) = &result {
1417 log::error!("failed to send message: {}", error);
1418 return Err(anyhow!("failed to send message: {}", error));
1419 }
1420
1421 let response = rx.await.context("connection lost")?.0;
1422 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
1423 return Err(RpcError::from_proto(error, type_name));
1424 }
1425 Ok(response)
1426 }
1427 }
1428
1429 pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
1430 envelope.id = self.next_message_id.fetch_add(1, SeqCst);
1431 self.outgoing_tx.unbounded_send(envelope)?;
1432 Ok(())
1433 }
1434}
1435
1436impl ProtoClient for ChannelClient {
1437 fn request(
1438 &self,
1439 envelope: proto::Envelope,
1440 request_type: &'static str,
1441 ) -> BoxFuture<'static, Result<proto::Envelope>> {
1442 self.request_dynamic(envelope, request_type).boxed()
1443 }
1444
1445 fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
1446 self.send_dynamic(envelope)
1447 }
1448
1449 fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
1450 self.send_dynamic(envelope)
1451 }
1452
1453 fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
1454 &self.message_handlers
1455 }
1456
1457 fn is_via_collab(&self) -> bool {
1458 false
1459 }
1460}