Cargo.lock 🔗
@@ -9119,6 +9119,7 @@ name = "remote"
version = "0.1.0"
dependencies = [
"anyhow",
+ "async-trait",
"collections",
"fs",
"futures 0.3.30",
Conrad Irwin and Nathan created
Before this change messages could be lost on reconnect, now they will
not be.
Release Notes:
- SSH Remoting: make reconnects smoother
---------
Co-authored-by: Nathan <nathan@zed.dev>
Cargo.lock | 1
crates/collab/src/tests/remote_editing_collaboration_tests.rs | 3
crates/project/src/project.rs | 4
crates/proto/proto/zed.proto | 8
crates/proto/src/macros.rs | 1
crates/proto/src/proto.rs | 2
crates/remote/Cargo.toml | 1
crates/remote/src/ssh_session.rs | 598 +++-
crates/remote_server/src/remote_editing_tests.rs | 46
crates/remote_server/src/unix.rs | 2
10 files changed, 467 insertions(+), 199 deletions(-)
@@ -9119,6 +9119,7 @@ name = "remote"
version = "0.1.0"
dependencies = [
"anyhow",
+ "async-trait",
"collections",
"fs",
"futures 0.3.30",
@@ -26,7 +26,7 @@ async fn test_sharing_an_ssh_remote_project(
.await;
// Set up project on remote FS
- let (client_ssh, server_ssh) = SshRemoteClient::fake(cx_a, server_cx);
+ let (port, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
let remote_fs = FakeFs::new(server_cx.executor());
remote_fs
.insert_tree(
@@ -67,6 +67,7 @@ async fn test_sharing_an_ssh_remote_project(
)
});
+ let client_ssh = SshRemoteClient::fake_client(port, cx_a).await;
let (project_a, worktree_id) = client_a
.build_ssh_project("/code/project1", client_ssh, cx_a)
.await;
@@ -1243,6 +1243,10 @@ impl Project {
self.client.clone()
}
+ pub fn ssh_client(&self) -> Option<Model<SshRemoteClient>> {
+ self.ssh_client.clone()
+ }
+
pub fn user_store(&self) -> Model<UserStore> {
self.user_store.clone()
}
@@ -12,6 +12,7 @@ message Envelope {
uint32 id = 1;
optional uint32 responding_to = 2;
optional PeerId original_sender_id = 3;
+ optional uint32 ack_id = 266;
oneof payload {
Hello hello = 4;
@@ -295,7 +296,9 @@ message Envelope {
OpenServerSettings open_server_settings = 263;
GetPermalinkToLine get_permalink_to_line = 264;
- GetPermalinkToLineResponse get_permalink_to_line_response = 265; // current max
+ GetPermalinkToLineResponse get_permalink_to_line_response = 265;
+
+ FlushBufferedMessages flush_buffered_messages = 267;
}
reserved 87 to 88;
@@ -2522,3 +2525,6 @@ message GetPermalinkToLine {
message GetPermalinkToLineResponse {
string permalink = 1;
}
+
+message FlushBufferedMessages {}
+message FlushBufferedMessagesResponse {}
@@ -32,6 +32,7 @@ macro_rules! messages {
responding_to,
original_sender_id,
payload: Some(envelope::Payload::$name(self)),
+ ack_id: None,
}
}
@@ -372,6 +372,7 @@ messages!(
(OpenServerSettings, Foreground),
(GetPermalinkToLine, Foreground),
(GetPermalinkToLineResponse, Foreground),
+ (FlushBufferedMessages, Foreground),
);
request_messages!(
@@ -498,6 +499,7 @@ request_messages!(
(RemoveWorktree, Ack),
(OpenServerSettings, OpenBufferResponse),
(GetPermalinkToLine, GetPermalinkToLineResponse),
+ (FlushBufferedMessages, Ack),
);
entity_messages!(
@@ -19,6 +19,7 @@ test-support = ["fs/test-support"]
[dependencies]
anyhow.workspace = true
+async-trait.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
@@ -6,6 +6,7 @@ use crate::{
proxy::ProxyLaunchError,
};
use anyhow::{anyhow, Context as _, Result};
+use async_trait::async_trait;
use collections::HashMap;
use futures::{
channel::{
@@ -13,7 +14,7 @@ use futures::{
oneshot,
},
future::BoxFuture,
- select_biased, AsyncReadExt as _, Future, FutureExt as _, SinkExt, StreamExt as _,
+ select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
};
use gpui::{
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
@@ -30,13 +31,14 @@ use smol::{
};
use std::{
any::TypeId,
+ collections::VecDeque,
ffi::OsStr,
fmt,
ops::ControlFlow,
path::{Path, PathBuf},
sync::{
atomic::{AtomicU32, Ordering::SeqCst},
- Arc,
+ Arc, Weak,
},
time::{Duration, Instant},
};
@@ -275,68 +277,6 @@ async fn run_cmd(command: &mut process::Command) -> Result<String> {
}
}
-struct ChannelForwarder {
- quit_tx: UnboundedSender<()>,
- forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
-}
-
-impl ChannelForwarder {
- fn new(
- mut incoming_tx: UnboundedSender<Envelope>,
- mut outgoing_rx: UnboundedReceiver<Envelope>,
- cx: &AsyncAppContext,
- ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
- let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
-
- let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
- let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
-
- let forwarding_task = cx.background_executor().spawn(async move {
- loop {
- select_biased! {
- _ = quit_rx.next().fuse() => {
- break;
- },
- incoming_envelope = proxy_incoming_rx.next().fuse() => {
- if let Some(envelope) = incoming_envelope {
- if incoming_tx.send(envelope).await.is_err() {
- break;
- }
- } else {
- break;
- }
- }
- outgoing_envelope = outgoing_rx.next().fuse() => {
- if let Some(envelope) = outgoing_envelope {
- if proxy_outgoing_tx.send(envelope).await.is_err() {
- break;
- }
- } else {
- break;
- }
- }
- }
- }
-
- (incoming_tx, outgoing_rx)
- });
-
- (
- Self {
- forwarding_task,
- quit_tx,
- },
- proxy_incoming_tx,
- proxy_outgoing_rx,
- )
- }
-
- async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
- let _ = self.quit_tx.send(()).await;
- self.forwarding_task.await
- }
-}
-
const MAX_MISSED_HEARTBEATS: usize = 5;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(5);
@@ -346,9 +286,8 @@ const MAX_RECONNECT_ATTEMPTS: usize = 3;
enum State {
Connecting,
Connected {
- ssh_connection: SshRemoteConnection,
+ ssh_connection: Box<dyn SshRemoteProcess>,
delegate: Arc<dyn SshClientDelegate>,
- forwarder: ChannelForwarder,
multiplex_task: Task<Result<()>>,
heartbeat_task: Task<Result<()>>,
@@ -356,18 +295,16 @@ enum State {
HeartbeatMissed {
missed_heartbeats: usize,
- ssh_connection: SshRemoteConnection,
+ ssh_connection: Box<dyn SshRemoteProcess>,
delegate: Arc<dyn SshClientDelegate>,
- forwarder: ChannelForwarder,
multiplex_task: Task<Result<()>>,
heartbeat_task: Task<Result<()>>,
},
Reconnecting,
ReconnectFailed {
- ssh_connection: SshRemoteConnection,
+ ssh_connection: Box<dyn SshRemoteProcess>,
delegate: Arc<dyn SshClientDelegate>,
- forwarder: ChannelForwarder,
error: anyhow::Error,
attempts: usize,
@@ -391,11 +328,11 @@ impl fmt::Display for State {
}
impl State {
- fn ssh_connection(&self) -> Option<&SshRemoteConnection> {
+ fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> {
match self {
- Self::Connected { ssh_connection, .. } => Some(ssh_connection),
- Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection),
- Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection),
+ Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
+ Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
+ Self::ReconnectFailed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
_ => None,
}
}
@@ -429,14 +366,12 @@ impl State {
Self::HeartbeatMissed {
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task,
..
} => Self::Connected {
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task,
},
@@ -449,14 +384,12 @@ impl State {
Self::Connected {
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task,
} => Self::HeartbeatMissed {
missed_heartbeats: 1,
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task,
},
@@ -464,14 +397,12 @@ impl State {
missed_heartbeats,
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task,
} => Self::HeartbeatMissed {
missed_heartbeats: missed_heartbeats + 1,
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task,
},
@@ -529,7 +460,8 @@ impl SshRemoteClient {
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
- let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
+ let client =
+ cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "client"))?;
let this = cx.new_model(|_| Self {
client: client.clone(),
unique_identifier: unique_identifier.clone(),
@@ -537,26 +469,19 @@ impl SshRemoteClient {
state: Arc::new(Mutex::new(Some(State::Connecting))),
})?;
- let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
- ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
-
- let (ssh_connection, ssh_proxy_process) = Self::establish_connection(
+ let (ssh_connection, io_task) = Self::establish_connection(
unique_identifier,
false,
connection_options,
+ incoming_tx,
+ outgoing_rx,
+ connection_activity_tx,
delegate.clone(),
&mut cx,
)
.await?;
- let multiplex_task = Self::multiplex(
- this.downgrade(),
- ssh_proxy_process,
- proxy_incoming_tx,
- proxy_outgoing_rx,
- connection_activity_tx,
- &mut cx,
- );
+ let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
log::error!("failed to establish connection: {}", error);
@@ -570,7 +495,6 @@ impl SshRemoteClient {
*this.state.lock() = Some(State::Connected {
ssh_connection,
delegate,
- forwarder: proxy,
multiplex_task,
heartbeat_task,
});
@@ -592,7 +516,6 @@ impl SshRemoteClient {
heartbeat_task,
ssh_connection,
delegate,
- forwarder,
} = state
else {
return None;
@@ -616,7 +539,6 @@ impl SshRemoteClient {
drop(heartbeat_task);
drop(ssh_connection);
drop(delegate);
- drop(forwarder);
})
}
@@ -638,33 +560,30 @@ impl SshRemoteClient {
}
let state = lock.take().unwrap();
- let (attempts, mut ssh_connection, delegate, forwarder) = match state {
+ let (attempts, mut ssh_connection, delegate) = match state {
State::Connected {
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task,
}
| State::HeartbeatMissed {
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task,
..
} => {
drop(multiplex_task);
drop(heartbeat_task);
- (0, ssh_connection, delegate, forwarder)
+ (0, ssh_connection, delegate)
}
State::ReconnectFailed {
attempts,
ssh_connection,
delegate,
- forwarder,
..
- } => (attempts, ssh_connection, delegate, forwarder),
+ } => (attempts, ssh_connection, delegate),
State::Connecting
| State::Reconnecting
| State::ReconnectExhausted
@@ -691,41 +610,37 @@ impl SshRemoteClient {
let client = self.client.clone();
let reconnect_task = cx.spawn(|this, mut cx| async move {
macro_rules! failed {
- ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr, $forwarder:expr) => {
+ ($error:expr, $attempts:expr, $ssh_connection:expr, $delegate:expr) => {
return State::ReconnectFailed {
error: anyhow!($error),
attempts: $attempts,
ssh_connection: $ssh_connection,
delegate: $delegate,
- forwarder: $forwarder,
};
};
}
- if let Err(error) = ssh_connection.master_process.kill() {
- failed!(error, attempts, ssh_connection, delegate, forwarder);
- };
-
if let Err(error) = ssh_connection
- .master_process
- .status()
+ .kill()
.await
.context("Failed to kill ssh process")
{
- failed!(error, attempts, ssh_connection, delegate, forwarder);
- }
+ failed!(error, attempts, ssh_connection, delegate);
+ };
- let connection_options = ssh_connection.socket.connection_options.clone();
+ let connection_options = ssh_connection.connection_options();
- let (incoming_tx, outgoing_rx) = forwarder.into_channels().await;
- let (forwarder, proxy_incoming_tx, proxy_outgoing_rx) =
- ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+ let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
+ let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
- let (ssh_connection, ssh_process) = match Self::establish_connection(
+ let (ssh_connection, io_task) = match Self::establish_connection(
identifier,
true,
connection_options,
+ incoming_tx,
+ outgoing_rx,
+ connection_activity_tx,
delegate.clone(),
&mut cx,
)
@@ -733,27 +648,20 @@ impl SshRemoteClient {
{
Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
Err(error) => {
- failed!(error, attempts, ssh_connection, delegate, forwarder);
+ failed!(error, attempts, ssh_connection, delegate);
}
};
- let multiplex_task = Self::multiplex(
- this.clone(),
- ssh_process,
- proxy_incoming_tx,
- proxy_outgoing_rx,
- connection_activity_tx,
- &mut cx,
- );
+ let multiplex_task = Self::monitor(this.clone(), io_task, &cx);
+ client.reconnect(incoming_rx, outgoing_tx, &cx);
- if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
- failed!(error, attempts, ssh_connection, delegate, forwarder);
+ if let Err(error) = client.resync(HEARTBEAT_TIMEOUT).await {
+ failed!(error, attempts, ssh_connection, delegate);
};
State::Connected {
ssh_connection,
delegate,
- forwarder,
multiplex_task,
heartbeat_task: Self::heartbeat(this.clone(), connection_activity_rx, &mut cx),
}
@@ -797,7 +705,7 @@ impl SshRemoteClient {
cx.emit(SshRemoteEvent::Disconnected);
Ok(())
} else {
- log::debug!("State has transition from Reconnecting into new state while attempting reconnect. Ignoring new state.");
+ log::debug!("State has transition from Reconnecting into new state while attempting reconnect.");
Ok(())
}
})
@@ -910,13 +818,12 @@ impl SshRemoteClient {
}
fn multiplex(
- this: WeakModel<Self>,
mut ssh_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>,
mut outgoing_rx: UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>,
cx: &AsyncAppContext,
- ) -> Task<Result<()>> {
+ ) -> Task<Result<i32>> {
let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
@@ -988,7 +895,7 @@ impl SshRemoteClient {
}
});
- cx.spawn(|mut cx| async move {
+ cx.spawn(|_| async move {
let result = futures::select! {
result = stdin_task.fuse() => {
result.context("stdin")
@@ -1002,9 +909,22 @@ impl SshRemoteClient {
};
match result {
- Ok(_) => {
- let exit_code = ssh_proxy_process.status().await?.code().unwrap_or(1);
+ Ok(_) => Ok(ssh_proxy_process.status().await?.code().unwrap_or(1)),
+ Err(error) => Err(error),
+ }
+ })
+ }
+ fn monitor(
+ this: WeakModel<Self>,
+ io_task: Task<Result<i32>>,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<()>> {
+ cx.spawn(|mut cx| async move {
+ let result = io_task.await;
+
+ match result {
+ Ok(exit_code) => {
if let Some(error) = ProxyLaunchError::from_exit_code(exit_code) {
match error {
ProxyLaunchError::ServerNotRunning => {
@@ -1058,21 +978,40 @@ impl SshRemoteClient {
cx.notify();
}
+ #[allow(clippy::too_many_arguments)]
async fn establish_connection(
unique_identifier: String,
reconnect: bool,
connection_options: SshConnectionOptions,
+ incoming_tx: UnboundedSender<Envelope>,
+ outgoing_rx: UnboundedReceiver<Envelope>,
+ connection_activity_tx: Sender<()>,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
- ) -> Result<(SshRemoteConnection, Child)> {
+ ) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<i32>>)> {
+ #[cfg(any(test, feature = "test-support"))]
+ if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) {
+ let io_task = fake::SshRemoteConnection::multiplex(
+ fake.connection_options(),
+ incoming_tx,
+ outgoing_rx,
+ connection_activity_tx,
+ cx,
+ )
+ .await;
+ return Ok((fake, io_task));
+ }
+
let ssh_connection =
SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
let platform = ssh_connection.query_platform().await?;
let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
- ssh_connection
- .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
- .await?;
+ if !reconnect {
+ ssh_connection
+ .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
+ .await?;
+ }
let socket = ssh_connection.socket.clone();
run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
@@ -1097,7 +1036,15 @@ impl SshRemoteClient {
.spawn()
.context("failed to spawn remote server")?;
- Ok((ssh_connection, ssh_proxy_process))
+ let io_task = Self::multiplex(
+ ssh_proxy_process,
+ incoming_tx,
+ outgoing_rx,
+ connection_activity_tx,
+ &cx,
+ );
+
+ Ok((Box::new(ssh_connection), io_task))
}
pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
@@ -1109,7 +1056,7 @@ impl SshRemoteClient {
.lock()
.as_ref()
.and_then(|state| state.ssh_connection())
- .map(|ssh_connection| ssh_connection.socket.ssh_args())
+ .map(|ssh_connection| ssh_connection.ssh_args())
}
pub fn proto_client(&self) -> AnyProtoClient {
@@ -1124,7 +1071,6 @@ impl SshRemoteClient {
self.connection_options.clone()
}
- #[cfg(not(any(test, feature = "test-support")))]
pub fn connection_state(&self) -> ConnectionState {
self.state
.lock()
@@ -1133,37 +1079,59 @@ impl SshRemoteClient {
.unwrap_or(ConnectionState::Disconnected)
}
- #[cfg(any(test, feature = "test-support"))]
- pub fn connection_state(&self) -> ConnectionState {
- ConnectionState::Connected
- }
-
pub fn is_disconnected(&self) -> bool {
self.connection_state() == ConnectionState::Disconnected
}
#[cfg(any(test, feature = "test-support"))]
- pub fn fake(
+ pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
+ let port = self.connection_options().port.unwrap();
+ client_cx.spawn(|cx| async move {
+ let (channel, server_cx) = cx
+ .update_global(|c: &mut fake::ServerConnections, _| c.get(port))
+ .unwrap();
+
+ let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
+ let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
+ channel.reconnect(incoming_rx, outgoing_tx, &server_cx);
+ })
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn fake_server(
client_cx: &mut gpui::TestAppContext,
server_cx: &mut gpui::TestAppContext,
- ) -> (Model<Self>, Arc<ChannelClient>) {
- use gpui::Context;
-
- let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
- let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
-
- (
- client_cx.update(|cx| {
- let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
- cx.new_model(|_| Self {
- client,
- unique_identifier: "fake".to_string(),
- connection_options: SshConnectionOptions::default(),
- state: Arc::new(Mutex::new(None)),
- })
- }),
- server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
- )
+ ) -> (u16, Arc<ChannelClient>) {
+ use gpui::BorrowAppContext;
+ let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
+ let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
+ let server_client =
+ server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
+ let port = client_cx.update(|cx| {
+ cx.update_default_global(|c: &mut fake::ServerConnections, _| {
+ c.push(server_client.clone(), server_cx.to_async())
+ })
+ });
+ (port, server_client)
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub async fn fake_client(port: u16, client_cx: &mut gpui::TestAppContext) -> Model<Self> {
+ client_cx
+ .update(|cx| {
+ Self::new(
+ "fake".to_string(),
+ SshConnectionOptions {
+ host: "<fake>".to_string(),
+ port: Some(port),
+ ..Default::default()
+ },
+ Arc::new(fake::Delegate),
+ cx,
+ )
+ })
+ .await
+ .unwrap()
}
}
@@ -1173,6 +1141,13 @@ impl From<SshRemoteClient> for AnyProtoClient {
}
}
+#[async_trait]
+trait SshRemoteProcess: Send + Sync {
+ async fn kill(&mut self) -> Result<()>;
+ fn ssh_args(&self) -> Vec<String>;
+ fn connection_options(&self) -> SshConnectionOptions;
+}
+
struct SshRemoteConnection {
socket: SshSocket,
master_process: process::Child,
@@ -1187,6 +1162,25 @@ impl Drop for SshRemoteConnection {
}
}
+#[async_trait]
+impl SshRemoteProcess for SshRemoteConnection {
+ async fn kill(&mut self) -> Result<()> {
+ self.master_process.kill()?;
+
+ self.master_process.status().await?;
+
+ Ok(())
+ }
+
+ fn ssh_args(&self) -> Vec<String> {
+ self.socket.ssh_args()
+ }
+
+ fn connection_options(&self) -> SshConnectionOptions {
+ self.socket.connection_options.clone()
+ }
+}
+
impl SshRemoteConnection {
#[cfg(not(unix))]
async fn new(
@@ -1469,9 +1463,13 @@ type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, ones
pub struct ChannelClient {
next_message_id: AtomicU32,
- outgoing_tx: mpsc::UnboundedSender<Envelope>,
- response_channels: ResponseChannels, // Lock
- message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
+ outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
+ buffer: Mutex<VecDeque<Envelope>>,
+ response_channels: ResponseChannels,
+ message_handlers: Mutex<ProtoMessageHandlerSet>,
+ max_received: AtomicU32,
+ name: &'static str,
+ task: Mutex<Task<Result<()>>>,
}
impl ChannelClient {
@@ -1479,32 +1477,59 @@ impl ChannelClient {
incoming_rx: mpsc::UnboundedReceiver<Envelope>,
outgoing_tx: mpsc::UnboundedSender<Envelope>,
cx: &AppContext,
+ name: &'static str,
) -> Arc<Self> {
- let this = Arc::new(Self {
- outgoing_tx,
+ Arc::new_cyclic(|this| Self {
+ outgoing_tx: Mutex::new(outgoing_tx),
next_message_id: AtomicU32::new(0),
+ max_received: AtomicU32::new(0),
response_channels: ResponseChannels::default(),
message_handlers: Default::default(),
- });
-
- Self::start_handling_messages(this.clone(), incoming_rx, cx);
-
- this
+ buffer: Mutex::new(VecDeque::new()),
+ name,
+ task: Mutex::new(Self::start_handling_messages(
+ this.clone(),
+ incoming_rx,
+ &cx.to_async(),
+ )),
+ })
}
fn start_handling_messages(
- this: Arc<Self>,
+ this: Weak<Self>,
mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
- cx: &AppContext,
- ) {
+ cx: &AsyncAppContext,
+ ) -> Task<Result<()>> {
cx.spawn(|cx| {
- let this = Arc::downgrade(&this);
async move {
let peer_id = PeerId { owner_id: 0, id: 0 };
while let Some(incoming) = incoming_rx.next().await {
let Some(this) = this.upgrade() else {
return anyhow::Ok(());
};
+ if let Some(ack_id) = incoming.ack_id {
+ let mut buffer = this.buffer.lock();
+ while buffer.front().is_some_and(|msg| msg.id <= ack_id) {
+ buffer.pop_front();
+ }
+ }
+ if let Some(proto::envelope::Payload::FlushBufferedMessages(_)) =
+ &incoming.payload
+ {
+ log::debug!("{}:ssh message received. name:FlushBufferedMessages", this.name);
+ {
+ let buffer = this.buffer.lock();
+ for envelope in buffer.iter() {
+ this.outgoing_tx.lock().unbounded_send(envelope.clone()).ok();
+ }
+ }
+ let mut envelope = proto::Ack{}.into_envelope(0, Some(incoming.id), None);
+ envelope.id = this.next_message_id.fetch_add(1, SeqCst);
+ this.outgoing_tx.lock().unbounded_send(envelope).ok();
+ continue;
+ }
+
+ this.max_received.store(incoming.id, SeqCst);
if let Some(request_id) = incoming.responding_to {
let request_id = MessageId(request_id);
@@ -1526,26 +1551,37 @@ impl ChannelClient {
this.clone().into(),
cx.clone(),
) {
- log::debug!("ssh message received. name:{type_name}");
- match future.await {
- Ok(_) => {
- log::debug!("ssh message handled. name:{type_name}");
+ log::debug!("{}:ssh message received. name:{type_name}", this.name);
+ cx.foreground_executor().spawn(async move {
+ match future.await {
+ Ok(_) => {
+ log::debug!("{}:ssh message handled. name:{type_name}", this.name);
+ }
+ Err(error) => {
+ log::error!(
+ "{}:error handling message. type:{type_name}, error:{error}", this.name,
+ );
+ }
}
- Err(error) => {
- log::error!(
- "error handling message. type:{type_name}, error:{error}",
- );
- }
- }
+ }).detach()
} else {
- log::error!("unhandled ssh message name:{type_name}");
+ log::error!("{}:unhandled ssh message name:{type_name}", this.name);
}
}
}
anyhow::Ok(())
}
})
- .detach();
+ }
+
+ pub fn reconnect(
+ self: &Arc<Self>,
+ incoming_rx: UnboundedReceiver<Envelope>,
+ outgoing_tx: UnboundedSender<Envelope>,
+ cx: &AsyncAppContext,
+ ) {
+ *self.outgoing_tx.lock() = outgoing_tx;
+ *self.task.lock() = Self::start_handling_messages(Arc::downgrade(self), incoming_rx, cx);
}
pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
@@ -1581,6 +1617,26 @@ impl ChannelClient {
}
}
+ pub async fn resync(&self, timeout: Duration) -> Result<()> {
+ smol::future::or(
+ async {
+ self.request(proto::FlushBufferedMessages {}).await?;
+ for envelope in self.buffer.lock().iter() {
+ self.outgoing_tx
+ .lock()
+ .unbounded_send(envelope.clone())
+ .ok();
+ }
+ Ok(())
+ },
+ async {
+ smol::Timer::after(timeout).await;
+ Err(anyhow!("Timeout detected"))
+ },
+ )
+ .await
+ }
+
pub async fn ping(&self, timeout: Duration) -> Result<()> {
smol::future::or(
async {
@@ -1610,7 +1666,8 @@ impl ChannelClient {
let mut response_channels_lock = self.response_channels.lock();
response_channels_lock.insert(MessageId(envelope.id), tx);
drop(response_channels_lock);
- let result = self.outgoing_tx.unbounded_send(envelope);
+
+ let result = self.send_buffered(envelope);
async move {
if let Err(error) = &result {
log::error!("failed to send message: {}", error);
@@ -1627,7 +1684,15 @@ impl ChannelClient {
pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
envelope.id = self.next_message_id.fetch_add(1, SeqCst);
- self.outgoing_tx.unbounded_send(envelope)?;
+ self.send_buffered(envelope)
+ }
+
+ pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
+ envelope.ack_id = Some(self.max_received.load(SeqCst));
+ self.buffer.lock().push_back(envelope.clone());
+ // ignore errors on send (happen while we're reconnecting)
+ // assume that the global "disconnected" overlay is sufficient.
+ self.outgoing_tx.lock().unbounded_send(envelope).ok();
Ok(())
}
}
@@ -1657,3 +1722,148 @@ impl ProtoClient for ChannelClient {
false
}
}
+
+#[cfg(any(test, feature = "test-support"))]
+mod fake {
+ use std::{path::PathBuf, sync::Arc};
+
+ use anyhow::Result;
+ use async_trait::async_trait;
+ use futures::{
+ channel::{
+ mpsc::{self, Sender},
+ oneshot,
+ },
+ select_biased, FutureExt, SinkExt, StreamExt,
+ };
+ use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task};
+ use rpc::proto::Envelope;
+
+ use super::{
+ ChannelClient, SshClientDelegate, SshConnectionOptions, SshPlatform, SshRemoteProcess,
+ };
+
+ pub(super) struct SshRemoteConnection {
+ connection_options: SshConnectionOptions,
+ }
+
+ impl SshRemoteConnection {
+ pub(super) fn new(
+ connection_options: &SshConnectionOptions,
+ ) -> Option<Box<dyn SshRemoteProcess>> {
+ if connection_options.host == "<fake>" {
+ return Some(Box::new(Self {
+ connection_options: connection_options.clone(),
+ }));
+ }
+ return None;
+ }
+ pub(super) async fn multiplex(
+ connection_options: SshConnectionOptions,
+ mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
+ mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
+ mut connection_activity_tx: Sender<()>,
+ cx: &mut AsyncAppContext,
+ ) -> Task<Result<i32>> {
+ let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
+ let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
+
+ let (channel, server_cx) = cx
+ .update(|cx| {
+ cx.update_global(|conns: &mut ServerConnections, _| {
+ conns.get(connection_options.port.unwrap())
+ })
+ })
+ .unwrap();
+ channel.reconnect(server_incoming_rx, server_outgoing_tx, &server_cx);
+
+ // send to proxy_tx to get to the server.
+ // receive from
+
+ cx.background_executor().spawn(async move {
+ loop {
+ select_biased! {
+ server_to_client = server_outgoing_rx.next().fuse() => {
+ let Some(server_to_client) = server_to_client else {
+ return Ok(1)
+ };
+ connection_activity_tx.try_send(()).ok();
+ client_incoming_tx.send(server_to_client).await.ok();
+ }
+ client_to_server = client_outgoing_rx.next().fuse() => {
+ let Some(client_to_server) = client_to_server else {
+ return Ok(1)
+ };
+ server_incoming_tx.send(client_to_server).await.ok();
+ }
+ }
+ }
+ })
+ }
+ }
+
+ #[async_trait]
+ impl SshRemoteProcess for SshRemoteConnection {
+ async fn kill(&mut self) -> Result<()> {
+ Ok(())
+ }
+
+ fn ssh_args(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn connection_options(&self) -> SshConnectionOptions {
+ self.connection_options.clone()
+ }
+ }
+
+ #[derive(Default)]
+ pub(super) struct ServerConnections(Vec<(Arc<ChannelClient>, AsyncAppContext)>);
+ impl Global for ServerConnections {}
+
+ impl ServerConnections {
+ pub(super) fn push(&mut self, server: Arc<ChannelClient>, cx: AsyncAppContext) -> u16 {
+ self.0.push((server.clone(), cx));
+ self.0.len() as u16 - 1
+ }
+
+ pub(super) fn get(&mut self, port: u16) -> (Arc<ChannelClient>, AsyncAppContext) {
+ self.0
+ .get(port as usize)
+ .expect("no fake server for port")
+ .clone()
+ }
+ }
+
+ pub(super) struct Delegate;
+
+ impl SshClientDelegate for Delegate {
+ fn ask_password(
+ &self,
+ _: String,
+ _: &mut AsyncAppContext,
+ ) -> oneshot::Receiver<Result<String>> {
+ unreachable!()
+ }
+ fn remote_server_binary_path(
+ &self,
+ _: SshPlatform,
+ _: &mut AsyncAppContext,
+ ) -> Result<PathBuf> {
+ unreachable!()
+ }
+ fn get_server_binary(
+ &self,
+ _: SshPlatform,
+ _: &mut AsyncAppContext,
+ ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
+ unreachable!()
+ }
+ fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {
+ unreachable!()
+ }
+ fn set_error(&self, _: String, _: &mut AsyncAppContext) {
+ unreachable!()
+ }
+ }
+}
@@ -641,6 +641,47 @@ async fn test_open_server_settings(cx: &mut TestAppContext, server_cx: &mut Test
})
}
+#[gpui::test(iterations = 20)]
+async fn test_reconnect(cx: &mut TestAppContext, server_cx: &mut TestAppContext) {
+ let (project, _headless, fs) = init_test(cx, server_cx).await;
+
+ let (worktree, _) = project
+ .update(cx, |project, cx| {
+ project.find_or_create_worktree("/code/project1", true, cx)
+ })
+ .await
+ .unwrap();
+
+ let worktree_id = worktree.read_with(cx, |worktree, _| worktree.id());
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer((worktree_id, Path::new("src/lib.rs")), cx)
+ })
+ .await
+ .unwrap();
+
+ buffer.update(cx, |buffer, cx| {
+ assert_eq!(buffer.text(), "fn one() -> usize { 1 }");
+ let ix = buffer.text().find('1').unwrap();
+ buffer.edit([(ix..ix + 1, "100")], None, cx);
+ });
+
+ let client = cx.read(|cx| project.read(cx).ssh_client().unwrap());
+ client
+ .update(cx, |client, cx| client.simulate_disconnect(cx))
+ .detach();
+
+ project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+ .await
+ .unwrap();
+
+ assert_eq!(
+ fs.load("/code/project1/src/lib.rs".as_ref()).await.unwrap(),
+ "fn one() -> usize { 100 }"
+ );
+}
+
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::try_init().ok();
@@ -651,9 +692,9 @@ async fn init_test(
cx: &mut TestAppContext,
server_cx: &mut TestAppContext,
) -> (Model<Project>, Model<HeadlessProject>, Arc<FakeFs>) {
- let (ssh_remote_client, ssh_server_client) = SshRemoteClient::fake(cx, server_cx);
init_logger();
+ let (forwarder, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx);
let fs = FakeFs::new(server_cx.executor());
fs.insert_tree(
"/code",
@@ -694,8 +735,9 @@ async fn init_test(
cx,
)
});
- let project = build_project(ssh_remote_client, cx);
+ let ssh = SshRemoteClient::fake_client(forwarder, cx).await;
+ let project = build_project(ssh, cx);
project
.update(cx, {
let headless = headless.clone();
@@ -279,7 +279,7 @@ fn start_server(
})
.detach();
- ChannelClient::new(incoming_rx, outgoing_tx, cx)
+ ChannelClient::new(incoming_rx, outgoing_tx, cx, "server")
}
fn init_paths() -> anyhow::Result<()> {