@@ -54,7 +54,7 @@ use parking_lot::{Mutex, RwLock};
use paths::{local_tasks_file_relative_path, local_vscode_tasks_file_relative_path};
pub use prettier_store::PrettierStore;
use project_settings::{ProjectSettings, SettingsObserver, SettingsObserverEvent};
-use remote::SshSession;
+use remote::SshRemoteClient;
use rpc::{proto::SSH_PROJECT_ID, AnyProtoClient, ErrorCode};
use search::{SearchInputKind, SearchQuery, SearchResult};
use search_history::SearchHistory;
@@ -138,7 +138,7 @@ pub struct Project {
join_project_response_message_id: u32,
user_store: Model<UserStore>,
fs: Arc<dyn Fs>,
- ssh_session: Option<Arc<SshSession>>,
+ ssh_client: Option<Arc<SshRemoteClient>>,
client_state: ProjectClientState,
collaborators: HashMap<proto::PeerId, Collaborator>,
client_subscriptions: Vec<client::Subscription>,
@@ -643,7 +643,7 @@ impl Project {
user_store,
settings_observer,
fs,
- ssh_session: None,
+ ssh_client: None,
buffers_needing_diff: Default::default(),
git_diff_debouncer: DebouncedDelay::new(),
terminals: Terminals {
@@ -664,7 +664,7 @@ impl Project {
}
pub fn ssh(
- ssh: Arc<SshSession>,
+ ssh: Arc<SshRemoteClient>,
client: Arc<Client>,
node: NodeRuntime,
user_store: Model<UserStore>,
@@ -682,14 +682,14 @@ impl Project {
SnippetProvider::new(fs.clone(), BTreeSet::from_iter([global_snippets_dir]), cx);
let worktree_store =
- cx.new_model(|_| WorktreeStore::remote(false, ssh.clone().into(), 0, None));
+ cx.new_model(|_| WorktreeStore::remote(false, ssh.to_proto_client(), 0, None));
cx.subscribe(&worktree_store, Self::on_worktree_store_event)
.detach();
let buffer_store = cx.new_model(|cx| {
BufferStore::remote(
worktree_store.clone(),
- ssh.clone().into(),
+ ssh.to_proto_client(),
SSH_PROJECT_ID,
cx,
)
@@ -698,7 +698,7 @@ impl Project {
.detach();
let settings_observer = cx.new_model(|cx| {
- SettingsObserver::new_ssh(ssh.clone().into(), worktree_store.clone(), cx)
+ SettingsObserver::new_ssh(ssh.to_proto_client(), worktree_store.clone(), cx)
});
cx.subscribe(&settings_observer, Self::on_settings_observer_event)
.detach();
@@ -709,7 +709,7 @@ impl Project {
buffer_store.clone(),
worktree_store.clone(),
languages.clone(),
- ssh.clone().into(),
+ ssh.to_proto_client(),
SSH_PROJECT_ID,
cx,
)
@@ -733,7 +733,7 @@ impl Project {
user_store,
settings_observer,
fs,
- ssh_session: Some(ssh.clone()),
+ ssh_client: Some(ssh.clone()),
buffers_needing_diff: Default::default(),
git_diff_debouncer: DebouncedDelay::new(),
terminals: Terminals {
@@ -751,7 +751,7 @@ impl Project {
search_excluded_history: Self::new_search_history(),
};
- let client: AnyProtoClient = ssh.clone().into();
+ let client: AnyProtoClient = ssh.to_proto_client();
ssh.subscribe_to_entity(SSH_PROJECT_ID, &cx.handle());
ssh.subscribe_to_entity(SSH_PROJECT_ID, &this.buffer_store);
@@ -907,7 +907,7 @@ impl Project {
user_store: user_store.clone(),
snippets,
fs,
- ssh_session: None,
+ ssh_client: None,
settings_observer: settings_observer.clone(),
client_subscriptions: Default::default(),
_subscriptions: vec![cx.on_release(Self::release)],
@@ -1230,7 +1230,7 @@ impl Project {
match self.client_state {
ProjectClientState::Remote { replica_id, .. } => replica_id,
_ => {
- if self.ssh_session.is_some() {
+ if self.ssh_client.is_some() {
1
} else {
0
@@ -1638,7 +1638,7 @@ impl Project {
pub fn is_local(&self) -> bool {
match &self.client_state {
ProjectClientState::Local | ProjectClientState::Shared { .. } => {
- self.ssh_session.is_none()
+ self.ssh_client.is_none()
}
ProjectClientState::Remote { .. } => false,
}
@@ -1647,7 +1647,7 @@ impl Project {
pub fn is_via_ssh(&self) -> bool {
match &self.client_state {
ProjectClientState::Local | ProjectClientState::Shared { .. } => {
- self.ssh_session.is_some()
+ self.ssh_client.is_some()
}
ProjectClientState::Remote { .. } => false,
}
@@ -1933,8 +1933,9 @@ impl Project {
}
BufferStoreEvent::BufferChangedFilePath { .. } => {}
BufferStoreEvent::BufferDropped(buffer_id) => {
- if let Some(ref ssh_session) = self.ssh_session {
- ssh_session
+ if let Some(ref ssh_client) = self.ssh_client {
+ ssh_client
+ .to_proto_client()
.send(proto::CloseBuffer {
project_id: 0,
buffer_id: buffer_id.to_proto(),
@@ -2139,13 +2140,14 @@ impl Project {
} => {
let operation = language::proto::serialize_operation(operation);
- if let Some(ssh) = &self.ssh_session {
- ssh.send(proto::UpdateBuffer {
- project_id: 0,
- buffer_id: buffer_id.to_proto(),
- operations: vec![operation.clone()],
- })
- .ok();
+ if let Some(ssh) = &self.ssh_client {
+ ssh.to_proto_client()
+ .send(proto::UpdateBuffer {
+ project_id: 0,
+ buffer_id: buffer_id.to_proto(),
+ operations: vec![operation.clone()],
+ })
+ .ok();
}
self.enqueue_buffer_ordered_message(BufferOrderedMessage::Operation {
@@ -2825,14 +2827,13 @@ impl Project {
) -> Receiver<Model<Buffer>> {
let (tx, rx) = smol::channel::unbounded();
- let (client, remote_id): (AnyProtoClient, _) =
- if let Some(ssh_session) = self.ssh_session.clone() {
- (ssh_session.into(), 0)
- } else if let Some(remote_id) = self.remote_id() {
- (self.client.clone().into(), remote_id)
- } else {
- return rx;
- };
+ let (client, remote_id): (AnyProtoClient, _) = if let Some(ssh_client) = &self.ssh_client {
+ (ssh_client.to_proto_client(), 0)
+ } else if let Some(remote_id) = self.remote_id() {
+ (self.client.clone().into(), remote_id)
+ } else {
+ return rx;
+ };
let request = client.request(proto::FindSearchCandidates {
project_id: remote_id,
@@ -2961,11 +2962,13 @@ impl Project {
exists.then(|| ResolvedPath::AbsPath(expanded))
})
- } else if let Some(ssh_session) = self.ssh_session.as_ref() {
- let request = ssh_session.request(proto::CheckFileExists {
- project_id: SSH_PROJECT_ID,
- path: path.to_string(),
- });
+ } else if let Some(ssh_client) = self.ssh_client.as_ref() {
+ let request = ssh_client
+ .to_proto_client()
+ .request(proto::CheckFileExists {
+ project_id: SSH_PROJECT_ID,
+ path: path.to_string(),
+ });
cx.background_executor().spawn(async move {
let response = request.await.log_err()?;
if response.exists {
@@ -3035,13 +3038,13 @@ impl Project {
) -> Task<Result<Vec<PathBuf>>> {
if self.is_local() {
DirectoryLister::Local(self.fs.clone()).list_directory(query, cx)
- } else if let Some(session) = self.ssh_session.as_ref() {
+ } else if let Some(session) = self.ssh_client.as_ref() {
let request = proto::ListRemoteDirectory {
dev_server_id: SSH_PROJECT_ID,
path: query,
};
- let response = session.request(request);
+ let response = session.to_proto_client().request(request);
cx.background_executor().spawn(async move {
let response = response.await?;
Ok(response.entries.into_iter().map(PathBuf::from).collect())
@@ -3465,11 +3468,11 @@ impl Project {
cx: AsyncAppContext,
) -> Result<proto::Ack> {
let buffer_store = this.read_with(&cx, |this, cx| {
- if let Some(ssh) = &this.ssh_session {
+ if let Some(ssh) = &this.ssh_client {
let mut payload = envelope.payload.clone();
payload.project_id = 0;
cx.background_executor()
- .spawn(ssh.request(payload))
+ .spawn(ssh.to_proto_client().request(payload))
.detach_and_log_err(cx);
}
this.buffer_store.clone()
@@ -7,19 +7,23 @@ use crate::{
use anyhow::{anyhow, Context as _, Result};
use collections::HashMap;
use futures::{
- channel::{mpsc, oneshot},
+ channel::{
+ mpsc::{self, UnboundedReceiver, UnboundedSender},
+ oneshot,
+ },
future::BoxFuture,
- select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, StreamExt as _,
+ select_biased, AsyncReadExt as _, AsyncWriteExt as _, Future, FutureExt as _, SinkExt,
+ StreamExt as _,
};
use gpui::{AppContext, AsyncAppContext, Model, SemanticVersion, Task};
use parking_lot::Mutex;
use rpc::{
proto::{self, build_typed_envelope, Envelope, EnvelopedMessage, PeerId, RequestMessage},
- EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
+ AnyProtoClient, EntityMessageSubscriber, ProtoClient, ProtoMessageHandlerSet, RpcError,
};
use smol::{
fs,
- process::{self, Stdio},
+ process::{self, Child, Stdio},
};
use std::{
any::TypeId,
@@ -44,22 +48,6 @@ pub struct SshSocket {
socket_path: PathBuf,
}
-pub struct SshSession {
- next_message_id: AtomicU32,
- response_channels: ResponseChannels, // Lock
- outgoing_tx: mpsc::UnboundedSender<Envelope>,
- spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
- client_socket: Option<SshSocket>,
- state: Mutex<ProtoMessageHandlerSet>, // Lock
- _io_task: Option<Task<Result<()>>>,
-}
-
-struct SshClientState {
- socket: SshSocket,
- master_process: process::Child,
- _temp_dir: TempDir,
-}
-
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SshConnectionOptions {
pub host: String,
@@ -105,18 +93,13 @@ impl SshConnectionOptions {
}
}
-struct SpawnRequest {
- command: String,
- process_tx: oneshot::Sender<process::Child>,
-}
-
#[derive(Copy, Clone, Debug)]
pub struct SshPlatform {
pub os: &'static str,
pub arch: &'static str,
}
-pub trait SshClientDelegate {
+pub trait SshClientDelegate: Send + Sync {
fn ask_password(
&self,
prompt: String,
@@ -132,48 +115,249 @@ pub trait SshClientDelegate {
fn set_error(&self, error_message: String, cx: &mut AsyncAppContext);
}
-type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
+impl SshSocket {
+ fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
+ let mut command = process::Command::new("ssh");
+ self.ssh_options(&mut command)
+ .arg(self.connection_options.ssh_url())
+ .arg(program);
+ command
+ }
+
+ fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
+ command
+ .stdin(Stdio::piped())
+ .stdout(Stdio::piped())
+ .stderr(Stdio::piped())
+ .args(["-o", "ControlMaster=no", "-o"])
+ .arg(format!("ControlPath={}", self.socket_path.display()))
+ }
+
+ fn ssh_args(&self) -> Vec<String> {
+ vec![
+ "-o".to_string(),
+ "ControlMaster=no".to_string(),
+ "-o".to_string(),
+ format!("ControlPath={}", self.socket_path.display()),
+ self.connection_options.ssh_url(),
+ ]
+ }
+}
-impl SshSession {
- pub async fn client(
+async fn run_cmd(command: &mut process::Command) -> Result<String> {
+ let output = command.output().await?;
+ if output.status.success() {
+ Ok(String::from_utf8_lossy(&output.stdout).to_string())
+ } else {
+ Err(anyhow!(
+ "failed to run command: {}",
+ String::from_utf8_lossy(&output.stderr)
+ ))
+ }
+}
+#[cfg(unix)]
+async fn read_with_timeout(
+ stdout: &mut process::ChildStdout,
+ timeout: std::time::Duration,
+ output: &mut Vec<u8>,
+) -> Result<(), std::io::Error> {
+ smol::future::or(
+ async {
+ stdout.read_to_end(output).await?;
+ Ok::<_, std::io::Error>(())
+ },
+ async {
+ smol::Timer::after(timeout).await;
+
+ Err(std::io::Error::new(
+ std::io::ErrorKind::TimedOut,
+ "Read operation timed out",
+ ))
+ },
+ )
+ .await
+}
+
+struct ChannelForwarder {
+ quit_tx: UnboundedSender<()>,
+ forwarding_task: Task<(UnboundedSender<Envelope>, UnboundedReceiver<Envelope>)>,
+}
+
+impl ChannelForwarder {
+ fn new(
+ mut incoming_tx: UnboundedSender<Envelope>,
+ mut outgoing_rx: UnboundedReceiver<Envelope>,
+ cx: &mut AsyncAppContext,
+ ) -> (Self, UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
+ let (quit_tx, mut quit_rx) = mpsc::unbounded::<()>();
+
+ let (proxy_incoming_tx, mut proxy_incoming_rx) = mpsc::unbounded::<Envelope>();
+ let (mut proxy_outgoing_tx, proxy_outgoing_rx) = mpsc::unbounded::<Envelope>();
+
+ let forwarding_task = cx.background_executor().spawn(async move {
+ loop {
+ select_biased! {
+ _ = quit_rx.next().fuse() => {
+ break;
+ },
+ incoming_envelope = proxy_incoming_rx.next().fuse() => {
+ if let Some(envelope) = incoming_envelope {
+ if incoming_tx.send(envelope).await.is_err() {
+ break;
+ }
+ } else {
+ break;
+ }
+ }
+ outgoing_envelope = outgoing_rx.next().fuse() => {
+ if let Some(envelope) = outgoing_envelope {
+ if proxy_outgoing_tx.send(envelope).await.is_err() {
+ break;
+ }
+ } else {
+ break;
+ }
+ }
+ }
+ }
+
+ (incoming_tx, outgoing_rx)
+ });
+
+ (
+ Self {
+ forwarding_task,
+ quit_tx,
+ },
+ proxy_incoming_tx,
+ proxy_outgoing_rx,
+ )
+ }
+
+ async fn into_channels(mut self) -> (UnboundedSender<Envelope>, UnboundedReceiver<Envelope>) {
+ let _ = self.quit_tx.send(()).await;
+ self.forwarding_task.await
+ }
+}
+
+struct SshRemoteClientState {
+ ssh_connection: SshRemoteConnection,
+ delegate: Arc<dyn SshClientDelegate>,
+ forwarder: ChannelForwarder,
+ _multiplex_task: Task<Result<()>>,
+}
+
+pub struct SshRemoteClient {
+ client: Arc<ChannelClient>,
+ inner_state: Arc<Mutex<Option<SshRemoteClientState>>>,
+}
+
+impl SshRemoteClient {
+ pub async fn new(
connection_options: SshConnectionOptions,
delegate: Arc<dyn SshClientDelegate>,
cx: &mut AsyncAppContext,
) -> Result<Arc<Self>> {
- let client_state = SshClientState::new(connection_options, delegate.clone(), cx).await?;
+ let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<Envelope>();
+ let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
- let platform = client_state.query_platform().await?;
- let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
- let remote_binary_path = delegate.remote_server_binary_path(cx)?;
- client_state
- .ensure_server_binary(
- &delegate,
- &local_binary_path,
- &remote_binary_path,
- version,
+ let client = cx.update(|cx| ChannelClient::new(incoming_rx, outgoing_tx, cx))?;
+ let this = Arc::new(Self {
+ client,
+ inner_state: Arc::new(Mutex::new(None)),
+ });
+
+ let inner_state = {
+ let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
+ ChannelForwarder::new(incoming_tx, outgoing_rx, cx);
+
+ let (ssh_connection, ssh_process) =
+ Self::establish_connection(connection_options.clone(), delegate.clone(), cx)
+ .await?;
+
+ let multiplex_task = Self::multiplex(
+ this.clone(),
+ ssh_process,
+ proxy_incoming_tx,
+ proxy_outgoing_rx,
cx,
- )
- .await?;
+ );
- let (spawn_process_tx, mut spawn_process_rx) = mpsc::unbounded::<SpawnRequest>();
- let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded::<Envelope>();
- let (incoming_tx, incoming_rx) = mpsc::unbounded::<Envelope>();
+ SshRemoteClientState {
+ ssh_connection,
+ delegate,
+ forwarder: proxy,
+ _multiplex_task: multiplex_task,
+ }
+ };
- let socket = client_state.socket.clone();
- run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
+ this.inner_state.lock().replace(inner_state);
- let mut remote_server_child = socket
- .ssh_command(format!(
- "RUST_LOG={} RUST_BACKTRACE={} {:?} run",
- std::env::var("RUST_LOG").unwrap_or_default(),
- std::env::var("RUST_BACKTRACE").unwrap_or_default(),
- remote_binary_path,
- ))
- .spawn()
- .context("failed to spawn remote server")?;
- let mut child_stderr = remote_server_child.stderr.take().unwrap();
- let mut child_stdout = remote_server_child.stdout.take().unwrap();
- let mut child_stdin = remote_server_child.stdin.take().unwrap();
+ Ok(this)
+ }
+
+ fn reconnect(this: Arc<Self>, cx: &mut AsyncAppContext) -> Result<()> {
+ let Some(state) = this.inner_state.lock().take() else {
+ return Err(anyhow!("reconnect is already in progress"));
+ };
+
+ let SshRemoteClientState {
+ mut ssh_connection,
+ delegate,
+ forwarder: proxy,
+ _multiplex_task,
+ } = state;
+ drop(_multiplex_task);
+
+ cx.spawn(|mut cx| async move {
+ let (incoming_tx, outgoing_rx) = proxy.into_channels().await;
+
+ ssh_connection.master_process.kill()?;
+ ssh_connection
+ .master_process
+ .status()
+ .await
+ .context("Failed to kill ssh process")?;
+
+ let connection_options = ssh_connection.socket.connection_options.clone();
+
+ let (ssh_connection, ssh_process) =
+ Self::establish_connection(connection_options, delegate.clone(), &mut cx).await?;
+
+ let (proxy, proxy_incoming_tx, proxy_outgoing_rx) =
+ ChannelForwarder::new(incoming_tx, outgoing_rx, &mut cx);
+
+ let inner_state = SshRemoteClientState {
+ ssh_connection,
+ delegate,
+ forwarder: proxy,
+ _multiplex_task: Self::multiplex(
+ this.clone(),
+ ssh_process,
+ proxy_incoming_tx,
+ proxy_outgoing_rx,
+ &mut cx,
+ ),
+ };
+ this.inner_state.lock().replace(inner_state);
+
+ anyhow::Ok(())
+ })
+ .detach();
+
+ anyhow::Ok(())
+ }
+
+ fn multiplex(
+ this: Arc<Self>,
+ mut ssh_process: Child,
+ incoming_tx: UnboundedSender<Envelope>,
+ mut outgoing_rx: UnboundedReceiver<Envelope>,
+ cx: &mut AsyncAppContext,
+ ) -> Task<Result<()>> {
+ let mut child_stderr = ssh_process.stderr.take().unwrap();
+ let mut child_stdout = ssh_process.stdout.take().unwrap();
+ let mut child_stdin = ssh_process.stdin.take().unwrap();
let io_task = cx.background_executor().spawn(async move {
let mut stdin_buffer = Vec::new();
@@ -194,27 +378,15 @@ impl SshSession {
write_message(&mut child_stdin, &mut stdin_buffer, outgoing).await?;
}
- request = spawn_process_rx.next().fuse() => {
- let Some(request) = request else {
- return Ok(());
- };
-
- log::info!("spawn process: {:?}", request.command);
- let child = client_state.socket
- .ssh_command(&request.command)
- .spawn()
- .context("failed to create channel")?;
- request.process_tx.send(child).ok();
- }
-
result = child_stdout.read(&mut stdout_buffer).fuse() => {
match result {
Ok(0) => {
child_stdin.close().await?;
outgoing_rx.close();
- let status = remote_server_child.status().await?;
+ let status = ssh_process.status().await?;
if !status.success() {
- log::error!("channel exited with status: {status:?}");
+ log::error!("ssh process exited with status: {status:?}");
+ return Err(anyhow!("ssh process exited with non-zero status code: {:?}", status.code()));
}
return Ok(());
}
@@ -267,239 +439,112 @@ impl SshSession {
}
});
- cx.update(|cx| {
- Self::new(
- incoming_rx,
- outgoing_tx,
- spawn_process_tx,
- Some(socket),
- Some(io_task),
- cx,
- )
- })
- }
+ cx.spawn(|mut cx| async move {
+ let result = io_task.await;
- pub fn server(
- incoming_rx: mpsc::UnboundedReceiver<Envelope>,
- outgoing_tx: mpsc::UnboundedSender<Envelope>,
- cx: &AppContext,
- ) -> Arc<SshSession> {
- let (tx, _rx) = mpsc::unbounded();
- Self::new(incoming_rx, outgoing_tx, tx, None, None, cx)
- }
-
- #[cfg(any(test, feature = "test-support"))]
- pub fn fake(
- client_cx: &mut gpui::TestAppContext,
- server_cx: &mut gpui::TestAppContext,
- ) -> (Arc<Self>, Arc<Self>) {
- let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
- let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
- let (tx, _rx) = mpsc::unbounded();
- (
- client_cx.update(|cx| {
- Self::new(
- server_to_client_rx,
- client_to_server_tx,
- tx.clone(),
- None, // todo()
- None,
- cx,
- )
- }),
- server_cx.update(|cx| {
- Self::new(
- client_to_server_rx,
- server_to_client_tx,
- tx.clone(),
- None,
- None,
- cx,
- )
- }),
- )
- }
-
- fn new(
- mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
- outgoing_tx: mpsc::UnboundedSender<Envelope>,
- spawn_process_tx: mpsc::UnboundedSender<SpawnRequest>,
- client_socket: Option<SshSocket>,
- io_task: Option<Task<Result<()>>>,
- cx: &AppContext,
- ) -> Arc<SshSession> {
- let this = Arc::new(Self {
- next_message_id: AtomicU32::new(0),
- response_channels: ResponseChannels::default(),
- outgoing_tx,
- spawn_process_tx,
- client_socket,
- state: Default::default(),
- _io_task: io_task,
- });
-
- cx.spawn(|cx| {
- let this = Arc::downgrade(&this);
- async move {
- let peer_id = PeerId { owner_id: 0, id: 0 };
- while let Some(incoming) = incoming_rx.next().await {
- let Some(this) = this.upgrade() else {
- return anyhow::Ok(());
- };
-
- if let Some(request_id) = incoming.responding_to {
- let request_id = MessageId(request_id);
- let sender = this.response_channels.lock().remove(&request_id);
- if let Some(sender) = sender {
- let (tx, rx) = oneshot::channel();
- if incoming.payload.is_some() {
- sender.send((incoming, tx)).ok();
- }
- rx.await.ok();
- }
- } else if let Some(envelope) =
- build_typed_envelope(peer_id, Instant::now(), incoming)
- {
- let type_name = envelope.payload_type_name();
- if let Some(future) = ProtoMessageHandlerSet::handle_message(
- &this.state,
- envelope,
- this.clone().into(),
- cx.clone(),
- ) {
- log::debug!("ssh message received. name:{type_name}");
- match future.await {
- Ok(_) => {
- log::debug!("ssh message handled. name:{type_name}");
- }
- Err(error) => {
- log::error!(
- "error handling message. type:{type_name}, error:{error}",
- );
- }
- }
- } else {
- log::error!("unhandled ssh message name:{type_name}");
- }
- }
- }
- anyhow::Ok(())
+ if let Err(error) = result {
+ log::warn!("ssh io task died with error: {:?}. reconnecting...", error);
+ Self::reconnect(this, &mut cx).ok();
}
- })
- .detach();
-
- this
- }
- pub fn request<T: RequestMessage>(
- &self,
- payload: T,
- ) -> impl 'static + Future<Output = Result<T::Response>> {
- log::debug!("ssh request start. name:{}", T::NAME);
- let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
- async move {
- let response = response.await?;
- log::debug!("ssh request finish. name:{}", T::NAME);
- T::Response::from_envelope(response)
- .ok_or_else(|| anyhow!("received a response of the wrong type"))
- }
- }
-
- pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
- log::debug!("ssh send name:{}", T::NAME);
- self.send_dynamic(payload.into_envelope(0, None, None))
+ Ok(())
+ })
}
- pub fn request_dynamic(
- &self,
- mut envelope: proto::Envelope,
- type_name: &'static str,
- ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
- envelope.id = self.next_message_id.fetch_add(1, SeqCst);
- let (tx, rx) = oneshot::channel();
- let mut response_channels_lock = self.response_channels.lock();
- response_channels_lock.insert(MessageId(envelope.id), tx);
- drop(response_channels_lock);
- let result = self.outgoing_tx.unbounded_send(envelope);
- async move {
- if let Err(error) = &result {
- log::error!("failed to send message: {}", error);
- return Err(anyhow!("failed to send message: {}", error));
- }
-
- let response = rx.await.context("connection lost")?.0;
- if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
- return Err(RpcError::from_proto(error, type_name));
- }
- Ok(response)
- }
- }
+ async fn establish_connection(
+ connection_options: SshConnectionOptions,
+ delegate: Arc<dyn SshClientDelegate>,
+ cx: &mut AsyncAppContext,
+ ) -> Result<(SshRemoteConnection, Child)> {
+ let ssh_connection =
+ SshRemoteConnection::new(connection_options, delegate.clone(), cx).await?;
- pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
- envelope.id = self.next_message_id.fetch_add(1, SeqCst);
- self.outgoing_tx.unbounded_send(envelope)?;
- Ok(())
- }
+ let platform = ssh_connection.query_platform().await?;
+ let (local_binary_path, version) = delegate.get_server_binary(platform, cx).await??;
+ let remote_binary_path = delegate.remote_server_binary_path(cx)?;
+ ssh_connection
+ .ensure_server_binary(
+ &delegate,
+ &local_binary_path,
+ &remote_binary_path,
+ version,
+ cx,
+ )
+ .await?;
- pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
- let id = (TypeId::of::<E>(), remote_id);
+ let socket = ssh_connection.socket.clone();
+ run_cmd(socket.ssh_command(&remote_binary_path).arg("version")).await?;
- let mut state = self.state.lock();
- if state.entities_by_type_and_remote_id.contains_key(&id) {
- panic!("already subscribed to entity");
- }
+ let ssh_process = socket
+ .ssh_command(format!(
+ "RUST_LOG={} RUST_BACKTRACE={} {:?} run",
+ std::env::var("RUST_LOG").unwrap_or_default(),
+ std::env::var("RUST_BACKTRACE").unwrap_or_default(),
+ remote_binary_path,
+ ))
+ .spawn()
+ .context("failed to spawn remote server")?;
- state.entities_by_type_and_remote_id.insert(
- id,
- EntityMessageSubscriber::Entity {
- handle: entity.downgrade().into(),
- },
- );
+ Ok((ssh_connection, ssh_process))
}
- pub async fn spawn_process(&self, command: String) -> process::Child {
- let (process_tx, process_rx) = oneshot::channel();
- self.spawn_process_tx
- .unbounded_send(SpawnRequest {
- command,
- process_tx,
- })
- .ok();
- process_rx.await.unwrap()
+ pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
+ self.client.subscribe_to_entity(remote_id, entity);
}
- pub fn ssh_args(&self) -> Vec<String> {
- self.client_socket.as_ref().unwrap().ssh_args()
+ pub fn ssh_args(&self) -> Option<Vec<String>> {
+ let state = self.inner_state.lock();
+ state
+ .as_ref()
+ .map(|state| state.ssh_connection.socket.ssh_args())
}
-}
-impl ProtoClient for SshSession {
- fn request(
- &self,
- envelope: proto::Envelope,
- request_type: &'static str,
- ) -> BoxFuture<'static, Result<proto::Envelope>> {
- self.request_dynamic(envelope, request_type).boxed()
+ pub fn to_proto_client(&self) -> AnyProtoClient {
+ self.client.clone().into()
}
- fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
- self.send_dynamic(envelope)
- }
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn fake(
+ client_cx: &mut gpui::TestAppContext,
+ server_cx: &mut gpui::TestAppContext,
+ ) -> (Arc<Self>, Arc<ChannelClient>) {
+ let (server_to_client_tx, server_to_client_rx) = mpsc::unbounded();
+ let (client_to_server_tx, client_to_server_rx) = mpsc::unbounded();
- fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
- self.send_dynamic(envelope)
+ (
+ client_cx.update(|cx| {
+ let client = ChannelClient::new(server_to_client_rx, client_to_server_tx, cx);
+ Arc::new(Self {
+ client,
+ inner_state: Arc::new(Mutex::new(None)),
+ })
+ }),
+ server_cx.update(|cx| ChannelClient::new(client_to_server_rx, server_to_client_tx, cx)),
+ )
}
+}
- fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
- &self.state
+impl From<SshRemoteClient> for AnyProtoClient {
+ fn from(client: SshRemoteClient) -> Self {
+ AnyProtoClient::new(client.client.clone())
}
+}
- fn is_via_collab(&self) -> bool {
- false
+struct SshRemoteConnection {
+ socket: SshSocket,
+ master_process: process::Child,
+ _temp_dir: TempDir,
+}
+
+impl Drop for SshRemoteConnection {
+ fn drop(&mut self) {
+ if let Err(error) = self.master_process.kill() {
+ log::error!("failed to kill SSH master process: {}", error);
+ }
}
}
-impl SshClientState {
+impl SshRemoteConnection {
#[cfg(not(unix))]
async fn new(
_connection_options: SshConnectionOptions,
@@ -740,74 +785,181 @@ impl SshClientState {
}
}
-#[cfg(unix)]
-async fn read_with_timeout(
- stdout: &mut process::ChildStdout,
- timeout: std::time::Duration,
- output: &mut Vec<u8>,
-) -> Result<(), std::io::Error> {
- smol::future::or(
- async {
- stdout.read_to_end(output).await?;
- Ok::<_, std::io::Error>(())
- },
- async {
- smol::Timer::after(timeout).await;
+type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
- Err(std::io::Error::new(
- std::io::ErrorKind::TimedOut,
- "Read operation timed out",
- ))
- },
- )
- .await
+pub struct ChannelClient {
+ next_message_id: AtomicU32,
+ outgoing_tx: mpsc::UnboundedSender<Envelope>,
+ response_channels: ResponseChannels, // Lock
+ message_handlers: Mutex<ProtoMessageHandlerSet>, // Lock
}
-impl Drop for SshClientState {
- fn drop(&mut self) {
- if let Err(error) = self.master_process.kill() {
- log::error!("failed to kill SSH master process: {}", error);
+impl ChannelClient {
+ pub fn new(
+ incoming_rx: mpsc::UnboundedReceiver<Envelope>,
+ outgoing_tx: mpsc::UnboundedSender<Envelope>,
+ cx: &AppContext,
+ ) -> Arc<Self> {
+ let this = Arc::new(Self {
+ outgoing_tx,
+ next_message_id: AtomicU32::new(0),
+ response_channels: ResponseChannels::default(),
+ message_handlers: Default::default(),
+ });
+
+ Self::start_handling_messages(this.clone(), incoming_rx, cx);
+
+ this
+ }
+
+ fn start_handling_messages(
+ this: Arc<Self>,
+ mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
+ cx: &AppContext,
+ ) {
+ cx.spawn(|cx| {
+ let this = Arc::downgrade(&this);
+ async move {
+ let peer_id = PeerId { owner_id: 0, id: 0 };
+ while let Some(incoming) = incoming_rx.next().await {
+ let Some(this) = this.upgrade() else {
+ return anyhow::Ok(());
+ };
+
+ if let Some(request_id) = incoming.responding_to {
+ let request_id = MessageId(request_id);
+ let sender = this.response_channels.lock().remove(&request_id);
+ if let Some(sender) = sender {
+ let (tx, rx) = oneshot::channel();
+ if incoming.payload.is_some() {
+ sender.send((incoming, tx)).ok();
+ }
+ rx.await.ok();
+ }
+ } else if let Some(envelope) =
+ build_typed_envelope(peer_id, Instant::now(), incoming)
+ {
+ let type_name = envelope.payload_type_name();
+ if let Some(future) = ProtoMessageHandlerSet::handle_message(
+ &this.message_handlers,
+ envelope,
+ this.clone().into(),
+ cx.clone(),
+ ) {
+ log::debug!("ssh message received. name:{type_name}");
+ match future.await {
+ Ok(_) => {
+ log::debug!("ssh message handled. name:{type_name}");
+ }
+ Err(error) => {
+ log::error!(
+ "error handling message. type:{type_name}, error:{error}",
+ );
+ }
+ }
+ } else {
+ log::error!("unhandled ssh message name:{type_name}");
+ }
+ }
+ }
+ anyhow::Ok(())
+ }
+ })
+ .detach();
+ }
+
+ pub fn subscribe_to_entity<E: 'static>(&self, remote_id: u64, entity: &Model<E>) {
+ let id = (TypeId::of::<E>(), remote_id);
+
+ let mut message_handlers = self.message_handlers.lock();
+ if message_handlers
+ .entities_by_type_and_remote_id
+ .contains_key(&id)
+ {
+ panic!("already subscribed to entity");
}
+
+ message_handlers.entities_by_type_and_remote_id.insert(
+ id,
+ EntityMessageSubscriber::Entity {
+ handle: entity.downgrade().into(),
+ },
+ );
}
-}
-impl SshSocket {
- fn ssh_command<S: AsRef<OsStr>>(&self, program: S) -> process::Command {
- let mut command = process::Command::new("ssh");
- self.ssh_options(&mut command)
- .arg(self.connection_options.ssh_url())
- .arg(program);
- command
+ pub fn request<T: RequestMessage>(
+ &self,
+ payload: T,
+ ) -> impl 'static + Future<Output = Result<T::Response>> {
+ log::debug!("ssh request start. name:{}", T::NAME);
+ let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
+ async move {
+ let response = response.await?;
+ log::debug!("ssh request finish. name:{}", T::NAME);
+ T::Response::from_envelope(response)
+ .ok_or_else(|| anyhow!("received a response of the wrong type"))
+ }
}
- fn ssh_options<'a>(&self, command: &'a mut process::Command) -> &'a mut process::Command {
- command
- .stdin(Stdio::piped())
- .stdout(Stdio::piped())
- .stderr(Stdio::piped())
- .args(["-o", "ControlMaster=no", "-o"])
- .arg(format!("ControlPath={}", self.socket_path.display()))
+ pub fn send<T: EnvelopedMessage>(&self, payload: T) -> Result<()> {
+ log::debug!("ssh send name:{}", T::NAME);
+ self.send_dynamic(payload.into_envelope(0, None, None))
}
- fn ssh_args(&self) -> Vec<String> {
- vec![
- "-o".to_string(),
- "ControlMaster=no".to_string(),
- "-o".to_string(),
- format!("ControlPath={}", self.socket_path.display()),
- self.connection_options.ssh_url(),
- ]
+ pub fn request_dynamic(
+ &self,
+ mut envelope: proto::Envelope,
+ type_name: &'static str,
+ ) -> impl 'static + Future<Output = Result<proto::Envelope>> {
+ envelope.id = self.next_message_id.fetch_add(1, SeqCst);
+ let (tx, rx) = oneshot::channel();
+ let mut response_channels_lock = self.response_channels.lock();
+ response_channels_lock.insert(MessageId(envelope.id), tx);
+ drop(response_channels_lock);
+ let result = self.outgoing_tx.unbounded_send(envelope);
+ async move {
+ if let Err(error) = &result {
+ log::error!("failed to send message: {}", error);
+ return Err(anyhow!("failed to send message: {}", error));
+ }
+
+ let response = rx.await.context("connection lost")?.0;
+ if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
+ return Err(RpcError::from_proto(error, type_name));
+ }
+ Ok(response)
+ }
+ }
+
+ pub fn send_dynamic(&self, mut envelope: proto::Envelope) -> Result<()> {
+ envelope.id = self.next_message_id.fetch_add(1, SeqCst);
+ self.outgoing_tx.unbounded_send(envelope)?;
+ Ok(())
}
}
-async fn run_cmd(command: &mut process::Command) -> Result<String> {
- let output = command.output().await?;
- if output.status.success() {
- Ok(String::from_utf8_lossy(&output.stdout).to_string())
- } else {
- Err(anyhow!(
- "failed to run command: {}",
- String::from_utf8_lossy(&output.stderr)
- ))
+impl ProtoClient for ChannelClient {
+ fn request(
+ &self,
+ envelope: proto::Envelope,
+ request_type: &'static str,
+ ) -> BoxFuture<'static, Result<proto::Envelope>> {
+ self.request_dynamic(envelope, request_type).boxed()
+ }
+
+ fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {
+ self.send_dynamic(envelope)
+ }
+
+ fn send_response(&self, envelope: Envelope, _message_type: &'static str) -> anyhow::Result<()> {
+ self.send_dynamic(envelope)
+ }
+
+ fn message_handler_set(&self) -> &Mutex<ProtoMessageHandlerSet> {
+ &self.message_handlers
+ }
+
+ fn is_via_collab(&self) -> bool {
+ false
}
}