From d45b830412a9a3099c77a00bea1f9fc11de57580 Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Thu, 24 Oct 2024 14:37:54 -0600 Subject: [PATCH] SSH connection pooling (#19692) Co-Authored-By: Max Closes #ISSUE Release Notes: - SSH Remoting: Reuse connections across hosts --------- Co-authored-by: Max --- .../remote_editing_collaboration_tests.rs | 4 +- crates/recent_projects/src/remote_servers.rs | 11 +- crates/remote/src/ssh_session.rs | 804 +++++++++++------- .../remote_server/src/remote_editing_tests.rs | 4 +- 4 files changed, 484 insertions(+), 339 deletions(-) diff --git a/crates/collab/src/tests/remote_editing_collaboration_tests.rs b/crates/collab/src/tests/remote_editing_collaboration_tests.rs index 52086c856c2884c11403b74a333d177bdd20556f..0e13c88d9464ea53b2d9dc5a0d16067a05611108 100644 --- a/crates/collab/src/tests/remote_editing_collaboration_tests.rs +++ b/crates/collab/src/tests/remote_editing_collaboration_tests.rs @@ -26,7 +26,7 @@ async fn test_sharing_an_ssh_remote_project( .await; // Set up project on remote FS - let (port, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); + let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx); let remote_fs = FakeFs::new(server_cx.executor()); remote_fs .insert_tree( @@ -67,7 +67,7 @@ async fn test_sharing_an_ssh_remote_project( ) }); - let client_ssh = SshRemoteClient::fake_client(port, cx_a).await; + let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await; let (project_a, worktree_id) = client_a .build_ssh_project("/code/project1", client_ssh, cx_a) .await; diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index 7081afc903902d037c7768ce7fc90ea7f35bc2b4..d7f3beccb21388c2f5ef6c96181a2217ff21fbb3 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -17,6 +17,7 @@ use gpui::{ use picker::Picker; use project::Project; use remote::SshConnectionOptions; +use remote::SshRemoteClient; use settings::update_settings_file; use settings::Settings; use ui::{ @@ -46,6 +47,7 @@ pub struct RemoteServerProjects { scroll_handle: ScrollHandle, workspace: WeakView, selectable_items: SelectableItemList, + retained_connections: Vec>, } struct CreateRemoteServer { @@ -355,6 +357,7 @@ impl RemoteServerProjects { scroll_handle: ScrollHandle::new(), workspace, selectable_items: Default::default(), + retained_connections: Vec::new(), } } @@ -424,7 +427,7 @@ impl RemoteServerProjects { let address_editor = editor.clone(); let creating = cx.spawn(move |this, mut cx| async move { match connection.await { - Some(_) => this + Some(Some(client)) => this .update(&mut cx, |this, cx| { let _ = this.workspace.update(cx, |workspace, _| { workspace @@ -432,14 +435,14 @@ impl RemoteServerProjects { .telemetry() .report_app_event("create ssh server".to_string()) }); - + this.retained_connections.push(client); this.add_ssh_server(connection_options, cx); this.mode = Mode::default_mode(); this.selectable_items.reset_selection(); cx.notify() }) .log_err(), - None => this + _ => this .update(&mut cx, |this, cx| { address_editor.update(cx, |this, _| { this.set_read_only(false); @@ -1056,7 +1059,7 @@ impl RemoteServerProjects { ); cx.spawn(|mut cx| async move { - if confirmation.await.ok() == Some(1) { + if confirmation.await.ok() == Some(0) { remote_servers .update(&mut cx, |this, cx| { this.delete_ssh_server(index, cx); diff --git a/crates/remote/src/ssh_session.rs b/crates/remote/src/ssh_session.rs index f3baa5a286816a7247cbfe7a35ea759a8a2e31de..d47e0375ea75f5b359a42f9779f548e195b2d96b 100644 --- a/crates/remote/src/ssh_session.rs +++ b/crates/remote/src/ssh_session.rs @@ -13,17 +13,18 @@ use futures::{ mpsc::{self, Sender, UnboundedReceiver, UnboundedSender}, oneshot, }, - future::BoxFuture, + future::{BoxFuture, Shared}, select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _, }; use gpui::{ - AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task, - WeakModel, + AppContext, AsyncAppContext, BorrowAppContext, Context, EventEmitter, Global, Model, + ModelContext, SemanticVersion, Task, WeakModel, }; use parking_lot::Mutex; use rpc::{ proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage}, - AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError, + AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet, + RpcError, }; use smol::{ fs, @@ -56,7 +57,7 @@ pub struct SshSocket { socket_path: PathBuf, } -#[derive(Debug, Default, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] pub struct SshConnectionOptions { pub host: String, pub username: Option, @@ -290,7 +291,7 @@ const MAX_RECONNECT_ATTEMPTS: usize = 3; enum State { Connecting, Connected { - ssh_connection: Box, + ssh_connection: Arc, delegate: Arc, multiplex_task: Task>, @@ -299,7 +300,7 @@ enum State { HeartbeatMissed { missed_heartbeats: usize, - ssh_connection: Box, + ssh_connection: Arc, delegate: Arc, multiplex_task: Task>, @@ -307,7 +308,7 @@ enum State { }, Reconnecting, ReconnectFailed { - ssh_connection: Box, + ssh_connection: Arc, delegate: Arc, error: anyhow::Error, @@ -332,7 +333,7 @@ impl fmt::Display for State { } impl State { - fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> { + fn ssh_connection(&self) -> Option<&dyn RemoteConnection> { match self { Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()), Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()), @@ -462,7 +463,7 @@ impl SshRemoteClient { connection_options: SshConnectionOptions, cancellation: oneshot::Receiver<()>, delegate: Arc, - cx: &AppContext, + cx: &mut AppContext, ) -> Task>>> { cx.spawn(|mut cx| async move { let success = Box::pin(async move { @@ -479,17 +480,28 @@ impl SshRemoteClient { state: Arc::new(Mutex::new(Some(State::Connecting))), })?; - let (ssh_connection, io_task) = Self::establish_connection( + let ssh_connection = cx + .update(|cx| { + cx.update_default_global(|pool: &mut ConnectionPool, cx| { + pool.connect(connection_options, &delegate, cx) + }) + })? + .await + .map_err(|e| e.cloned())?; + let remote_binary_path = ssh_connection + .get_remote_binary_path(&delegate, false, &mut cx) + .await?; + + let io_task = ssh_connection.start_proxy( + remote_binary_path, unique_identifier, false, - connection_options, incoming_tx, outgoing_rx, connection_activity_tx, delegate.clone(), &mut cx, - ) - .await?; + ); let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx); @@ -578,7 +590,7 @@ impl SshRemoteClient { } let state = lock.take().unwrap(); - let (attempts, mut ssh_connection, delegate) = match state { + let (attempts, ssh_connection, delegate) = match state { State::Connected { ssh_connection, delegate, @@ -624,7 +636,7 @@ impl SshRemoteClient { log::info!("Trying to reconnect to ssh server... Attempt {}", attempts); - let identifier = self.unique_identifier.clone(); + let unique_identifier = self.unique_identifier.clone(); let client = self.client.clone(); let reconnect_task = cx.spawn(|this, mut cx| async move { macro_rules! failed { @@ -652,19 +664,33 @@ impl SshRemoteClient { let (incoming_tx, incoming_rx) = mpsc::unbounded::(); let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1); - let (ssh_connection, io_task) = match Self::establish_connection( - identifier, - true, - connection_options, - incoming_tx, - outgoing_rx, - connection_activity_tx, - delegate.clone(), - &mut cx, - ) + let (ssh_connection, io_task) = match async { + let ssh_connection = cx + .update_global(|pool: &mut ConnectionPool, cx| { + pool.connect(connection_options, &delegate, cx) + })? + .await + .map_err(|error| error.cloned())?; + + let remote_binary_path = ssh_connection + .get_remote_binary_path(&delegate, true, &mut cx) + .await?; + + let io_task = ssh_connection.start_proxy( + remote_binary_path, + unique_identifier, + true, + incoming_tx, + outgoing_rx, + connection_activity_tx, + delegate.clone(), + &mut cx, + ); + anyhow::Ok((ssh_connection, io_task)) + } .await { - Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process), + Ok((ssh_connection, io_task)) => (ssh_connection, io_task), Err(error) => { failed!(error, attempts, ssh_connection, delegate); } @@ -834,108 +860,6 @@ impl SshRemoteClient { } } - fn multiplex( - mut ssh_proxy_process: Child, - incoming_tx: UnboundedSender, - mut outgoing_rx: UnboundedReceiver, - mut connection_activity_tx: Sender<()>, - cx: &AsyncAppContext, - ) -> Task> { - let mut child_stderr = ssh_proxy_process.stderr.take().unwrap(); - let mut child_stdout = ssh_proxy_process.stdout.take().unwrap(); - let mut child_stdin = ssh_proxy_process.stdin.take().unwrap(); - - let mut stdin_buffer = Vec::new(); - let mut stdout_buffer = Vec::new(); - let mut stderr_buffer = Vec::new(); - let mut stderr_offset = 0; - - let stdin_task = cx.background_executor().spawn(async move { - while let Some(outgoing) = outgoing_rx.next().await { - write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?; - } - anyhow::Ok(()) - }); - - let stdout_task = cx.background_executor().spawn({ - let mut connection_activity_tx = connection_activity_tx.clone(); - async move { - loop { - stdout_buffer.resize(MESSAGE_LEN_SIZE, 0); - let len = child_stdout.read(&mut stdout_buffer).await?; - - if len == 0 { - return anyhow::Ok(()); - } - - if len < MESSAGE_LEN_SIZE { - child_stdout.read_exact(&mut stdout_buffer[len..]).await?; - } - - let message_len = message_len_from_buffer(&stdout_buffer); - let envelope = - read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len) - .await?; - connection_activity_tx.try_send(()).ok(); - incoming_tx.unbounded_send(envelope).ok(); - } - } - }); - - let stderr_task: Task> = cx.background_executor().spawn(async move { - loop { - stderr_buffer.resize(stderr_offset + 1024, 0); - - let len = child_stderr - .read(&mut stderr_buffer[stderr_offset..]) - .await?; - if len == 0 { - return anyhow::Ok(()); - } - - stderr_offset += len; - let mut start_ix = 0; - while let Some(ix) = stderr_buffer[start_ix..stderr_offset] - .iter() - .position(|b| b == &b'\n') - { - let line_ix = start_ix + ix; - let content = &stderr_buffer[start_ix..line_ix]; - start_ix = line_ix + 1; - if let Ok(record) = serde_json::from_slice::(content) { - record.log(log::logger()) - } else { - eprintln!("(remote) {}", String::from_utf8_lossy(content)); - } - } - stderr_buffer.drain(0..start_ix); - stderr_offset -= start_ix; - - connection_activity_tx.try_send(()).ok(); - } - }); - - cx.spawn(|_| async move { - let result = futures::select! { - result = stdin_task.fuse() => { - result.context("stdin") - } - result = stdout_task.fuse() => { - result.context("stdout") - } - result = stderr_task.fuse() => { - result.context("stderr") - } - }; - - let status = ssh_proxy_process.status().await?.code().unwrap_or(1); - match result { - Ok(_) => Ok(status), - Err(error) => Err(error), - } - }) - } - fn monitor( this: WeakModel, io_task: Task>, @@ -1005,75 +929,6 @@ impl SshRemoteClient { cx.notify(); } - #[allow(clippy::too_many_arguments)] - async fn establish_connection( - unique_identifier: String, - reconnect: bool, - connection_options: SshConnectionOptions, - incoming_tx: UnboundedSender, - outgoing_rx: UnboundedReceiver, - connection_activity_tx: Sender<()>, - delegate: Arc, - cx: &mut AsyncAppContext, - ) -> Result<(Box, Task>)> { - #[cfg(any(test, feature = "test-support"))] - if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) { - let io_task = fake::SshRemoteConnection::multiplex( - fake.connection_options(), - incoming_tx, - outgoing_rx, - connection_activity_tx, - cx, - ) - .await; - return Ok((fake, io_task)); - } - - let ssh_connection = - SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?; - - let platform = ssh_connection.query_platform().await?; - let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?; - if !reconnect { - ssh_connection - .ensure_server_binary(&delegate, &remote_binary_path, platform, cx) - .await?; - } - - let socket = ssh_connection.socket.clone(); - run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; - - delegate.set_status(Some("Starting proxy"), cx); - - let mut start_proxy_command = format!( - "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}", - std::env::var("RUST_LOG").unwrap_or_default(), - std::env::var("RUST_BACKTRACE").unwrap_or_default(), - remote_binary_path, - unique_identifier, - ); - if reconnect { - start_proxy_command.push_str(" --reconnect"); - } - - let ssh_proxy_process = socket - .ssh_command(start_proxy_command) - // IMPORTANT: we kill this process when we drop the task that uses it. - .kill_on_drop(true) - .spawn() - .context("failed to spawn remote server")?; - - let io_task = Self::multiplex( - ssh_proxy_process, - incoming_tx, - outgoing_rx, - connection_activity_tx, - &cx, - ); - - Ok((Box::new(ssh_connection), io_task)) - } - pub fn subscribe_to_entity(&self, remote_id: u64, entity: &Model) { self.client.subscribe_to_entity(remote_id, entity); } @@ -1112,15 +967,21 @@ impl SshRemoteClient { #[cfg(any(test, feature = "test-support"))] pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> { - let port = self.connection_options().port.unwrap(); + let opts = self.connection_options(); client_cx.spawn(|cx| async move { - let (channel, server_cx) = cx - .update_global(|c: &mut fake::ServerConnections, _| c.get(port)) + let connection = cx + .update_global(|c: &mut ConnectionPool, _| { + if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) { + c.clone() + } else { + panic!("missing test connection") + } + }) + .unwrap() + .await .unwrap(); - let (outgoing_tx, _) = mpsc::unbounded::(); - let (_, incoming_rx) = mpsc::unbounded::(); - channel.reconnect(incoming_rx, outgoing_tx, &server_cx); + connection.simulate_disconnect(&cx); }) } @@ -1128,78 +989,190 @@ impl SshRemoteClient { pub fn fake_server( client_cx: &mut gpui::TestAppContext, server_cx: &mut gpui::TestAppContext, - ) -> (u16, Arc) { - use gpui::BorrowAppContext; + ) -> (SshConnectionOptions, Arc) { + let port = client_cx + .update(|cx| cx.default_global::().connections.len() as u16 + 1); + let opts = SshConnectionOptions { + host: "".to_string(), + port: Some(port), + ..Default::default() + }; let (outgoing_tx, _) = mpsc::unbounded::(); let (_, incoming_rx) = mpsc::unbounded::(); let server_client = server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server")); - let port = client_cx.update(|cx| { - cx.update_default_global(|c: &mut fake::ServerConnections, _| { - c.push(server_client.clone(), server_cx.to_async()) + let connection: Arc = Arc::new(fake::FakeRemoteConnection { + connection_options: opts.clone(), + server_cx: fake::SendableCx::new(server_cx.to_async()), + server_channel: server_client.clone(), + }); + + client_cx.update(|cx| { + cx.update_default_global(|c: &mut ConnectionPool, cx| { + c.connections.insert( + opts.clone(), + ConnectionPoolEntry::Connecting( + cx.foreground_executor() + .spawn({ + let connection = connection.clone(); + async move { Ok(connection.clone()) } + }) + .shared(), + ), + ); }) }); - (port, server_client) + + (opts, server_client) } #[cfg(any(test, feature = "test-support"))] - pub async fn fake_client(port: u16, client_cx: &mut gpui::TestAppContext) -> Model { + pub async fn fake_client( + opts: SshConnectionOptions, + client_cx: &mut gpui::TestAppContext, + ) -> Model { let (_tx, rx) = oneshot::channel(); client_cx - .update(|cx| { - Self::new( - "fake".to_string(), - SshConnectionOptions { - host: "".to_string(), - port: Some(port), - ..Default::default() - }, - rx, - Arc::new(fake::Delegate), - cx, - ) - }) + .update(|cx| Self::new("fake".to_string(), opts, rx, Arc::new(fake::Delegate), cx)) .await .unwrap() .unwrap() } } +enum ConnectionPoolEntry { + Connecting(Shared, Arc>>>), + Connected(Weak), +} + +#[derive(Default)] +struct ConnectionPool { + connections: HashMap, +} + +impl Global for ConnectionPool {} + +impl ConnectionPool { + pub fn connect( + &mut self, + opts: SshConnectionOptions, + delegate: &Arc, + cx: &mut AppContext, + ) -> Shared, Arc>>> { + let connection = self.connections.get(&opts); + match connection { + Some(ConnectionPoolEntry::Connecting(task)) => { + let delegate = delegate.clone(); + cx.spawn(|mut cx| async move { + delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx); + }) + .detach(); + return task.clone(); + } + Some(ConnectionPoolEntry::Connected(ssh)) => { + if let Some(ssh) = ssh.upgrade() { + if !ssh.has_been_killed() { + return Task::ready(Ok(ssh)).shared(); + } + } + self.connections.remove(&opts); + } + None => {} + } + + let task = cx + .spawn({ + let opts = opts.clone(); + let delegate = delegate.clone(); + |mut cx| async move { + let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx) + .await + .map(|connection| Arc::new(connection) as Arc); + + cx.update_global(|pool: &mut Self, _| { + debug_assert!(matches!( + pool.connections.get(&opts), + Some(ConnectionPoolEntry::Connecting(_)) + )); + match connection { + Ok(connection) => { + pool.connections.insert( + opts.clone(), + ConnectionPoolEntry::Connected(Arc::downgrade(&connection)), + ); + Ok(connection) + } + Err(error) => { + pool.connections.remove(&opts); + Err(Arc::new(error)) + } + } + })? + } + }) + .shared(); + + self.connections + .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone())); + task + } +} + impl From for AnyProtoClient { fn from(client: SshRemoteClient) -> Self { AnyProtoClient::new(client.client.clone()) } } -#[async_trait] -trait SshRemoteProcess: Send + Sync { - async fn kill(&mut self) -> Result<()>; +#[async_trait(?Send)] +trait RemoteConnection: Send + Sync { + #[allow(clippy::too_many_arguments)] + fn start_proxy( + &self, + remote_binary_path: PathBuf, + unique_identifier: String, + reconnect: bool, + incoming_tx: UnboundedSender, + outgoing_rx: UnboundedReceiver, + connection_activity_tx: Sender<()>, + delegate: Arc, + cx: &mut AsyncAppContext, + ) -> Task>; + async fn get_remote_binary_path( + &self, + delegate: &Arc, + reconnect: bool, + cx: &mut AsyncAppContext, + ) -> Result; + async fn kill(&self) -> Result<()>; + fn has_been_killed(&self) -> bool; fn ssh_args(&self) -> Vec; fn connection_options(&self) -> SshConnectionOptions; + + #[cfg(any(test, feature = "test-support"))] + fn simulate_disconnect(&self, _: &AsyncAppContext) {} } struct SshRemoteConnection { socket: SshSocket, - master_process: process::Child, + master_process: Mutex>, + platform: SshPlatform, _temp_dir: TempDir, } -impl Drop for SshRemoteConnection { - fn drop(&mut self) { - if let Err(error) = self.master_process.kill() { - log::error!("failed to kill SSH master process: {}", error); - } +#[async_trait(?Send)] +impl RemoteConnection for SshRemoteConnection { + async fn kill(&self) -> Result<()> { + let Some(mut process) = self.master_process.lock().take() else { + return Ok(()); + }; + process.kill().ok(); + process.status().await?; + Ok(()) } -} -#[async_trait] -impl SshRemoteProcess for SshRemoteConnection { - async fn kill(&mut self) -> Result<()> { - self.master_process.kill()?; - - self.master_process.status().await?; - - Ok(()) + fn has_been_killed(&self) -> bool { + self.master_process.lock().is_none() } fn ssh_args(&self) -> Vec { @@ -1209,6 +1182,70 @@ impl SshRemoteProcess for SshRemoteConnection { fn connection_options(&self) -> SshConnectionOptions { self.socket.connection_options.clone() } + + async fn get_remote_binary_path( + &self, + delegate: &Arc, + reconnect: bool, + cx: &mut AsyncAppContext, + ) -> Result { + let platform = self.platform; + let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?; + if !reconnect { + self.ensure_server_binary(&delegate, &remote_binary_path, platform, cx) + .await?; + } + + let socket = self.socket.clone(); + run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?; + Ok(remote_binary_path) + } + + fn start_proxy( + &self, + remote_binary_path: PathBuf, + unique_identifier: String, + reconnect: bool, + incoming_tx: UnboundedSender, + outgoing_rx: UnboundedReceiver, + connection_activity_tx: Sender<()>, + delegate: Arc, + cx: &mut AsyncAppContext, + ) -> Task> { + delegate.set_status(Some("Starting proxy"), cx); + + let mut start_proxy_command = format!( + "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}", + std::env::var("RUST_LOG").unwrap_or_default(), + std::env::var("RUST_BACKTRACE").unwrap_or_default(), + remote_binary_path, + unique_identifier, + ); + if reconnect { + start_proxy_command.push_str(" --reconnect"); + } + + let ssh_proxy_process = match self + .socket + .ssh_command(start_proxy_command) + // IMPORTANT: we kill this process when we drop the task that uses it. + .kill_on_drop(true) + .spawn() + { + Ok(process) => process, + Err(error) => { + return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error))) + } + }; + + Self::multiplex( + ssh_proxy_process, + incoming_tx, + outgoing_rx, + connection_activity_tx, + &cx, + ) + } } impl SshRemoteConnection { @@ -1305,6 +1342,7 @@ impl SshRemoteConnection { ]) .arg(format!("ControlPath={}", socket_path.display())) .arg(&url) + .kill_on_drop(true) .spawn()?; // Wait for this ssh process to close its stdout, indicating that authentication @@ -1348,16 +1386,139 @@ impl SshRemoteConnection { Err(anyhow!(error_message))?; } + let socket = SshSocket { + connection_options, + socket_path, + }; + + let os = run_cmd(socket.ssh_command("uname").arg("-s")).await?; + let arch = run_cmd(socket.ssh_command("uname").arg("-m")).await?; + + let os = match os.trim() { + "Darwin" => "macos", + "Linux" => "linux", + _ => Err(anyhow!("unknown uname os {os:?}"))?, + }; + let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") { + "aarch64" + } else if arch.starts_with("x86") || arch.starts_with("i686") { + "x86_64" + } else { + Err(anyhow!("unknown uname architecture {arch:?}"))? + }; + + let platform = SshPlatform { os, arch }; + Ok(Self { - socket: SshSocket { - connection_options, - socket_path, - }, - master_process, + socket, + master_process: Mutex::new(Some(master_process)), + platform, _temp_dir: temp_dir, }) } + fn multiplex( + mut ssh_proxy_process: Child, + incoming_tx: UnboundedSender, + mut outgoing_rx: UnboundedReceiver, + mut connection_activity_tx: Sender<()>, + cx: &AsyncAppContext, + ) -> Task> { + let mut child_stderr = ssh_proxy_process.stderr.take().unwrap(); + let mut child_stdout = ssh_proxy_process.stdout.take().unwrap(); + let mut child_stdin = ssh_proxy_process.stdin.take().unwrap(); + + let mut stdin_buffer = Vec::new(); + let mut stdout_buffer = Vec::new(); + let mut stderr_buffer = Vec::new(); + let mut stderr_offset = 0; + + let stdin_task = cx.background_executor().spawn(async move { + while let Some(outgoing) = outgoing_rx.next().await { + write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?; + } + anyhow::Ok(()) + }); + + let stdout_task = cx.background_executor().spawn({ + let mut connection_activity_tx = connection_activity_tx.clone(); + async move { + loop { + stdout_buffer.resize(MESSAGE_LEN_SIZE, 0); + let len = child_stdout.read(&mut stdout_buffer).await?; + + if len == 0 { + return anyhow::Ok(()); + } + + if len < MESSAGE_LEN_SIZE { + child_stdout.read_exact(&mut stdout_buffer[len..]).await?; + } + + let message_len = message_len_from_buffer(&stdout_buffer); + let envelope = + read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len) + .await?; + connection_activity_tx.try_send(()).ok(); + incoming_tx.unbounded_send(envelope).ok(); + } + } + }); + + let stderr_task: Task> = cx.background_executor().spawn(async move { + loop { + stderr_buffer.resize(stderr_offset + 1024, 0); + + let len = child_stderr + .read(&mut stderr_buffer[stderr_offset..]) + .await?; + if len == 0 { + return anyhow::Ok(()); + } + + stderr_offset += len; + let mut start_ix = 0; + while let Some(ix) = stderr_buffer[start_ix..stderr_offset] + .iter() + .position(|b| b == &b'\n') + { + let line_ix = start_ix + ix; + let content = &stderr_buffer[start_ix..line_ix]; + start_ix = line_ix + 1; + if let Ok(record) = serde_json::from_slice::(content) { + record.log(log::logger()) + } else { + eprintln!("(remote) {}", String::from_utf8_lossy(content)); + } + } + stderr_buffer.drain(0..start_ix); + stderr_offset -= start_ix; + + connection_activity_tx.try_send(()).ok(); + } + }); + + cx.spawn(|_| async move { + let result = futures::select! { + result = stdin_task.fuse() => { + result.context("stdin") + } + result = stdout_task.fuse() => { + result.context("stdout") + } + result = stderr_task.fuse() => { + result.context("stderr") + } + }; + + let status = ssh_proxy_process.status().await?.code().unwrap_or(1); + match result { + Ok(_) => Ok(status), + Err(error) => Err(error), + } + }) + } + async fn ensure_server_binary( &self, delegate: &Arc, @@ -1621,26 +1782,6 @@ impl SshRemoteConnection { Ok(()) } - async fn query_platform(&self) -> Result { - let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?; - let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?; - - let os = match os.trim() { - "Darwin" => "macos", - "Linux" => "linux", - _ => Err(anyhow!("unknown uname os {os:?}"))?, - }; - let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") { - "aarch64" - } else if arch.starts_with("x86") || arch.starts_with("i686") { - "x86_64" - } else { - Err(anyhow!("unknown uname architecture {arch:?}"))? - }; - - Ok(SshPlatform { os, arch }) - } - async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> { let mut command = process::Command::new("scp"); let output = self @@ -1974,50 +2115,86 @@ mod fake { }, select_biased, FutureExt, SinkExt, StreamExt, }; - use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task}; + use gpui::{AsyncAppContext, SemanticVersion, Task}; use rpc::proto::Envelope; use super::{ - ChannelClient, ServerBinary, SshClientDelegate, SshConnectionOptions, SshPlatform, - SshRemoteProcess, + ChannelClient, RemoteConnection, ServerBinary, SshClientDelegate, SshConnectionOptions, + SshPlatform, }; - pub(super) struct SshRemoteConnection { - connection_options: SshConnectionOptions, + pub(super) struct FakeRemoteConnection { + pub(super) connection_options: SshConnectionOptions, + pub(super) server_channel: Arc, + pub(super) server_cx: SendableCx, } - impl SshRemoteConnection { - pub(super) fn new( - connection_options: &SshConnectionOptions, - ) -> Option> { - if connection_options.host == "" { - return Some(Box::new(Self { - connection_options: connection_options.clone(), - })); - } - return None; + pub(super) struct SendableCx(AsyncAppContext); + // safety: you can only get the other cx on the main thread. + impl SendableCx { + pub(super) fn new(cx: AsyncAppContext) -> Self { + Self(cx) + } + fn get(&self, _: &AsyncAppContext) -> AsyncAppContext { + self.0.clone() + } + } + unsafe impl Send for SendableCx {} + unsafe impl Sync for SendableCx {} + + #[async_trait(?Send)] + impl RemoteConnection for FakeRemoteConnection { + async fn kill(&self) -> Result<()> { + Ok(()) + } + + fn has_been_killed(&self) -> bool { + false + } + + fn ssh_args(&self) -> Vec { + Vec::new() + } + + fn connection_options(&self) -> SshConnectionOptions { + self.connection_options.clone() + } + + fn simulate_disconnect(&self, cx: &AsyncAppContext) { + let (outgoing_tx, _) = mpsc::unbounded::(); + let (_, incoming_rx) = mpsc::unbounded::(); + self.server_channel + .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx)); + } + + async fn get_remote_binary_path( + &self, + _delegate: &Arc, + _reconnect: bool, + _cx: &mut AsyncAppContext, + ) -> Result { + Ok(PathBuf::new()) } - pub(super) async fn multiplex( - connection_options: SshConnectionOptions, + + fn start_proxy( + &self, + _remote_binary_path: PathBuf, + _unique_identifier: String, + _reconnect: bool, mut client_incoming_tx: mpsc::UnboundedSender, mut client_outgoing_rx: mpsc::UnboundedReceiver, mut connection_activity_tx: Sender<()>, + _delegate: Arc, cx: &mut AsyncAppContext, ) -> Task> { let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::(); let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::(); - let (channel, server_cx) = cx - .update(|cx| { - cx.update_global(|conns: &mut ServerConnections, _| { - conns.get(connection_options.port.unwrap()) - }) - }) - .unwrap(); - channel.reconnect(server_incoming_rx, server_outgoing_tx, &server_cx); - - // send to proxy_tx to get to the server. - // receive from + self.server_channel.reconnect( + server_incoming_rx, + server_outgoing_tx, + &self.server_cx.get(cx), + ); cx.background_executor().spawn(async move { loop { @@ -2041,39 +2218,6 @@ mod fake { } } - #[async_trait] - impl SshRemoteProcess for SshRemoteConnection { - async fn kill(&mut self) -> Result<()> { - Ok(()) - } - - fn ssh_args(&self) -> Vec { - Vec::new() - } - - fn connection_options(&self) -> SshConnectionOptions { - self.connection_options.clone() - } - } - - #[derive(Default)] - pub(super) struct ServerConnections(Vec<(Arc, AsyncAppContext)>); - impl Global for ServerConnections {} - - impl ServerConnections { - pub(super) fn push(&mut self, server: Arc, cx: AsyncAppContext) -> u16 { - self.0.push((server.clone(), cx)); - self.0.len() as u16 - 1 - } - - pub(super) fn get(&mut self, port: u16) -> (Arc, AsyncAppContext) { - self.0 - .get(port as usize) - .expect("no fake server for port") - .clone() - } - } - pub(super) struct Delegate; impl SshClientDelegate for Delegate { @@ -2099,8 +2243,6 @@ mod fake { unreachable!() } - fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) { - unreachable!() - } + fn set_status(&self, _: Option<&str>, _: &mut AsyncAppContext) {} } } diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 32333def7fb7f10abce8902797c40ec70eba9506..f7420ef5b091b70c8036a77650eb81208edfa2d3 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -702,7 +702,7 @@ async fn init_test( ) -> (Model, Model, Arc) { init_logger(); - let (forwarder, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx); + let (opts, ssh_server_client) = SshRemoteClient::fake_server(cx, server_cx); let fs = FakeFs::new(server_cx.executor()); fs.insert_tree( "/code", @@ -744,7 +744,7 @@ async fn init_test( ) }); - let ssh = SshRemoteClient::fake_client(forwarder, cx).await; + let ssh = SshRemoteClient::fake_client(opts, cx).await; let project = build_project(ssh, cx); project .update(cx, {