diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index ccb66a0c626e046dccb0d3828d45c6a082ffe821..2db494ecbe15a9898ab88d6cf679a9b875f543cd 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -9,7 +9,7 @@ use anyhow::{anyhow, Context as _, Result}; use collections::HashMap; use futures::{ channel::{ - mpsc::{self, UnboundedReceiver, UnboundedSender}, + mpsc::{self, Sender, UnboundedReceiver, UnboundedSender}, oneshot, }, future::BoxFuture, @@ -28,7 +28,6 @@ use rpc::{ use smol::{ fs, process::{self, Child, Stdio}, - Timer, }; use std::{ any::TypeId, @@ -441,6 +440,7 @@ impl SshRemoteClient { cx.spawn(|mut cx| async move { let (outgoing_tx, outgoing_rx) = mpsc::unbounded::(); let (incoming_tx, incoming_rx) = mpsc::unbounded::(); + let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1); let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?; let this = cx.new_model(|_| Self { @@ -467,6 +467,7 @@ impl SshRemoteClient { ssh_proxy_process, proxy_incoming_tx, proxy_outgoing_rx, + connection_activity_tx, &mut cx, ); @@ -476,7 +477,7 @@ impl SshRemoteClient { return Err(error); } - let heartbeat_task = Self::heartbeat(this.downgrade(), &mut cx); + let heartbeat_task = Self::heartbeat(this.downgrade(), connection_activity_rx, &mut cx); this.update(&mut cx, |this, _| { *this.state.lock() = Some(State::Connected { @@ -518,7 +519,7 @@ impl SshRemoteClient { // We wait 50ms instead of waiting for a response, because // waiting for a response would require us to wait on the main thread // which we want to avoid in an `on_app_quit` callback. - Timer::after(Duration::from_millis(50)).await; + smol::Timer::after(Duration::from_millis(50)).await; } // Drop `multiplex_task` because it owns our ssh_proxy_process, which is a @@ -632,6 +633,7 @@ impl SshRemoteClient { let (incoming_tx, outgoing_rx) = forwarder.into_channels().await; let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) = ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx); + let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1); let (ssh_connection, ssh_process) = match Self::establish_connection( identifier, @@ -653,6 +655,7 @@ impl SshRemoteClient { ssh_process, proxy_incoming_tx, proxy_outgoing_rx, + connection_activity_tx, &mut cx, ); @@ -665,7 +668,7 @@ impl SshRemoteClient { delegate, forwarder, multiplex_task, - heartbeat_task: Self::heartbeat(this.clone(), &mut cx), + heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx), } }); @@ -717,41 +720,60 @@ impl SshRemoteClient { Ok(()) } - fn heartbeat(this: WeakModel, cx: &mut AsyncAppContext) -> Task> { + fn heartbeat( + this: WeakModel, + mut connection_activity_rx: mpsc::Receiver<()>, + cx: &mut AsyncAppContext, + ) -> Task> { let Ok(client) = this.update(cx, |this, _| this.client.clone()) else { return Task::ready(Err(anyhow!("SshRemoteClient lost"))); }; + cx.spawn(|mut cx| { let this = this.clone(); async move { let mut missed_heartbeats = 0; - let mut timer = Timer::interval(HEARTBEAT_INTERVAL); + let keepalive_timer = cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse(); + futures::pin_mut!(keepalive_timer); + loop { - timer.next().await; - - log::debug!("Sending heartbeat to server..."); - - let result = client.ping(HEARTBEAT_TIMEOUT).await; - if result.is_err() { - missed_heartbeats += 1; - log::warn!( - "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.", - HEARTBEAT_TIMEOUT, - missed_heartbeats, - MAX_MISSED_HEARTBEATS - ); - } else if missed_heartbeats != 0 { - missed_heartbeats = 0; - } else { - continue; - } + select_biased! { + _ = connection_activity_rx.next().fuse() => { + keepalive_timer.set(cx.background_executor().timer(HEARTBEAT_INTERVAL).fuse()); + } + _ = keepalive_timer => { + log::debug!("Sending heartbeat to server..."); - let result = this.update(&mut cx, |this, mut cx| { - this.handle_heartbeat_result(missed_heartbeats, &mut cx) - })?; - if result.is_break() { - return Ok(()); + let result = select_biased! { + _ = connection_activity_rx.next().fuse() => { + Ok(()) + } + ping_result = client.ping(HEARTBEAT_TIMEOUT).fuse() => { + ping_result + } + }; + if result.is_err() { + missed_heartbeats += 1; + log::warn!( + "No heartbeat from server after {:?}. Missed heartbeat {} out of {}.", + HEARTBEAT_TIMEOUT, + missed_heartbeats, + MAX_MISSED_HEARTBEATS + ); + } else if missed_heartbeats != 0 { + missed_heartbeats = 0; + } else { + continue; + } + + let result = this.update(&mut cx, |this, mut cx| { + this.handle_heartbeat_result(missed_heartbeats, &mut cx) + })?; + if result.is_break() { + return Ok(()); + } + } } } } @@ -792,6 +814,7 @@ impl SshRemoteClient { mut ssh_proxy_process: Child, incoming_tx: UnboundedSender, mut outgoing_rx: UnboundedReceiver, + mut connection_activity_tx: Sender<()>, cx: &AsyncAppContext, ) -> Task> { let mut child_stderr = ssh_proxy_process.stderr.take().unwrap(); @@ -833,6 +856,7 @@ impl SshRemoteClient { let message_len = message_len_from_buffer(&stdout_buffer); match read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len).await { Ok(envelope) => { + connection_activity_tx.try_send(()).ok(); incoming_tx.unbounded_send(envelope).ok(); } Err(error) => { @@ -863,6 +887,8 @@ impl SshRemoteClient { } stderr_buffer.drain(0..start_ix); stderr_offset -= start_ix; + + connection_activity_tx.try_send(()).ok(); } Err(error) => { Err(anyhow!("error reading stderr: {error:?}"))?; @@ -1392,16 +1418,19 @@ impl ChannelClient { cx.clone(), ) { log::debug!("ssh message received. name:{type_name}"); - match future.await { - Ok(_) => { - log::debug!("ssh message handled. name:{type_name}"); - } - Err(error) => { - log::error!( - "error handling message. type:{type_name}, error:{error}", - ); + cx.foreground_executor().spawn(async move { + match future.await { + Ok(_) => { + log::debug!("ssh message handled. name:{type_name}"); + } + Err(error) => { + log::error!( + "error handling message. type:{type_name}, error:{error}", + ); + } } - } + }).detach(); + } else { log::error!("unhandled ssh message name:{type_name}"); }