@@ -13,17 +13,18 @@ use futures::{
mpsc::{self, Sender, UnboundedReceiver, UnboundedSender},
oneshot,
},
- future::BoxFuture,
+ future::{BoxFuture, Shared},
select, select_biased, AsyncReadExt as _, Future, FutureExt as _, StreamExt as _,
};
use gpui::{
- AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, SemanticVersion, Task,
- WeakModel,
+ AppContext, AsyncAppContext, BorrowAppContext, Context, EventEmitter, Global, Model,
+ ModelContext, SemanticVersion, Task, WeakModel,
};
use parking_lot::Mutex;
use rpc::{
proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
- AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
+ AnyProtoClient, EntityMessageSubscriber, ErrorExt, ProtoClient, ProtoMessageHandlerSet,
+ RpcError,
};
use smol::{
fs,
@@ -56,7 +57,7 @@ pub struct SshSocket {
socket_path: PathBuf,
}
-#[derive(Debug, Default, Clone, PartialEq, Eq)]
+#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
pub struct SshConnectionOptions {
pub host: String,
pub username: Option<String>,
@@ -290,7 +291,7 @@ const MAX_RECONNECT_ATTEMPTS: usize = 3;
enum State {
Connecting,
Connected {
- ssh_connection: Box<dyn SshRemoteProcess>,
+ ssh_connection: Arc<dyn RemoteConnection>,
delegate: Arc<dyn SshClientDelegate>,
multiplex_task: Task<Result<()>>,
@@ -299,7 +300,7 @@ enum State {
HeartbeatMissed {
missed_heartbeats: usize,
- ssh_connection: Box<dyn SshRemoteProcess>,
+ ssh_connection: Arc<dyn RemoteConnection>,
delegate: Arc<dyn SshClientDelegate>,
multiplex_task: Task<Result<()>>,
@@ -307,7 +308,7 @@ enum State {
},
Reconnecting,
ReconnectFailed {
- ssh_connection: Box<dyn SshRemoteProcess>,
+ ssh_connection: Arc<dyn RemoteConnection>,
delegate: Arc<dyn SshClientDelegate>,
error: anyhow::Error,
@@ -332,7 +333,7 @@ impl fmt::Display for State {
}
impl State {
- fn ssh_connection(&self) -> Option<&dyn SshRemoteProcess> {
+ fn ssh_connection(&self) -> Option<&dyn RemoteConnection> {
match self {
Self::Connected { ssh_connection, .. } => Some(ssh_connection.as_ref()),
Self::HeartbeatMissed { ssh_connection, .. } => Some(ssh_connection.as_ref()),
@@ -462,7 +463,7 @@ impl SshRemoteClient {
connection_options: SshConnectionOptions,
cancellation: oneshot::Receiver<()>,
delegate: Arc<dyn SshClientDelegate>,
- cx: &AppContext,
+ cx: &mut AppContext,
) -> Task<Result<Option<Model<Self>>>> {
cx.spawn(|mut cx| async move {
let success = Box::pin(async move {
@@ -479,17 +480,28 @@ impl SshRemoteClient {
state: Arc::new(Mutex::new(Some(State::Connecting))),
})?;
- let (ssh_connection, io_task) = Self::establish_connection(
+ let ssh_connection = cx
+ .update(|cx| {
+ cx.update_default_global(|pool: &mut ConnectionPool, cx| {
+ pool.connect(connection_options, &delegate, cx)
+ })
+ })?
+ .await
+ .map_err(|e| e.cloned())?;
+ let remote_binary_path = ssh_connection
+ .get_remote_binary_path(&delegate, false, &mut cx)
+ .await?;
+
+ let io_task = ssh_connection.start_proxy(
+ remote_binary_path,
unique_identifier,
false,
- connection_options,
incoming_tx,
outgoing_rx,
connection_activity_tx,
delegate.clone(),
&mut cx,
- )
- .await?;
+ );
let multiplex_task = Self::monitor(this.downgrade(), io_task, &cx);
@@ -578,7 +590,7 @@ impl SshRemoteClient {
}
let state = lock.take().unwrap();
- let (attempts, mut ssh_connection, delegate) = match state {
+ let (attempts, ssh_connection, delegate) = match state {
State::Connected {
ssh_connection,
delegate,
@@ -624,7 +636,7 @@ impl SshRemoteClient {
log::info!("Trying to reconnect to ssh server... Attempt {}", attempts);
- let identifier = self.unique_identifier.clone();
+ let unique_identifier = self.unique_identifier.clone();
let client = self.client.clone();
let reconnect_task = cx.spawn(|this, mut cx| async move {
macro_rules! failed {
@@ -652,19 +664,33 @@ impl SshRemoteClient {
let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
let (connection_activity_tx, connection_activity_rx) = mpsc::channel::<()>(1);
- let (ssh_connection, io_task) = match Self::establish_connection(
- identifier,
- true,
- connection_options,
- incoming_tx,
- outgoing_rx,
- connection_activity_tx,
- delegate.clone(),
- &mut cx,
- )
+ let (ssh_connection, io_task) = match async {
+ let ssh_connection = cx
+ .update_global(|pool: &mut ConnectionPool, cx| {
+ pool.connect(connection_options, &delegate, cx)
+ })?
+ .await
+ .map_err(|error| error.cloned())?;
+
+ let remote_binary_path = ssh_connection
+ .get_remote_binary_path(&delegate, true, &mut cx)
+ .await?;
+
+ let io_task = ssh_connection.start_proxy(
+ remote_binary_path,
+ unique_identifier,
+ true,
+ incoming_tx,
+ outgoing_rx,
+ connection_activity_tx,
+ delegate.clone(),
+ &mut cx,
+ );
+ anyhow::Ok((ssh_connection, io_task))
+ }
.await
{
- Ok((ssh_connection, ssh_process)) => (ssh_connection, ssh_process),
+ Ok((ssh_connection, io_task)) => (ssh_connection, io_task),
Err(error) => {
failed!(error, attempts, ssh_connection, delegate);
}
@@ -834,108 +860,6 @@ impl SshRemoteClient {
}
}
- fn multiplex(
- mut ssh_proxy_process: Child,
- incoming_tx: UnboundedSender<Envelope>,
- mut outgoing_rx: UnboundedReceiver<Envelope>,
- mut connection_activity_tx: Sender<()>,
- cx: &AsyncAppContext,
- ) -> Task<Result<i32>> {
- let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
- let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
- let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
-
- let mut stdin_buffer = Vec::new();
- let mut stdout_buffer = Vec::new();
- let mut stderr_buffer = Vec::new();
- let mut stderr_offset = 0;
-
- let stdin_task = cx.background_executor().spawn(async move {
- while let Some(outgoing) = outgoing_rx.next().await {
- write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
- }
- anyhow::Ok(())
- });
-
- let stdout_task = cx.background_executor().spawn({
- let mut connection_activity_tx = connection_activity_tx.clone();
- async move {
- loop {
- stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
- let len = child_stdout.read(&mut stdout_buffer).await?;
-
- if len == 0 {
- return anyhow::Ok(());
- }
-
- if len < MESSAGE_LEN_SIZE {
- child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
- }
-
- let message_len = message_len_from_buffer(&stdout_buffer);
- let envelope =
- read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
- .await?;
- connection_activity_tx.try_send(()).ok();
- incoming_tx.unbounded_send(envelope).ok();
- }
- }
- });
-
- let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
- loop {
- stderr_buffer.resize(stderr_offset + 1024, 0);
-
- let len = child_stderr
- .read(&mut stderr_buffer[stderr_offset..])
- .await?;
- if len == 0 {
- return anyhow::Ok(());
- }
-
- stderr_offset += len;
- let mut start_ix = 0;
- while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
- .iter()
- .position(|b| b == &b'\n')
- {
- let line_ix = start_ix + ix;
- let content = &stderr_buffer[start_ix..line_ix];
- start_ix = line_ix + 1;
- if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
- record.log(log::logger())
- } else {
- eprintln!("(remote) {}", String::from_utf8_lossy(content));
- }
- }
- stderr_buffer.drain(0..start_ix);
- stderr_offset -= start_ix;
-
- connection_activity_tx.try_send(()).ok();
- }
- });
-
- cx.spawn(|_| async move {
- let result = futures::select! {
- result = stdin_task.fuse() => {
- result.context("stdin")
- }
- result = stdout_task.fuse() => {
- result.context("stdout")
- }
- result = stderr_task.fuse() => {
- result.context("stderr")
- }
- };
-
- let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
- match result {
- Ok(_) => Ok(status),
- Err(error) => Err(error),
- }
- })
- }
-
fn monitor(
this: WeakModel<Self>,
io_task: Task<Result<i32>>,
@@ -1005,75 +929,6 @@ impl SshRemoteClient {
cx.notify();
}
- #[allow(clippy::too_many_arguments)]
- async fn establish_connection(
- unique_identifier: String,
- reconnect: bool,
- connection_options: SshConnectionOptions,
- incoming_tx: UnboundedSender<Envelope>,
- outgoing_rx: UnboundedReceiver<Envelope>,
- connection_activity_tx: Sender<()>,
- delegate: Arc<dyn SshClientDelegate>,
- cx: &mut AsyncAppContext,
- ) -> Result<(Box<dyn SshRemoteProcess>, Task<Result<i32>>)> {
- #[cfg(any(test, feature = "test-support"))]
- if let Some(fake) = fake::SshRemoteConnection::new(&connection_options) {
- let io_task = fake::SshRemoteConnection::multiplex(
- fake.connection_options(),
- incoming_tx,
- outgoing_rx,
- connection_activity_tx,
- cx,
- )
- .await;
- return Ok((fake, io_task));
- }
-
- let ssh_connection =
- SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
-
- let platform = ssh_connection.query_platform().await?;
- let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
- if !reconnect {
- ssh_connection
- .ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
- .await?;
- }
-
- let socket = ssh_connection.socket.clone();
- run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
-
- delegate.set_status(Some("Starting proxy"), cx);
-
- let mut start_proxy_command = format!(
- "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
- std::env::var("RUST_LOG").unwrap_or_default(),
- std::env::var("RUST_BACKTRACE").unwrap_or_default(),
- remote_binary_path,
- unique_identifier,
- );
- if reconnect {
- start_proxy_command.push_str(" --reconnect");
- }
-
- let ssh_proxy_process = socket
- .ssh_command(start_proxy_command)
- // IMPORTANT: we kill this process when we drop the task that uses it.
- .kill_on_drop(true)
- .spawn()
- .context("failed to spawn remote server")?;
-
- let io_task = Self::multiplex(
- ssh_proxy_process,
- incoming_tx,
- outgoing_rx,
- connection_activity_tx,
- &cx,
- );
-
- Ok((Box::new(ssh_connection), io_task))
- }
-
pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
self.client.subscribe_to_entity(remote_id, entity);
}
@@ -1112,15 +967,21 @@ impl SshRemoteClient {
#[cfg(any(test, feature = "test-support"))]
pub fn simulate_disconnect(&self, client_cx: &mut AppContext) -> Task<()> {
- let port = self.connection_options().port.unwrap();
+ let opts = self.connection_options();
client_cx.spawn(|cx| async move {
- let (channel, server_cx) = cx
- .update_global(|c: &mut fake::ServerConnections, _| c.get(port))
+ let connection = cx
+ .update_global(|c: &mut ConnectionPool, _| {
+ if let Some(ConnectionPoolEntry::Connecting(c)) = c.connections.get(&opts) {
+ c.clone()
+ } else {
+ panic!("missing test connection")
+ }
+ })
+ .unwrap()
+ .await
.unwrap();
- let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
- let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
- channel.reconnect(incoming_rx, outgoing_tx, &server_cx);
+ connection.simulate_disconnect(&cx);
})
}
@@ -1128,78 +989,190 @@ impl SshRemoteClient {
pub fn fake_server(
client_cx: &mut gpui::TestAppContext,
server_cx: &mut gpui::TestAppContext,
- ) -> (u16, Arc<ChannelClient>) {
- use gpui::BorrowAppContext;
+ ) -> (SshConnectionOptions, Arc<ChannelClient>) {
+ let port = client_cx
+ .update(|cx| cx.default_global::<ConnectionPool>().connections.len() as u16 + 1);
+ let opts = SshConnectionOptions {
+ host: "<fake>".to_string(),
+ port: Some(port),
+ ..Default::default()
+ };
let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
let server_client =
server_cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx, "fake-server"));
- let port = client_cx.update(|cx| {
- cx.update_default_global(|c: &mut fake::ServerConnections, _| {
- c.push(server_client.clone(), server_cx.to_async())
+ let connection: Arc<dyn RemoteConnection> = Arc::new(fake::FakeRemoteConnection {
+ connection_options: opts.clone(),
+ server_cx: fake::SendableCx::new(server_cx.to_async()),
+ server_channel: server_client.clone(),
+ });
+
+ client_cx.update(|cx| {
+ cx.update_default_global(|c: &mut ConnectionPool, cx| {
+ c.connections.insert(
+ opts.clone(),
+ ConnectionPoolEntry::Connecting(
+ cx.foreground_executor()
+ .spawn({
+ let connection = connection.clone();
+ async move { Ok(connection.clone()) }
+ })
+ .shared(),
+ ),
+ );
})
});
- (port, server_client)
+
+ (opts, server_client)
}
#[cfg(any(test, feature = "test-support"))]
- pub async fn fake_client(port: u16, client_cx: &mut gpui::TestAppContext) -> Model<Self> {
+ pub async fn fake_client(
+ opts: SshConnectionOptions,
+ client_cx: &mut gpui::TestAppContext,
+ ) -> Model<Self> {
let (_tx, rx) = oneshot::channel();
client_cx
- .update(|cx| {
- Self::new(
- "fake".to_string(),
- SshConnectionOptions {
- host: "<fake>".to_string(),
- port: Some(port),
- ..Default::default()
- },
- rx,
- Arc::new(fake::Delegate),
- cx,
- )
- })
+ .update(|cx| Self::new("fake".to_string(), opts, rx, Arc::new(fake::Delegate), cx))
.await
.unwrap()
.unwrap()
}
}
+enum ConnectionPoolEntry {
+ Connecting(Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>>),
+ Connected(Weak<dyn RemoteConnection>),
+}
+
+#[derive(Default)]
+struct ConnectionPool {
+ connections: HashMap<SshConnectionOptions, ConnectionPoolEntry>,
+}
+
+impl Global for ConnectionPool {}
+
+impl ConnectionPool {
+ pub fn connect(
+ &mut self,
+ opts: SshConnectionOptions,
+ delegate: &Arc<dyn SshClientDelegate>,
+ cx: &mut AppContext,
+ ) -> Shared<Task<Result<Arc<dyn RemoteConnection>, Arc<anyhow::Error>>>> {
+ let connection = self.connections.get(&opts);
+ match connection {
+ Some(ConnectionPoolEntry::Connecting(task)) => {
+ let delegate = delegate.clone();
+ cx.spawn(|mut cx| async move {
+ delegate.set_status(Some("Waiting for existing connection attempt"), &mut cx);
+ })
+ .detach();
+ return task.clone();
+ }
+ Some(ConnectionPoolEntry::Connected(ssh)) => {
+ if let Some(ssh) = ssh.upgrade() {
+ if !ssh.has_been_killed() {
+ return Task::ready(Ok(ssh)).shared();
+ }
+ }
+ self.connections.remove(&opts);
+ }
+ None => {}
+ }
+
+ let task = cx
+ .spawn({
+ let opts = opts.clone();
+ let delegate = delegate.clone();
+ |mut cx| async move {
+ let connection = SshRemoteConnection::new(opts.clone(), delegate, &mut cx)
+ .await
+ .map(|connection| Arc::new(connection) as Arc<dyn RemoteConnection>);
+
+ cx.update_global(|pool: &mut Self, _| {
+ debug_assert!(matches!(
+ pool.connections.get(&opts),
+ Some(ConnectionPoolEntry::Connecting(_))
+ ));
+ match connection {
+ Ok(connection) => {
+ pool.connections.insert(
+ opts.clone(),
+ ConnectionPoolEntry::Connected(Arc::downgrade(&connection)),
+ );
+ Ok(connection)
+ }
+ Err(error) => {
+ pool.connections.remove(&opts);
+ Err(Arc::new(error))
+ }
+ }
+ })?
+ }
+ })
+ .shared();
+
+ self.connections
+ .insert(opts.clone(), ConnectionPoolEntry::Connecting(task.clone()));
+ task
+ }
+}
+
impl From<SshRemoteClient> for AnyProtoClient {
fn from(client: SshRemoteClient) -> Self {
AnyProtoClient::new(client.client.clone())
}
}
-#[async_trait]
-trait SshRemoteProcess: Send + Sync {
- async fn kill(&mut self) -> Result<()>;
+#[async_trait(?Send)]
+trait RemoteConnection: Send + Sync {
+ #[allow(clippy::too_many_arguments)]
+ fn start_proxy(
+ &self,
+ remote_binary_path: PathBuf,
+ unique_identifier: String,
+ reconnect: bool,
+ incoming_tx: UnboundedSender<Envelope>,
+ outgoing_rx: UnboundedReceiver<Envelope>,
+ connection_activity_tx: Sender<()>,
+ delegate: Arc<dyn SshClientDelegate>,
+ cx: &mut AsyncAppContext,
+ ) -> Task<Result<i32>>;
+ async fn get_remote_binary_path(
+ &self,
+ delegate: &Arc<dyn SshClientDelegate>,
+ reconnect: bool,
+ cx: &mut AsyncAppContext,
+ ) -> Result<PathBuf>;
+ async fn kill(&self) -> Result<()>;
+ fn has_been_killed(&self) -> bool;
fn ssh_args(&self) -> Vec<String>;
fn connection_options(&self) -> SshConnectionOptions;
+
+ #[cfg(any(test, feature = "test-support"))]
+ fn simulate_disconnect(&self, _: &AsyncAppContext) {}
}
struct SshRemoteConnection {
socket: SshSocket,
- master_process: process::Child,
+ master_process: Mutex<Option<process::Child>>,
+ platform: SshPlatform,
_temp_dir: TempDir,
}
-impl Drop for SshRemoteConnection {
- fn drop(&mut self) {
- if let Err(error) = self.master_process.kill() {
- log::error!("failed to kill SSH master process: {}", error);
- }
+#[async_trait(?Send)]
+impl RemoteConnection for SshRemoteConnection {
+ async fn kill(&self) -> Result<()> {
+ let Some(mut process) = self.master_process.lock().take() else {
+ return Ok(());
+ };
+ process.kill().ok();
+ process.status().await?;
+ Ok(())
}
-}
-#[async_trait]
-impl SshRemoteProcess for SshRemoteConnection {
- async fn kill(&mut self) -> Result<()> {
- self.master_process.kill()?;
-
- self.master_process.status().await?;
-
- Ok(())
+ fn has_been_killed(&self) -> bool {
+ self.master_process.lock().is_none()
}
fn ssh_args(&self) -> Vec<String> {
@@ -1209,6 +1182,70 @@ impl SshRemoteProcess for SshRemoteConnection {
fn connection_options(&self) -> SshConnectionOptions {
self.socket.connection_options.clone()
}
+
+ async fn get_remote_binary_path(
+ &self,
+ delegate: &Arc<dyn SshClientDelegate>,
+ reconnect: bool,
+ cx: &mut AsyncAppContext,
+ ) -> Result<PathBuf> {
+ let platform = self.platform;
+ let remote_binary_path = delegate.remote_server_binary_path(platform, cx)?;
+ if !reconnect {
+ self.ensure_server_binary(&delegate, &remote_binary_path, platform, cx)
+ .await?;
+ }
+
+ let socket = self.socket.clone();
+ run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
+ Ok(remote_binary_path)
+ }
+
+ fn start_proxy(
+ &self,
+ remote_binary_path: PathBuf,
+ unique_identifier: String,
+ reconnect: bool,
+ incoming_tx: UnboundedSender<Envelope>,
+ outgoing_rx: UnboundedReceiver<Envelope>,
+ connection_activity_tx: Sender<()>,
+ delegate: Arc<dyn SshClientDelegate>,
+ cx: &mut AsyncAppContext,
+ ) -> Task<Result<i32>> {
+ delegate.set_status(Some("Starting proxy"), cx);
+
+ let mut start_proxy_command = format!(
+ "RUST_LOG={} RUST_BACKTRACE={} {:?} proxy --identifier {}",
+ std::env::var("RUST_LOG").unwrap_or_default(),
+ std::env::var("RUST_BACKTRACE").unwrap_or_default(),
+ remote_binary_path,
+ unique_identifier,
+ );
+ if reconnect {
+ start_proxy_command.push_str(" --reconnect");
+ }
+
+ let ssh_proxy_process = match self
+ .socket
+ .ssh_command(start_proxy_command)
+ // IMPORTANT: we kill this process when we drop the task that uses it.
+ .kill_on_drop(true)
+ .spawn()
+ {
+ Ok(process) => process,
+ Err(error) => {
+ return Task::ready(Err(anyhow!("failed to spawn remote server: {}", error)))
+ }
+ };
+
+ Self::multiplex(
+ ssh_proxy_process,
+ incoming_tx,
+ outgoing_rx,
+ connection_activity_tx,
+ &cx,
+ )
+ }
}
impl SshRemoteConnection {
@@ -1305,6 +1342,7 @@ impl SshRemoteConnection {
])
.arg(format!("ControlPath={}", socket_path.display()))
.arg(&url)
+ .kill_on_drop(true)
.spawn()?;
// Wait for this ssh process to close its stdout, indicating that authentication
@@ -1348,16 +1386,139 @@ impl SshRemoteConnection {
Err(anyhow!(error_message))?;
}
+ let socket = SshSocket {
+ connection_options,
+ socket_path,
+ };
+
+ let os = run_cmd(socket.ssh_command("uname").arg("-s")).await?;
+ let arch = run_cmd(socket.ssh_command("uname").arg("-m")).await?;
+
+ let os = match os.trim() {
+ "Darwin" => "macos",
+ "Linux" => "linux",
+ _ => Err(anyhow!("unknown uname os {os:?}"))?,
+ };
+ let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
+ "aarch64"
+ } else if arch.starts_with("x86") || arch.starts_with("i686") {
+ "x86_64"
+ } else {
+ Err(anyhow!("unknown uname architecture {arch:?}"))?
+ };
+
+ let platform = SshPlatform { os, arch };
+
Ok(Self {
- socket: SshSocket {
- connection_options,
- socket_path,
- },
- master_process,
+ socket,
+ master_process: Mutex::new(Some(master_process)),
+ platform,
_temp_dir: temp_dir,
})
}
+ fn multiplex(
+ mut ssh_proxy_process: Child,
+ incoming_tx: UnboundedSender<Envelope>,
+ mut outgoing_rx: UnboundedReceiver<Envelope>,
+ mut connection_activity_tx: Sender<()>,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<i32>> {
+ let mut child_stderr = ssh_proxy_process.stderr.take().unwrap();
+ let mut child_stdout = ssh_proxy_process.stdout.take().unwrap();
+ let mut child_stdin = ssh_proxy_process.stdin.take().unwrap();
+
+ let mut stdin_buffer = Vec::new();
+ let mut stdout_buffer = Vec::new();
+ let mut stderr_buffer = Vec::new();
+ let mut stderr_offset = 0;
+
+ let stdin_task = cx.background_executor().spawn(async move {
+ while let Some(outgoing) = outgoing_rx.next().await {
+ write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
+ }
+ anyhow::Ok(())
+ });
+
+ let stdout_task = cx.background_executor().spawn({
+ let mut connection_activity_tx = connection_activity_tx.clone();
+ async move {
+ loop {
+ stdout_buffer.resize(MESSAGE_LEN_SIZE, 0);
+ let len = child_stdout.read(&mut stdout_buffer).await?;
+
+ if len == 0 {
+ return anyhow::Ok(());
+ }
+
+ if len < MESSAGE_LEN_SIZE {
+ child_stdout.read_exact(&mut stdout_buffer[len..]).await?;
+ }
+
+ let message_len = message_len_from_buffer(&stdout_buffer);
+ let envelope =
+ read_message_with_len(&mut child_stdout, &mut stdout_buffer, message_len)
+ .await?;
+ connection_activity_tx.try_send(()).ok();
+ incoming_tx.unbounded_send(envelope).ok();
+ }
+ }
+ });
+
+ let stderr_task: Task<anyhow::Result<()>> = cx.background_executor().spawn(async move {
+ loop {
+ stderr_buffer.resize(stderr_offset + 1024, 0);
+
+ let len = child_stderr
+ .read(&mut stderr_buffer[stderr_offset..])
+ .await?;
+ if len == 0 {
+ return anyhow::Ok(());
+ }
+
+ stderr_offset += len;
+ let mut start_ix = 0;
+ while let Some(ix) = stderr_buffer[start_ix..stderr_offset]
+ .iter()
+ .position(|b| b == &b'\n')
+ {
+ let line_ix = start_ix + ix;
+ let content = &stderr_buffer[start_ix..line_ix];
+ start_ix = line_ix + 1;
+ if let Ok(record) = serde_json::from_slice::<LogRecord>(content) {
+ record.log(log::logger())
+ } else {
+ eprintln!("(remote) {}", String::from_utf8_lossy(content));
+ }
+ }
+ stderr_buffer.drain(0..start_ix);
+ stderr_offset -= start_ix;
+
+ connection_activity_tx.try_send(()).ok();
+ }
+ });
+
+ cx.spawn(|_| async move {
+ let result = futures::select! {
+ result = stdin_task.fuse() => {
+ result.context("stdin")
+ }
+ result = stdout_task.fuse() => {
+ result.context("stdout")
+ }
+ result = stderr_task.fuse() => {
+ result.context("stderr")
+ }
+ };
+
+ let status = ssh_proxy_process.status().await?.code().unwrap_or(1);
+ match result {
+ Ok(_) => Ok(status),
+ Err(error) => Err(error),
+ }
+ })
+ }
+
async fn ensure_server_binary(
&self,
delegate: &Arc<dyn SshClientDelegate>,
@@ -1621,26 +1782,6 @@ impl SshRemoteConnection {
Ok(())
}
- async fn query_platform(&self) -> Result<SshPlatform> {
- let os = run_cmd(self.socket.ssh_command("uname").arg("-s")).await?;
- let arch = run_cmd(self.socket.ssh_command("uname").arg("-m")).await?;
-
- let os = match os.trim() {
- "Darwin" => "macos",
- "Linux" => "linux",
- _ => Err(anyhow!("unknown uname os {os:?}"))?,
- };
- let arch = if arch.starts_with("arm") || arch.starts_with("aarch64") {
- "aarch64"
- } else if arch.starts_with("x86") || arch.starts_with("i686") {
- "x86_64"
- } else {
- Err(anyhow!("unknown uname architecture {arch:?}"))?
- };
-
- Ok(SshPlatform { os, arch })
- }
-
async fn upload_file(&self, src_path: &Path, dest_path: &Path) -> Result<()> {
let mut command = process::Command::new("scp");
let output = self
@@ -1974,50 +2115,86 @@ mod fake {
},
select_biased, FutureExt, SinkExt, StreamExt,
};
- use gpui::{AsyncAppContext, BorrowAppContext, Global, SemanticVersion, Task};
+ use gpui::{AsyncAppContext, SemanticVersion, Task};
use rpc::proto::Envelope;
use super::{
- ChannelClient, ServerBinary, SshClientDelegate, SshConnectionOptions, SshPlatform,
- SshRemoteProcess,
+ ChannelClient, RemoteConnection, ServerBinary, SshClientDelegate, SshConnectionOptions,
+ SshPlatform,
};
- pub(super) struct SshRemoteConnection {
- connection_options: SshConnectionOptions,
+ pub(super) struct FakeRemoteConnection {
+ pub(super) connection_options: SshConnectionOptions,
+ pub(super) server_channel: Arc<ChannelClient>,
+ pub(super) server_cx: SendableCx,
}
- impl SshRemoteConnection {
- pub(super) fn new(
- connection_options: &SshConnectionOptions,
- ) -> Option<Box<dyn SshRemoteProcess>> {
- if connection_options.host == "<fake>" {
- return Some(Box::new(Self {
- connection_options: connection_options.clone(),
- }));
- }
- return None;
+ pub(super) struct SendableCx(AsyncAppContext);
+ // safety: you can only get the other cx on the main thread.
+ impl SendableCx {
+ pub(super) fn new(cx: AsyncAppContext) -> Self {
+ Self(cx)
+ }
+ fn get(&self, _: &AsyncAppContext) -> AsyncAppContext {
+ self.0.clone()
+ }
+ }
+ unsafe impl Send for SendableCx {}
+ unsafe impl Sync for SendableCx {}
+
+ #[async_trait(?Send)]
+ impl RemoteConnection for FakeRemoteConnection {
+ async fn kill(&self) -> Result<()> {
+ Ok(())
+ }
+
+ fn has_been_killed(&self) -> bool {
+ false
+ }
+
+ fn ssh_args(&self) -> Vec<String> {
+ Vec::new()
+ }
+
+ fn connection_options(&self) -> SshConnectionOptions {
+ self.connection_options.clone()
+ }
+
+ fn simulate_disconnect(&self, cx: &AsyncAppContext) {
+ let (outgoing_tx, _) = mpsc::unbounded::<Envelope>();
+ let (_, incoming_rx) = mpsc::unbounded::<Envelope>();
+ self.server_channel
+ .reconnect(incoming_rx, outgoing_tx, &self.server_cx.get(&cx));
+ }
+
+ async fn get_remote_binary_path(
+ &self,
+ _delegate: &Arc<dyn SshClientDelegate>,
+ _reconnect: bool,
+ _cx: &mut AsyncAppContext,
+ ) -> Result<PathBuf> {
+ Ok(PathBuf::new())
}
- pub(super) async fn multiplex(
- connection_options: SshConnectionOptions,
+
+ fn start_proxy(
+ &self,
+ _remote_binary_path: PathBuf,
+ _unique_identifier: String,
+ _reconnect: bool,
mut client_incoming_tx: mpsc::UnboundedSender<Envelope>,
mut client_outgoing_rx: mpsc::UnboundedReceiver<Envelope>,
mut connection_activity_tx: Sender<()>,
+ _delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
) -> Task<Result<i32>> {
let (mut server_incoming_tx, server_incoming_rx) = mpsc::unbounded::<Envelope>();
let (server_outgoing_tx, mut server_outgoing_rx) = mpsc::unbounded::<Envelope>();
- let (channel, server_cx) = cx
- .update(|cx| {
- cx.update_global(|conns: &mut ServerConnections, _| {
- conns.get(connection_options.port.unwrap())
- })
- })
- .unwrap();
- channel.reconnect(server_incoming_rx, server_outgoing_tx, &server_cx);
-
- // send to proxy_tx to get to the server.
- // receive from
+ self.server_channel.reconnect(
+ server_incoming_rx,
+ server_outgoing_tx,
+ &self.server_cx.get(cx),
+ );
cx.background_executor().spawn(async move {
loop {
@@ -2041,39 +2218,6 @@ mod fake {
}
}
- #[async_trait]
- impl SshRemoteProcess for SshRemoteConnection {
- async fn kill(&mut self) -> Result<()> {
- Ok(())
- }
-
- fn ssh_args(&self) -> Vec<String> {
- Vec::new()
- }
-
- fn connection_options(&self) -> SshConnectionOptions {
- self.connection_options.clone()
- }
- }
-
- #[derive(Default)]
- pub(super) struct ServerConnections(Vec<(Arc<ChannelClient>, AsyncAppContext)>);
- impl Global for ServerConnections {}
-
- impl ServerConnections {
- pub(super) fn push(&mut self, server: Arc<ChannelClient>, cx: AsyncAppContext) -> u16 {
- self.0.push((server.clone(), cx));
- self.0.len() as u16 - 1
- }
-
- pub(super) fn get(&mut self, port: u16) -> (Arc<ChannelClient>, AsyncAppContext) {
- self.0
- .get(port as usize)
- .expect("no fake server for port")
- .clone()
- }
- }
-
pub(super) struct Delegate;
impl SshClientDelegate for Delegate {