client.rs

  1#[cfg(any(test, feature = "test-support"))]
  2pub mod test;
  3
  4pub mod channel;
  5pub mod http;
  6pub mod user;
  7
  8use anyhow::{anyhow, Context, Result};
  9use async_recursion::async_recursion;
 10use async_tungstenite::tungstenite::{
 11    error::Error as WebsocketError,
 12    http::{Request, StatusCode},
 13};
 14use gpui::{action, AsyncAppContext, Entity, ModelContext, MutableAppContext, Task};
 15use lazy_static::lazy_static;
 16use parking_lot::RwLock;
 17use postage::{prelude::Stream, watch};
 18use rand::prelude::*;
 19use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
 20use std::{
 21    any::TypeId,
 22    collections::HashMap,
 23    convert::TryFrom,
 24    fmt::Write as _,
 25    future::Future,
 26    sync::{Arc, Weak},
 27    time::{Duration, Instant},
 28};
 29use surf::Url;
 30use thiserror::Error;
 31use util::{ResultExt, TryFutureExt};
 32
 33pub use channel::*;
 34pub use rpc::*;
 35pub use user::*;
 36
 37lazy_static! {
 38    static ref ZED_SERVER_URL: String =
 39        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
 40    static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
 41        .ok()
 42        .and_then(|s| if s.is_empty() { None } else { Some(s) });
 43}
 44
 45action!(Authenticate);
 46
 47pub fn init(rpc: Arc<Client>, cx: &mut MutableAppContext) {
 48    cx.add_global_action(move |_: &Authenticate, cx| {
 49        let rpc = rpc.clone();
 50        cx.spawn(|cx| async move { rpc.authenticate_and_connect(&cx).log_err().await })
 51            .detach();
 52    });
 53}
 54
 55pub struct Client {
 56    peer: Arc<Peer>,
 57    state: RwLock<ClientState>,
 58    authenticate:
 59        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 60    establish_connection: Option<
 61        Box<
 62            dyn 'static
 63                + Send
 64                + Sync
 65                + Fn(
 66                    &Credentials,
 67                    &AsyncAppContext,
 68                ) -> Task<Result<Connection, EstablishConnectionError>>,
 69        >,
 70    >,
 71}
 72
 73#[derive(Error, Debug)]
 74pub enum EstablishConnectionError {
 75    #[error("upgrade required")]
 76    UpgradeRequired,
 77    #[error("unauthorized")]
 78    Unauthorized,
 79    #[error("{0}")]
 80    Other(#[from] anyhow::Error),
 81    #[error("{0}")]
 82    Io(#[from] std::io::Error),
 83    #[error("{0}")]
 84    Http(#[from] async_tungstenite::tungstenite::http::Error),
 85}
 86
 87impl From<WebsocketError> for EstablishConnectionError {
 88    fn from(error: WebsocketError) -> Self {
 89        if let WebsocketError::Http(response) = &error {
 90            match response.status() {
 91                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 92                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 93                _ => {}
 94            }
 95        }
 96        EstablishConnectionError::Other(error.into())
 97    }
 98}
 99
100impl EstablishConnectionError {
101    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
102        Self::Other(error.into())
103    }
104}
105
106#[derive(Copy, Clone, Debug)]
107pub enum Status {
108    SignedOut,
109    UpgradeRequired,
110    Authenticating,
111    Connecting,
112    ConnectionError,
113    Connected { connection_id: ConnectionId },
114    ConnectionLost,
115    Reauthenticating,
116    Reconnecting,
117    ReconnectionError { next_reconnection: Instant },
118}
119
120struct ClientState {
121    credentials: Option<Credentials>,
122    status: (watch::Sender<Status>, watch::Receiver<Status>),
123    entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
124    model_handlers: HashMap<
125        (TypeId, u64),
126        Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
127    >,
128    _maintain_connection: Option<Task<()>>,
129    heartbeat_interval: Duration,
130}
131
132#[derive(Clone)]
133pub struct Credentials {
134    pub user_id: u64,
135    pub access_token: String,
136}
137
138impl Default for ClientState {
139    fn default() -> Self {
140        Self {
141            credentials: None,
142            status: watch::channel_with(Status::SignedOut),
143            entity_id_extractors: Default::default(),
144            model_handlers: Default::default(),
145            _maintain_connection: None,
146            heartbeat_interval: Duration::from_secs(5),
147        }
148    }
149}
150
151pub struct Subscription {
152    client: Weak<Client>,
153    id: (TypeId, u64),
154}
155
156impl Drop for Subscription {
157    fn drop(&mut self) {
158        if let Some(client) = self.client.upgrade() {
159            drop(
160                client
161                    .state
162                    .write()
163                    .model_handlers
164                    .remove(&self.id)
165                    .unwrap(),
166            );
167        }
168    }
169}
170
171impl Client {
172    pub fn new() -> Arc<Self> {
173        Arc::new(Self {
174            peer: Peer::new(),
175            state: Default::default(),
176            authenticate: None,
177            establish_connection: None,
178        })
179    }
180
181    #[cfg(any(test, feature = "test-support"))]
182    pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
183    where
184        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
185    {
186        self.authenticate = Some(Box::new(authenticate));
187        self
188    }
189
190    #[cfg(any(test, feature = "test-support"))]
191    pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
192    where
193        F: 'static
194            + Send
195            + Sync
196            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
197    {
198        self.establish_connection = Some(Box::new(connect));
199        self
200    }
201
202    pub fn user_id(&self) -> Option<u64> {
203        self.state
204            .read()
205            .credentials
206            .as_ref()
207            .map(|credentials| credentials.user_id)
208    }
209
210    pub fn status(&self) -> watch::Receiver<Status> {
211        self.state.read().status.1.clone()
212    }
213
214    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
215        let mut state = self.state.write();
216        *state.status.0.borrow_mut() = status;
217
218        match status {
219            Status::Connected { .. } => {
220                let heartbeat_interval = state.heartbeat_interval;
221                let this = self.clone();
222                let foreground = cx.foreground();
223                state._maintain_connection = Some(cx.foreground().spawn(async move {
224                    loop {
225                        foreground.timer(heartbeat_interval).await;
226                        let _ = this.request(proto::Ping {}).await;
227                    }
228                }));
229            }
230            Status::ConnectionLost => {
231                let this = self.clone();
232                let foreground = cx.foreground();
233                let heartbeat_interval = state.heartbeat_interval;
234                state._maintain_connection = Some(cx.spawn(|cx| async move {
235                    let mut rng = StdRng::from_entropy();
236                    let mut delay = Duration::from_millis(100);
237                    while let Err(error) = this.authenticate_and_connect(&cx).await {
238                        log::error!("failed to connect {}", error);
239                        this.set_status(
240                            Status::ReconnectionError {
241                                next_reconnection: Instant::now() + delay,
242                            },
243                            &cx,
244                        );
245                        foreground.timer(delay).await;
246                        delay = delay
247                            .mul_f32(rng.gen_range(1.0..=2.0))
248                            .min(heartbeat_interval);
249                    }
250                }));
251            }
252            Status::SignedOut | Status::UpgradeRequired => {
253                state._maintain_connection.take();
254            }
255            _ => {}
256        }
257    }
258
259    pub fn subscribe<T, M, F>(
260        self: &Arc<Self>,
261        cx: &mut ModelContext<M>,
262        mut handler: F,
263    ) -> Subscription
264    where
265        T: EnvelopedMessage,
266        M: Entity,
267        F: 'static
268            + Send
269            + Sync
270            + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
271    {
272        let subscription_id = (TypeId::of::<T>(), Default::default());
273        let client = self.clone();
274        let mut state = self.state.write();
275        let model = cx.weak_handle();
276        let prev_extractor = state
277            .entity_id_extractors
278            .insert(subscription_id.0, Box::new(|_| Default::default()));
279        if prev_extractor.is_some() {
280            panic!("registered a handler for the same entity twice")
281        }
282
283        state.model_handlers.insert(
284            subscription_id,
285            Box::new(move |envelope, cx| {
286                if let Some(model) = model.upgrade(cx) {
287                    let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
288                    model.update(cx, |model, cx| {
289                        if let Err(error) = handler(model, *envelope, client.clone(), cx) {
290                            log::error!("error handling message: {}", error)
291                        }
292                    });
293                }
294            }),
295        );
296
297        Subscription {
298            client: Arc::downgrade(self),
299            id: subscription_id,
300        }
301    }
302
303    pub fn subscribe_to_entity<T, M, F>(
304        self: &Arc<Self>,
305        remote_id: u64,
306        cx: &mut ModelContext<M>,
307        mut handler: F,
308    ) -> Subscription
309    where
310        T: EntityMessage,
311        M: Entity,
312        F: 'static
313            + Send
314            + Sync
315            + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
316    {
317        let subscription_id = (TypeId::of::<T>(), remote_id);
318        let client = self.clone();
319        let mut state = self.state.write();
320        let model = cx.weak_handle();
321        state
322            .entity_id_extractors
323            .entry(subscription_id.0)
324            .or_insert_with(|| {
325                Box::new(|envelope| {
326                    let envelope = envelope
327                        .as_any()
328                        .downcast_ref::<TypedEnvelope<T>>()
329                        .unwrap();
330                    envelope.payload.remote_entity_id()
331                })
332            });
333        let prev_handler = state.model_handlers.insert(
334            subscription_id,
335            Box::new(move |envelope, cx| {
336                if let Some(model) = model.upgrade(cx) {
337                    let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
338                    model.update(cx, |model, cx| {
339                        if let Err(error) = handler(model, *envelope, client.clone(), cx) {
340                            log::error!("error handling message: {}", error)
341                        }
342                    });
343                }
344            }),
345        );
346        if prev_handler.is_some() {
347            panic!("registered a handler for the same entity twice")
348        }
349
350        Subscription {
351            client: Arc::downgrade(self),
352            id: subscription_id,
353        }
354    }
355
356    #[async_recursion(?Send)]
357    pub async fn authenticate_and_connect(
358        self: &Arc<Self>,
359        cx: &AsyncAppContext,
360    ) -> anyhow::Result<()> {
361        let was_disconnected = match *self.status().borrow() {
362            Status::SignedOut => true,
363            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
364                false
365            }
366            Status::Connected { .. }
367            | Status::Connecting { .. }
368            | Status::Reconnecting { .. }
369            | Status::Authenticating
370            | Status::Reauthenticating => return Ok(()),
371            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
372        };
373
374        if was_disconnected {
375            self.set_status(Status::Authenticating, cx);
376        } else {
377            self.set_status(Status::Reauthenticating, cx)
378        }
379
380        let mut used_keychain = false;
381        let credentials = self.state.read().credentials.clone();
382        let credentials = if let Some(credentials) = credentials {
383            credentials
384        } else if let Some(credentials) = read_credentials_from_keychain(cx) {
385            used_keychain = true;
386            credentials
387        } else {
388            let credentials = match self.authenticate(&cx).await {
389                Ok(credentials) => credentials,
390                Err(err) => {
391                    self.set_status(Status::ConnectionError, cx);
392                    return Err(err);
393                }
394            };
395            credentials
396        };
397
398        if was_disconnected {
399            self.set_status(Status::Connecting, cx);
400        } else {
401            self.set_status(Status::Reconnecting, cx);
402        }
403
404        match self.establish_connection(&credentials, cx).await {
405            Ok(conn) => {
406                log::info!("connected to rpc address {}", *ZED_SERVER_URL);
407                self.state.write().credentials = Some(credentials.clone());
408                if !used_keychain && IMPERSONATE_LOGIN.is_none() {
409                    write_credentials_to_keychain(&credentials, cx).log_err();
410                }
411                self.set_connection(conn, cx).await;
412                Ok(())
413            }
414            Err(EstablishConnectionError::Unauthorized) => {
415                self.state.write().credentials.take();
416                if used_keychain {
417                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
418                    self.set_status(Status::SignedOut, cx);
419                    self.authenticate_and_connect(cx).await
420                } else {
421                    self.set_status(Status::ConnectionError, cx);
422                    Err(EstablishConnectionError::Unauthorized)?
423                }
424            }
425            Err(EstablishConnectionError::UpgradeRequired) => {
426                self.set_status(Status::UpgradeRequired, cx);
427                Err(EstablishConnectionError::UpgradeRequired)?
428            }
429            Err(error) => {
430                self.set_status(Status::ConnectionError, cx);
431                Err(error)?
432            }
433        }
434    }
435
436    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
437        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
438        cx.foreground()
439            .spawn({
440                let mut cx = cx.clone();
441                let this = self.clone();
442                async move {
443                    while let Some(message) = incoming.recv().await {
444                        let mut state = this.state.write();
445                        if let Some(extract_entity_id) =
446                            state.entity_id_extractors.get(&message.payload_type_id())
447                        {
448                            let payload_type_id = message.payload_type_id();
449                            let entity_id = (extract_entity_id)(message.as_ref());
450                            let handler_key = (payload_type_id, entity_id);
451                            if let Some(mut handler) = state.model_handlers.remove(&handler_key) {
452                                drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
453                                let start_time = Instant::now();
454                                log::info!("RPC client message {}", message.payload_type_name());
455                                (handler)(message, &mut cx);
456                                log::info!(
457                                    "RPC message handled. duration:{:?}",
458                                    start_time.elapsed()
459                                );
460                                this.state
461                                    .write()
462                                    .model_handlers
463                                    .insert(handler_key, handler);
464                            } else {
465                                log::info!("unhandled message {}", message.payload_type_name());
466                            }
467                        } else {
468                            log::info!("unhandled message {}", message.payload_type_name());
469                        }
470                    }
471                }
472            })
473            .detach();
474
475        self.set_status(Status::Connected { connection_id }, cx);
476
477        let handle_io = cx.background().spawn(handle_io);
478        let this = self.clone();
479        let cx = cx.clone();
480        cx.foreground()
481            .spawn(async move {
482                match handle_io.await {
483                    Ok(()) => this.set_status(Status::SignedOut, &cx),
484                    Err(err) => {
485                        log::error!("connection error: {:?}", err);
486                        this.set_status(Status::ConnectionLost, &cx);
487                    }
488                }
489            })
490            .detach();
491    }
492
493    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
494        if let Some(callback) = self.authenticate.as_ref() {
495            callback(cx)
496        } else {
497            self.authenticate_with_browser(cx)
498        }
499    }
500
501    fn establish_connection(
502        self: &Arc<Self>,
503        credentials: &Credentials,
504        cx: &AsyncAppContext,
505    ) -> Task<Result<Connection, EstablishConnectionError>> {
506        if let Some(callback) = self.establish_connection.as_ref() {
507            callback(credentials, cx)
508        } else {
509            self.establish_websocket_connection(credentials, cx)
510        }
511    }
512
513    fn establish_websocket_connection(
514        self: &Arc<Self>,
515        credentials: &Credentials,
516        cx: &AsyncAppContext,
517    ) -> Task<Result<Connection, EstablishConnectionError>> {
518        let request = Request::builder()
519            .header(
520                "Authorization",
521                format!("{} {}", credentials.user_id, credentials.access_token),
522            )
523            .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
524        cx.background().spawn(async move {
525            if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
526                let stream = smol::net::TcpStream::connect(host).await?;
527                let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
528                let (stream, _) =
529                    async_tungstenite::async_tls::client_async_tls(request, stream).await?;
530                Ok(Connection::new(stream))
531            } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
532                let stream = smol::net::TcpStream::connect(host).await?;
533                let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
534                let (stream, _) = async_tungstenite::client_async(request, stream).await?;
535                Ok(Connection::new(stream))
536            } else {
537                Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?
538            }
539        })
540    }
541
542    pub fn authenticate_with_browser(
543        self: &Arc<Self>,
544        cx: &AsyncAppContext,
545    ) -> Task<Result<Credentials>> {
546        let platform = cx.platform();
547        let executor = cx.background();
548        executor.clone().spawn(async move {
549            // Generate a pair of asymmetric encryption keys. The public key will be used by the
550            // zed server to encrypt the user's access token, so that it can'be intercepted by
551            // any other app running on the user's device.
552            let (public_key, private_key) =
553                rpc::auth::keypair().expect("failed to generate keypair for auth");
554            let public_key_string =
555                String::try_from(public_key).expect("failed to serialize public key for auth");
556
557            // Start an HTTP server to receive the redirect from Zed's sign-in page.
558            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
559            let port = server.server_addr().port();
560
561            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
562            // that the user is signing in from a Zed app running on the same device.
563            let mut url = format!(
564                "{}/sign_in?native_app_port={}&native_app_public_key={}",
565                *ZED_SERVER_URL, port, public_key_string
566            );
567
568            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
569                log::info!("impersonating user @{}", impersonate_login);
570                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
571            }
572
573            platform.open_url(&url);
574
575            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
576            // access token from the query params.
577            //
578            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
579            // custom URL scheme instead of this local HTTP server.
580            let (user_id, access_token) = executor
581                .spawn(async move {
582                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
583                        let path = req.url();
584                        let mut user_id = None;
585                        let mut access_token = None;
586                        let url = Url::parse(&format!("http://example.com{}", path))
587                            .context("failed to parse login notification url")?;
588                        for (key, value) in url.query_pairs() {
589                            if key == "access_token" {
590                                access_token = Some(value.to_string());
591                            } else if key == "user_id" {
592                                user_id = Some(value.to_string());
593                            }
594                        }
595                        req.respond(
596                            tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
597                                tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
598                            ),
599                        )
600                        .context("failed to respond to login http request")?;
601                        Ok((
602                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
603                            access_token
604                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
605                        ))
606                    } else {
607                        Err(anyhow!("didn't receive login redirect"))
608                    }
609                })
610                .await?;
611
612            let access_token = private_key
613                .decrypt_string(&access_token)
614                .context("failed to decrypt access token")?;
615            platform.activate(true);
616
617            Ok(Credentials {
618                user_id: user_id.parse()?,
619                access_token,
620            })
621        })
622    }
623
624    pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
625        let conn_id = self.connection_id()?;
626        self.peer.disconnect(conn_id).await;
627        self.set_status(Status::SignedOut, cx);
628        Ok(())
629    }
630
631    fn connection_id(&self) -> Result<ConnectionId> {
632        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
633            Ok(connection_id)
634        } else {
635            Err(anyhow!("not connected"))
636        }
637    }
638
639    pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
640        self.peer.send(self.connection_id()?, message).await
641    }
642
643    pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
644        self.peer.request(self.connection_id()?, request).await
645    }
646
647    pub fn respond<T: RequestMessage>(
648        &self,
649        receipt: Receipt<T>,
650        response: T::Response,
651    ) -> impl Future<Output = Result<()>> {
652        self.peer.respond(receipt, response)
653    }
654}
655
656fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
657    if IMPERSONATE_LOGIN.is_some() {
658        return None;
659    }
660
661    let (user_id, access_token) = cx
662        .platform()
663        .read_credentials(&ZED_SERVER_URL)
664        .log_err()
665        .flatten()?;
666    Some(Credentials {
667        user_id: user_id.parse().ok()?,
668        access_token: String::from_utf8(access_token).ok()?,
669    })
670}
671
672fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
673    cx.platform().write_credentials(
674        &ZED_SERVER_URL,
675        &credentials.user_id.to_string(),
676        credentials.access_token.as_bytes(),
677    )
678}
679
680const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
681
682pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
683    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
684}
685
686pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
687    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
688    let mut parts = path.split('/');
689    let id = parts.next()?.parse::<u64>().ok()?;
690    let access_token = parts.next()?;
691    if access_token.is_empty() {
692        return None;
693    }
694    Some((id, access_token.to_string()))
695}
696
697const LOGIN_RESPONSE: &'static str = "
698<!DOCTYPE html>
699<html>
700<script>window.close();</script>
701</html>
702";
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707    use crate::test::FakeServer;
708    use gpui::TestAppContext;
709
710    #[gpui::test(iterations = 10)]
711    async fn test_heartbeat(cx: TestAppContext) {
712        cx.foreground().forbid_parking();
713
714        let user_id = 5;
715        let mut client = Client::new();
716        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
717
718        cx.foreground().advance_clock(Duration::from_secs(10));
719        let ping = server.receive::<proto::Ping>().await.unwrap();
720        server.respond(ping.receipt(), proto::Ack {}).await;
721
722        cx.foreground().advance_clock(Duration::from_secs(10));
723        let ping = server.receive::<proto::Ping>().await.unwrap();
724        server.respond(ping.receipt(), proto::Ack {}).await;
725
726        client.disconnect(&cx.to_async()).await.unwrap();
727        assert!(server.receive::<proto::Ping>().await.is_err());
728    }
729
730    #[gpui::test(iterations = 10)]
731    async fn test_reconnection(cx: TestAppContext) {
732        cx.foreground().forbid_parking();
733
734        let user_id = 5;
735        let mut client = Client::new();
736        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
737        let mut status = client.status();
738        assert!(matches!(
739            status.recv().await,
740            Some(Status::Connected { .. })
741        ));
742        assert_eq!(server.auth_count(), 1);
743
744        server.forbid_connections();
745        server.disconnect().await;
746        while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
747
748        server.allow_connections();
749        cx.foreground().advance_clock(Duration::from_secs(10));
750        while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
751        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
752
753        server.forbid_connections();
754        server.disconnect().await;
755        while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
756
757        // Clear cached credentials after authentication fails
758        server.roll_access_token();
759        server.allow_connections();
760        cx.foreground().advance_clock(Duration::from_secs(10));
761        assert_eq!(server.auth_count(), 1);
762        cx.foreground().advance_clock(Duration::from_secs(10));
763        while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
764        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
765    }
766
767    #[test]
768    fn test_encode_and_decode_worktree_url() {
769        let url = encode_worktree_url(5, "deadbeef");
770        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
771        assert_eq!(
772            decode_worktree_url(&format!("\n {}\t", url)),
773            Some((5, "deadbeef".to_string()))
774        );
775        assert_eq!(decode_worktree_url("not://the-right-format"), None);
776    }
777}