Fix remote ping timing out (#39114)

localcc created

Closes #38899 

Release Notes:

- N/A

Change summary

crates/proto/proto/zed.proto       |  6 ++
crates/proto/src/proto.rs          |  4 +
crates/remote/src/remote_client.rs | 68 ++++++++++++++++++++++++++++++++
3 files changed, 76 insertions(+), 2 deletions(-)

Detailed changes

crates/proto/proto/zed.proto 🔗

@@ -416,7 +416,9 @@ message Envelope {
         StashDrop stash_drop = 378;
         StashApply stash_apply = 379;
 
-        GitRenameBranch git_rename_branch = 380; // current max
+        GitRenameBranch git_rename_branch = 380;
+
+        RemoteStarted remote_started = 381; // current max
     }
 
     reserved 87 to 88;
@@ -490,3 +492,5 @@ message Test {
 message FlushBufferedMessages {}
 
 message FlushBufferedMessagesResponse {}
+
+message RemoteStarted {}

crates/proto/src/proto.rs 🔗

@@ -324,6 +324,7 @@ messages!(
     (ExternalAgentsUpdated, Background),
     (ExternalAgentLoadingStatusUpdated, Background),
     (NewExternalAgentVersionAvailable, Background),
+    (RemoteStarted, Background),
 );
 
 request_messages!(
@@ -497,7 +498,8 @@ request_messages!(
     (GitClone, GitCloneResponse),
     (ToggleLspLogs, Ack),
     (GetProcesses, GetProcessesResponse),
-    (GetAgentServerCommand, AgentServerCommand)
+    (GetAgentServerCommand, AgentServerCommand),
+    (RemoteStarted, Ack),
 );
 
 lsp_messages!(

crates/remote/src/remote_client.rs 🔗

@@ -358,6 +358,24 @@ impl RemoteClient {
 
                 let multiplex_task = Self::monitor(this.downgrade(), io_task, cx);
 
+                let timeout = cx.background_executor().timer(HEARTBEAT_TIMEOUT).fuse();
+                futures::pin_mut!(timeout);
+
+                select_biased! {
+                    ready = client.wait_for_remote_started() => {
+                        if ready.is_none() {
+                            let error = anyhow::anyhow!("remote client exited before becoming ready");
+                            log::error!("failed to establish connection: {}", error);
+                            return Err(error);
+                        }
+                    },
+                    _ = timeout => {
+                        let error = anyhow::anyhow!("remote client did not become ready within the timeout");
+                        log::error!("failed to establish connection: {}", error);
+                        return Err(error);
+                    }
+                }
+
                 if let Err(error) = client.ping(HEARTBEAT_TIMEOUT).await {
                     log::error!("failed to establish connection: {}", error);
                     return Err(error);
@@ -1080,6 +1098,37 @@ pub(crate) trait RemoteConnection: Send + Sync {
 
 type ResponseChannels = Mutex<HashMap<MessageId, oneshot::Sender<(Envelope, oneshot::Sender<()>)>>>;
 
+struct Signal<T> {
+    tx: Mutex<Option<oneshot::Sender<T>>>,
+    rx: Shared<Task<Option<T>>>,
+}
+
+impl<T: Send + Clone + 'static> Signal<T> {
+    pub fn new(cx: &App) -> Self {
+        let (tx, rx) = oneshot::channel();
+
+        let task = cx
+            .background_executor()
+            .spawn(async move { rx.await.ok() })
+            .shared();
+
+        Self {
+            tx: Mutex::new(Some(tx)),
+            rx: task,
+        }
+    }
+
+    fn set(&self, value: T) {
+        if let Some(tx) = self.tx.lock().take() {
+            let _ = tx.send(value);
+        }
+    }
+
+    fn wait(&self) -> Shared<Task<Option<T>>> {
+        self.rx.clone()
+    }
+}
+
 struct ChannelClient {
     next_message_id: AtomicU32,
     outgoing_tx: Mutex<mpsc::UnboundedSender<Envelope>>,
@@ -1089,6 +1138,7 @@ struct ChannelClient {
     max_received: AtomicU32,
     name: &'static str,
     task: Mutex<Task<Result<()>>>,
+    remote_started: Signal<()>,
 }
 
 impl ChannelClient {
@@ -1111,15 +1161,25 @@ impl ChannelClient {
                 incoming_rx,
                 &cx.to_async(),
             )),
+            remote_started: Signal::new(cx),
         })
     }
 
+    fn wait_for_remote_started(&self) -> Shared<Task<Option<()>>> {
+        self.remote_started.wait()
+    }
+
     fn start_handling_messages(
         this: Weak<Self>,
         mut incoming_rx: mpsc::UnboundedReceiver<Envelope>,
         cx: &AsyncApp,
     ) -> Task<Result<()>> {
         cx.spawn(async move |cx| {
+            if let Some(this) = this.upgrade() {
+                let envelope = proto::RemoteStarted {}.into_envelope(0, None, None);
+                this.outgoing_tx.lock().unbounded_send(envelope).ok();
+            };
+
             let peer_id = PeerId { owner_id: 0, id: 0 };
             while let Some(incoming) = incoming_rx.next().await {
                 let Some(this) = this.upgrade() else {
@@ -1152,6 +1212,14 @@ impl ChannelClient {
                     continue;
                 }
 
+                if let Some(proto::envelope::Payload::RemoteStarted(_)) = &incoming.payload {
+                    this.remote_started.set(());
+                    let mut envelope = proto::Ack {}.into_envelope(0, Some(incoming.id), None);
+                    envelope.id = this.next_message_id.fetch_add(1, SeqCst);
+                    this.outgoing_tx.lock().unbounded_send(envelope).ok();
+                    continue;
+                }
+
                 this.max_received.store(incoming.id, SeqCst);
 
                 if let Some(request_id) = incoming.responding_to {